2017-04-06 6 views
7

से बना ऑपरेशन के लिए एक कस्टम ढाल कैसे पंजीकृत करें अधिक विशेष रूप से मेरे पास एक सरल fprop है जो टीएफ संचालन की संरचना है। मैं रजिस्टरग्राइडेंट का उपयोग करके अपनी खुद की ढाल विधि के साथ tensorflow gradient गणना को ओवरराइड करना चाहता हूं।टीएफ ऑपरेशंस

इस कोड के साथ क्या गलत है?

import tensorflow as tf 
from tensorflow.python.framework import ops 

@ops.RegisterGradient("MyopGrad") 
def frop_grad(op, grad): 
    x = op.inputs[0] 
    return 0 * x # zero out to see the difference: 

def fprop(x): 
    x = tf.sqrt(x) 
    out = tf.maximum(x, .2) 
    return out 

a = tf.Variable(tf.constant([5., 4., 3., 2., 1.], dtype=tf.float32)) 
h = fprop(a) 
h = tf.identity(h, name="Myop") 
grad = tf.gradients(h, a) 

g = tf.get_default_graph() 
with g.gradient_override_map({'Myop': 'MyopGrad'}): 
    with tf.Session() as sess: 
     sess.run(tf.initialize_all_variables()) 
     result = sess.run(grad) 

print(result[0]) 

मैं प्रिंट में सब शून्य देखना चाहते हैं, लेकिन इसके बजाय मैं हो रही है:

[ 0.2236068 0.25000003 0.28867513 0.35355341 0.5  ] 

उत्तर

7

आप with g.gradient_override_map({'Myop': 'MyopGrad'})

इसके अलावा के दायरे के भीतर सेशन परिभाषित करने की जरूरत है, तो आप की जरूरत है अपने नए ढाल पर Myop नाम के बजाय मानचित्र Identity

यहाँ पूर्ण कोड है:

import tensorflow as tf 
from tensorflow.python.framework import ops 

@ops.RegisterGradient("MyopGrad") 
def frop_grad(op, grad): 
    x = op.inputs[0] 
    return 0 * x # zero out to see the difference: 

def fprop(x): 
    x = tf.sqrt(x) 
    out = tf.maximum(x, .2) 
    return out 

a = tf.Variable(tf.constant([5., 4., 3., 2., 1.], dtype=tf.float32)) 
h = fprop(a) 

g = tf.get_default_graph() 
with g.gradient_override_map({'Identity': 'MyopGrad'}): 
    h = tf.identity(h, name="Myop") 
    grad = tf.gradients(h, a) 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    result = sess.run(grad) 

print(result[0]) 

आउटपुट:

[ 0. 0. 0. 0. 0.] 
+1

इस पहचान सेशन और नहीं fprop समारोह के लिए एक कस्टम ढाल समारोह को परिभाषित नहीं करता है? यदि आप शून्य से x गुणा नहीं करते हैं तो आप [5., 4., 3., 2., 1.] नहीं देखेंगे, बल्कि इसके बजाय आप पहचान() op में इनपुट देखेंगे। – Milad

संबंधित मुद्दे