2016-07-06 8 views
11

मैं उन चरों को देखना चाहता हूं जो उनके मूल्यों के साथ एक tensorflow चेकपॉइंट में सहेजे गए हैं। मैं उन परिवर्तनीय नामों को कैसे ढूंढ सकता हूं जो एक tensorflow चेकपॉइंट में सहेजे गए हैं?एक tensorflow चेकपॉइंट में सहेजे गए चर नामों को कैसे ढूंढें?

संपादित करें:

मैं tf.train.NewCheckpointReader जो here समझाया गया है इस्तेमाल किया। लेकिन, यह tensorflow के प्रलेखन में नहीं दिया गया है। क्या कोई और तरीका है?

`

import tensorflow as tf 
    v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name="v0") 
    v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]], dtype=tf.float32, 
        name="v1") 
    init_all_op = tf.initialize_all_variables() 
    save = tf.train.Saver({"v0": v0, "v1": v1}) 
    checkpoint_path = os.path.join(model_dir, "model.ckpt")  

    with tf.Session() as sess: 
     sess.run(init_all_op) 
     # Saves a checkpoint.  
     save.save(sess, checkpoint_path) 

     # Creates a reader. 
     reader = tf.train.NewCheckpointReader(checkpoint_path) 
     print('reder:\n', reader) 

     # Verifies that the tensors exist. 
     print('is exist v0?', reader.has_tensor("v0")) 
     print('is exist v1?', reader.has_tensor("v1")) 

     # Verifies that debug string contains the right strings. 
     debug_string = reader.debug_string() 
     print('\n All Variables: \n', debug_string) 

     # Verifies get_variable_to_shape_map() returns the correct information. 
     var_map = reader.get_variable_to_shape_map() 
     print('\n All Variables information :\n', var_map) 

     # Verifies get_tensor() returns the tensor value. 
     v0_tensor = reader.get_tensor("v0") 
     v1_tensor = reader.get_tensor("v1") 
     print('\n returns the v0 tensor value:\n', v0_tensor) 
     print('\n returns the v1 tensor value:\n', v1_tensor) 

`

+0

मैंने देखा कि आपने जवाब स्वीकार कर लिया है। इस प्रकार, कोड 'print_tensors_in_checkpoint_file' चलाने के लिए आपने जो कोड लिखा था, वह मैं क्या उपयोग करने की कोशिश कर रहा था, लेकिन जब भी मैं 'tf.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file' पायथन कहता हूं कि मॉड्यूल' tensorflow.python' में कोई नहीं है गुण 'उपकरण'। मुझे लगता है कि अगर आप इस फ़ंक्शन को चलाने के तरीके की एक छोटी उदाहरण स्क्रिप्ट प्रदान करते हैं तो यह बहुत उपयोगी होगा (क्योंकि वह फ़ाइल या तो उदाहरण प्रदान नहीं करती है), विशेष रूप से जब आपने उत्तर स्वीकार कर लिया है तो मुझे लगता है कि आपके लिए कुछ काम किया गया है। – Pinocchio

उत्तर

4

आप inspect_checkpoint.py उपकरण का उपयोग कर सकते हैं।

+2

मैं इसका उपयोग करने की कोशिश कर रहा था लेकिन जब भी मैं 'tf.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file' पायथन कहता हूं कि मॉड्यूल' tensorflow.python' में कोई विशेषता 'टूल' नहीं है। मुझे लगता है कि अगर आप इस फ़ंक्शन को चलाने के तरीके की एक छोटी उदाहरण स्क्रिप्ट प्रदान करते हैं तो टीआई बेहद सहायक होगी (क्योंकि वह फ़ाइल या तो उदाहरण प्रदान नहीं करती है) – Pinocchio

19

उदाहरण उपयोग:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file 
checkpoint_path = os.path.join(model_dir, "model.ckpt") 

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80] 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='') 

# List contents of v0 tensor. 
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0') 

# List contents of v1 tensor. 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1') 

अद्यतन:all_tensors तर्क print_tensors_in_checkpoint_file को जोड़ा गया है के बाद से Tensorflow 0.12.0-rc0 ताकि आप all_tensors=False या all_tensors=True जोड़ने के लिए यदि आवश्यक हो तो पड़ सकता है।

वैकल्पिक विधि:

from tensorflow.python import pywrap_tensorflow 
checkpoint_path = os.path.join(model_dir, "model.ckpt") 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 
var_to_shape_map = reader.get_variable_to_shape_map() 
for key in var_to_shape_map: 
    print("tensor_name: ", key) 
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names 

आशा है कि यह मदद करता है।

+0

वास्तव में सहायक, धन्यवाद! – allen

1

ऊपर जवाब देने के लिए जोड़ा जा रहा है:

मॉडल V2 प्रारूप का उपयोग कर बचाया गया है, तो

model-10000.data-00000-of-00001 
model-10000.index 
model-10000.meta 

आपका चौकी इनपुट नाम केवल उपसर्ग

print_tensors_in_checkpoint_file(file_name='/home/RNN/models/model_10000', tensor_name='',all_tensors=True) 

स्रोत होना चाहिए: @LingjiaDeng द्वारा https://github.com/tensorflow/tensorflow/issues/7696 पर

संबंधित मुद्दे