Tensorflow 学习笔记:Input Pipeline - Dataset
Dataset是Tensorflow里面一个比较重要的概念,我们知道机器学习算法需要大概的数据来训练data model. 所以Dataset就是用来做这么一件重要的事情:定义数据pipline,为学习算法提供训练数据。 其实我们也可以将Dataset理解成一个数据源,指向某些包含训练数据的文件列表,或者是内存里面已有的数据结构(比如Tensor objects)。 Dataset 数据结构组成Dataset的基本单元是element。每个element必需有相同的数据结构,其中每个element包含多个Tensor objects。比如: # 创建一个dataset,里面包含一个2-Dimension (4x10) Tensor对象
dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4,10]))
# 创建一个dataset,里面包含两个Tensor,tensor1的shape为(4x3),tensor2的shape为(4x5)
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4,3]),tf.random_uniform([4,5])))
创建Dataset前面说了Dataset可以理解成数据源, 那么怎么创建一个Dataset并使它跟多个数据源关联呢?Tensorflow Dataset API提供了两种方式:
读取Dataset从前面Dataset的定义以及结构可以看出,Dataset其实是对Tensor提供了一层封装,而Tensor又是对真实的训练数据的封装,这些数据可能是一个N-Dimension matrix,或者是指向一批数据文件的向量。其实我们可以会问为什么要设计的这么复杂,又是matrix,又是Tensor的,直接用Tensor/Matrix的API来读取训练数据不就行了么? 我觉得可以从下面几个方向来思考:
回到正题,Dataset提供Iterator.get_next() API来读取它的每一个element,这个element包含一个或者多个我们需要的Tensor objects。 至于每次调用get_next()返回多少个element,则取决于batch size的大小。或者你可以认为batch size就是决定每次读取多少个训练数据,一个训练数据就是一个element。 Iterator的调用步骤:
这里需要特别提一下one shot iterator,它每次只读取一个element,而且 这种Iterator不需要初始化,也就是上面的第3步不需要显式地调用。但是只有当Dataset不包含任何参数时才可以为它创建one shot iterator, 前面例子里的Dataset都不能创建one shot iterator。 dataset2 = tf.data.Dataset.from_tensor_slices(tf.constant([[1,2,3],[2,4,6],[3,6,9]]))
iter2 = dataset2.make_one_shot_iterator()
用Dataset读取文件前面的例子里很多的都是从Ternsor对象中创建Dataset, 所以用Iterator读取到的可能是一些常量数据,比如文件名,数组之类的。但是在真实的世界中,训练数据都是存放在文件中的,比如CSV,JPG,所以我们关心的其实并不是这些文件名本身,还是其中的内容。那么如果我的Tensor中存放的是一些文件名字,怎么用Dataset来读取其中的数据呢? Dataset提供了一个数据预处理的API map()。 预处理的意思是可以对每一个element进行transformation,Iterator的get_next()拿到的可能是一个字符串代表某个文件名或者CSV文件里的一行,然后transformation的时候将这个文件的内容读取出来并保存在内存的Tensor对象。 读取文本文件这里用TextLineDataset读取csv文件: def readTextFile(filename):
_CSV_COLUMN_DEFAULTS = [[1],[0],[''],['']]
_CSV_COLUMNS = [
'age','workclass','education','education_num','marital_status','occupation','income_bracket'
]
dataset = tf.data.TextLineDataset(filename)
iterator = dataset.make_one_shot_iterator()
textline = iterator.get_next()
with tf.Session() as sess:
print(textline.eval())
# convert text to list of tensors for each column
def parseCSVLine(value):
columns = tf.decode_csv(value,_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS,columns))
return features
dataset2 = dataset.map(parseCSVLine)
iterator2 = dataset2.make_one_shot_iterator()
textline2 = iterator2.get_next()
with tf.Session() as sess:
print(textline2)
这里parseCSVLine 将从csv读取到的每一行进行decode 处理(tf.decode_csv), 从而将每一列转成对应的Tensor object。 读取图片文件# Reads an image from a file,decodes it into a dense tensor,and resizes it
# to a fixed shape.
def _parse_function(filename,label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_image(image_string)
image_resized = tf.image.resize_images(image_decoded,[28,28])
return image_resized,label
# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg","/var/data/image2.jpg",...])
# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0,37,...])
dataset = tf.data.Dataset.from_tensor_slices((filenames,labels))
dataset = dataset.map(_parse_function) (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |