2017-01-04 11 views
8

मैं TensorFlow जानने की कोशिश कर रहा हूँ और में उदाहरण का अध्ययन: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynbटेंसरफ्लो: dataset.train.next_batch कैसे परिभाषित किया गया है?

मैं तो नीचे दिए गए कोड में कुछ सवाल हैं:

for epoch in range(training_epochs): 
    # Loop over all batches 
    for i in range(total_batch): 
     batch_xs, batch_ys = mnist.train.next_batch(batch_size) 
     # Run optimization op (backprop) and cost op (to get loss value) 
     _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs}) 
    # Display logs per epoch step 
    if epoch % display_step == 0: 
     print("Epoch:", '%04d' % (epoch+1), 
       "cost=", "{:.9f}".format(c)) 

mnist के बाद से सिर्फ एक डाटासेट है, mnist.train.next_batch वास्तव में क्या मतलब है? dataset.train.next_batch कैसे परिभाषित किया गया था?

धन्यवाद!

उत्तर

19

mnist ऑब्जेक्ट read_data_sets() function से tf.contrib.learn मॉड्यूल में परिभाषित किया गया है। mnist.train.next_batch(batch_size) विधि here लागू की गई है, और यह दो सरणी का एक टुपल लौटाती है, जहां पहला batch_size एमएनआईएसटी छवियों के बैच का प्रतिनिधित्व करता है, और दूसरा उन छवियों के अनुरूप batch-size लेबल का बैच दर्शाता है।

छवियों आकार [batch_size, 784] (के बाद से वहाँ एक MNIST छवि में 784 पिक्सल कर रहे हैं) के 2-डी NumPy सरणी के रूप में वापस कर रहे हैं, और लेबल आकार [batch_size] के एक 1-डी NumPy सरणी या तो के रूप में वापस कर रहे हैं (यदि read_data_sets() था one_hot=False के साथ बुलाया गया) या आकार के 2-डी NumPy सरणी [batch_size, 10] (यदि read_data_sets()one_hot=True के साथ बुलाया गया था)।

+7

यह उल्लेख के लायक है कि [next_batch] (https://github.com/tensorflow/tensorflow/blob/7c36309c37b04843030664cdc64aca2bb7d6ecaa/tensorflow/contrib/learn/python/learn/datasets/mnist.py#L160) जाने के बाद उदाहरण फेरबदल है उनमें से सभी युग के माध्यम से। आप 'dataSet._index_in_epoch' द्वारा युग में कहां हैं, ट्रैक कर सकते हैं, जैसे' mnist.train._index_in_epoch' –

संबंधित मुद्दे