2016-08-08 16 views
10

मैं बस ग्राफ को सहेजने और पुनर्स्थापित करने का प्रयास करता हूं, लेकिन सबसे सरल उदाहरण अपेक्षा के अनुसार काम नहीं करता है (यह संस्करण 0.9.0 या 0.10.0 का उपयोग लिनक्स 64 पर पाइथन 2.7 या 3.5.2 का उपयोग कर CUDA के बिना किया जाता है)tensorflow.train.import_meta_graph काम नहीं करता है?

सबसे पहले मैं इस तरह ग्राफ बचाने:

import tensorflow as tf 
v1 = tf.placeholder('float32') 
v2 = tf.placeholder('float32') 
v3 = tf.mul(v1,v2) 
c1 = tf.constant(22.0) 
v4 = tf.add(v3,c1) 
sess = tf.Session() 
result = sess.run(v4,feed_dict={v1:12.0, v2:3.3}) 
g1 = tf.train.export_meta_graph("file") 
## alternately I also tried: 
## g1 = tf.train.export_meta_graph("file",collection_list=["v4"]) 

यह एक फ़ाइल "फ़ाइल" कि गैर खाली है बनाता है और भी कुछ है कि एक उचित ग्राफ परिभाषा की तरह लग रहा करने के लिए G1 सेट।

तब मैं इस ग्राफ बहाल करने की कोशिश:

import tensorflow as tf 
g=tf.train.import_meta_graph("file") 

यह एक त्रुटि के बिना काम करता है, लेकिन सभी में कुछ भी वापस नहीं करता है।

क्या कोई भी "v4" के लिए ग्राफ को बस सहेजने के लिए आवश्यक कोड प्रदान कर सकता है और इसे पूरी तरह से पुनर्स्थापित कर सकता है ताकि इसे नए सत्र में चलाने से एक ही परिणाम मिलेगा?

उत्तर

27

MetaGraphDef का पुन: उपयोग करने के लिए, आपको अपने मूल ग्राफ में दिलचस्प टेंसर के नाम रिकॉर्ड करने की आवश्यकता होगी। उदाहरण के लिए, पहले कार्यक्रम में, v1, v2 और v4 की परिभाषा में स्पष्ट name तर्क सेट:

v1 = tf.placeholder(tf.float32, name="v1") 
v2 = tf.placeholder(tf.float32, name="v2") 
# ... 
v4 = tf.add(v3, c1, name="v4") 

उसके बाद, आप tensors की स्ट्रिंग नाम मूल ग्राफ में अपने कॉल में sess.run() का उपयोग कर सकते । उदाहरण के लिए, निम्नलिखित स्निपेट काम करना चाहिए:

import tensorflow as tf 
_ = tf.train.import_meta_graph("./file") 

sess = tf.Session() 
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) 

वैकल्पिक रूप से, आप tf.get_default_graph().get_tensor_by_name() उपयोग कर सकते हैं ब्याज की tensors, जिसे फिर आप sess.run() को पारित कर सकते हैं के लिए tf.Tensor वस्तुओं को पाने के लिए:

import tensorflow as tf 
_ = tf.train.import_meta_graph("./file") 
g = tf.get_default_graph() 

v1 = g.get_tensor_by_name("v1:0") 
v2 = g.get_tensor_by_name("v2:0") 
v4 = g.get_tensor_by_name("v4:0") 

sess = tf.Session() 
result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3}) 

अद्यतन: टिप्पणियों में चर्चा के आधार पर, यहां परिवर्तनीय सामग्री को सहेजने सहित सहेजने और लोड करने का एक पूरा उदाहरण है। यह एक अलग ऑपरेशन में परिवर्तनीय vx के मान को दोगुना करके एक चर की बचत को दर्शाता है।

सहेजा जा रहा है:

import tensorflow as tf 
v1 = tf.placeholder(tf.float32, name="v1") 
v2 = tf.placeholder(tf.float32, name="v2") 
v3 = tf.mul(v1, v2) 
vx = tf.Variable(10.0, name="vx") 
v4 = tf.add(v3, vx, name="v4") 
saver = tf.train.Saver([vx]) 
sess = tf.Session() 
sess.run(tf.initialize_all_variables()) 
sess.run(vx.assign(tf.add(vx, vx))) 
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) 
print(result) 
saver.save(sess, "./model_ex1") 

पुनर्स्थापित कर रहा है:

import tensorflow as tf 
saver = tf.train.import_meta_graph("./model_ex1.meta") 
sess = tf.Session() 
saver.restore(sess, "./model_ex1") 
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) 
print(result) 

लब्बोलुआब यह है कि, ताकि किसी सहेजे गए मॉडल का उपयोग करने के लिए, आप नोड्स के कम से कम कुछ के नाम याद रखना चाहिए (उदाहरण के लिए एक प्रशिक्षण सेशन, एक इनपुट प्लेसहोल्डर, एक मूल्यांकन टेंसर, आदि)। MetaGraphDef मॉडल में निहित चर की सूची संग्रहीत करता है, और इन्हें चेकपॉइंट से पुनर्स्थापित करने में मदद करता है, लेकिन आपको खुद को मॉडल के प्रशिक्षण/मूल्यांकन में उपयोग किए जाने वाले टेंसर/संचालन का पुनर्निर्माण करना होगा।

+0

आपको बहुत धन्यवाद, मुझे लगता है कि अब समझें कि import_meta_graph फ़ंक्शन केवल डिफ़ॉल्ट ग्राफ को अपडेट करता है और इसे कुछ भी उपयोगी नहीं करना चाहिए। साथ ही, डिफ़ॉल्ट ग्राफ से मुझे जो भी चाहिए, उसे एक्सेस करने का कोई तरीका नहीं है जब तक कि मैंने इसे मूल रूप से नाम नहीं दिया। और जाहिर है, पुनर्स्थापित करने के बाद उपयोग किए गए नाम में ": 0" किसी भी तरह से ऑपरेशन को अपने आउटपुट से डिस्टिंग करने के लिए उपयोग किया जाता है। – Johsm

+0

यह सही है।'Import_meta_graph()' से वापसी मान एक 'tf.train.Saver' है, जो केवल तभी उपयोगी होता है जब आपके ग्राफ़ में वेरिएबल्स हैं जिन्हें आप पुनर्स्थापित करना चाहते हैं। – mrry

+0

आह सही, तो मॉडल के भीतर चर के मान सहेजे नहीं जा रहे हैं और इसके द्वारा स्वचालित रूप से बहाल नहीं हो रहे हैं? मानते हैं कि v4 प्रशिक्षित वैराइबल्स की अज्ञात संख्या पर भी निर्भर करेगा, उन्हें स्टोर करने के लिए कोड क्या होगा और फिर बाद में उन्हें पुनर्स्थापित भी किया जाएगा? मेरा उदाहरण कोड बस आसान होना था, लेकिन मैं सिर्फ एक प्रशिक्षित मॉडल को सहेजना चाहता हूं, फिर इसका उपयोग करें। तो प्रशिक्षित मॉडल को बचाने का मतलब यह होगा कि मैं ग्राफ को सहेजने के समय उनके पास मौजूद सभी चर मूल्यों को सहेजना चाहता हूं और फिर बाद में इसे पुनर्स्थापित करना चाहता हूं। – Johsm

संबंधित मुद्दे