2017-09-24 24 views
8

मैं टेंसरफ्लो के साथ कुछ प्रयोग कर रहा हूं, और एक झटके में भाग गया हूं। मैं एक मॉडल में बदलाव को evalute करने के लिए टीएफ का उपयोग करने की कोशिश कर रहा हूँ, फिर या तो हानि समारोह में परिणामी परिवर्तन के आधार पर, मॉडल को बरकरार या वापस लाएं। मुझे मुश्किल हिस्सा (सशर्त नियंत्रण) मिला है, लेकिन मैं कुछ ऐसी चीज पर फंस गया हूं जो काफी सरल होना चाहिए: मैं एक पुनरावृत्ति के लिए tf.trainable_variables स्टोर नहीं कर सकता, फिर आवश्यक होने पर इसे पुनर्स्थापित कर सकता हूं।डिस्क पर मूल्य को सहेजे बिना, मैं पिछले मूल्य पर टेंसर को कैसे पुनर्स्थापित कर सकता हूं?

के निर्माण का कहना है कि एक Op करते हैं:

... 
store_trainable_vars = [] 

for v in tf.trainable_variables(): 

    store_trainable_vars.append(v) 
... 

फिर बाद में, मैं यह है कि वह था जब इस Op पिछले रन था tf.trainable_variables पुनर्स्थापित करना चाहते हैं। ,

def reject_move(): 

    revert_state = [] 

    for (v, s) in zip(tf.trainable_variables(), store_trainable_vars): 

     revert_state.append(tf.assign(v, s, name="revert_state")) 

    return(revert_state) 

जाहिर है, इस फिर से मूल्यांकन होगा store_trainable_vars जो tf.trainable_variables() के वर्तमान मूल्य, revert_state Op obviating के लिए बारी लिंक में: मैं की तरह कुछ करना चाहता हूँ चाहते हैं। मुझे उन टेंसर के वर्तमान मूल्य पर वापस कॉल किए बिना टेंसर के मूल्य को स्टोर और पुनर्प्राप्त करने के कुछ तरीके की आवश्यकता है। जैसे

... 
store_trainable_vars = [] 

for v in tf.trainable_variables(): 

    store_trainable_vars.append(v.value_right_now()) 
... 

जहां v.value_right_now() एक निरंतर कि ओवरराइट जब तक परिवर्तन नहीं होगा रिटर्न कुछ।

मुझे पता है कि मैं सेवर का उपयोग कर सकता हूं, लेकिन वह समाधान डिस्क पर लिखता है, जो इस एप्लिकेशन के लिए स्वीकार्य नहीं है क्योंकि यह एक प्रशिक्षण लूप के अंदर चलाएगा।

मुझे शायद कुछ स्पष्ट याद आ रहा है - किसी भी मार्गदर्शन की सराहना की जाएगी।

उत्तर

1

यह मेरा मूल उद्देश्य इस प्रश्न का उत्तर देने का मेरा इरादा नहीं था, लेकिन मैं एक ऐसी विधि के साथ आया हूं जो काफी अच्छी तरह से काम करता है। तो, मैंने सोचा कि मैं इसे साझा करूंगा। मुख्य अंतर्दृष्टि this बहुत चालाक उत्तर से आई थी। दृष्टिकोण इनलाइन वैरिएबल असाइनमेंट के लिए बनाए गए असाइनमेंट नोड्स का पुन: उपयोग करना है। उस दृष्टिकोण को लागू करने वाली एक पूर्ण कक्षा नीचे दी गई है।

import tensorflow as tf 


class TensorFlowState(object): 

    def __init__(self): 

     # Get the graph. 
     graph = tf.get_default_graph() 

     # Extract the global varibles from the graph. 
     self.gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 

     # Exract the Assign operations for later use. 
     self.assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") 
          for v in self.gvars] 

     # Extract the initial value ops from each Assign op for later use. 
     self.init_values = [op.inputs[1] for op in self.assign_ops] 

    def start(self, sess): 

     self.sess = sess 

    def store(self): 

     # Record the current state of the TF global varaibles 
     self.state = self.sess.run(self.gvars) 

    def restore(self): 
    # Create a dictionary of the iniailizers and stored state of globals. 
    feed_dict = {init_value: val 
       for init_value, val in zip(self.init_values, self.state)} 

    # Use the initializer ops for each variable to load the stored values. 
    return(self.sess.run(self.assign_ops, feed_dict=feed_dict)) 

का उपयोग करने के लिए बस वर्ग का दृष्टांत, फोन start विधि एक tf.Session पारित करने के लिए, और के रूप में अपने जरूरी प्रशिक्षण पाश अंदर की जरूरत फोन store और restore तरीकों। मैंने ऑप्टिमाइज़र बनाने के लिए इस कार्यान्वयन का उपयोग किया है, जो टेंसरफ्लो के साथ ढाल वाले मूल अनुकूलक के रूप में तेज़ी से चलता है।

5

मैन्युअल रूप से एक ग्राफ राज्य पुनर्स्थापित करने के लिए आप tf.tuple या tf.group आपरेशन, कि एक थोक बदलाव के लिए प्रवाह को संशोधित करेगा उपयोग करने की आवश्यकता:

