BigDong's picture
add Ultra-FineWeb lighteval task python file
4ed02d8
# NOTE: This script is not fully tested.
import json
import sys
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.functions import col, udf, lit
from pyspark.sql.types import MapType, StringType, FloatType
from preprocess_content import fasttext_preprocess_func
from fasttext_infer import fasttext_infer
def get_fasttext_pred(content: str):
"""Filter the prediction result.
Args:
content (str): text.
Returns:
Optional[str]: json string with pred_label and pred_score.
"""
norm_content = fasttext_preprocess_func(content)
label, score = fasttext_infer(norm_content)
if label == '__label__pos':
return json.dumps({'pred_label': label, 'pred_score': score}, ensure_ascii=False)
else:
return None
if __name__ == "__main__":
input_path = sys.argv[1]
save_path = sys.argv[2]
content_key = "content"
spark = (SparkSession.builder.enableHiveSupport()
.config("hive.exec.dynamic.partition", "true")
.config("hive.exec.dynamic.partition.mode", "nonstrict")
.appName("FastTextInference")
.getOrCreate())
predict_udf = udf(get_fasttext_pred)
# df = spark.read.json(input_path)
df = spark.read.parquet(input_path)
df = df.withColumn("fasttext_pred", predict_udf(col(content_key)))
df = df.filter(col("fasttext_pred").isNotNull())
# df.coalesce(1000).write.mode("overwrite").json(save_path)
df.coalesce(1000).write.mode("overwrite").parquet(save_path)
spark.stop()