2016-07-02 5 views
24

क्या कोई फ़ंक्शन कॉल या कोई अन्य तरीका है जो टेंसफोर्लो ग्राफ में पैरामीटर की कुल संख्या को गिनने के लिए है?एक tensorflow मॉडल में ट्रेन करने योग्य मानकों की कुल संख्या कैसे गिनें?

तक मापदंडों मेरा मतलब है: trainable चर के एक एन मंद वेक्टर एन मानकों है, एक NxM मैट्रिक्स N*M मानकों, आदि है तो अनिवार्य रूप से मैं एक में सभी trainable चर के आकार आयामों के उत्पाद योग करने के लिए करना चाहते हैं tensorflow सत्र।

+0

अपने प्रश्न का विवरण और शीर्षक से मेल नहीं खाते (जब तक कि मैं ग्राफ और मॉडल की शब्दावली को भ्रमित नहीं कर रहा हूं)। प्रश्न में आप एक ग्राफ और शीर्षक के बारे में पूछते हैं जो आप मॉडल के बारे में पूछते हैं। क्या होगा यदि आपके पास दो अलग-अलग मॉडल हों? मैं इस सवाल पर स्पष्ट करने का सुझाव दूंगा। –

उत्तर

35

tf.trainable_variables() में प्रत्येक चर के आकार पर लूप।

total_parameters = 0 
for variable in tf.trainable_variables(): 
    # shape is an array of tf.Dimension 
    shape = variable.get_shape() 
    print(shape) 
    print(len(shape)) 
    variable_parameters = 1 
    for dim in shape: 
     print(dim) 
     variable_parameters *= dim.value 
    print(variable_parameters) 
    total_parameters += variable_parameters 
print(total_parameters) 
+2

यदि आपके पास एक से अधिक मॉडल हैं, तो कैसे tf.trainable_variables() 'पता है कि किस का उपयोग करना है? –

+2

tf.trainable_variables() वर्तमान ग्राफ में मौजूद ट्रेनबल के रूप में चिह्नित सभी चर लौटाता है। यदि वर्तमान ग्राफ में आपके पास एक से अधिक मॉडल हैं, तो आपको उनके नामों का उपयोग करके चर को मैन्युअल रूप से फ़िल्टर करना होगा। Somethink जैसे variable.name.strartswith ("model2"): ... – nessuno

+0

यह समाधान मुझे त्रुटि देता है "अपवाद हुआ: 'int' ऑब्जेक्ट को str implicitly रूपांतरित नहीं कर सकता"। आपको नीचे दिए गए उत्तर में सुझाए गए अनुसार 'int' को स्पष्ट रूप से 'int' डालना होगा (जो मैं सही उत्तर मानता हूं) –

6

सुनिश्चित नहीं है कि वास्तव में दिया गया उत्तर वास्तव में चलता है (मुझे लगता है कि आपको मंद वस्तु को काम करने के लिए एक int में परिवर्तित करने की आवश्यकता है)।

def count_number_trainable_params(): 
    ''' 
    Counts the number of trainable variables. 
    ''' 
    tot_nb_params = 0 
    for trainable_variable in tf.trainable_variables(): 
     shape = trainable_variable.get_shape() # e.g [D,F] or [W,H,C] 
     current_nb_params = get_nb_params_shape(shape) 
     tot_nb_params = tot_nb_params + current_nb_params 
    return tot_nb_params 

def get_nb_params_shape(shape): 
    ''' 
    Computes the total number of params for a given shap. 
    Works for any number of shapes etc [D,F] or [W,H,C] computes D*F and W*H*C. 
    ''' 
    nb_params = 1 
    for dim in shape: 
     nb_params = nb_params*int(dim) 
    return nb_params 
+0

उत्तर काम करता है (r0.11.0)। तुम्हारा प्लग एन प्ले है :) –

+0

@ एफ 4। ऐसा लगता है कि इसके साथ एक बग है क्योंकि 'y' का उपयोग नहीं किया जा रहा है। –

+0

@ चार्लीपार्कर मैंने इसे कुछ सेकंड पहले तय किया;) –

4

दो मौजूदा जवाब अच्छे हैं आप अपने आप को मानकों की संख्या परिकलित में देख रहे हैं: है यहाँ एक है कि काम करता है और तुम सिर्फ कार्यों पेस्ट और उन्हें फोन कॉपी कर सकते हैं (बहुत कुछ टिप्पणियां जोड़ी) । यदि आपका प्रश्न "टेंसरफ्लो मॉडल को प्रोफाइल करने का एक आसान तरीका है" के आधार पर और अधिक था, तो मैं अत्यधिक tfprof में देखने की अनुशंसा करता हूं। यह पैरामीटर की संख्या की गणना सहित, आपके मॉडल प्रोफाइल करता है।

1

मैं अपने बराबर लेकिन छोटे कार्यान्वयन में फेंक देंगे:

def count_params(): 
    "print number of trainable variables" 
    size = lambda v: reduce(lambda x, y: x*y, v.get_shape().as_list()) 
    n = sum(size(v) for v in tf.trainable_variables()) 
    print "Model size: %dK" % (n/1000,) 
12

मेरे पास एक और भी छोटा संस्करण, numpy का उपयोग कर का उपयोग करते हुए एक लाइन समाधान:

np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) 
+0

, v में size_as_list() फ़ंक्शन नहीं है लेकिन केवल get_shape() फ़ंक्शन – mustafa

+0

मुझे लगता है कि पुराने संस्करणों में नहीं है .shape लेकिन get_shape()। मेरा जवाब अपडेट किया गया। वैसे भी, मैंने v.shape.as_list() लिखा है और v.shape_as_list() नहीं। –

+5

'np.sum ([np.prod (v.shape) vf में tf.trainable_variables()]) 'tensorFlow 1.2 में भी काम करता है –

संबंधित मुद्दे