2016-01-10 6 views
5

स्पार्क डेटाफ्रेम को TFRecords फ़ाइल में परिवर्तित करने के लिए मैं टेन्सफोर्लो रिकॉर्डवाइटर क्लास का शुद्ध जावा/स्कैला कार्यान्वयन लिखने की कोशिश कर रहा हूं। प्रलेखन के अनुसार, TFRecords में, प्रत्येक रिकॉर्ड इस प्रकार से स्वरूपित है:लिखने के लिए शुद्ध जावा/स्कैला कोड Tensorflow TFRecords डेटा फ़ाइल

:

uint64 length 
uint32 masked_crc32_of_length 
byte data[length] 
uint32 masked_crc32_of_data 

और सीआरसी

masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul 

वर्तमान में मुखौटा, मैं निम्नलिखित कोड के साथ अमरूद कार्यान्वयन के साथ CRC की गणना

import com.google.common.hash.Hashing 

object CRC32 { 
    val kMaskDelta = 0xa282ead8 

    def hash(in: Array[Byte]): Int = { 
    val hashing = Hashing.crc32c() 
    hashing.hashBytes(in).asInt() 
    } 

    def mask(crc: Int): Int ={ 
    ((crc >> 15) | (crc << 17)) + kMaskDelta 
    } 
} 

मेरी कोड के बाकी है:

डेटा एन्कोडिंग हिस्सा डॉन है कोड का निम्न भाग के साथ ई:

object LittleEndianEncoding { 
    def encodeLong(in: Long): Array[Byte] = { 
    val baos = new ByteArrayOutputStream() 
    val out = new LittleEndianDataOutputStream(baos) 
    out.writeLong(in) 
    baos.toByteArray 
    } 

    def encodeInt(in: Int): Array[Byte] = { 
    val baos = new ByteArrayOutputStream() 
    val out = new LittleEndianDataOutputStream(baos) 

    out.writeInt(in) 
    baos.toByteArray 
    } 
} 

रिकॉर्ड प्रोटोकॉल बफर के साथ उत्पन्न कर रहे हैं:

import com.google.protobuf.ByteString 
import org.tensorflow.example._ 

import collection.JavaConversions._ 
import collection.mutable._ 

object TFRecord { 

    def int64Feature(in: Long): Feature = { 

    val valueBuilder = Int64List.newBuilder() 
    valueBuilder.addValue(in) 

    Feature.newBuilder() 
     .setInt64List(valueBuilder.build()) 
     .build() 
    } 


    def floatFeature(in: Float): Feature = { 
    val valueBuilder = FloatList.newBuilder() 
    valueBuilder.addValue(in) 
    Feature.newBuilder() 
     .setFloatList(valueBuilder.build()) 
     .build() 
    } 

    def floatVectorFeature(in: Array[Float]): Feature = { 
    val valueBuilder = FloatList.newBuilder() 
    in.foreach(valueBuilder.addValue) 

    Feature.newBuilder() 
     .setFloatList(valueBuilder.build()) 
     .build() 
    } 

    def bytesFeature(in: Array[Byte]): Feature = { 
    val valueBuilder = BytesList.newBuilder() 
    valueBuilder.addValue(ByteString.copyFrom(in)) 
    Feature.newBuilder() 
     .setBytesList(valueBuilder.build()) 
     .build() 
    } 

    def makeFeatures(features: HashMap[String, Feature]): Features = { 
    Features.newBuilder().putAllFeature(features).build() 
    } 


    def makeExample(features: Features): Example = { 
    Example.newBuilder().setFeatures(features).build() 
    } 

} 

और यहाँ कैसे मैं चीजों को एक साथ उपयोग के क्रम में मेरी TFRecords फ़ाइल उत्पन्न करने के लिए का एक उदाहरण है:

val label = TFRecord.int64Feature(1) 
val feature = TFRecord.floatVectorFeature(Array[Float](1, 2, 3, 4)) 
val features = TFRecord.makeFeatures(HashMap[String, Feature] ("feature"->feature, "label"-> label)) 
val ex = TFRecord.makeExample(features) 
val exSerialized = ex.toByteArray() 
val length = LittleEndianEncoding.encodeLong(exSerialized.length) 
val crcLength = LittleEndianEncoding.encodeInt(CRC32.mask(CRC32.hash(length))) 
val crcEx = LittleEndianEncoding.encodeInt(CRC32.mask(CRC32.hash(exSerialized))) 

val out = new FileOutputStream(new File("test.tfrecords")) 
out.write(length) 
out.write(crcLength) 
out.write(exSerialized) 
out.write(crcEx) 
out.close() 

जब मैं फ़ाइल मैं TFRecordReader साथ Tensorflow अंदर मिला पढ़ने की कोशिश, मैं निम्नलिखित त्रुटि मिलती है:

W tensorflow/core/common_runtime/executor.cc:1076] 0x24cc430 Compute status: Data loss: corrupted record at 0 

मुझे संदेह है कि सीआरसी मास्क गणना सही नहीं है या अंतराल जावा और सी ++ जेनरेट की गई फ़ाइल के बीच समान नहीं है।

+0

आपको त्रुटि संदेश कहां मिल रहा है? –

+0

जब मैं tensorflow में फ़ाइल पढ़ता हूं तो मुझे डेटा दूषित त्रुटि मिलती है। – jrabary

+0

मास्क फ़ंक्शन 'masked_crc = ((crc <15) के साथ प्राप्त परिणाम की तुलना में सही नहीं है। (Crc << 17)) + 0xa282ead8ul' – jrabary

उत्तर

6

मेरे कार्यान्वयन के साथ समस्या सीआरसी मास्क की गणना है। यहाँ मैंने पाया ठीक है:

def mask(crc: Int): Int ={ 
    ((crc >>> 15) | (crc << 17)) + kMaskDelta 
} 

कुंजी के बजाय >>

1

FWIW अहस्ताक्षरित पारी बिटवाइज़ ऑपरेटर >>> का प्रयोग होता है, Tensorflow टीम पढ़ने/TFRecords लिखने के लिए उपयोगिता कोड प्रदान की गई है, जो कर सकते हैं found in the ecosystem repo

संबंधित मुद्दे