pytorch 批次遍历数据集打印数据的例子
发布时间:2020-12-17 17:34:23 所属栏目:Python 来源:网络整理
导读:我就废话不多说了,直接上代码吧! from os import listdirimport osfrom time import timeimport torch.utils.data as dataimport torchvision.transforms as transformsfrom torch.utils.data import DataLoaderdef printProgressBar(iteration,total,pref
我就废话不多说了,直接上代码吧! from os import listdir import os from time import time import torch.utils.data as data import torchvision.transforms as transforms from torch.utils.data import DataLoader def printProgressBar(iteration,total,prefix='',suffix='',decimals=1,length=100,fill='=',empty=' ',tip='>',begin='[',end=']',done="[DONE]",clear=True): percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) filledLength = int(length * iteration // total) bar = fill * filledLength if iteration != total: bar = bar + tip bar = bar + empty * (length - filledLength - len(tip)) display = 'r{prefix}{begin}{bar}{end} {percent}%{suffix}' .format(prefix=prefix,begin=begin,bar=bar,end=end,percent=percent,suffix=suffix) print(display,end=''),# comma after print() required for python 2 if iteration == total: # print with newline on complete if clear: # display given complete message with spaces to 'erase' previous progress bar finish = 'r{prefix}{done}'.format(prefix=prefix,done=done) if hasattr(str,'decode'): # handle python 2 non-unicode strings for proper length measure finish = finish.decode('utf-8') display = display.decode('utf-8') clear = ' ' * max(len(display) - len(finish),0) print(finish + clear) else: print('') class DatasetFromFolder(data.Dataset): def __init__(self,image_dir): super(DatasetFromFolder,self).__init__() self.photo_path = os.path.join(image_dir,"a") self.sketch_path = os.path.join(image_dir,"b") self.image_filenames = [x for x in listdir(self.photo_path) if is_image_file(x)] transform_list = [transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5))] self.transform = transforms.Compose(transform_list) def __getitem__(self,index): # Load Image input = load_img(os.path.join(self.photo_path,self.image_filenames[index])) input = self.transform(input) target = load_img(os.path.join(self.sketch_path,self.image_filenames[index])) target = self.transform(target) return input,target def __len__(self): return len(self.image_filenames) if __name__ == '__main__': dataset = DatasetFromFolder("./dataset/facades/train") dataloader = DataLoader(dataset=dataset,num_workers=8,batch_size=1,shuffle=True) total = len(dataloader) for epoch in range(20): t0 = time() for i,batch in enumerate(dataloader): real_a,real_b = batch[0],batch[1] printProgressBar(i + 1,total + 1,length=20,prefix='Epoch %s ' % str(1),suffix=',d_loss: %d' % 1) printProgressBar(total,done='Epoch [%s] ' % str(epoch) + ',time: %.2f s' % (time() - t0) ) 以上这篇pytorch 批次遍历数据集打印数据的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。 (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |