加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

分层注意网络中的输入层代表什么

发布时间:2020-12-17 17:41:10 所属栏目:Python 来源:网络整理
导读:我正在尝试掌握分层注意力网络(HAN)的概念,我在网上找到的大多数代码或多或少与此处的代码相似:https://medium.com/jatana/report-on-text-classification-using-cnn-rnn-han-f0e887214d5f: embedding_layer=Embedding(len(word_index)+1,EMBEDDING_DIM,we

我正在尝试掌握分层注意力网络(HAN)的概念,我在网上找到的大多数代码或多或少与此处的代码相似:https://medium.com/jatana/report-on-text-classification-using-cnn-rnn-han-f0e887214d5f:

embedding_layer=Embedding(len(word_index)+1,EMBEDDING_DIM,weights=[embedding_matrix],input_length=MAX_SENT_LENGTH,trainable=True)
sentence_input = Input(shape=(MAX_SENT_LENGTH,),dtype='int32',name='input1')
embedded_sequences = embedding_layer(sentence_input)
l_lstm = Bidirectional(LSTM(100))(embedded_sequences)
sentEncoder = Model(sentence_input,l_lstm)

review_input = Input(shape=(MAX_SENTS,MAX_SENT_LENGTH),name='input2')
review_encoder = TimeDistributed(sentEncoder)(review_input)
l_lstm_sent = Bidirectional(LSTM(100))(review_encoder)
preds = Dense(len(macronum),activation='softmax')(l_lstm_sent)
model = Model(review_input,preds)

我的问题是:这里的输入层代表什么?我猜测input1代表用嵌入层包装的句子,但是在这种情况下,input2是什么?它是sendEncoder的输出吗?在这种情况下,它应该是浮点数,或者如果它是嵌入单词的另一层,那么它也应该被嵌入层包裹.

最佳答案
HAN模型以层次结构处理文本:它采用已经拆分为句子的文档(这就是input2的形状为(MAX_SENTS,MAX_SENT_LENGTH)的原因);然后使用sendEncoder模型独立处理每个句子(这就是input1的形状为(MAX_SENT_LENGTH,)的原因),最后将所有编码的句子一起处理.

因此,在您的代码中,整个模型都存储在model中,其输入层是input2,您将输入已拆分为句子且其词已进行整数编码(使其与嵌入层兼容)的文档.另一个输入层属于sendEncoder模型,该模型在模型内部使用(并非直接由您使用):

review_encoder = TimeDistributed(sentEncoder)(review_input)

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读