2015-12-28 6 views
8

बहाल मैं जवाब के सुझाव को लागू करने की कोशिश कर रहा हूँ: Tensorflow: how to save/restore a model?tensorflow: बचत और सत्र

मैंने किसी चीज़ जो एक sklearn शैली में एक tensorflow मॉडल लपेटता है।

import tensorflow as tf 
class tflasso(): 
    saver = tf.train.Saver() 
    def __init__(self, 
       learning_rate = 2e-2, 
       training_epochs = 5000, 
        display_step = 50, 
        BATCH_SIZE = 100, 
        ALPHA = 1e-5, 
        checkpoint_dir = "./", 
      ): 
     ... 

    def _create_network(self): 
     ... 


    def _load_(self, sess, checkpoint_dir = None): 
     if checkpoint_dir: 
      self.checkpoint_dir = checkpoint_dir 

     print("loading a session") 
     ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) 
     if ckpt and ckpt.model_checkpoint_path: 
      self.saver.restore(sess, ckpt.model_checkpoint_path) 
     else: 
      raise Exception("no checkpoint found") 
     return 

    def fit(self, train_X, train_Y , load = True): 
     self.X = train_X 
     self.xlen = train_X.shape[1] 
     # n_samples = y.shape[0] 

     self._create_network() 
     tot_loss = self._create_loss() 
     optimizer = tf.train.AdagradOptimizer(self.learning_rate).minimize(tot_loss) 

     # Initializing the variables 
     init = tf.initialize_all_variables() 
     " training per se" 
     getb = batchgen(self.BATCH_SIZE) 

     yvar = train_Y.var() 
     print(yvar) 
     # Launch the graph 
     NUM_CORES = 3 # Choose how many cores to use. 
     sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES, 
                  intra_op_parallelism_threads=NUM_CORES) 
     with tf.Session(config= sess_config) as sess: 
      sess.run(init) 
      if load: 
       self._load_(sess) 
      # Fit all training data 
      for epoch in range(self.training_epochs): 
       for (_x_, _y_) in getb(train_X, train_Y): 
        _y_ = np.reshape(_y_, [-1, 1]) 
        sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_}) 
       # Display logs per epoch step 
       if (1+epoch) % self.display_step == 0: 
        cost = sess.run(tot_loss, 
          feed_dict={ self.vars.xx: train_X, 
            self.vars.yy: np.reshape(train_Y, [-1, 1])}) 
        rsq = 1 - cost/yvar 
        logstr = "Epoch: {:4d}\tcost = {:.4f}\tR^2 = {:.4f}".format((epoch+1), cost, rsq) 
        print(logstr) 
        self.saver.save(sess, self.checkpoint_dir + 'model.ckpt', 
         global_step= 1+ epoch) 

      print("Optimization Finished!") 
     return self 

जब मैं चलाएँ:

tfl = tflasso() 
tfl.fit(train_X, train_Y , load = False) 

मैं आउटपुट प्राप्त:

Epoch: 50 cost = 38.4705 R^2 = -1.2036 
    b1: 0.118122 
Epoch: 100 cost = 26.4506 R^2 = -0.5151 
    b1: 0.133597 
Epoch: 150 cost = 22.4330 R^2 = -0.2850 
    b1: 0.142261 
Epoch: 200 cost = 20.0361 R^2 = -0.1477 
    b1: 0.147998 

हालांकि, जब मैं मानकों को ठीक करने की कोशिश (यहां तक ​​कि वस्तु की हत्या के बिना): tfl.fit(train_X, train_Y , load = True)

मुझे अजीब परिणाम मिलते हैं। सबसे पहले, भारित मूल्य सहेजे गए किसी से मेल नहीं खाता है।

loading a session 
loaded b1: 0.1   <------- Loaded another value than saved 
Epoch: 50 cost = 30.8483 R^2 = -0.7670 
    b1: 0.137484 

लोड करने का सही तरीका क्या है, और शायद पहले सहेजे गए चर का निरीक्षण करें?

+0

tensorflow प्रलेखन बहुत बुनियादी उदाहरण से रहित है, आप उदाहरण फ़ोल्डर में खुदाई और ज्यादातर अपने स्वयं के – diffeomorphism

उत्तर

10

