2017-09-09 31 views
10

केरास fit_generator() मॉडल विधि एक जेनरेटर की अपेक्षा करता है जो आकार (इनपुट, लक्ष्य) के tuples पैदा करता है, जहां दोनों तत्व NumPy arrays हैं। The documentation का अर्थ यह है कि अगर मैं जनरेटर में Dataset iterator लपेटता हूं, और टेंसर को न्यूमपी सरणी में परिवर्तित करना सुनिश्चित करता हूं, तो मुझे जाना अच्छा होना चाहिए।कैसे टेंसरफ्लो के डेटासेट एपीआई और केरास को उचित रूप से संयोजित करें?

import numpy as np 
import os 
import keras.backend as K 
from keras.layers import Dense, Input 
from keras.models import Model 
import tensorflow as tf 
from tensorflow.contrib.data import Dataset 

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

with tf.Session() as sess: 
    def create_data_generator(): 
     dat1 = np.arange(4).reshape(-1, 1) 
     ds1 = Dataset.from_tensor_slices(dat1).repeat() 

     dat2 = np.arange(5, 9).reshape(-1, 1) 
     ds2 = Dataset.from_tensor_slices(dat2).repeat() 

     ds = Dataset.zip((ds1, ds2)).batch(4) 
     iterator = ds.make_one_shot_iterator() 
     while True: 
      next_val = iterator.get_next() 
      yield sess.run(next_val) 

datagen = create_data_generator() 

input_vals = Input(shape=(1,)) 
output = Dense(1, activation='relu')(input_vals) 
model = Model(inputs=input_vals, outputs=output) 
model.compile('rmsprop', 'mean_squared_error') 
model.fit_generator(datagen, steps_per_epoch=1, epochs=5, 
        verbose=2, max_queue_size=2) 

यहाँ त्रुटि मैं प्राप्त होते हैं:: इस कोड, हालांकि, मुझे एक त्रुटि देता है, अजीब पर्याप्त

Using TensorFlow backend. 
Epoch 1/5 
Exception in thread Thread-1: 
Traceback (most recent call last): 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__ 
    fetch, allow_tensor=True, allow_operation=True)) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element 
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked 
    raise ValueError("Tensor %s is not an element of this graph." % obj) 
ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph. 

During handling of the above exception, another exception occurred: 

Traceback (most recent call last): 
    File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner 
    self.run() 
    File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run 
    self._target(*self._args, **self._kwargs) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task 
    generator_output = next(self._generator) 
    File "./datagen_test.py", line 25, in create_data_generator 
    yield sess.run(next_val) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run 
    run_metadata_ptr) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run 
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__ 
    self._fetch_mapper = _FetchMapper.for_fetch(fetches) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch 
    return _ListFetchMapper(fetch) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__ 
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp> 
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch 
    return _ElementFetchMapper(fetches, contraction_fn) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__ 
    'Tensor. (%s)' % (fetch, str(e))) 
ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.) 

Traceback (most recent call last): 
    File "./datagen_test.py", line 34, in <module> 
    verbose=2, max_queue_size=2) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper 
    return func(*args, **kwargs) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator 
    generator_output = next(output_generator) 
StopIteration 

next(datagen) रखने वाली पंक्ति जोड़ने सीधे के बाद मैं कहाँ से प्रारंभ datagen बस चलाने के लिए कोड का कारण बनता है ठीक है, कोई त्रुटि के साथ।

मेरा मूल कोड क्यों काम नहीं करता है? जब मैं उस कोड को अपने कोड में जोड़ता हूं तो यह क्यों काम करना शुरू कर देता है? क्या केरस के साथ टेंसरफ्लो के डेटासेट एपीआई का उपयोग करने का एक और अधिक प्रभावी तरीका है जिसमें टेंसर को न्यूमपी सरणी में परिवर्तित करने और फिर से वापस शामिल करने में शामिल नहीं है?

+0

मुझे यकीन नहीं है कि यही कारण है, लेकिन मुझे यह वास्तव में अजीब लगता है कि आप 'ब्लॉक' के अंदर एक फ़ंक्शन को परिभाषित करते हैं। –

+0

जाहिर है, जनरेटर परिभाषा के अंदर 'साथ' ब्लॉक डालने से कोड अतिरिक्त लाइन के साथ और बिना दोनों काम करता है, हालांकि मैं शपथ ले सकता था कि मैंने पहले इसे इस तरह से आजमाया। इस बात को ध्यान में रखते हुए (मुझे लगता है) टेंसरफ्लो 'सत्र का काम, हालांकि, मुझे नहीं लगता कि इसे कोई फर्क क्यों पड़ता है। एक और रहस्य – Jason

+0

