欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

如何在Spark中高效利用PyTorch模型进行大规模预测:第二部分 - 使用pandas_udf的实践与优化

最编程 2024-07-27 13:56:40
...

pandas_udf在udf的基础上进行了进一步的优化,利用pandas_udf程序运行效率更高。在这里我们可以借助于pandas_udf提升我们程序的运行效率:

# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "64")

sc.addFile('get_model.py')
from get_model import get_model

model_path = '/path/to/model.pt'
data_path = '/path/to/data'

# model 表示训练好的模型
model = torch.load(model_path)
bc_model_state = sc.broadcast(model.state_dict())


def get_model_for_eval():
  # Broadcast the model state_dict  
  model = get_model()
  model.load_state_dict(bc_model_state.value)
  model.eval()
  return model

# model = torch.load(model_path)
# model = sc.broadcast(model)


@pandas_udf(FloatType())
def predict_batch_udf(arr: pd.Series) -> pd.Series:
    model = get_model_for_eval()
    # model.to(device)
    arr = np.vstack(arr.map(lambda x: eval(x)).values)
    arr = torch.tensor(arr).long()
    with torch.no_grad():
        predictions = list(model(arr).cpu().numpy())
            
    return pd.Series(predictions)

# 预测
data = data.withColumn('predictions', predict_batch_udf('features'))

参考:
How to run inference of a pytorch model on pyspark dataframe (create new column with prediction) using pandas_udf?
Distributed model inference using PyTorch
How to use custom classes with Apache Spark (pyspark)?
Coupling PySpark Transformation with PyTorch Inference

推荐阅读