2017-01-26 43 views
6

में RNN की विकलता गणना करने के लिए कैसे मैं शब्द Word RNNtensorflow

RNN की विकलता की गणना कैसे करें टेन्सर प्रवाह की RNN implmentation चल रहा हूँ।

के बाद प्रशिक्षण में कोड है कि प्रत्येक युग में प्रशिक्षण नुकसान और अन्य बातों से पता चलता है:

for e in range(model.epoch_pointer.eval(), args.num_epochs): 
     sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) 
     data_loader.reset_batch_pointer() 
     state = sess.run(model.initial_state) 
     speed = 0 
     if args.init_from is None: 
      assign_op = model.batch_pointer.assign(0) 
      sess.run(assign_op) 
      assign_op = model.epoch_pointer.assign(e) 
      sess.run(assign_op) 
     if args.init_from is not None: 
      data_loader.pointer = model.batch_pointer.eval() 
      args.init_from = None 
     for b in range(data_loader.pointer, data_loader.num_batches): 
      start = time.time() 
      x, y = data_loader.next_batch() 
      feed = {model.input_data: x, model.targets: y, model.initial_state: state, 
        model.batch_time: speed} 
      summary, train_loss, state, _, _ = sess.run([merged, model.cost, model.final_state, 
                 model.train_op, model.inc_batch_pointer_op], feed) 
      train_writer.add_summary(summary, e * data_loader.num_batches + b) 
      speed = time.time() - start 
      if (e * data_loader.num_batches + b) % args.batch_size == 0: 
       print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ 
        .format(e * data_loader.num_batches + b, 
          args.num_epochs * data_loader.num_batches, 
          e, train_loss, speed)) 
      if (e * data_loader.num_batches + b) % args.save_every == 0 \ 
        or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result 
       checkpoint_path = os.path.join(args.save_dir, 'model.ckpt') 
       saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b) 
       print("model saved to {}".format(checkpoint_path)) 
    train_writer.close() 

उत्तर

4

परियोजना आप संदर्भित कर रहे उपयोगों sequence_to_sequence_loss_by_example, जो क्रोस एंट्रोपी नुकसान देता है। इसलिए प्रशिक्षण की परेशानी की गणना करने के लिए, आपको केवल here समझाए गए नुकसान को समाप्त करने की आवश्यकता है।

train_perplexity = tf.exp(train_loss) 

हम, क्योंकि TensorFlow प्राकृतिक लघुगणक (TF Documentation) के साथ क्रोस एंट्रोपी नुकसान का आकलन करता एक आधार के रूप के बजाय ई 2 उपयोग करने के लिए है। संकेत के लिए धन्यवाद, @ माथियास एरो और @ कोलिन स्को।

विस्तृत विवरण

दो संभाव्यता वितरण पी और क्यू के क्रोस एंट्रोपी हमें बिट्स की न्यूनतम औसत संख्या हम पी, की घटनाओं एन्कोड करने के लिए जब हम प्र तो, पी के आधार पर एक कोडिंग योजना विकसित की जरूरत बताता है सच वितरण है, जिसे हम आमतौर पर नहीं जानते हैं। हम जितना संभव हो सके पी के करीब एक क्यू खोजना चाहते हैं, ताकि हम एक अच्छी कोडिंग योजना विकसित कर सकें, जितना संभव हो सके प्रति बिट कुछ बिट्स के साथ।

मुझे बिट्स नहीं कहना चाहिए, क्योंकि अगर हम क्रॉस-एन्ट्रॉपी की गणना में आधार 2 का उपयोग करते हैं तो हम केवल बिट्स का उपयोग कर सकते हैं। लेकिन टेंसरफ्लो प्राकृतिक लॉगरिदम का उपयोग करता है, इसलिए इसके बजाय नट्स में क्रॉस-एन्ट्रॉपी को मापने दें।

तो मान लीजिए कि हमारे पास एक खराब भाषा मॉडल है जो कहता है कि कॉर्पस में प्रत्येक टोकन (चरित्र/शब्द) अगले के रूप में समान रूप से संभव है। 1000 टोकन के एक कॉर्पस के लिए, इस मॉडल में लॉग (1000) = 6.9 नट का क्रॉस-एन्ट्रॉपी होगा। अगले टोकन की भविष्यवाणी करते समय, इसे प्रत्येक चरण में 1000 टोकन के बीच समान रूप से चुनना होगा।

एक बेहतर भाषा मॉडल एक संभाव्यता वितरण क्यू निर्धारित करेगा जो पी के करीब है। इस प्रकार, क्रॉस-एन्ट्रॉपी कम है - हमें 3.9 नट्स का क्रॉस-एन्ट्रॉपी मिल सकती है। अब हम विकलता को मापने के लिए चाहते हैं, हम बस क्रोस एंट्रोपी exponentiate:

exp (3.9) = 49.4

तो, नमूने, जिसके लिए हम नुकसान की गणना पर, अच्छा मॉडल था जैसा कि इसे लगभग 50 टोकन के बीच समान रूप से और स्वतंत्र रूप से चुनना था।

+0

मेरे मामले में ट्रेन की हानि 6.3 है, तो आप कह रहे हैं कि ट्रेन की परेशानी 2^6 = 64 होगी? –

+0

@ शानखान हां। आपका मॉडल प्रशिक्षण डेटा पर उलझन में है जैसे कि इसे प्रत्येक शब्द के लिए 64 विकल्पों के बीच यादृच्छिक रूप से चुनना पड़ा। –

+1

क्या डाउनवॉटर पर टिप्पणी करने की देखभाल क्यों होगी? –

0

यह निर्भर करता है कि आपका नुकसान फ़ंक्शन आपको बेस 2 या बेस ई में डेटा की लॉग संभावना देता है या नहीं। यह मॉडल legacy_seq2seq.sequence_loss_by_example का उपयोग कर रहा है, जो टेंसरफ्लो की बाइनरी क्रॉसेंट्रॉपी का उपयोग करता है, जो appears to use logs of base e है। इसलिए, भले ही हम एक अलग संभावना वितरण (पाठ) से निपट रहे हैं, हमें ई के साथ विस्तार करना चाहिए, यानी कॉलिन स्को के सुझाव के रूप में tf.exp (train_loss) का उपयोग करें।