2016-08-15 8 views
7

मैंने के साथ Tensorflow Tutorial में उदाहरण के समान tf.nn.seq2seq.model_with_buckets प्रशिक्षित किया।टेन्सफोर्लो: model_with_buckets मॉडल में freeze_graph.py के लिए "output_node_names" क्या हैं?

अब मैं freeze_graph.py का उपयोग करके ग्राफ को फ्रीज करना चाहता हूं। मैं अपने मॉडल में "output_node_names" कैसे ढूंढ सकता हूं?

उत्तर

3

आप वैकल्पिक name="myname" तर्क को पारित करके अपने मॉडल में नोड्स के नामों का चयन कर सकते हैं जो कि किसी नोड का निर्माण करने वाले किसी भी Tensorflow ऑपरेटर को तर्क देता है। यदि आप उन्हें निर्दिष्ट नहीं करते हैं तो Tensorflow स्वचालित रूप से ग्राफ नोड्स के लिए नाम चुनेंगे, लेकिन यदि आप उन नोड्स को freeze_graph.py जैसे टूल पर पहचानना चाहते हैं, तो अपने नामों को चुनना सबसे अच्छा है। वे नाम हैं जिन्हें आप output_node_names पर पास करते हैं।

+0

धन्यवाद, मुझे लगता है कि मुझे अब मूल अवधारणा मिल गई है, लेकिन मैं अभी भी model_with_buckets के साथ संघर्ष कर रहा हूं ... मेरे पास 4 बाल्टी हैं और प्रत्येक में आरएनएन है, फिर एम्बेडिंग_एटिशन_डिक्डर और फिर अनुक्रम_लॉस, प्रत्येक में कई सेल्स शामिल हैं, और मैं उन्हें सभी नाम दे सकता हूं। लेकिन मैं एक आउटपुट नोड नहीं देख सकता, कि मैं "output_node_names" को पास कर सकता हूं। क्या मुझे seq2seq.py कोड के कोड में अंतिम नोड की तरह जोड़ने की आवश्यकता है? और क्या मेरे पास "output_node_names" या चार के लिए एक आउटपुट नोड है? – WS91

+0

पास करने के लिए कौन से नोड्स इस बात पर निर्भर करते हैं कि आप गणना करना चाहते हैं। उदाहरण के लिए, एक अनुमान ग्राफ के लिए, आप आमतौर पर केवल एक नोड का उपयोग करेंगे जो मॉडल के आउटपुट का प्रतिनिधित्व करता है। असल में --- आपको उस नोड का उपयोग करना चाहिए जिसे आप session.run() पर अपने ग्राफ को चलाने के दौरान पास करना चाहते हैं। (आपको अपने ग्राफ में सबकुछ नाम देने की आवश्यकता नहीं है, केवल एक या दो नोड जिनके मूल्य आप जानना चाहते हैं।) –

+0

यह अभी भी स्पष्ट नहीं है कि इस विशेष seq2seq आलेख को कैसे फ्रीज़ करें –

1

आप की तरह कुछ के साथ अपने मॉडल में नोड नाम के सभी प्राप्त कर सकते हैं:

node_names = [node.name for node in tf.get_default_graph.as_graph_def().node] 

या ग्राफ बहाल करने के साथ:

saver = tf.train.import_meta_graph(/path/to/meta/graph) 
sess = tf.Session() 
saver.resore(sess, /path/to/checkpoints) 
graph = sess.graph 
print([node.name for node in graph.as_graph_def().node]) 

आप इन फिल्टर करने के लिए केवल प्राप्त करने के लिए आवश्यकता हो सकती है आपके आउटपुट नोड्स, या केवल नोड्स जो आप चाहते हैं, लेकिन इससे कम से कम आपको उस ग्राफ के नाम प्राप्त करने में सहायता मिल सकती है जिसे आपने पहले ही प्रशिक्षित किया है और प्रत्येक नोड के लिए name='some_name' के साथ परिभाषित करने का जोखिम नहीं उठा सकता है।

आदर्श रूप से, आप प्रत्येक ऑपरेशन या टेंसर के लिए name पैरामीटर को परिभाषित करना चाहते हैं जिसे आप बाद में एक्सेस करना चाहते हैं।

+1

ठीक है, seq2seq मॉडल में हजारों हैं उन्हें। और वे बाल्टी पर निर्भर करते हैं। –

+0

यह ग्राफ़ में प्रत्येक नोड नाम उत्पन्न करता है, आवश्यक रूप से आउटपुट नहीं। – rambossa

संबंधित मुद्दे