5

में अमान्य लेबल कॉलम त्रुटि के साथ इनपुट दिया गया था। मैं SCALA में रैंडम वन क्लासिफायर मॉडल का उपयोग करके 5-गुना क्रॉस सत्यापन का उपयोग करके सटीकता खोजने का प्रयास कर रहा हूं।RandomForestClassifier को Apache Spark

java.lang.IllegalArgumentException: लेकिन मैं जबकि चल रहा है निम्न त्रुटि हो रही है RandomForestClassifier अमान्य लेबल स्तंभ लेबल के साथ इनपुट दिया गया था निर्दिष्ट वर्गों की संख्या के बिना,। स्ट्रिंग इंडेक्सर देखें।

लाइन पर उपरोक्त त्रुटि हो रही है ---> वैल cvModel = cv.fit (trainingData)

इस प्रकार कोड है जो मैं डेटा के पार सत्यापन के लिए इस्तेमाल किया यादृच्छिक वन का उपयोग कर सेट है:

import org.apache.spark.ml.Pipeline 
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} 
import org.apache.spark.ml.classification.RandomForestClassifier 
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator 
import org.apache.spark.mllib.linalg.Vectors 
import org.apache.spark.mllib.regression.LabeledPoint 

val data = sc.textFile("exprogram/dataset.txt") 
val parsedData = data.map { line => 
val parts = line.split(',') 
LabeledPoint(parts(41).toDouble, 
Vectors.dense(parts(0).split(',').map(_.toDouble))) 
} 


val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) 
val training = splits(0) 
val test = splits(1) 

val trainingData = training.toDF() 

val testData = test.toDF() 

val nFolds: Int = 5 
val NumTrees: Int = 5 

val rf = new  
RandomForestClassifier() 
     .setLabelCol("label") 
     .setFeaturesCol("features") 
     .setNumTrees(NumTrees) 

val pipeline = new Pipeline() 
     .setStages(Array(rf)) 

val paramGrid = new ParamGridBuilder() 
      .build() 

val evaluator = new MulticlassClassificationEvaluator() 
    .setLabelCol("label") 
    .setPredictionCol("prediction") 
    .setMetricName("precision") 

val cv = new CrossValidator() 
    .setEstimator(pipeline) 
    .setEvaluator(evaluator) 
    .setEstimatorParamMaps(paramGrid) 
    .setNumFolds(nFolds) 

val cvModel = cv.fit(trainingData) 

val results = cvModel.transform(testData) 
.select("label","prediction").collect 

val numCorrectPredictions = results.map(row => 
if (row.getDouble(0) == row.getDouble(1)) 1 else 0).foldLeft(0)(_ + _) 
val accuracy = 1.0D * numCorrectPredictions/results.size 

println("Test set accuracy: %.3f".format(accuracy)) 

क्या कोई भी बता सकता है कि उपर्युक्त कोड में क्या गलती है।

उत्तर

8

RandomForestClassifier, कई अन्य एमएल एल्गोरिदम के समान, लेबल कॉलम और लेबल मानों पर सेट किए जाने वाले विशिष्ट मेटाडेटा को [0, 1, 2 ..., #classes) से अभिन्न मान होने के लिए युगल के रूप में दर्शाया जाना आवश्यक है। आम तौर पर इसे TransformersStringIndexer जैसे अपस्ट्रीम द्वारा नियंत्रित किया जाता है। चूंकि आप लेबल को मैन्युअल रूप से मेटाडेटा फ़ील्ड को कनवर्ट करते हैं और क्लासिफायर यह पुष्टि नहीं कर सकते कि ये आवश्यकताएं संतुष्ट हैं।

val df = Seq(
    (0.0, Vectors.dense(1, 0, 0, 0)), 
    (1.0, Vectors.dense(0, 1, 0, 0)), 
    (2.0, Vectors.dense(0, 0, 1, 0)), 
    (2.0, Vectors.dense(0, 0, 0, 1)) 
).toDF("label", "features") 

val rf = new RandomForestClassifier() 
    .setFeaturesCol("features") 
    .setNumTrees(5) 

rf.setLabelCol("label").fit(df) 
// java.lang.IllegalArgumentException: RandomForestClassifier was given input ... 

आप या तो फिर से एनकोड लेबल स्तंभ StringIndexer का उपयोग कर:

import org.apache.spark.ml.feature.StringIndexer 

val indexer = new StringIndexer() 
    .setInputCol("label") 
    .setOutputCol("label_idx") 
    .fit(df) 

rf.setLabelCol("label_idx").fit(indexer.transform(df)) 

या set required metadata manually:

val meta = NominalAttribute 
    .defaultAttr 
    .withName("label") 
    .withValues("0.0", "1.0", "2.0") 
    .toMetadata 

rf.setLabelCol("label_meta").fit(
    df.withColumn("label_meta", $"label".as("", meta)) 
) 

नोट:

लेबलउपयोग करके बनाए गए 113,210 आवृत्ति पर निर्भर नहीं मूल्य:

indexer.labels 
// Array[String] = Array(2.0, 0.0, 1.0) 

PySpark:

from pyspark.sql.types import StructField, DoubleType 

StructField(
    "label", DoubleType(), False, 
    {"ml_attr": { 
     "name": "label", 
     "type": "nominal", 
     "vals": ["0.0", "1.0", "2.0"] 
    }} 
) 
:

अजगर मेटाडाटा क्षेत्रों में स्कीमा पर सीधे सेट किया जा सकता