python - How do I set TensorFlow RNN state when state_is_tuple=True? -
i have written rnn language model using tensorflow. model implemented rnn
class. graph structure built in constructor, while rnn.train
, rnn.test
methods run it.
i want able reset rnn state when move new document in training set, or when want run validation set during training. managing state inside training loop, passing graph via feed dictionary.
in constructor define the rnn so
cell = tf.nn.rnn_cell.lstmcell(hidden_units) rnn_layers = tf.nn.rnn_cell.multirnncell([cell] * layers) self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32) self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state") self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=true, initial_state=self.state)
the training loop looks this
document in document: state = session.run(self.reset_state) x, y in document: _, state = session.run([self.train_step, self.next_state], feed_dict={self.x:x, self.y:y, self.state:state})
x
, y
batches of training data in document. idea pass latest state along after each batch, except when start new document, when 0 out state running self.reset_state
.
this works. want change rnn use recommended state_is_tuple=true
. however, don't know how pass more complicated lstm state object via feed dictionary. don't know arguments pass self.state = tf.placeholder(...)
line in constructor.
what correct strategy here? there still isn't example code or documentation dynamic_rnn
available.
tensorflow issues 2695 , 2838 appear relevant.
a blog post on wildml addresses these issues doesn't directly spell out answer.
see tensorflow: remember lstm state next batch (stateful lstm).
one problem tensorflow placeholder can feed python list or numpy array (i think). can't save state between runs in tuples of lstmstatetuple.
i solved saving state in tensor this
initial_state = np.zeros((num_layers, 2, batch_size, state_size))
you have 2 components in lstm layer, cell state , hidden state, thats "2" comes from. (this article great: https://arxiv.org/pdf/1506.00019.pdf)
when building graph unpack , create tuple state this:
state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size]) l = tf.unpack(state_placeholder, axis=0) rnn_tuple_state = tuple( [tf.nn.rnn_cell.lstmstatetuple(l[idx][0],l[idx][1]) idx in range(num_layers)] )
then new state usual way
cell = tf.nn.rnn_cell.lstmcell(state_size, state_is_tuple=true) cell = tf.nn.rnn_cell.multirnncell([cell] * num_layers, state_is_tuple=true) outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)
it shouldn't this... perhaps working on solution.
Comments
Post a Comment