2016-10-23 5 views
6

का उपयोग करके मैं tensorflow में भारित क्रॉस एंट्रॉपी हानि को कैसे कार्यान्वित कर सकता हूं। मैं tensorflow (कैफे से आ रहा हूं) का उपयोग करना शुरू कर रहा हूं, और मैं sparse_softmax_cross_entropy_with_logits हानि का उपयोग कर रहा हूं। फ़ंक्शन एक शॉट एन्कोडिंग के बजाय 0,1,...C-1 जैसे लेबल स्वीकार करता है। अब, मैं कक्षा लेबल के आधार पर भारोत्तोलन का उपयोग करना चाहता हूं; मुझे पता है कि यह मैट्रिक्स गुणा के साथ संभवतः किया जा सकता है यदि मैं softmax_cross_entropy_with_logits (एक हॉट एन्कोडिंग) का उपयोग करता हूं, तो sparse_softmax_cross_entropy_with_logits के साथ ऐसा करने का कोई तरीका है?sparse_softmax_cross_entropy_with_logits

उत्तर

1

कक्षा वजन को लॉग्स द्वारा गुणा किया जाता है, ताकि यह अभी भी sparse_softmax_cross_entropy_with_logits के लिए काम करता है। "टेंसर प्रवाह में कक्षा असंतुलित बाइनरी वर्गीकृत के लिए हानि फ़ंक्शन" के लिए this solution का संदर्भ लें।

एक तरफ ध्यान दें के रूप में, आप सीधे में sparse_softmax_cross_entropy

tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None) 

इस विधि वजन पारित कर सकते हैं क्रोस एंट्रोपी नुकसान का उपयोग कर

tf.nn.sparse_softmax_cross_entropy_with_logits. 

वजन में कमी के लिए एक गुणांक के रूप में कार्य के लिए है। यदि एक स्केलर प्रदान किया जाता है, तो नुकसान को दिए गए मूल्य से आसानी से बढ़ाया जाता है। यदि वजन आकार [बैच_साइज] का टेंसर है, तो हानि भार प्रत्येक संबंधित नमूने पर लागू होते हैं।

+0

मैं अगर वहाँ एक रास्ता एक गर्म लेबल से बचने के लिए था सोच रहा था; दिए गए लिंक में, वहाँ है, क्योंकि अभी भी वजन वेक्टर के साथ एक गर्म लेबल की मैट्रिक्स गुणा करने की आवश्यकता। एक और तरीका लंबाई बैचसाइज के वजन वेक्टर का उपयोग करेगा, लेकिन फिर मुझे प्रत्येक बैच के लिए इस वेक्टर की गणना करना होगा; मैं इसे कैसे परिभाषित कर सकता हूं (चूंकि यह लेबल पर निर्भर करता है) बिना किसी लेबल लेबल मैट्रिक्स की गणना किए? –

+3

मुझे नहीं लगता कि यह जवाब सही है। 'Tf.contrib.losses.sparse_softmax_cross_entropy' में वजन प्रति-नमूना प्रति प्रति नमूना है। – andong777

+0

यह सही है, यह सिर्फ कष्टप्रद है। आप प्रत्येक अद्यतन के लिए एक भार पारित करेंगे और यह वर्तमान अद्यतन में मौजूद विशेष वर्ग पर निर्भर करेगा। तो यदि आपके पास आकार 3 का बैच था और कक्षाएं 1,1,2 थीं। और आप वज़न कक्षा 1 को 50% पर लेना चाहते थे, तो आप इस हानि समारोह का उपयोग करेंगे और वज़न तर्क को मानकों के साथ एक टेंसर पास करेंगे [0.5,0.5,1.0]। यह प्रभावी रूप से आपकी कक्षा का भार देगा ... सुरुचिपूर्ण? नहीं। प्रभावी हाँ। –

2

विशिष्ट रूप से बाइनरी वर्गीकरण के लिए, weighted_cross_entropy_with_logits है, जो भारित सॉफ्टमैक्स क्रॉस एन्ट्रॉपी की गणना करता है।

sparse_softmax_cross_entropy_with_logits एक उच्च कुशल गैर भारित ऑपरेशन के लिए पूंछ है (SparseSoftmaxXentWithLogitsOp जो हुड के नीचे SparseXentEigenImpl का उपयोग करता है देखें), तो यह "प्लगेबल" नहीं है।

मल्टी-क्लास केस में, आपका विकल्प या तो एक हॉट एन्कोडिंग पर स्विच हो या tf.losses.sparse_softmax_cross_entropy हानिकारक फ़ंक्शन को हैकी तरीके से उपयोग करें, जैसा कि पहले से ही सुझाव दिया गया है, जहां आपको वर्तमान बैच में लेबल के आधार पर वजन कम करना होगा ।

2
import tensorflow as tf 
import numpy as np 

np.random.seed(123) 
sess = tf.InteractiveSession() 

# let's say we have the logits and labels of a batch of size 6 with 5 classes 
logits = tf.constant(np.random.randint(0, 10, 30).reshape(6, 5), dtype=tf.float32) 
labels = tf.constant(np.random.randint(0, 5, 6), dtype=tf.int32) 

# specify some class weightings 
class_weights = tf.constant([0.3, 0.1, 0.2, 0.3, 0.1]) 

# specify the weights for each sample in the batch (without having to compute the onehot label matrix) 
weights = tf.gather(class_weights, labels) 

# compute the loss 
tf.losses.sparse_softmax_cross_entropy(labels, logits, weights).eval()