File size: 3,223 Bytes
353256e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780e3dc
 
 
353256e
bd00e79
353256e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import List, Tuple

import torch
from transformers import AutoTokenizer, AutoModel
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList

DEFAULT_MODEL_PATH = "THUDM/chatglm2-6b"


class InvalidScoreLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if torch.isnan(scores).any() or torch.isinf(scores).any():
            scores.zero_()
            scores[..., 5] = 5e4
        return scores


class ChatGLM2(object):
    def __init__(self, model_path=None):
        self.model_path = DEFAULT_MODEL_PATH
        if model_path:
            self.model_path = model_path
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).float()

    def generate(
        self,
        prompt: str,
        do_sample: bool = True,
        max_length: int = 8192,
        num_beams: int = 1,
        temperature: float = 0.8,
        top_p: float = 0.8,
    ):
        logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor}
        inputs = self.tokenizer([prompt], return_tensors="pt")
        inputs = inputs.to(self.model.device)
        outputs = self.model.generate(**inputs, **gen_kwargs)
        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
        response = self.tokenizer.decode(outputs)
        response = self.model.process_response(response)
        return response

    def stream_generate(
        self,
        prompt: str,
        do_sample: bool = True,
        max_length: int = 8192,
        temperature: float = 0.8,
        top_p: float = 0.8,
    ):
        logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor}
        inputs = self.tokenizer([prompt], return_tensors="pt")
        inputs = inputs.to(self.model.device)
        for outputs in self.model.stream_generate(**inputs, **gen_kwargs):
            outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
            response = self.tokenizer.decode(outputs)
            if response and response[-1] != "�":
                response = self.model.process_response(response)
                yield response

    def stream_chat(
        self, 
        query: str, 
        history: List[Tuple[str, str]],
        max_length: int = 8192, 
        do_sample=True,
        top_p=0.8,
        temperature=0.8
    ):
        stream = self.model.stream_chat(self.tokenizer, query, history, 
            max_length=max_length, do_sample=do_sample, top_p=top_p, temperature=temperature)
        for resp, new_history in stream:
            yield resp, new_history