लघु जवाब:
- आप अनुकूलक आप mon_sess.run को पारित करने के लिए वैश्विक कदम पारित करने के लिए की जरूरत है। यह सहेजे गए चेकपॉइंट्स को सहेजने और पुनर्प्राप्त करने के लिए दोनों संभव बनाता है।
- एक ही मॉनीटरर्ड ट्रेनिंग सत्र के माध्यम से एक साथ प्रशिक्षण + क्रॉस सत्यापन सत्र चलाने के लिए संभव है। सबसे पहले, आपको एक ही ग्राफ की अलग-अलग धाराओं के माध्यम से प्रशिक्षण बैचों और क्रॉस सत्यापन बैचों से गुजरना होगा (मुझे सलाह है कि यह कैसे करें इसे करने के लिए this guide देखें)। दूसरा, आपको - mon_sess.run() - प्रशिक्षण स्ट्रीम के लिए एक अनुकूलक पास करना होगा, साथ ही क्रॉस सत्यापन स्ट्रीम के नुकसान (/ पैरामीटर जिसे आप ट्रैक करना चाहते हैं) के पैरामीटर को पास करना होगा। यदि आप प्रशिक्षण से अलग से एक परीक्षण सत्र चलाने के लिए चाहते हैं, तो बस ग्राफ के माध्यम से केवल परीक्षण सेट चलाएं, और ग्राफ के माध्यम से केवल test_loss (/ अन्य पैरामीटर जिन्हें आप ट्रैक करना चाहते हैं) चलाएं। यह कैसे किया जाता है के बारे में अधिक जानकारी के लिए, नीचे देखें।
लांग जवाब:
मैं (मेरा उत्तर को अपडेट करते ही मैं अपने आप को क्या tf.train.MonitoredSession साथ किया जा सकता का एक बेहतर दृश्य मिल जाएगा tf.train.MonitoredTrainingSession बस के एक विशेष संस्करण पैदा कर रही है tf.train.MonitoredSession, जैसा कि source code में देखा जा सकता है)।
निम्नलिखित उदाहरण कोड दिखा रहा है कि आप प्रत्येक 5 सेकंड में'/ckpt_dir 'पर चेकपॉइंट कैसे सहेज सकते हैं।जब बाधित, यह अपने सहेजा गया अंतिम चौकी पर पुन: प्रारंभ होगा:
- tf.train.MonitoredTrainingSession एक tf.train.Scaffold बनाता है:
def train(inputs, labels_onehot, global_step):
out = tf.contrib.layers.fully_connected(
inputs,
num_outputs=10,
activation_fn=tf.nn.sigmoid)
loss = tf.reduce_mean(
tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=out,
labels=labels_onehot), axis=1))
train_op = opt.minimize(loss, global_step=global_step)
return train_op
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
inputs = ...
labels_onehot = ...
train_op = train(inputs, labels_onehot, global_step)
with tf.train.MonitoredTrainingSession(
checkpoint_dir='./ckpt_dir',
save_checkpoint_secs=5,
hooks=[ ... ] # Choose your hooks
) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
क्या आदेश प्राप्त करने के लिए यह वास्तव में तीन बातें है में MonitoredTrainingSession में हो रहा है वस्तु, जो वेब में एक मकड़ी की तरह काम करता है; यह मॉडल को प्रशिक्षित करने, सहेजने और लोड करने के लिए आवश्यक टुकड़ों को इकट्ठा करता है।
- यह tf.train.ChiefSessionCreator ऑब्जेक्ट बनाता है। इसका मेरा ज्ञान सीमित है, लेकिन इसकी समझ से, इसका उपयोग तब किया जाता है जब आपके टीएफ एल्गोरिदम कई सर्वरों में फैलता है। मेरा लेना यह है कि यह कंप्यूटर को फ़ाइल चला रहा है कि यह मुख्य कंप्यूटर है, और यह है कि चेकपॉइंट निर्देशिका को सहेजा जाना चाहिए, और लॉगर्स को यहां अपना डेटा लॉग करना चाहिए, आदि
- यह एक बनाता है tf.train.CheckpointSaverHook, जो चेकपॉइंट को बचाने के लिए उपयोग किया जाता है।
इसे काम करने के लिए, tf.train.CheckpointSaverHook और tf.train.ChiefSessionCreator को चेकपॉइंट निर्देशिका और मचान के समान संदर्भों को पारित किया जाना चाहिए। उदाहरण में अपने मानकों के साथ tf.train.MonitoredTrainingSession ऊपर ऊपर 3 घटकों के साथ लागू किया जाना थे, तो यह कुछ इस तरह दिखेगा:
checkpoint_dir = './ckpt_dir'
scaffold = tf.train.Scaffold()
saverhook = tf.train.CheckpointSaverHook(
checkpoint_dir=checkpoint_dir,
save_secs=5
scaffold=scaffold
)
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir
)
with tf.train.MonitoredSession(
session_creator=session_creator,
hooks=[saverhook]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
आदेश एक ट्रेन + पार सत्यापन सत्र करने के लिए, आप बस (ऊपर जबकि पाश में) एक ही ग्राफ के माध्यम से दो अलग सेट पारित करने के लिए, और फिर से चलाने की जरूरत है:
mon_sess.run([train_op, cross_validation_loss])
यह सत्यापन के लिए प्रशिक्षण सेट के लिए प्रशिक्षण अनुकूलक, साथ ही validation_loss पैरामीटर चलाता है सेट। यदि आपका ग्राफ सही ढंग से कार्यान्वित किया गया है, तो इसका मतलब है कि ग्राफ केवल प्रशिक्षण सेट पर प्रशिक्षित किया जाएगा, और केवल क्रॉस सत्यापन सेट पर मान्य होगा।
स्रोत
2018-02-08 22:15:04