|
|
|
import json |
|
import sys |
|
|
|
from pyspark.sql import SparkSession |
|
from pyspark import SparkConf |
|
from pyspark.sql.functions import col, udf, lit |
|
from pyspark.sql.types import MapType, StringType, FloatType |
|
|
|
from preprocess_content import fasttext_preprocess_func |
|
from fasttext_infer import fasttext_infer |
|
|
|
|
|
def get_fasttext_pred(content: str): |
|
"""Filter the prediction result. |
|
|
|
Args: |
|
content (str): text. |
|
|
|
Returns: |
|
Optional[str]: json string with pred_label and pred_score. |
|
""" |
|
norm_content = fasttext_preprocess_func(content) |
|
label, score = fasttext_infer(norm_content) |
|
|
|
if label == '__label__pos': |
|
return json.dumps({'pred_label': label, 'pred_score': score}, ensure_ascii=False) |
|
else: |
|
return None |
|
|
|
if __name__ == "__main__": |
|
|
|
input_path = sys.argv[1] |
|
save_path = sys.argv[2] |
|
|
|
content_key = "content" |
|
|
|
spark = (SparkSession.builder.enableHiveSupport() |
|
.config("hive.exec.dynamic.partition", "true") |
|
.config("hive.exec.dynamic.partition.mode", "nonstrict") |
|
.appName("FastTextInference") |
|
.getOrCreate()) |
|
|
|
predict_udf = udf(get_fasttext_pred) |
|
|
|
|
|
df = spark.read.parquet(input_path) |
|
df = df.withColumn("fasttext_pred", predict_udf(col(content_key))) |
|
df = df.filter(col("fasttext_pred").isNotNull()) |
|
|
|
|
|
df.coalesce(1000).write.mode("overwrite").parquet(save_path) |
|
|
|
spark.stop() |
|
|