分布式机器学习 - Spark MLlib
最编程
2024-05-02 06:58:15
...
from pyspark.ml.classification import RandomForestClassifier
gbt = RandomForestClassifier(
labelCol="TARGET",
featuresCol="features",
maxDepth=8,
numTrees=500,
subsamplingRate=1.0,
featureSubsetStrategy='auto',
seed=SEED
)
# Train a GBT model.
model = gbt.fit(train)
# Select (prediction, true label) and compute areaUnderROC
evaluator = BinaryClassificationEvaluator(
labelCol="TARGET",
metricName='areaUnderROC'
)
train_auc = evaluator.evaluate(model.transform(train))
test_auc = evaluator.evaluate(model.transform(test))
print(f"Train auc: {train_auc:.4f}")
print(f"Test auc: {test_auc:.4f}")
Java HotSpot(TM) 64-Bit Server VM warning: CodeCache is full. Compiler has been disabled.
Java HotSpot(TM) 64-Bit Server VM warning: Try increasing the code cache size using -XX:ReservedCodeCacheSize=
CodeCache: size=131072Kb used=38814Kb max_used=39023Kb free=92257Kb
bounds [0x000000010464c000, 0x0000000106c9c000, 0x000000010c64c000]
total_blobs=13345 nmethods=12309 adapters=949
compilation: disabled (not enough contiguous free space left)
Train auc: 0.7526
Test auc: 0.7235
feature_imp = pd.Series(
model.featureImportances.toArray(),
index=assembler.getInputCols()
).sort_values(ascending=False)
print(feature_imp.head(20))
EXT_SOURCE_2 0.183568
EXT_SOURCE_3 0.175979
EXT_SOURCE_1 0.094980
DAYS_EMPLOYED 0.050050
OCCUPATION_TYPE 0.032153
DAYS_BIRTH 0.032032
NAME_EDUCATION_TYPE 0.025601
DAYS_LAST_PHONE_CHANGE 0.022394
AMT_GOODS_PRICE 0.019779
REGION_RATING_CLIENT_W_CITY 0.014936
CODE_GENDER_M 0.014736
REGION_RATING_CLIENT 0.012078
ORGANIZATION_TYPE 0.011209
AMT_CREDIT 0.010922
NAME_INCOME_TYPE_Working 0.010745
DAYS_ID_PUBLISH 0.010505
FLAG_DOCUMENT_3 0.009315
OWN_CAR_AGE 0.009004
AMT_ANNUITY 0.007916
TOTALAREA_MODE 0.007510
dtype: float64
上一篇: 您的快递坏了:快递丢失丢失到底是谁的错?
下一篇: 常用垃圾收集器简介