2016-04-02 12 views
6

में पूर्व-प्रशिक्षित word2vec वैक्टर इंजेक्शन इंजेक्शन मैं मौजूदा tensorflow seq2seq मॉडल में pretrained word2vec वैक्टर इंजेक्षन करने की कोशिश कर रहा था।टेंसरफ्लो seq2seq

this answer के बाद, मैंने निम्नलिखित कोड का निर्माण किया। लेकिन ऐसा लगता है कि यह प्रदर्शन में सुधार नहीं करता है, हालांकि वैरिएबल में मान अपडेट किए गए हैं।

मेरी समझ में त्रुटि इस तथ्य के कारण हो सकती है कि एंबोडिंगवापर या एम्बेडिंग_टैशन_डेकोडर शब्दावली आदेश से स्वतंत्र रूप से एम्बेडिंग बनाते हैं?

टेंसफोर्लो मॉडल में प्रक्षेपित वैक्टरों को लोड करने का सबसे अच्छा तरीका क्या होगा?

SOURCE_EMBEDDING_KEY = "embedding_attention_seq2seq/RNN/EmbeddingWrapper/embedding" 
TARGET_EMBEDDING_KEY = "embedding_attention_seq2seq/embedding_attention_decoder/embedding" 


def inject_pretrained_word2vec(session, word2vec_path, input_size, dict_dir, source_vocab_size, target_vocab_size): 
    word2vec_model = word2vec.load(word2vec_path, encoding="latin-1") 
    print("w2v model created!") 
    session.run(tf.initialize_all_variables()) 

    assign_w2v_pretrained_vectors(session, word2vec_model, SOURCE_EMBEDDING_KEY, source_vocab_path, source_vocab_size) 
    assign_w2v_pretrained_vectors(session, word2vec_model, TARGET_EMBEDDING_KEY, target_vocab_path, target_vocab_size) 


def assign_w2v_pretrained_vectors(session, word2vec_model, embedding_key, vocab_path, vocab_size): 
    vectors_variable = [v for v in tf.trainable_variables() if embedding_key in v.name] 
    if len(vectors_variable) != 1: 
     print("Word vector variable not found or too many. key: " + embedding_key) 
     print("Existing embedding trainable variables:") 
     print([v.name for v in tf.trainable_variables() if "embedding" in v.name]) 
     sys.exit(1) 

    vectors_variable = vectors_variable[0] 
    vectors = vectors_variable.eval() 

    with gfile.GFile(vocab_path, mode="r") as vocab_file: 
     counter = 0 
     while counter < vocab_size: 
      vocab_w = vocab_file.readline().replace("\n", "") 
      # for each word in vocabulary check if w2v vector exist and inject. 
      # otherwise dont change the value. 
      if word2vec_model.__contains__(vocab_w): 
       w2w_word_vector = word2vec_model.get_vector(vocab_w) 
       vectors[counter] = w2w_word_vector 
      counter += 1 

    session.run([vectors_variable.initializer], 
      {vectors_variable.initializer.inputs[1]: vectors}) 

उत्तर

5

मैं seq2seq उदाहरण से परिचित नहीं हूँ, लेकिन सामान्य रूप में आप अपने embeddings इंजेक्षन करने के लिए निम्न कोड का उपयोग कर सकते हैं:

आप आप कहाँ निर्माण का ग्राफ़ बनाने:

with tf.device("/cpu:0"): 
    embedding = tf.get_variable("embedding", [vocabulary_size, embedding_size])  
    inputs = tf.nn.embedding_lookup(embedding, input_data) 

जब आप निष्पादित करें (अपना ग्राफ बनाने के बाद और प्रशिक्षण देने से पहले), बस एम्बेड किए गए चर पर अपने सहेजे गए एम्बेडिंग असाइन करें:

session.run(tf.assign(embedding, embeddings_that_you_want_to_use)) 

विचार यह है कि embedding_lookup चर में मौजूद input_data मानों को प्रतिस्थापित करेगा।

संबंधित मुद्दे