टी एल; डॉ: आप केवल एक बार इस वर्ग rework करने के लिए इतना है कि self.create_network() कहा जाता है (मैं) की कोशिश करनी चाहिए, और (ii) tf.train.Saver() से पहले का निर्माण किया है।

यहां दो सूक्ष्म मुद्दे हैं, जो कोड संरचना के कारण हैं, और tf.train.Saver constructor का डिफ़ॉल्ट व्यवहार है। जब आप कोई तर्क नहीं देते हैं (जैसे कि आपके कोड में), यह आपके प्रोग्राम में चर के वर्तमान सेट को एकत्र करता है, और उन्हें सहेजने और पुनर्स्थापित करने के लिए ग्राफ को ऑप्स जोड़ता है। आपके कोड में, जब आप tflasso() पर कॉल करते हैं, तो यह एक सेवर का निर्माण करेगा, और कोई चर नहीं होगा (क्योंकि create_network() अभी तक नहीं कहा गया है)। नतीजतन, चेकपॉइंट खाली होना चाहिए।

दूसरा मुद्दा यह है कि — डिफ़ॉल्ट रूप से — सहेजे गए चेकपॉइंट का प्रारूप name property of a variable से वर्तमान मूल्य पर एक मानचित्र है। आप ही नाम के दो चर बनाते हैं तो वे TensorFlow द्वारा स्वचालित रूप से "uniquified" हो जाएगा:

v = tf.Variable(..., name="weights") 
assert v.name == "weights" 
w = tf.Variable(..., name="weights") 
assert v.name == "weights_1" # The "_1" is added by TensorFlow. 

इस का परिणाम जब आप tfl.fit() को दूसरी कॉल में self.create_network() कहते हैं, चर सब होगा, है चेकपॉइंट — में संग्रहीत नामों से अलग-अलग नाम या नेटवर्क के बाद सेवर का निर्माण किया गया होता। (आप सेवर निर्माता के लिए एक नाम- Variable शब्दकोश पारित करके इस व्यवहार से बच सकते हैं, लेकिन यह आम तौर पर काफी अजीब है।)

दो मुख्य समाधान हैं:

  1. tflasso.fit() की प्रत्येक कॉल में, बनाने के नया मॉडल tf.Graph को परिभाषित करके, फिर उस ग्राफ में नेटवर्क बनाने और tf.train.Saver बनाने के द्वारा।

  2. अनुशंसित नेटवर्क है, तो tflasso निर्माता में tf.train.Saver बनाएँ, और tflasso.fit() की प्रत्येक कॉल पर इस ग्राफ का पुन: उपयोग।ध्यान दें कि आपको चीजों को पुनर्गठित करने के लिए कुछ और काम करने की आवश्यकता हो सकती है (विशेष रूप से, मुझे यकीन नहीं है कि आप self.X और self.xlen के साथ क्या करते हैं) लेकिन placeholders और इसे खिलाकर इसे प्राप्त करना संभव होना चाहिए।

+0

धन्यवाद पर इसके बारे में समझ बनाने के लिए है! 'X'' के इनपुट आकार को सेट करने के लिए 'xlen' का उपयोग' self._create_network() 'के भीतर किया जाता है (प्लेसहोल्डर इनिट:' self.vars.xx = tf.placeholder (" float ", shape = [none, self.xlen ]) ')। आप जो कहते हैं उससे, पसंदीदा तरीका 'xlen' को प्रारंभकर्ता को पास करना है। –

+0

ऑब्जेक्ट के पुन: प्रारंभ करने पर यूनिकिफायर/स्पष्ट पुराने टीएफ चर को रीसेट करने का कोई तरीका है? –

+1

ऐसा करने के लिए आपको एक नया 'tf.Graph' बनाने की आवश्यकता है और इसे (i) नेटवर्क बनाने से पहले इसे डिफ़ॉल्ट बनाएं और (ii)' सेवर 'बनाएं। यदि आप tf.Graph() के रूप में 'tflasso.fit()' के शरीर को asfde() के रूप में लपेटते हैं, as_default(): 'ब्लॉक, और उस ब्लॉक के अंदर' सेवर 'निर्माण को स्थानांतरित करें, तो नाम प्रत्येक बार समान होना चाहिए कॉल करें 'फिट() '। – mrry