2016-03-18 12 views
9

मैं this link में ट्यूटोरियल का अनुसरण कर रहा हूं और मॉडल (नीचे) के लिए मूल्यांकन विधि बदलने की कोशिश कर रहा हूं। मैं एक शीर्ष -5 मूल्यांकन प्राप्त करना चाहते हैं और मैं निम्नलिखित कोड का उपयोग करने के कोशिश कर रहा हूँ:TensorFlow in_top_k मूल्यांकन इनपुट argumants

topFiver=tf.nn.in_top_k(y, y_, 5, name=None) 

बहरहाल, यह निम्न त्रुटि पैदावार:

File "AlexNet.py", line 111, in <module> 
    topFiver = tf.nn.in_top_k(pred, y, 5, name=None) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 346, in in_top_k 
    targets=targets, k=k, name=name) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 486, in apply_op 
    _Attr(op_def, input_arg.type_attr)) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 59, in _SatisfiesTypeConstraint 
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) 
TypeError: DataType float32 for attr 'T' not in list of allowed values: int32, int64 

जहां तक ​​मेरा बता सकते हैं, समस्या यह है कि tf.nn.in_top_k() केवल tf.int32 या tf.int64 डेटा के लिए काम करता है, लेकिन मेरे डेटा tf.float32 प्रारूप में है। क्या इसके समाधान की कोई युक्ति है?

उत्तर

19

targetstf.nn.in_top_k(predictions, targets, k) पर तर्क कक्षा आईडी का वेक्टर होना चाहिए (यानी predictions मैट्रिक्स में कॉलम के सूचकांक) होना चाहिए। इसका मतलब है कि यह केवल एकल श्रेणी वर्गीकरण समस्याओं के लिए काम करता है।

तो आपकी समस्या को एक एकल वर्ग समस्या है, तो मुझे लगता है कि अपने y_ टेन्सर अपने उदाहरण के लिए सच लेबल की एक एक गर्म एन्कोडिंग (उदाहरण के लिए है, क्योंकि आप भी उन्हें tf.nn.softmax_cross_entropy_with_logits() की तरह एक सेशन के लिए गुजरती हैं। कि में मामला है, आप दो विकल्प हैं:।

  • तो लेबल मूल पूर्णांक लेबल के रूप में संग्रहीत किया गया है, उनमें tf.nn.in_top_k() उन्हें सीधे एक गर्म में रूपांतरित किए बिना पारित (इसके अलावा, अपने नुकसान समारोह के रूप में उपयोग करने पर विचार tf.nn.sparse_softmax_cross_entropy_with_logits() क्योंकि यह हो सकता है अधिक कुशल बनें।)
  • यदि लेबल मूल रूप से संग्रहीत किए गए थे एक गर्म प्रारूप, आप tf.argmax() का उपयोग कर पूर्णांकों के लिए उन्हें परिवर्तित कर सकते हैं:

    labels = tf.argmax(y_, 1) 
    topFiver = tf.nn.in_top_k(y, labels, 5)