2016-03-24 15 views
34

द्वारा प्रशिक्षित मॉडल में कुछ वजन का मूल्य प्राप्त करें मैंने टेंसरफ्लो के साथ एक कॉन्वनेट मॉडल को प्रशिक्षित किया है, और मैं परत में एक विशेष वजन प्राप्त करना चाहता हूं। उदाहरण के लिए torch7 में मैं बस model.modules[2].weights तक पहुंच सकता हूं। परत 2 के वजन प्राप्त करने के लिए। मैं टेंसरफ्लो में एक ही चीज़ कैसे करूं?टेंसरफ्लो

उत्तर

54

टेंसरफ्लो में, प्रशिक्षित वजन tf.Variable वस्तुओं द्वारा दर्शाया जाता है। यदि आपने tf.Variable — उदा। v — स्वयं कहा जाता है, तो आप sess.run(v) पर कॉल करके न्यूमपी सरणी के रूप में अपना मान प्राप्त कर सकते हैं (जहां sesstf.Session है)।

यदि आपके पास वर्तमान में tf.Variable पर कोई पॉइंटर नहीं है, तो आप tf.trainable_variables() पर कॉल करके वर्तमान ग्राफ में प्रशिक्षित चर की एक सूची प्राप्त कर सकते हैं। यह फ़ंक्शन वर्तमान ग्राफ़ में सभी ट्रेन करने योग्य tf.Variable ऑब्जेक्ट्स की एक सूची देता है, और आप v.name संपत्ति से मेल खाते हुए एक को चुन सकते हैं। उदाहरण के लिए:

# Desired variable is called "tower_2/filter:0". 
var = [v for v in tf.trainable_variables() if v.name == "tower_2/filter:0"][0] 
+0

बहुत धन्यवाद @mrry, अगर मैं Tensorflow द्वारा किसी भी मॉडल चिड़ियाघर समर्थन से संबंधित मॉडल लोड मैं एक ही समारोह मैंने कोशिश की के साथ अपने trainable मापदंडों का उपयोग कर सकते हैं, लेकिन यह खाली मैट्रिक्स लौट आए। कोई जवाब कृपया –

+3

यह मॉडल लोड करने के लिए उपयोग की गई तंत्र पर निर्भर करता है। यदि आप नए 'tf.train.import_meta_graph() 'का उपयोग करते हैं तो' tf.trainable_variables()' को काम करना चाहिए। यदि आप निचले स्तर 'tf.import_graph_def()' फ़ंक्शन का उपयोग करते हैं, तो आपको 'return_elements' वैकल्पिक तर्क में चर के नाम को पास करना होगा, और एक टेंसर वापस कर दिया जाएगा (जिसे आप 'sess.run () ' – mrry

+0

बहुत सारे धन्यवाद –

संबंधित मुद्दे