2016-01-26 13 views
16

से सभी चर कैसे प्राप्त करें मेरे पास एक सेटअप है जहां मुझे मुख्य प्रारंभिकरण के बाद एक LSTM प्रारंभ करने की आवश्यकता है जो tf.initialize_all_variables() का उपयोग करता है। अर्थात। मैं tf.initialize_variables([var_list])Tensorflow: rnn_cell.BasicLSTM और rnn_cell.MultiRNNCell

कॉल करना चाहते हैं वहाँ दोनों के लिए सभी आंतरिक trainable चर इकट्ठा करने के लिए जिस तरह से है: कि मैं बस प्रारंभ कर सकते हैं

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

तो इन पैरामीटर?

मुख्य कारण मुझे यह चाहिए क्योंकि मैं पहले से कुछ प्रशिक्षित मूल्यों को फिर से शुरू नहीं करना चाहता हूं।

उत्तर

17

आपकी समस्या का समाधान करने का सबसे आसान तरीका चरणीय दायरे का उपयोग करना है। किसी दायरे के भीतर चर के नामों को इसके नाम से उपसर्ग किया जाएगा।

cell = rnn_cell.BasicLSTMCell(num_nodes) 

with tf.variable_scope("LSTM") as vs: 
    # Execute the LSTM cell here in any way, for example: 
    for i in range(num_steps): 
    output[i], state = cell(input_data[i], state) 

    # Retrieve just the LSTM variables. 
    lstm_variables = [v for v in tf.all_variables() 
        if v.name.startswith(vs.name)] 

# [..] 
# Initialize the LSTM variables. 
tf.initialize_variables(lstm_variables) 

यह MultiRNNCell साथ उसी तरह काम करेगा: यहाँ एक लघु स्निपेट है।

संपादित करें: tf.all_variables()

+0

यह सही है, धन्यवाद। यह नहीं पता था कि 'tf.trainable_variables()' गुंजाइश का सम्मान करता है, लेकिन मुझे लगता है कि हिंडसाइट में यह समझ में आता है! – bge0

+1

'tf.trainable_variables() 'के बजाय' tf.all_variables()' को जोड़ना बेहतर विकल्प होगा। मुख्य रूप से क्योंकि ऐसी चीजें हैं जो ऑप्टिमाइज़र की तरह हैं जिनके पास ट्रेन करने योग्य चर नहीं हैं, हालांकि उन्हें अभी भी प्रारंभ करने की आवश्यकता होगी। – bge0

+1

धन्यवाद, आप सही हैं। मैंने कोड अपडेट किया। –

11

को tf.trainable_variables बदल तुम भी tf.get_collection() उपयोग कर सकते हैं:

cell = rnn_cell.BasicLSTMCell(num_nodes) 
with tf.variable_scope("LSTM") as vs: 
    # Execute the LSTM cell here in any way, for example: 
    for i in range(num_steps): 
    output[i], state = cell(input_data[i], state) 

    lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name) 

(आंशिक रूप से राफाल के जवाब से नकल)

ध्यान दें कि अंतिम पंक्ति में सूची समझ के बराबर है राफल का कोड

असल में, tensorflow चर के वैश्विक संग्रह को संग्रहीत करता है, जिसे tf.all_variables() या tf.get_collection(tf.GraphKeys.VARIABLES) द्वारा लाया जा सकता है। यदि आप tf.get_collection() फ़ंक्शन में scope (स्कोप नाम) निर्दिष्ट करते हैं, तो आप संग्रह में केवल तनख्वाह (इस मामले में चर) प्राप्त करते हैं जिनके स्कोप निर्दिष्ट दायरे में हैं।

संपादित करें: आप केवल ट्रेन करने योग्य चर प्राप्त करने के लिए tf.GraphKeys.TRAINABLE_VARIABLES का उपयोग भी कर सकते हैं। लेकिन चूंकि वेनिला बेसिकलस्टमेलेल किसी भी गैर-ट्रेन करने योग्य चर को प्रारंभ नहीं करता है, दोनों कार्यशील रूप से समकक्ष होंगे। डिफ़ॉल्ट ग्राफ संग्रह की पूरी सूची के लिए, this देखें।

+0

में अन्य सभी प्रकार के चर के अनुरूप हो सके। राफल के समाधान की तुलना में यह बेहतर तरीका है :-) –

+1

जैसा कि मैंने ऊपर टिप्पणी की है, आपको शायद बेहतर उपयोग करना चाहिए। tf.get_collection (..., स्कोप = बनामनाम + "/") 'क्योंकि "एलएसटीएम 2" नामक एक और गुंजाइश हो सकती है। – Albert