或者
import tensorflow as tf from tensorflow.contrib.rnn import LSTMCell cell = LSTMCell(128,state_is_tuple=False) init_state = tf.random_normal([16,128]) # 可以是上一模型的输出 output,new_state = cell(inputs=init_state,state=init_state) # new_state是[batch_size,256] mlp = tf.keras.layers.Dense(units=128) new_state = mlp(new_state) for i in range(20): output,new_state = cell(inputs=output,state=new_state) new_state = mlp(new_state) print()或者
import tensorflow as tf cell = tf.keras.layers.LSTMCell(128) init_state = tf.random_normal([16,1,128]) # 可以是上一模型的输出 output,new_state = cell(inputs=init_state,states=init_state) for i in range(20): output,new_state = cell(inputs=output,states=init_state) print()或者
import tensorflow as tf cell = tf.keras.layers.LSTMCell(128) init_state = tf.random_normal([16,20,128]) # 可以是上一模型的输出 output,new_state = cell(inputs=init_state,states=init_state) print()