加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

如何在TensorFlow中使用parallel_interleave

发布时间:2020-12-20 11:04:12 所属栏目:Python 来源:网络整理
导读:我正在阅读TensorFlow benchmarks repo中的代码.以下代码是从TFRecord文件创建TensorFlow数据集的部分: ds = tf.data.TFRecordDataset.list_files(tfrecord_file_names)ds = ds.apply(interleave_ops.parallel_interleave(tf.data.TFRecordDataset,cycle_le
我正在阅读TensorFlow benchmarks repo中的代码.以下代码是从TFRecord文件创建TensorFlow数据集的部分:

ds = tf.data.TFRecordDataset.list_files(tfrecord_file_names)
ds = ds.apply(interleave_ops.parallel_interleave(tf.data.TFRecordDataset,cycle_length=10))

我试图更改此代码以直接从JPEG图像文件创建数据集:

ds = tf.data.Dataset.from_tensor_slices(jpeg_file_names)
ds = ds.apply(interleave_ops.parallel_interleave(?,cycle_length=10))

我不知道写什么?地点. parallel_interleave()中的map_func是TF_cord文件的tf.data.TFRecordDataset类的__init __(),但我不知道要为JPEG文件写什么.

我们不需要在这里进行任何转换.因为我们将压缩两个数据集,然后再进行转换.代码如下:

counter = tf.data.Dataset.range(batch_size)
ds = tf.data.Dataset.zip((ds,counter))
ds = ds.apply( 
     batching.map_and_batch( 
     map_func=preprocess_fn,
     batch_size=batch_size,
     num_parallel_batches=num_splits))

因为我们不需要改造吗?地方,我试图使用一个空的map_func,但有错误“map_funcmust返回aDataset`对象”.我也尝试使用tf.data.Dataset,但是输出说Dataset是一个不允许放在那里的抽象类.

任何人都可以帮忙吗?非常感谢.

解决方法

当您具有将源数据集的每个元素转换为多个元素到目标数据集的转换时,parallel_interleave非常有用.我不确定为什么他们会在基准测试报告中使用它,当他们可以使用并行调用的地图时.

以下是我建议使用parallel_interleave从多个目录中读取图像的方法,每个目录包含一个类:

classes = sorted(glob(directory + '/*/')) # final slash selects directories only
num_classes = len(classes)

labels = np.arange(num_classes,dtype=np.int32)

dirs = DS.from_tensor_slices((classes,labels))               # 1
files = dirs.apply(tf.contrib.data.parallel_interleave(
    get_files,cycle_length=num_classes,block_length=4,# 2
    sloppy=False)) # False is important ! Otherwise it mixes labels
files = files.cache()
imgs = files.map(read_decode,num_parallel_calls=20).        # 3
            .apply(tf.contrib.data.shuffle_and_repeat(100))
            .batch(batch_size)
            .prefetch(5)

有三个步骤.首先,我们获取目录及其标签列表(#1).

然后,我们将这些映射到文件的数据集.但是如果我们做一个简单的.flatmap(),我们最终会得到标签0的所有文件,然后是标签1的所有文件,然后是2等…然后我们需要非常大的shuffle缓冲区才能得到一个有意义的洗牌.

因此,我们应用parallel_interleave(#2).这是get_files():

def get_files(dir_path,label):
    globbed = tf.string_join([dir_path,'*.jpg'])
    files = tf.matching_files(globbed)

    num_files = tf.shape(files)[0] # in the directory
    labels = tf.tile([label],[num_files,]) # expand label to all files
    return DS.from_tensor_slices((files,labels))

使用parallel_interleave可确保每个目录的list_files并行运行,因此在第一个目录中列出第一个block_length文件时,第二个目录中的第一个block_length文件也将可用(也可以从第3个,第4个等).此外,结果数据集将包含每个标签的交错块,例如,1 1 1 1 2 2 2 2 3 3 3 3 3 1 1 1 1 …(3类,block_length = 4)

最后,我们从文件列表中读取图像(#3).这是read_and_decode():

def read_decode(path,label):
    img = tf.image.decode_image(tf.read_file(path),channels=3)
    img = tf.image.resize_bilinear(tf.expand_dims(img,axis=0),target_size)
    img = tf.squeeze(img,0)
    img = preprocess_fct(img) # should work with Tensors !

    label = tf.one_hot(label,num_classes)
    img = tf.Print(img,[path,label],'Read_decode')
    return (img,label)

此函数采用图像路径及其标签,并为每个:路径的图像张量和标签的one_hot编码返回张量.这也是您可以对图像进行所有转换的地方.在这里,我做了调整大小和基本的预处理.

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读