File size: 3,600 Bytes
4ed02d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import json
import re
import os
import unicodedata
from typing import Tuple, List
from multiprocessing import Pool

import fasttext
import pandas as pd
from tqdm import tqdm
from transformers import LlamaTokenizerFast


language_model_map = {
    "en": "classifiers/ultra_fineweb_en.bin",
    "zh": "classifiers/ultra_fineweb_zh.bin"
}

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--language", type=str, required=True, help="Inference language, support: en, zh.")
    parser.add_argument("--tokenizer-path", type=str, default="local_tokenizer", help="Tokenizer path.")
    parser.add_argument("--content-file", type=str, default="scripts/local_scripts/single_content.txt", help="Content file to infer.")
    return parser.parse_args()


def fasttext_preprocess_func(content: str, tokenizer: LlamaTokenizerFast) -> str:
    """Fasttext preprocess function.

    Args:
        content (str): Content to process.

    Returns:
        str: Processed normalized content.
    """

    # 1. remove multiple newlines
    content = re.sub(r'\n{3,}', '\n\n', content)

    # 2. lower the content
    content = content.lower()

    # 3. remove diacritics
    content = ''.join(
        c for c in unicodedata.normalize('NFKD', content)
        if unicodedata.category(c) != 'Mn')

    # 4. word segmentation
    token_ids = tokenizer.encode(content, add_special_tokens=False)
    single_text_list = []
    for token_id in token_ids:
        curr_text = tokenizer.decode([token_id])
        single_text_list.append(curr_text)

    content = ' '.join(single_text_list)

    # 5. keep escape chars, \n, \t, \r -> \\n, \\t, \\r,
    # which will saved as \n, \t, \r in txt file.
    content = re.sub(r'\n', '\\\\n', content)
    content = re.sub(r'\r', '\\\\r', content)
    content = re.sub(r'\t', '\\\\t', content)
    content = re.sub(r' +', ' ', content)
    content = content.strip()

    return content


def fasttext_infer(norm_content: str, fasttext_model: fasttext.FastText) -> Tuple[str, float]:
    """Fasttext inference function

    Args:
        content (str): input text
    
    Returns:
        str: json string with pred_label and pred_score
    """

    pred_label, pred_prob = fasttext_model.predict(norm_content)
    pred_label = pred_label[0]
    _score = min(pred_prob.tolist()[0], 1)
    if pred_label == "__label__neg":
        _score = 1 - _score

    return pred_label, _score



def main():
    args = parse_args()
    language = args.language
    tokenizer_path = args.tokenizer_path
    content_file = args.content_file

    assert language in ["en", "zh"], f"Language {language} is not supported, please check the language."
    assert os.path.exists(content_file), f"Content file {content_file} does not exist, please check the content file."

    fasttext_model_path = language_model_map[language]

    # load tokenizer
    tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_path)

    # load fasttext model
    fasttext_model = fasttext.load_model(fasttext_model_path)

    content = open(content_file, "r").read()
    # first preprocess the content
    norm_content = fasttext_preprocess_func(content, tokenizer)
    # then infer the content
    pred_label, pred_score = fasttext_infer(norm_content, fasttext_model)
    # finally get the result
    print("-" * 100)
    print(f"Content: {content}")
    print()
    print(f"Normalized content: {norm_content}")
    print()
    print(f"  - Pred label: {pred_label}")
    print(f"  - Pred score: {pred_score}")
    print("-" * 100)


if __name__ == "__main__":
    main()