2016-08-24 14 views
10

मैंने RNN language model using TensorFlow लिखा है। मॉडल को RNN वर्ग के रूप में कार्यान्वित किया गया है। ग्राफ संरचना कन्स्ट्रक्टर में बनाई गई है, जबकि RNN.train और RNN.test विधियां इसे चलाती हैं।जब मैं state_is_tuple = True करता हूं तो मैं टेंसरफ्लो आरएनएन स्थिति कैसे सेट करूं?

जब मैं प्रशिक्षण सेट में किसी नए दस्तावेज़ में जाता हूं, या जब मैं प्रशिक्षण के दौरान सत्यापन सेट चलाता हूं तो मैं आरएनएन स्थिति को रीसेट करने में सक्षम होना चाहता हूं। मैं इसे प्रशिक्षण लूप के अंदर राज्य के प्रबंधन से करता हूं, इसे फीड डिक्शनरी के माध्यम से ग्राफ में पास करता हूं।

निर्माता में मैं जैसे RNN को परिभाषित तो

cell = tf.nn.rnn_cell.LSTMCell(hidden_units) 
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers) 
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32) 
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state") 
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True, 
                initial_state=self.state) 

प्रशिक्षण पाश की तरह इस

for document in document: 
    state = session.run(self.reset_state) 
    for x, y in document: 
      _, state = session.run([self.train_step, self.next_state], 
           feed_dict={self.x:x, self.y:y, self.state:state}) 

x और y एक दस्तावेज़ में प्रशिक्षण डेटा के बैच हैं लग रहा है। विचार यह है कि मैं प्रत्येक बैच के बाद नवीनतम स्थिति पास करता हूं, सिवाय जब मैं एक नया दस्तावेज़ शुरू करता हूं, जब मैं self.reset_state चलाकर राज्य को शून्य करता हूं।

यह सब काम करता है। अब मैं अनुशंसित state_is_tuple=True का उपयोग करने के लिए अपना आरएनएन बदलना चाहता हूं। हालांकि, मुझे नहीं पता कि एक फीड डिक्शनरी के माध्यम से अधिक जटिल एलएसटीएम राज्य वस्तु को कैसे पारित किया जाए। इसके अलावा मुझे नहीं पता कि मेरे कन्स्ट्रक्टर में self.state = tf.placeholder(...) लाइन पर कौन से तर्क पारित किए गए हैं।

यहां सही रणनीति क्या है? dynamic_rnn के लिए अभी भी बहुत उदाहरण कोड या दस्तावेज उपलब्ध नहीं है।


TensorFlow मुद्दों 2695 और 2838 प्रासंगिक लगें।

WILDML पर blog post इन मुद्दों को संबोधित करता है लेकिन सीधे जवाब का वर्णन नहीं करता है।

TensorFlow: Remember LSTM state for next batch (stateful LSTM) भी देखें।

+0

'rnn_cell._unpacked_state' और' rnn_cell._packed_state' देखें। लूप फ़ंक्शन में तर्क टेंसर की सूची के रूप में राज्य को पास करने के लिए इन्हें 'rnn._dynamic_rnn_loop()' में उपयोग किया जाता है। – JunkMechanic

+0

मुझे नवीनतम TensorFlow स्रोत में तारों को _ _unpacked_state' और '_packed_state' दिखाई नहीं देता है। क्या ये नाम बदल गए हैं? –

+0

हम्म। उन्हें हटा दिया गया है। इसके बजाए, एक नया मॉड्यूल 'tf.python.util.nest' अनुरूप 'flatten' और' pack_sequence_as' के साथ पेश किया गया है। – JunkMechanic

उत्तर

13

टेंसफोर्लो प्लेसहोल्डर के साथ एक समस्या यह है कि आप इसे केवल पाइथन सूची या नम्पी सरणी (मुझे लगता है) के साथ ही खिला सकते हैं। तो आप LSTMStateTuple के tuples में रनों के बीच राज्य को सहेज नहीं सकते हैं।

मैं इस

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

आप एक LSTM परत में दो घटक है की तरह एक टेन्सर में राज्य को बचाने के द्वारा इस हल, सेल राज्य और छिपा राज्य, thats क्या "2 " से आता है।

जब ग्राफ आप खोल के निर्माण और इस तरह टपल राज्य बनाने: (https://arxiv.org/pdf/1506.00019.pdf इस लेख बहुत अच्छा है): तो फिर

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size]) 
l = tf.unpack(state_placeholder, axis=0) 
rnn_tuple_state = tuple(
     [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1]) 
      for idx in range(num_layers)] 
) 

आप नए राज्य हमेशा की तरह मिलता है

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True) 
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True) 

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state) 

यह ऐसा नहीं होना चाहिए ... शायद वे एक समाधान पर काम कर रहे हैं।

+0

यदि आपके पास केवल एक परत है, तो यह 'state_placeholder = tf.placeholder (tf.float32, [2, batch_size, state_size]) 'और 'itial_state = np.zeros ((2, batch_size, state_size) बन जाता है)'? – Lukeyb

1

आरएनएन राज्य में खिलाने का एक आसान तरीका बस राज्य के दोनों घटकों में खिलाना है।

# Constructing the graph 
self.state = rnn_cell.zero_state(...) 
self.output, self.next_state = tf.nn.dynamic_rnn(
    rnn_cell, 
    self.input, 
    initial_state=self.state) 

# Running with initial state 
output, state = sess.run([self.output, self.next_state], feed_dict={ 
    self.input: input 
}) 

# Running with subsequent state: 
output, state = sess.run([self.output, self.next_state], feed_dict={ 
    self.input: input, 
    self.state[0]: state[0], 
    self.state[1]: state[1] 
}) 
संबंधित मुद्दे

 संबंधित मुद्दे