आप कर सकते हैं बहाल करने के स्लिम तरीका आज़माएं -।। slim.assign_from_checkpoint
https://github.com/tensorflow/tensorflow/blob/129665119ea60640f7ed921f36db9b5c23455224/tensorflow/contrib/slim/python/slim/learning.py
अनुरूप हिस्सा:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
images = tf.placeholder(tf.float32, [None, 224, 224, 3])
predictions = nets.vgg.vgg_16(images)
print [v.name for v in slim.get_variables_to_restore(exclude=['fc8']) ]
:
*************************************************
* Fine-Tuning Part of a model from a checkpoint *
*************************************************
Rather than initializing all of the weights of a given model, we sometimes
only want to restore some of the weights from a checkpoint. To do this, one
need only filter those variables to initialize as follows:
...
# Create the train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)
checkpoint_path = '/path/to/old_model_checkpoint'
# Specify the variables to restore via a list of inclusion or exclusion
# patterns:
variables_to_restore = slim.get_variables_to_restore(
include=["conv"], exclude=["fc8", "fc9])
# or
variables_to_restore = slim.get_variables_to_restore(exclude=["conv"])
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
checkpoint_path, variables_to_restore)
# Create an initial assignment function.
def InitAssignFn(sess):
sess.run(init_assign_op, init_feed_dict)
# Run training.
slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)
अद्यतन
मैं निम्नलिखित की कोशिश की
वहाँ स्लिम स्रोतों में संबंधित दस्तावेज़ है 03,210 और यह निर्गम (छोटा) हो गया है:
[u'vgg_16/conv1/conv1_1/weights:0',
u'vgg_16/conv1/conv1_1/biases:0',
…
u'vgg_16/fc6/weights:0',
u'vgg_16/fc6/biases:0',
u'vgg_16/fc7/weights:0',
u'vgg_16/fc7/biases:0',
u'vgg_16/fc8/weights:0',
u'vgg_16/fc8/biases:0']
तो यह लगता है कि आप vgg_16
साथ गुंजाइश उपसर्ग चाहिए:
print [v.name for v in slim.get_variables_to_restore(exclude=['vgg_16/fc8']) ]
(छोटा) देता है:
[u'vgg_16/conv1/conv1_1/weights:0',
u'vgg_16/conv1/conv1_1/biases:0',
…
u'vgg_16/fc6/weights:0',
u'vgg_16/fc6/biases:0',
u'vgg_16/fc7/weights:0',
u'vgg_16/fc7/biases:0']
अद्यतन 2
पूरा उदाहरण है कि निष्पादित करता है (अपने सिस्टम पर) त्रुटियों के बिना।
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
s = tf.Session(config=tf.ConfigProto(gpu_options={'allow_growth':True}))
images = tf.placeholder(tf.float32, [None, 224, 224, 3])
predictions = nets.vgg.vgg_16(images, 200)
variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
init_assign_op, init_feed_dict = slim.assign_from_checkpoint('./vgg16.ckpt', variables_to_restore)
s.run(init_assign_op, init_feed_dict)
vgg16.ckpt
ऊपर के उदाहरण में एक चौकी 1000 कक्षाएं VGG16 मॉडल के लिए tf.train.Saver
ने बचा लिया है।
init_assign_op, init_feed_dict = slim.assign_from_checkpoint('./vgg16.ckpt', slim.get_variables_to_restore())
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
1 init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
----> 2 './vgg16.ckpt', slim.get_variables_to_restore())
/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/framework/python/ops/variables.pyc in assign_from_checkpoint(model_path, var_list)
527 assign_ops.append(var.assign(placeholder_value))
528
--> 529 feed_dict[placeholder_value] = var_value.reshape(var.get_shape())
530
531 assign_op = control_flow_ops.group(*assign_ops)
ValueError: total size of new array must be unchanged
मुझे लगता है कि पहले से ही विधि की कोशिश की है:
200 कक्षाओं मॉडल के सभी चर के साथ इस चौकी का उपयोग करना (fc8 सहित) निम्न त्रुटि देता है। यह अभी भी मुझे एक ही त्रुटि देता है: 'InvalidArgumentError (ट्रैस बैक के लिए ऊपर देखें): असाइन दोनों tensors के आकार से मिलान करने की आवश्यकता है।lhs आकार = [1,1,4096,200] rhs आकार = [1,1,4096,1000] \t [[नोड: save_1/Assign_32 = असाइन करें [टी = DT_FLOAT, _class = ["loc: @ vgg_16/fc8/वजन "], use_locking = true, validate_shape = true, _device ="/job: localhost/replica: 0/कार्य: 0/gpu: 0 "] (vgg_16/fc8/भार, save_1/restore_slice_32/_3)]]' – user1050648
कृपया एक अद्यतन उत्तर –
हाय, धन्यवाद। ऐसा लगता है कि जब तक 'num_classes' VGG16 के अनुरूप नहीं है, तब तक यह चाल चलती है। यदि आप का उपयोग करते हुए 'vgg_16' का उदाहरण शुरू कर रहे हैं, तो 1000, कक्षाओं के बजाय, 200 अभी भी त्रुटि दिखाता है। – user1050648