File size: 1,569 Bytes
4ed02d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# NOTE: This script is not fully tested.
import json
import sys

from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.functions import col, udf, lit
from pyspark.sql.types import MapType, StringType, FloatType

from preprocess_content import fasttext_preprocess_func
from fasttext_infer import fasttext_infer


def get_fasttext_pred(content: str):
    """Filter the prediction result.

    Args:
        content (str): text.

    Returns:
        Optional[str]: json string with pred_label and pred_score.
    """
    norm_content = fasttext_preprocess_func(content)
    label, score = fasttext_infer(norm_content)

    if label == '__label__pos':
        return json.dumps({'pred_label': label, 'pred_score': score}, ensure_ascii=False)
    else:
        return None

if __name__ == "__main__":

    input_path = sys.argv[1]
    save_path = sys.argv[2]

    content_key = "content"

    spark = (SparkSession.builder.enableHiveSupport()
             .config("hive.exec.dynamic.partition", "true")
             .config("hive.exec.dynamic.partition.mode", "nonstrict")
             .appName("FastTextInference")
             .getOrCreate())

    predict_udf = udf(get_fasttext_pred)

    # df = spark.read.json(input_path)
    df = spark.read.parquet(input_path)
    df = df.withColumn("fasttext_pred", predict_udf(col(content_key)))
    df = df.filter(col("fasttext_pred").isNotNull())
    
    # df.coalesce(1000).write.mode("overwrite").json(save_path)
    df.coalesce(1000).write.mode("overwrite").parquet(save_path)

    spark.stop()