2017-04-09 16 views
6

में दिए गए गैर-वर्दी वितरण से प्रतिस्थापन के बिना नमूनाकरण मैं टेंसरफ्लो में numpy.random.choice(range(3),replacement=False,size=2,p=[0.1,0.2,0.7])
के समान कुछ ढूंढ रहा हूं।टेंसरफ्लो

निकटतम Op ऐसा लगता है कि tf.multinomial(tf.log(p)) जो इनपुट के रूप में लॉग इन लेता है लेकिन यह प्रतिस्थापन के बिना नमूना नहीं कर सकता है। क्या TensorFlow में एक गैर-वर्दी वितरण से नमूना करने का कोई अन्य तरीका है?

धन्यवाद।

उत्तर

2

तुम बस tf.py_func इस्तेमाल कर सकते हैं numpy.random.choice लपेट और यह एक TensorFlow सेशन के रूप में उपलब्ध बनाने के लिए:

a = tf.placeholder(tf.float32) 
size = tf.placeholder(tf.int32) 
replace = tf.placeholder(tf.bool) 
p = tf.placeholder(tf.float32) 

y = tf.py_func(np.random.choice, [a, size, replace, p], tf.float32) 

with tf.Session() as sess: 
    print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 

आप हमेशा की तरह numpy बीज निर्दिष्ट कर सकते हैं:

np.random.seed(1) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
np.random.seed(1) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
np.random.seed(1) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 

प्रिंट होगा:

[ 2. 0.] 
[ 2. 1.] 
[ 0. 1.] 
[ 2. 0.] 
[ 2. 1.] 
[ 0. 1.] 
[ 2. 0.] 
संबंधित मुद्दे