ब्लॉक के साथ सत्र समाप्त होने पर सत्र बंद नहीं होता है? मुझे लगता है कि वास्तव में परिभाषाओं को शामिल नहीं किया जाना चाहिए जिसका उपयोग इसके बाहर किया जाएगा ....यदि मैं इसे प्रश्न के उत्तर के रूप में पोस्ट करता हूं, तो क्या इसे उत्तर के रूप में चिह्नित किया जाएगा? –

उत्तर

7

टेंसर को numpy arrays में परिवर्तित किए बिना Dataset का उपयोग करने का वास्तव में एक और अधिक प्रभावी तरीका है। हालांकि, यह आधिकारिक दस्तावेज पर (अभी तक?) नहीं है। रिलीज नोट से, यह एक विशेषता है जो केरास 2.0.7 में पेश की गई है। इसका उपयोग करने के लिए आपको keras> = 2.0.7 इंस्टॉल करना पड़ सकता है।

x = np.arange(4).reshape(-1, 1).astype('float32') 
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4) 
it_x = ds_x.make_one_shot_iterator() 

y = np.arange(5, 9).reshape(-1, 1).astype('float32') 
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4) 
it_y = ds_y.make_one_shot_iterator() 

input_vals = Input(tensor=it_x.get_next()) 
output = Dense(1, activation='relu')(input_vals) 
model = Model(inputs=input_vals, outputs=output) 
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()]) 
model.fit(steps_per_epoch=1, epochs=5, verbose=2) 

कई मतभेद:

  1. आपूर्ति Input परत को tensor तर्क। केरा इस टेंसर से मूल्य पढ़ेगा, और इसे मॉडल के अनुकूल करने के लिए इनपुट के रूप में उपयोग करेगा।
  2. target_tensorsModel.compile() पर तर्क की आपूर्ति करें।
  3. दोनों x और y को float32 में परिवर्तित करना याद रखें। सामान्य उपयोग के तहत, केरास आपके लिए यह रूपांतरण करेगा। लेकिन अब आपको इसे खुद करना होगा।
  4. बैच आकार Dataset के निर्माण के दौरान निर्दिष्ट किया गया है। मॉडल फिटिंग को रोकने के लिए नियंत्रित करने के लिए steps_per_epoch और epochs का उपयोग करें।

संक्षेप में, का उपयोग Input(tensor=...), model.compile(target_tensors=...) और model.fit(x=None, y=None, ...) अपने डेटा tensors से पढ़ा जा करने के लिए कर रहे हैं।

+3

ऐसा लगता है कि दो अलग इटरेटर होने के लिए भी आवश्यक नहीं है। आप बस दो डेटासेट्स को ज़िप कर सकते हैं, 'next_val = it.get_next()' जैसे नोड बनाएं, और इसके आउटपुट के तत्व 'इनपुट()' और 'Model.compile()' फ़ंक्शंस में प्रदान करें। – Jason

+0

इटेटरेटर प्रारंभिकरण के बारे में क्या? क्या मैं किसी भी तरह से हर युग के साथ शुरू करने के लिए keras बता सकता हूँ? या मुझे अभी भी सत्र बनाने और इसे मैन्युअल रूप से करने की आवश्यकता है और फिर हर बार एक युग चलाएं? – backman

1

@ यू-यांग का जवाब देने के लिए इसके अलावा, आप भी tf.data.Datasetfit_generator निम्नलिखित

def tfdata_generator(images, labels, batch_size=128, shuffle=True,): 
    def map_func(image, label): 
     '''A transformation function 

     ''' 
     x_train = tf.reshape(tf.cast(image, tf.float32), image_shape) 
     y_train = tf.one_hot(tf.cast(label, tf.uint8), num_classes) 
     return [x_train, y_train] 

    dataset = tf.data.Dataset.from_tensor_slices((images, labels)) 
    dataset = dataset.map(map_func) 
    dataset = dataset.shuffle().batch(batch_size).repeat() 
    iterator = dataset.make_one_shot_iterator() 

    next_batch = iterator.get_next() 
    while True: 
     yield K.get_session().run(next_batch) 

अब आप इसे एक जनरेटर के रूप में कॉल कर सकते हैं के रूप में संशोधित एक जनरेटर बनने के लिए कर सकते हैं। इस उदाहरण में। मैंने mnist डेटासेट का उपयोग किया।

from tensorflow.contrib.learn.python.learn.datasets import mnist 

data = mnist.load_mnist() 
model = # your Keras model 

model.fit_generator(generator = tfdata_generator(data.train.images, data.train.labels), 
        steps_per_epoch=200, 
        workers = 0 , # This is important 
        verbose = 1) 
+0

यह AFAIK फिट_जेनरेटर के validation_data पैरामीटर का उपयोग करके, कैरेज़ को सत्यापन डेटा प्रदान करने का एकमात्र तरीका है – Warrick

संबंधित मुद्दे