File size: 3,124 Bytes
883f1f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0937ec8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import json
import re
import torch
from transformers import GPT2Tokenizer, T5ForConditionalGeneration 


# re_tokens = re.compile(r"[а-яА-Я]+\s*|\d+(?:\.\d+)?\s*|[^а-яА-Я\d\s]+\s*")
re_tokens = re.compile(r"(?:[.,!?]|[а-яА-Я]\S*|\d\S*(?:\.\d+)?|[^а-яА-Я\d\s]+)\s*")


def tokenize(text):
    return re.findall(re_tokens, text)


def strip_numbers(s):
    """
    From `1234567` to `1 234 567`
    """
    result = []
    for part in s.split():
        if part.isdigit():
            while len(part) > 3:
                result.append(part[:- 3 * ((len(part) - 1) // 3)])
                part = part[- 3 * ((len(part) - 1) // 3):]
            if part:
                result.append(part)
        else:
            result.append(part)
    return " ".join(result)


def construct_prompt(text):
    """
    From `я купил iphone 12X за 142 990 руб без 3-x часов 12:00, и т.д.` \
    to `<SC1>я купил [iphone 12X]<extra_id_0> за [142 990]<extra_id_1> руб без [3-x]<extra_id_2> часов [12:00]<extra_id_3>, и т.д.`.
    """
    result = "<SC1>"
    etid = 0
    token_to_add = ""
    for token in tokenize(text) + [""]:
        if not re.search("[a-zA-Z\d]", token):
            if token_to_add:
                end_match = re.search(r"(.+?)(\W*)$", token_to_add, re.M).groups()
                result += f"[{strip_numbers(end_match[0])}]<extra_id_{etid}>{end_match[1]}"
                etid += 1
                token_to_add = ""
            result += token
        else:
            token_to_add += token
    return result


def construct_answer(prompt:str, prediction:str) -> str:
    re_prompt = re.compile(r"\[([^\]]+)\]<extra_id_(\d+)>")
    re_pred = re.compile(r"\<extra_id_(\d+)\>(.+?)(?=\<extra_id_\d+\>|</s>)")
    pred_data = {}
    for match in re.finditer(re_pred, prediction.replace("\n", " ")):
        pred_data[match[1]] = match[2].strip()
    while match := re.search(re_prompt, prompt):
        replace = pred_data.get(match[2], match[1])
        prompt = prompt[:match.span()[0]] + replace + prompt[match.span()[1]:]
    return prompt.replace("<SC1>", "")


with open("examples.json") as f:
    test_examples = json.load(f)


tokenizer = GPT2Tokenizer.from_pretrained("saarus72/russian_text_normalizer", eos_token='</s>')
model = T5ForConditionalGeneration.from_pretrained("saarus72/russian_text_normalizer")


def predict(text):
    input_ids = torch.tensor([tokenizer.encode(text)])
    outputs = model.generate(input_ids, max_new_tokens=50, eos_token_id=tokenizer.eos_token_id, early_stopping=True)
    return tokenizer.decode(outputs[0][1:])


def norm(message, history):
    prompt = construct_prompt(message)
    yield f"```Prompt:\n{prompt}\nPrediction:\n...```\n..."
    prediction = predict(prompt)
    answer = construct_answer(prompt, prediction)
    # yield f"```\nPrompt:\n{prompt}\nPrediction:\n{prediction}\n```\n{answer}"
    yield f"Prompt:\n```{prompt}```\nPrediction:\n```\n{prediction}\n```\n{answer}"


demo = gr.ChatInterface(fn=norm, stop_btn=None, examples=list(test_examples.keys())).queue()
demo.launch()
#