2016-06-10 12 views
13

मैं एक बहु लेबल समस्या पर काम कर रहा हूं और मैं अपने मॉडल की सटीकता निर्धारित करने की कोशिश कर रहा हूं।टेन्सफोर्लो, बहु लेबल सटीकता गणना

मेरे मॉडल:

NUM_CLASSES = 361 

x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS]) 
y_ = tf.placeholder(tf.float32, [None, NUM_CLASSES]) 

# create the network 
pred = conv_net(x) 

# loss 
cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(pred, y_)) 

# train step 
train_step = tf.train.AdamOptimizer().minimize(cost) 

मैं दो अलग अलग तरीकों
में सटीकता गणना करना चाहते हैं - सभी लेबल है कि सही ढंग से भविष्यवाणी कर रहे हैं% - छवियों की% जहां सभी लेबल सही ढंग से भविष्यवाणी कर रहे हैं

दुर्भाग्यवश मैं केवल उन सभी लेबलों की गणना करने में सक्षम हूं जिनकी भविष्यवाणी की गई है।

मैंने सोचा था कि इस कोड को चित्रों के% की गणना कर सकेगा, जहां सभी लेबल सही ढंग से भविष्यवाणी कर रहे हैं

correct_prediction = tf.equal(tf.round(pred), tf.round(y_)) 

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

और कहा कि सही ढंग से भविष्यवाणी कर रहे हैं

pred_reshape = tf.reshape(pred, [ BATCH_SIZE * NUM_CLASSES, 1 ]) 
y_reshape = tf.reshape(y_, [ BATCH_SIZE * NUM_CLASSES, 1 ]) 

correct_prediction_all = tf.equal(tf.round(pred_reshape), tf.round(y_reshape)) 

accuracy_all = tf.reduce_mean(tf.cast(correct_prediction_all, tf.float32)) 

किसी भी तरह लेबल की जुटना सभी लेबलों का यह कोड% एक छवि से संबंधित खो गया है और मुझे यकीन नहीं है कि क्यों।

उत्तर

18

मेरा मानना ​​है कि अपने कोड में बग में है: correct_prediction = tf.equal(tf.round(pred), tf.round(y_))

predअसुरक्षित लॉग (यानि अंतिम सिग्मोइड के बिना) होना चाहिए।

correct_prediction = tf.equal(tf.round(tf.nn.sigmoid(pred)), tf.round(y_)) 

तो गणना करने के लिए:

  • सटीकता मतलब

    यहाँ आप (दोनों अंतराल [0, 1] में) ताकि आप लिखने की sigmoid(pred) और y_ के उत्पादन की तुलना करना चाहते सब कुछ खत्म हो लेबल:

accuracy1 = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
  • शुद्धता जहां सभी लेबल सही होने की जरूरत:
all_labels_true = tf.reduce_min(tf.cast(correct_prediction), tf.float32), 1) 
accuracy2 = tf.reduce_mean(all_labels_true) 
+0

धन्यवाद एक आकर्षण की तरह काम करता है, मैं सही ढंग से समझ है कि reduce_min एक छवि के सभी लेबल एक साथ पैक करते हैं? – MaMiFreak

+0

बैच के प्रत्येक तत्व के लिए यह न्यूनतम सही_प्रदर्शन लेगा। यह न्यूनतम 1 है यदि सभी तत्व 1 (यानी सभी भविष्यवाणियां सही हैं) और 0 यदि कम से कम एक तत्व गलत है (0 के बराबर) –

+0

मैंने इस दृष्टिकोण की कोशिश की, लेकिन मुझे प्रत्येक युग के लिए समान सटीकता मान मिलता है: 'ट्रेन सटीकता: 0.984375 परीक्षण सटीकता: 0.984375'। कोई विचार क्यों ऐसा होता है? https://stackoverflow.com/questions/49210520/cant-get-correct-acccuracy-for-multi-label-prediction – Peterdk

संबंधित मुद्दे