LSTM理解与pytorch使用

    技术2022-07-10  177

    LSTM理解与pytorch使用

    引言LSTM结构总体结构详细结构 Pytorch用法参数介绍使用实例获取中间各层的隐藏层信息关于变长输入

    引言

    LSTM应该说是每一个做机器学习的人都绕不开的东西,它的结构看起来复杂,但是充分体现着人脑在记忆过程中的特征,下面本文将介绍一下LSTM的结构以及pytorch的用法。

    LSTM结构

    总体结构

    首先,LSTM主要用来处理带有时序信息的数据,包括视频、句子,它将人脑的对于不同time step的记忆过程理解为一连串的cell分别对不同的时刻输入信息的处理。

    详细结构

    一个典型的 LSTM 结构可以分别从输入、处理和输出三个角度来解析:

    输入: 输入包含三个部分,分别是 cell 的信息𝐶t-1,它代表历史的记忆细胞(cell)状态信息的汇总;隐藏层的信息ht-1, 它是提取到的上个时刻的特征信息; 以及当前的输入𝑥t。处理: 处理部分主要是由遗忘门、输入门、输出门组成。遗忘门由当前的输入和隐藏层信息控制对于历史的 cell 信息的遗忘程度;输入门是决定当前的输入和隐藏 层信息的利用程度;输出门是由当前的 cell 状态和输入决定输出。输出: 分别是当前的 cell 状态𝐶’和当前的隐藏层信息h’。

    遗忘门: 输入门: 细胞状态更新: 输出门:

    Pytorch用法

    参数介绍

    class torch.nn.LSTM(*args, **kwargs)

    参数:

    input_size:输入的特征维度hidden_size:隐藏层的特征维度(即输出的特征维度)num_layers:LSTM隐层的层数,默认为1bias:False则bih=0和bhh=0. 默认为Truebatch_first:True则输入输出的数据格式为 (batch, seq, feature)dropout:除最后一层,每一层的输出都进行dropout,默认为: 0bidirectional:True则为双向LSTM,默认为False

    输入:input, (h0, c0) 输入数据格式: input(seq_len, batch, input_size) seq_len可以理解为一个视频有多少帧或者一个句子有多少单词,input_size就是一个帧或者一个单词可以用多少维的特征向量表示。 h0(num_layers * num_directions, batch, hidden_size) c0(num_layers * num_directions, batch, hidden_size)

    输出:output, (hn, cn) 输出数据格式: output(seq_len, batch, hidden_size * num_directions) hn(num_layers * num_directions, batch, hidden_size) cn(num_layers * num_directions, batch, hidden_size)

    使用实例

    rnn = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)#(input_size,hidden_size,num_layers) input = torch.randn(5, 3, 10)#(seq_len, batch, input_size) h0 = torch.randn(2, 3, 20) #(num_layers,batch,output_size) c0 = torch.randn(2, 3, 20) #(num_layers,batch,output_size) output, (hn, cn) = rnn(input, (h0, c0))

    output.shape #(seq_len, batch, output_size2) torch.Size([5, 3, 40]) hn.shape #(num_layers2, batch, output_size) torch.Size([2, 3, 20])

    获取中间各层的隐藏层信息

    lstm = nn.LSTM(3, 3) inputs = [torch.randn(1, 3) for _ in range(5)] # 这里的inputs的大小是一个含有5个1*3的tensor的列表,可以理解为一个5*1*3维的输入,其中5是seq_len,1是batch_size,3是input_size的大小 # 初始化隐藏状态 hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3)) for i in inputs: # 将序列的元素逐个输入到LSTM,经过每步操作,hidden 的值包含了隐藏状态的信息 out, hidden = lstm(i.view(1, 1, -1), hidden)

    关于变长输入

    由于在视觉里变长输入的情况较少,这里只给出几个链接: 1.pytorch中如何处理RNN输入变长序列padding

    https://zhuanlan.zhihu.com/p/34418001

    2.教你几招搞定 LSTMs 的独门绝技

    https://zhuanlan.zhihu.com/p/40391002
    Processed: 0.017, SQL: 9