加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 大数据 > 正文

Tensorflow学习笔记-输入数据处理框架

发布时间:2020-12-14 04:54:14 所属栏目:大数据 来源:网络整理
导读:通过前面几节的总结,Tensorflow关于TFRecord格式文件的处理、模型的训练的架构为: 1、获取文件列表、创建文件队列:http://www.voidcn.com/article/p-dvnkeclz-brg.html 2、图像预处理:http://www.voidcn.com/article/p-vzjvdssl-brh.html 3、合成Batch:

  通过前面几节的总结,Tensorflow关于TFRecord格式文件的处理、模型的训练的架构为:
  1、获取文件列表、创建文件队列:http://www.voidcn.com/article/p-dvnkeclz-brg.html
  2、图像预处理:http://www.voidcn.com/article/p-vzjvdssl-brh.html
  3、合成Batch:
  4、设计损失函数、梯度下降算法:http://www.voidcn.com/article/p-unavptag-bqy.html

Created with Rapha?l 2.1.0 获取输入文件列表 创建输入文件队列 从文件队列读取数据 整理成Batch作为神经网络的输入 设计损失函数 选择梯度下降法 训练

  对应的代码流程如下:

# 创建文件列表,并通过文件列表来创建文件队列。在调用输入数据处理流程前,需要统一
    # 所有的原始数据格式,并将它们存储到TFRecord文件中
    # match_filenames_once 获取符合正则表达式的所有文件
    files = tf.train.match_filenames_once('path/to/file-*-*')
    # 将文件列表生成文件队列
    filename_queue = tf.train.string_input_producer(files,shuffle=True)

    reader = tf.TFRecordReader()
    _,serialized_example = reader.read(filename_queue)
    # image:存储图像中的原始数据
    # label该样本所对应的标签
    # width,height,channel
    features = tf.parse_single_example(serialized_example,features={
        'image' : tf.FixedLenFeature([],tf.string),'label': tf.FixedLenFeature([],tf.int64),'width': tf.FixedLenFeature([],'heigth': tf.FixedLenFeature([],'channel': tf.FixedLenFeature([],tf.int64)
    })

    image,label = features['image'],features['label']
    width,height = features['width'],features['height']
    channel = features['channel']
    # 将原始图像数据解析出像素矩阵,并根据图像尺寸还原糖图像。
    decode_image = tf.decode_raw(image)
    decode_image.set_shape([width,height,channel])
    # 神经网络的输入大小
    image_size = 299
    # 对图像进行预处理操作,比对亮度、对比度、随机裁剪等操作
    distorted_image = propocess_train(decode_image,image_size,None)

    # shuffle_batch中的参数
    min_after_dequeue = 1000
    batch_size = 100
    capacity = min_after_dequeue + 3*batch_size
    image_batch,label_batch = tf.train.shuffle_batch([distorted_image,label],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue)

    logit = inference(image_batch)
    loss = cal_loss(logit,label_batch)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

    with tf.Session() as sess:
        # 变量初始化
        tf.global_variables_initializer().run()
        # 线程初始化和启动
        coord = tf.train.Coordinator()
        theads = tf.train.start_queue_runners(sess=sess,coord=coord)

        for i in range(STEPS):
            sess.run(train_step)
        # 停止所有线程
        coord.request_stop()
        coord.join(threads)

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读