File size: 3,279 Bytes
d3b6eff
 
 
 
85e407f
d3b6eff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch, sys
import transformers

try: model_path = sys.argv[1]
except: model_path = "e3.0"

print(f"Loading {model_path} ...")

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_path, 
    device_map = "auto",
    torch_dtype = torch.bfloat16,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(".")

from qwen_vocab import old2new, new2old
STOP_WORDS = "<|im_end|> <|endoftext|>".split()


def map_tids(map_dict, tids):
    return [ map_dict[x] for x in tids if x in map_dict ]


class KeywordsStoppingCriteria(transformers.StoppingCriteria):
    def __init__(self, str):
        self.keyword_ids = tokenizer.encode(str)
        self.keyword_ids = map_tids(old2new, self.keyword_ids)
        self.keyword_len = len(self.keyword_ids)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_token_ids = input_ids[0][-self.keyword_len:]
        return last_token_ids.tolist() == self.keyword_ids

stop_criteria_list = transformers.StoppingCriteriaList(
    [ KeywordsStoppingCriteria(x) for x in STOP_WORDS ]
)


def get_answer(q):
    if len(q) < 3: return "..."

    prompt = f"<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant"
    old_tids = tokenizer.encode(prompt)

    new_tids = map_tids(old2new, old_tids)
    new_old_tids = map_tids(new2old, new_tids)

    new_prompt = tokenizer.decode(new_old_tids)

    if new_old_tids != old_tids:
        print(f"!!! Cảnh báo sự trimm vocab làm mất thông tin !!!")
        print(f"!!! old prompt: {prompt}")
        print(f"!!! new prompt: {new_prompt}")

    inputs = tokenizer(new_prompt, return_tensors="pt").to(model.device)

    assert inputs["input_ids"][0].tolist() == new_old_tids

    for i, x in enumerate(new_tids):
        inputs["input_ids"][0][i] = x

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.3,
            top_p=1.0, top_k=30, do_sample=True,
            repetition_penalty=1.1,
            stopping_criteria=stop_criteria_list,
            pad_token_id=tokenizer.pad_token_id,
        )

    answer_tids = output_ids[0][len(inputs["input_ids"][0]) : ] # bỏ đi prompt tokens
    old_tids = map_tids(new2old, answer_tids.tolist())

    # print(prompt, answer_tids, old_tids) # DEBUG
    return tokenizer.decode(old_tids)\
        .split("<|im_end|>")[0].split("<end_of_turn>")[0].strip()


from utils import *
while True:
    # bỏ qua lỗi utf-8 encoding trong trường hợp nhập text từ console
    try: q = input(f"Bạn: {GREEN}").encode('utf-8', 'ignore').decode('utf-8', 'ignore')
    except Exception as e: print(f"{RESET}{e}"); q = ""

    reset_timer(timer=model_path)
    a = get_answer(q).strip()
    print(f"{RESET}Bot: {RED}{a}{RESET}")
    measure_time("timespent", timer=model_path)

'''
python3 model_chat.py ../Qwen2.5-1.5B-Instruct__trimm_vocab

python3 model_chat.py ../Qwen2.5-1.5B-Instruct

số tuổi của An trừ đi số tuổi của Lan là 3, An 10 tuổi hỏi Lan mấy tuổi?

ai tạo ra bạn

Bạn: tạo ra một câu hoàn chỉnh với từ "thực hiện"
Bot: Thì ra, việc thực hiện kế hoạch của chúng ta cần được lên lịch cụ thể.
'''