2017-03-29 12 views
6

टेन्सफोर्लो में, हम वितरित प्रशिक्षण के लिए Between-graph Replication का उपयोग करके कई टेन्सफोर्लो सत्र बना सकते हैं और बना सकते हैं। MonitoredTrainingSession() एकाधिक Tensorflow सत्रों का समन्वय करता है, और Tensorflow सत्र/ग्राफ़ को पुनर्स्थापित करने के लिए MonitoredTrainingSession() के लिए checkpoint_dir पर एक तर्क है।`पुनर्निर्देशित" और "परीक्षण मोड" के साथ `MonitoredTrainingSession()` कैसे काम करता है?

  1. हम आम तौर पर saver.restore(...) द्वारा Tensorflow रेखांकन बहाल करने के लिए tf.train.Saver() की वस्तु का उपयोग करें: अब मैं निम्नलिखित प्रश्न हैं। लेकिन MonitoredTrainingSession() का उपयोग करके हम उन्हें कैसे पुनर्स्थापित कर सकते हैं?
  2. चूंकि हम कई प्रक्रियाएं चलाते हैं और प्रत्येक प्रक्रिया प्रशिक्षण के लिए एक टेन्सफोर्लो सत्र बनाता है और बनाता है, मुझे आश्चर्य है कि हमें प्रशिक्षण के बाद परीक्षण (या भविष्यवाणी) के लिए कई प्रक्रियाएं भी चलाना पड़ता है। दूसरे शब्दों में, MonitoredTrainingSession() परीक्षण (या भविष्यवाणी) मोड के साथ कैसे काम करता है?

मैंने टेन्सफोर्लो दस्तावेज़ पढ़ा, लेकिन इन 2 प्रश्नों के उत्तर नहीं मिला। अगर किसी के पास समाधान है तो मैं वास्तव में सराहना करता हूं। धन्यवाद!

उत्तर

-1
  1. ऐसा लगता है कि बहाली आपके लिए संभाली जाती है। इंट वह डॉक्स यह कहना है कि फोन करने MonitoredTrainingSession MonitoredSession का एक उदाहरण है जो निर्माण पर रिटर्न "... पुनर्स्थापित करता चर अगर एक चौकी मौजूद है ..."

  2. बाहर चेक tf.contrib.learn.Estimator(..).predict(..) और अधिक विशेष रूप tf.contrib.learn.Estimator(..)._infer_model(..) तरीकों here और here API। वे वहां एक निगरानी सत्र भी बनाते हैं।

0

लघु जवाब:

  1. आप अनुकूलक आप mon_sess.run को पारित करने के लिए वैश्विक कदम पारित करने के लिए की जरूरत है। यह सहेजे गए चेकपॉइंट्स को सहेजने और पुनर्प्राप्त करने के लिए दोनों संभव बनाता है।
  2. एक ही मॉनीटरर्ड ट्रेनिंग सत्र के माध्यम से एक साथ प्रशिक्षण + क्रॉस सत्यापन सत्र चलाने के लिए संभव है। सबसे पहले, आपको एक ही ग्राफ की अलग-अलग धाराओं के माध्यम से प्रशिक्षण बैचों और क्रॉस सत्यापन बैचों से गुजरना होगा (मुझे सलाह है कि यह कैसे करें इसे करने के लिए this guide देखें)। दूसरा, आपको - mon_sess.run() - प्रशिक्षण स्ट्रीम के लिए एक अनुकूलक पास करना होगा, साथ ही क्रॉस सत्यापन स्ट्रीम के नुकसान (/ पैरामीटर जिसे आप ट्रैक करना चाहते हैं) के पैरामीटर को पास करना होगा। यदि आप प्रशिक्षण से अलग से एक परीक्षण सत्र चलाने के लिए चाहते हैं, तो बस ग्राफ के माध्यम से केवल परीक्षण सेट चलाएं, और ग्राफ के माध्यम से केवल test_loss (/ अन्य पैरामीटर जिन्हें आप ट्रैक करना चाहते हैं) चलाएं। यह कैसे किया जाता है के बारे में अधिक जानकारी के लिए, नीचे देखें।

लांग जवाब:

मैं (मेरा उत्तर को अपडेट करते ही मैं अपने आप को क्या tf.train.MonitoredSession साथ किया जा सकता का एक बेहतर दृश्य मिल जाएगा tf.train.MonitoredTrainingSession बस के एक विशेष संस्करण पैदा कर रही है tf.train.MonitoredSession, जैसा कि source code में देखा जा सकता है)।

निम्नलिखित उदाहरण कोड दिखा रहा है कि आप प्रत्येक 5 सेकंड में'/ckpt_dir 'पर चेकपॉइंट कैसे सहेज सकते हैं।जब बाधित, यह अपने सहेजा गया अंतिम चौकी पर पुन: प्रारंभ होगा:

  1. tf.train.MonitoredTrainingSession एक tf.train.Scaffold बनाता है:

    def train(inputs, labels_onehot, global_step): 
        out = tf.contrib.layers.fully_connected(
              inputs, 
              num_outputs=10, 
              activation_fn=tf.nn.sigmoid) 
        loss = tf.reduce_mean(
          tf.reduce_sum(
           tf.nn.sigmoid_cross_entropy_with_logits(
              logits=out, 
              labels=labels_onehot), axis=1)) 
        train_op = opt.minimize(loss, global_step=global_step) 
        return train_op 
    
    with tf.Graph().as_default(): 
        global_step = tf.train.get_or_create_global_step() 
        inputs = ... 
        labels_onehot = ... 
        train_op = train(inputs, labels_onehot, global_step) 
    
        with tf.train.MonitoredTrainingSession(
         checkpoint_dir='./ckpt_dir', 
         save_checkpoint_secs=5, 
         hooks=[ ... ] # Choose your hooks 
        ) as mon_sess: 
         while not mon_sess.should_stop(): 
          mon_sess.run(train_op) 
    

    क्या आदेश प्राप्त करने के लिए यह वास्तव में तीन बातें है में MonitoredTrainingSession में हो रहा है वस्तु, जो वेब में एक मकड़ी की तरह काम करता है; यह मॉडल को प्रशिक्षित करने, सहेजने और लोड करने के लिए आवश्यक टुकड़ों को इकट्ठा करता है।

  2. यह tf.train.ChiefSessionCreator ऑब्जेक्ट बनाता है। इसका मेरा ज्ञान सीमित है, लेकिन इसकी समझ से, इसका उपयोग तब किया जाता है जब आपके टीएफ एल्गोरिदम कई सर्वरों में फैलता है। मेरा लेना यह है कि यह कंप्यूटर को फ़ाइल चला रहा है कि यह मुख्य कंप्यूटर है, और यह है कि चेकपॉइंट निर्देशिका को सहेजा जाना चाहिए, और लॉगर्स को यहां अपना डेटा लॉग करना चाहिए, आदि
  3. यह एक बनाता है tf.train.CheckpointSaverHook, जो चेकपॉइंट को बचाने के लिए उपयोग किया जाता है।

इसे काम करने के लिए, tf.train.CheckpointSaverHook और tf.train.ChiefSessionCreator को चेकपॉइंट निर्देशिका और मचान के समान संदर्भों को पारित किया जाना चाहिए। उदाहरण में अपने मानकों के साथ tf.train.MonitoredTrainingSession ऊपर ऊपर 3 घटकों के साथ लागू किया जाना थे, तो यह कुछ इस तरह दिखेगा:

checkpoint_dir = './ckpt_dir' 

scaffold = tf.train.Scaffold() 
saverhook = tf.train.CheckpointSaverHook(
    checkpoint_dir=checkpoint_dir, 
    save_secs=5 
    scaffold=scaffold 
) 
session_creator = tf.train.ChiefSessionCreator(
    scaffold=scaffold, 
    checkpoint_dir=checkpoint_dir 
) 

with tf.train.MonitoredSession(
    session_creator=session_creator, 
    hooks=[saverhook]) as mon_sess: 
     while not mon_sess.should_stop(): 
      mon_sess.run(train_op) 

आदेश एक ट्रेन + पार सत्यापन सत्र करने के लिए, आप बस (ऊपर जबकि पाश में) एक ही ग्राफ के माध्यम से दो अलग सेट पारित करने के लिए, और फिर से चलाने की जरूरत है:

mon_sess.run([train_op, cross_validation_loss]) 

यह सत्यापन के लिए प्रशिक्षण सेट के लिए प्रशिक्षण अनुकूलक, साथ ही validation_loss पैरामीटर चलाता है सेट। यदि आपका ग्राफ सही ढंग से कार्यान्वित किया गया है, तो इसका मतलब है कि ग्राफ केवल प्रशिक्षण सेट पर प्रशिक्षित किया जाएगा, और केवल क्रॉस सत्यापन सेट पर मान्य होगा।

संबंधित मुद्दे