7

का उपयोग करते समय tensorflow में प्रशिक्षण के दौरान नेटवर्क का परीक्षण कैसे करें मैं नीचे दिए गए कोड का उपयोग करके अपने नेटवर्क पर अपने प्रशिक्षण उदाहरणों को खिलाने के लिए एक कतार का उपयोग कर रहा हूं, और यह ठीक से काम करता है।कतार

हालांकि, मैं कुछ परीक्षण डेटा हर n पुनरावृत्तियों को खिलाने के लिए सक्षम होने के लिए चाहते हैं, लेकिन मैं वास्तव में नहीं जानता कि मैं कैसे आगे बढ़ना चाहिए। क्या मुझे क्षणिक रूप से कतार बंद करनी चाहिए और परीक्षण डेटा को मैन्युअल रूप से फ़ीड करना चाहिए? क्या मुझे डेटा परीक्षण के लिए एक और कतार बनाना चाहिए?

संपादित करें: यह ऐसा करने का सही तरीके से एक अलग फाइल बनाने के लिए है है, eval.py कहते हैं, कि लगातार पिछले चौकी पढ़ता है और नेटवर्क का मूल्यांकन करता है? इस प्रकार वे इसे CIFAR10 उदाहरण में करते हैं।

batch = 128 # size of the batch 
x = tf.placeholder("float32", [None, n_steps, n_input]) 
y = tf.placeholder("float32", [None, n_classes]) 

queue = tf.RandomShuffleQueue(capacity=4*batch, 
         min_after_dequeue=3*batch, 
         dtypes=[tf.float32, tf.float32], 
         shapes=[[n_steps, n_input], [n_classes]]) 
enqueue_op = queue.enqueue_many([x, y]) 
X_batch, Y_batch = queue.dequeue_many(batch) 

sess = tf.Session() 

def load_and_enqueue(data): 
    while True: 
     X, Y = data.get_next_batch(batch) 
     sess.run(enqueue_op, feed_dict={x: X, y: Y}) 

train_thread = threading.Thread(target=load_and_enqueue, args=(data)) 
train_thread.daemon = True 
train_thread.start() 

for _ in xrange(max_iter): 
    sess.run(train_op) 
+0

इसके लिए कुछ अच्छे उच्च स्तरीय कार्य हैं जिन्हें हाल ही में [github repository] में जोड़ा गया है (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/ अजगर/स्लिम/evaluation.py)। वे एक अलग निष्पादन योग्य के साथ चल रहे मूल्यांकन पर आधारित हैं जो प्रशिक्षण द्वारा बनाई गई चेकपॉइंट फ़ाइलों को पढ़ता है। – user728291

+0

@ user728291, क्या एक ही स्क्रिप्ट के भीतर ऐसा करने के लिए कोई उदाहरण है? ऐसा लगता है कि कैफे जैसे अन्य टूल्स इसे पसंद करते हैं। –

+0

दो कतारों (या एक कतार और खिलाया गया प्लेसहोल्डर) का उपयोग करने के बारे में, और यह निर्धारित करने के लिए 'tf.where' का उपयोग करें कि इन दो स्रोतों में से कौन सा नेटवर्क नेटवर्क को खिलाने के लिए उपयोग किया जाता है? –

उत्तर

-1

आप अपने कोड में eval_op जोड़ सकते हैं, और फिर प्रत्येक एन (एन = 1000) पुनरावृत्तियों में मूल्यांकन कर सकते हैं। एक उदाहरण के रूप में अनुवर्ती है:

for niter in xrange(max_iter): 
    sess.run(train_op) 
    if niter % 1000 == 0: 
     sess.run(eval_op) 
1

आप एक और परीक्षण कतार और इस तरह परीक्षण मॉडल के रूप में प्रशिक्षण मॉडल की एक प्रति bulid कर सकते हैं:

trainX, trainY = Queue0(batchSize, ...)... 
testX, testY= Queue1(batchSize, ...)... 
modelTrain = inference(trainX, trainY, ...) 
# reuse variables 
modelTest = inference(testX, testY, ...) 
sess.run(train_op,loss_op,trainX,trainY) 
sess.run(test_op,testX,testY) 

इस तरह अधिक स्मृति का उपभोग कर सकते क्योंकि 2 मॉडल प्रारंभ कर रहे हैं, बेहतर समाधान देखने की उम्मीद