加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

pytorch 批次遍历数据集打印数据的例子

发布时间:2020-12-17 17:34:23 所属栏目:Python 来源:网络整理
导读:我就废话不多说了,直接上代码吧! from os import listdirimport osfrom time import timeimport torch.utils.data as dataimport torchvision.transforms as transformsfrom torch.utils.data import DataLoaderdef printProgressBar(iteration,total,pref

我就废话不多说了,直接上代码吧!

from os import listdir
import os
from time import time

import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

def printProgressBar(iteration,total,prefix='',suffix='',decimals=1,length=100,fill='=',empty=' ',tip='>',begin='[',end=']',done="[DONE]",clear=True):
  percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
  filledLength = int(length * iteration // total)
  bar = fill * filledLength
  if iteration != total:
    bar = bar + tip
  bar = bar + empty * (length - filledLength - len(tip))
  display = 'r{prefix}{begin}{bar}{end} {percent}%{suffix}' 
    .format(prefix=prefix,begin=begin,bar=bar,end=end,percent=percent,suffix=suffix)
  print(display,end=''),# comma after print() required for python 2
  if iteration == total: # print with newline on complete
    if clear: # display given complete message with spaces to 'erase' previous progress bar
      finish = 'r{prefix}{done}'.format(prefix=prefix,done=done)
      if hasattr(str,'decode'): # handle python 2 non-unicode strings for proper length measure
        finish = finish.decode('utf-8')
        display = display.decode('utf-8')
      clear = ' ' * max(len(display) - len(finish),0)
      print(finish + clear)
    else:
      print('')

class DatasetFromFolder(data.Dataset):
  def __init__(self,image_dir):
    super(DatasetFromFolder,self).__init__()
    self.photo_path = os.path.join(image_dir,"a")
    self.sketch_path = os.path.join(image_dir,"b")
    self.image_filenames = [x for x in listdir(self.photo_path) if is_image_file(x)]

    transform_list = [transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5))]

    self.transform = transforms.Compose(transform_list)

  def __getitem__(self,index):
    # Load Image
    input = load_img(os.path.join(self.photo_path,self.image_filenames[index]))
    input = self.transform(input)
    target = load_img(os.path.join(self.sketch_path,self.image_filenames[index]))
    target = self.transform(target)

    return input,target

  def __len__(self):
    return len(self.image_filenames)

if __name__ == '__main__':
  dataset = DatasetFromFolder("./dataset/facades/train")
  dataloader = DataLoader(dataset=dataset,num_workers=8,batch_size=1,shuffle=True)
  total = len(dataloader)
  for epoch in range(20):
    t0 = time()
    for i,batch in enumerate(dataloader):
      real_a,real_b = batch[0],batch[1]
      printProgressBar(i + 1,total + 1,length=20,prefix='Epoch %s ' % str(1),suffix=',d_loss: %d' % 1)
    printProgressBar(total,done='Epoch [%s] ' % str(epoch) +
               ',time: %.2f s' % (time() - t0)
             )

以上这篇pytorch 批次遍历数据集打印数据的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读