यह tensors मान जैसे ही तर्क के साथ tensors के एक टपल बनाता है , सिवाय इसके कि प्रत्येक टेंसर का मान केवल के बाद लौटाया गया है, सभी टेंसर के मानों की गणना की गई है।

[अपडेट] यहाँ कैसे मैं यह कर देगा:

import numpy as np 
import tensorflow as tf 

x = tf.placeholder(shape=[None, 5], dtype=tf.float32, name='x') 
W = tf.Variable(np.zeros([5, 5]), dtype=tf.float32, name='W') 
b = tf.Variable(np.zeros([5]), dtype=tf.float32, name='b') 
y = tf.add(tf.matmul(x, W), b) 

with tf.Session() as session: 
    batch = np.ones([2, 5]) 
    session.run(tf.global_variables_initializer()) 
    print session.run(y, feed_dict={x: batch})  # prints [2, 5] zeros 

    # store the current value 
    store = {v.name: v.eval(session) for v in tf.trainable_variables()} 
    print store          # prints [5, 5] and [5] zeros 

    # update 
    new = {'W:0': np.ones([5, 5]), 'b:0': np.ones([5])} 
    session.run(tf.tuple([tf.assign(var, new[var.name]) for var in tf.trainable_variables()])) 
    print session.run(y, feed_dict={x: batch})  # prints [2, 5] sixes 

    # restore 
    session.run(tf.tuple([tf.assign(var, store[var.name]) for var in tf.trainable_variables()])) 
    print session.run(y, feed_dict={x: batch})  # prints [2, 5] zeros again 

लेकिन मैं वास्तव में, लगता है कि आप अपने निर्णय के बारे में Saver पर पुनर्विचार करना चाहिए, क्योंकि यह एक प्रशिक्षण पाश अंदर के रूप में अच्छी तरह से इस्तेमाल किया जा डिजाइन किया गया था । आंतरिक रूप से, Saver आपके लिए सभी कठिन काम करता है (विशेष रूप से, यदि आवश्यक हो तो यह ओप कॉल tf.group और tf.control_dependencies को पुनर्स्थापित करता है), जो अन्यथा बहुत खराब बग का स्रोत बन सकता है। इसके अलावा, डिस्क आपके जीपीयू और मुख्य मेमोरी से हमेशा (लगभग) बड़ी होती है, इसलिए यदि आप मॉडल को मेमोरी में स्टोर कर सकते हैं, तो आपको डिस्क पर भी स्टोर करने में सक्षम होना चाहिए।

यहाँ some parameters कि डिस्क पर चौकी फ़ाइलों के प्रसार को नियंत्रित करने में मदद कर रहे हैं:

  • max_to_keep रखने के लिए हाल ही में चौकी फ़ाइलों की अधिकतम संख्या को इंगित करता है। जैसे-जैसे नई फाइलें बनाई जाती हैं, पुरानी फाइलें हटा दी जाती हैं। यदि कोई नहीं या 0, सभी चेकपॉइंट फाइलें रखी जाती हैं। 5 तक डिफ़ॉल्ट (यानी, 5 सबसे हालिया चेकपॉइंट फ़ाइलों को रखा जाता है)।
  • keep_checkpoint_every_n_hours: हालिया max_to_keep चेकपॉइंट फ़ाइलों को रखने के अलावा, आप प्रत्येक एन घंटे के प्रशिक्षण के लिए एक चेकपॉइंट फ़ाइल रखना चाह सकते हैं। यह उपयोगी हो सकता है यदि आप बाद में विश्लेषण करना चाहते हैं कि एक लंबे प्रशिक्षण सत्र के दौरान मॉडल कैसे प्रगति करता है। उदाहरण के लिए, keep_checkpoint_every_n_hours=2 गुजरने से यह सुनिश्चित होता है कि आप प्रशिक्षण के हर 2 घंटे के लिए एक चेकपॉइंट फ़ाइल रखें। 10,000 घंटे का डिफ़ॉल्ट मान प्रभावी रूप से सुविधा को अक्षम करता है।

[अपडेट] टिप्पणी में स्पष्ट किया के रूप में, मुख्य चिंता डिस्क विलंबता, कि नीचे प्रशिक्षण को धीमा कर सकता है, तो भी अक्सर पहुँचा है। यदि आप लिनक्स का उपयोग कर रहे हैं, तो यह caches अक्सर डिस्क पेज, Windows does it भी उपयोग किया जाता है। लेकिन अगर आप पूरी तरह से सुनिश्चित होना चाहते हैं, तो tmpfs का उपयोग करने पर विचार करें।

+0

मुझे स्पष्ट करना चाहिए: जब मैंने कहा कि मैं डिस्क पर लिखना नहीं चाहता था, तो ऐसा नहीं था क्योंकि मैं अंतरिक्ष के बारे में चिंतित था। यह भंडारण और बहाली, हर पुनरावृत्ति पर, सबसे बुरे मामले में होगी। डिस्क पर वापस पहुंचने के लिए यह चलने का समय जुर्माना है जिसे मैं टालने की कोशिश कर रहा हूं। क्या आप ग्राफ जवाब बहाल करने के लिए 'tf.group' के एक छोटे से उपयोग को प्रदर्शित करने के बजाय अपना उत्तर संपादित कर सकते हैं? (या सिर्फ इस तरह के एक उदाहरण से लिंक) –

संबंधित मुद्दे

 संबंधित मुद्दे