2016-03-27 21 views
7

पर ग्रेडिएंट प्राप्त करते समय धीमी कार्यक्षमता मैं टेंसरफ्लो के साथ एक साधारण मल्टीलायर परसेप्ट्रॉन का निर्माण कर रहा हूं, और मुझे तंत्रिका नेटवर्क के इनपुट में हानि के ग्रेडियंट (या त्रुटि सिग्नल) को भी प्राप्त करने की आवश्यकता है।टेंसरफ्लो:

यहाँ मेरी कोड है, जो काम करता है:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y)) 
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost) 
... 
for i in range(epochs): 
    .... 
    for batch in batches: 
     ... 
     sess.run(optimizer, feed_dict=feed_dict) 
     grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0] 

आखिरी पंक्ति (grads_wrt_input...) के बिना (प्रशिक्षण पाश शामिल करने के लिए संपादित), यह वास्तव में एक CUDA मशीन पर तेज़ चलता है। हालांकि, tf.gradients() दस गुना या अधिक से प्रदर्शन को कम कर देता है।

मुझे याद है कि नोड्स पर त्रुटि सिग्नल बैकप्रोपैगेशन एल्गोरिदम में मध्यवर्ती मानों के रूप में गणना की जाती हैं, और मैंने जावा लाइब्रेरी DeepLearning4j का उपयोग करके इसे सफलतापूर्वक किया है। मैं इस धारणा के तहत भी था कि यह पहले से ही optimizer द्वारा निर्मित गणना ग्राफ में थोड़ा सा संशोधन होगा।

इसे कैसे तेजी से बनाया जा सकता है, या नुकसान के ग्रेडियेंट की गणना करने का कोई और तरीका है w.r.t. इनपुट?

+0

क्या आप सचमुच प्रशिक्षण लूप में 'tf.gradients()' को कॉल कर रहे हैं? यदि ऐसा है तो मुझे संदेह है कि ओवरहेड बैकप्रॉप ग्राफ बनाने से हर बार जब आप इसे कॉल करते हैं? – mrry

+0

मैंने स्पष्टता के लिए प्रशिक्षण लूप कोड शामिल किया है; हाँ मैं प्रशिक्षण लूप में 'tf.gradients()' को बुला रहा हूं। कार्यक्रम धीरे-धीरे धीमा हो जाता है। इस इमारत को ऊपर की ओर रोकने के लिए मुझे क्या करना चाहिए? –

+0

एक बार ग्राडेंट्स के लिए गणना ग्राफ बनाने के लिए लूप के बाहर tf.gradients को कॉल करें। इसके अलावा आप compute_gradients –

उत्तर

10

tf.gradients() फ़ंक्शन हर बार इसे एक नया बैकप्रोपैगेशन ग्राफ़ बनाता है, इसलिए मंदी का कारण यह है कि टेंसरफ्लो को लूप के प्रत्येक पुनरावृत्ति पर एक नया ग्राफ पार्स करना पड़ता है। (यह आश्चर्यजनक रूप से महंगा हो सकता है: TensorFlow के वर्तमान संस्करण ही ग्राफ बार की एक बड़ी संख्या को क्रियान्वित करने के लिए अनुकूलित है।)

सौभाग्य से समाधान आसान है: बस ढ़ाल की गणना एक बार, पाश के बाहर। इस प्रकार आप अपने कोड पुनर्गठन कर सकते हैं:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y)) 
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost) 
grads_wrt_input_tensor = tf.gradients(cost, self.x)[0] 
# ... 
for i in range(epochs): 
    # ... 
    for batch in batches: 
     # ... 
     _, grads_wrt_input = sess.run([optimizer, grads_wrt_input_tensor], 
             feed_dict=feed_dict) 

ध्यान दें कि, प्रदर्शन के लिए, मैं भी दो sess.run() कॉल संयुक्त। यह सुनिश्चित करता है कि आगे के प्रचार, और बैकप्रोपैगेशन के अधिकांश का पुन: उपयोग किया जाएगा।


के रूप में एक अलग रूप में, एक टिप प्रदर्शन कीड़े खोजने के लिए इस अपने प्रशिक्षण पाश शुरू करने से पहले tf.get_default_graph().finalize() कॉल करने के लिए है की तरह। यदि आप अनजाने में ग्राफ में किसी भी नोड्स को जोड़ते हैं, तो यह अपवाद उठाएगा, जिससे इन बग के कारण का पता लगाना आसान हो जाता है।

+0

का उपयोग करके अपने अनुकूलक के लिए बनाए गए ढाल ग्राफ का पुन: उपयोग कर सकते हैं, धन्यवाद, धन्यवाद! मेरा कार्यक्रम अब तेज़ है। बीटीडब्ल्यू, मुझे लगता है कि 'sess.run() 'कॉल में सूची में' grads_wrt_input' 'grads_wrt_input_tensor' होना चाहिए। –

+0

अच्छा बिंदु: बस एक फिक्स के साथ संपादित! – mrry

संबंधित मुद्दे