python – Tensorflow LinearRegressor功能不能具有等级0
发布时间:2020-12-20 12:05:10 所属栏目:Python 来源:网络整理
导读:我正在学习本教程,但未能为在y = x之上生成的数据集构建线性回归量. 这是我的代码的最后一部分,如果你想重现我的错误,你可以在这里找到 complete source code: _CSV_COLUMN_DEFAULTS = [[0],[0]]_CSV_COLUMNS = ['x','y']def input_fn(data_file): def pars
我正在学习本教程,但未能为在y = x之上生成的数据集构建线性回归量.
这是我的代码的最后一部分,如果你想重现我的错误,你可以在这里找到 complete source code: _CSV_COLUMN_DEFAULTS = [[0],[0]] _CSV_COLUMNS = ['x','y'] def input_fn(data_file): def parse_csv(value): print('Parsing',data_file) columns = tf.decode_csv(value,record_defaults=_CSV_COLUMN_DEFAULTS) features = dict(zip(_CSV_COLUMNS,columns)) labels = features.pop('y') return features,labels # Extract lines from input files using the Dataset API. dataset = tf.data.TextLineDataset(data_file) dataset = dataset.map(parse_csv) iterator = dataset.make_one_shot_iterator() features,labels = iterator.get_next() return features,labels x = tf.feature_column.numeric_column('x') base_columns = [x] model_dir = tempfile.mkdtemp() model = tf.estimator.LinearRegressor(model_dir=model_dir,feature_columns=base_columns) model = model.train(input_fn=lambda: input_fn(data_file=file_path)) 不知何故,此代码将失败并显示错误消息 ValueError: Feature (key: x) cannot have rank 0. Give: Tensor("IteratorGetNext:0",shape=(),dtype=int32,device=/device:CPU:0) 由于tensorflow的性质,我发现基于给定的消息检查它真正出错的地方有点困难.任何帮助将不胜感激,谢谢! 解决方法
据我所知,值的第一个维度是batch_size.因此,当input_fn返回数据时,它应该作为批处理返回数据.
一旦您将数据作为批处理返回,它就可以工作,例如: dataset = tf.data.TextLineDataset(data_file) dataset = dataset.map(parse_csv) dataset = dataset.batch(10) # or any other batch size (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |