2016-04-16 28 views
22

मेरे पास आकार [None, 9, 2] (जहां None बैच है) के tensorflow में इनपुट है।tensorflow में Flatten बैच

इस पर आगे की क्रियाएं (उदा। मैटम) करने के लिए मुझे इसे [None, 18] आकार में बदलने की आवश्यकता है। यह कैसे करना है?

उत्तर

2

रनटाइम के दौरान के माध्यम से बैच आयाम के मूल्य को प्राप्त करने के लिए आप गतिशील रीशेपिंग का उपयोग कर सकते हैं, tf.reshape में नए आयामों के पूरे सेट की गणना करें। सूची लंबाई जानने के बिना वर्ग सूची में वर्ग सूची को दोबारा बदलने का एक उदाहरण यहां दिया गया है।

tf.reset_default_graph() 
sess = tf.InteractiveSession("") 
a = tf.placeholder(dtype=tf.int32) 
# get [9] 
ashape = tf.shape(a) 
# slice the list from 0th to 1st position 
ashape0 = tf.slice(ashape, [0], [1]) 
# reshape list to scalar, ie from [9] to 9 
ashape0_flat = tf.reshape(ashape0,()) 
# tf.sqrt doesn't support int, so cast to float 
ashape0_flat_float = tf.to_float(ashape0_flat) 
newshape0 = tf.sqrt(ashape0_flat_float) 
# convert [3, 3] Python list into [3, 3] Tensor 
newshape = tf.pack([newshape0, newshape0]) 
# tf.reshape doesn't accept float, so convert back to int 
newshape_int = tf.to_int32(newshape) 
a_reshaped = tf.reshape(a, newshape_int) 
sess.run(a_reshaped, feed_dict={a: np.ones((9))}) 

आप बैच का आकार जानने के बिना

array([[1, 1, 1], 
     [1, 1, 1], 
     [1, 1, 1]], dtype=int32) 
+0

मैं इस समाधान में या Tensorflow में किसी भी विधि 'tf.batch' नहीं दिख रहा है ... – Muneeb

34

आप tf.reshape() के साथ आसानी से यह कर सकते हैं देखना चाहिए।

x = tf.placeholder(tf.float32, shape=[None, 9,2]) 
shape = x.get_shape().as_list()  # a list: [None, 9, 2] 
dim = numpy.prod(shape[1:])   # dim = prod(9,2) = 18 
x2 = tf.reshape(x, [-1, dim])   # -1 means "all" 

-1 अंतिम पंक्ति में कोई बात नहीं क्या batchsize क्रम में है पूरे स्तंभ का मतलब है। आप इसे tf.reshape() में देख सकते हैं।


अद्यतन: आकार = [कोई नहीं, 3, कोई नहीं]

धन्यवाद @kbrose। उन मामलों के लिए जहां 1 से अधिक आयाम अपरिभाषित हैं, हम tf.shape() का उपयोग tf.reduce_prod() वैकल्पिक रूप से कर सकते हैं।

x = tf.placeholder(tf.float32, shape=[None, 3, None]) 
dim = tf.reduce_prod(tf.shape(x)[1:]) 
x2 = tf.reshape(x, [-1, dim]) 

tf.shape() एक आकार टेंसर देता है जिसे रनटाइम में मूल्यांकन किया जा सकता है। Tf.get_shape() और tf.shape() के बीच का अंतर in the doc देखा जा सकता है।

मैंने tf.contrib.layers.flatten() को दूसरे में भी कोशिश की। यह पहले मामले के लिए सबसे आसान है, लेकिन यह दूसरे को संभाल नहीं सकता है।

+0

यह अच्छी तरह से काम करता है अगर आप सभी अन्य आयामों के आकार पता है, लेकिन है नहीं तो अन्य आयामों अज्ञात है आकार। जैसे 'x = tf.placeholder (tf.float32, आकार = [कोई नहीं, 9, कोई नहीं])' – kbrose

+1

धन्यवाद @kbrose। मैंने मामले के लिए जवाब अद्यतन किया है। – weitang114

+0

@ weitang114 बहुत बढ़िया! – kbrose