File size: 5,299 Bytes
1aa07cb
 
 
 
 
 
 
 
 
 
 
34aa338
1aa07cb
 
 
 
 
34aa338
1aa07cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34aa338
1aa07cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34aa338
1aa07cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, T5Tokenizer

model_name = "akiFQC/japanese-dialogpt-small-aozora"
tokenizer = T5Tokenizer.from_pretrained(model_name)
tokenizer.do_lower_case = True  # due to some bug of tokenizer config loading
model = GPT2LMHeadModel.from_pretrained(model_name)


class DialogGPT:
    def __init__(self, tokenizer, model, n_candidate=4, param_lambda=0.10):
        self.tokenizer = tokenizer
        self.model = model
        self.model.eval()
        self.n_candidate = n_candidate
        self.param_lambda = param_lambda
        self.param_gamma: int = 2

    def _calc_single_scores(self, token_ids):
        with torch.inference_mode():
            candidate_token_ids = token_ids[:, :-1]
            label_token_ids = token_ids[:, 1:]
            outputs = self.model(candidate_token_ids, labels=label_token_ids)
        _, logits = outputs[:2]
        logits = torch.log_softmax(logits, dim=-1)

        logit_at_target = logits.gather(
            dim=-1, index=candidate_token_ids.unsqueeze(-1)
        ).squeeze(-1)

        # mask out pad token positio
        mask_at_pad = candidate_token_ids == self.tokenizer.pad_token_id
        # log_likelihood (b, l)
        log_likelihood = logit_at_target
        log_likelihood.masked_fill_(mask_at_pad, 0.0)
        log_likelihood_per_candidate = log_likelihood[:, self.param_gamma:].sum(dim=1)
        # normalize by length
        # log_likelihood_per_candidate = log_likelihood_per_candidate / (candidate_token_ids.shape[1] - mask_at_pad.sum(dim=1))
        return log_likelihood_per_candidate

    def _calc_scores(self, sequences, scores, input_ids=None):
        transition_scores = model.compute_transition_scores(
            sequences, scores, normalize_logits=True
        )
        if input_ids is None:
            input_length = 0
        else:
            input_length = input_ids.shape[1]
        generated_tokens = sequences[:, input_length:]  # n x l
        assert (
            generated_tokens.shape[1] == transition_scores.shape[1]
        ), f"{generated_tokens.shape[1]} != {transition_scores.shape[1]}"
        # print(transition_scores.shape)
        # print(generated_tokens)
        transition_scores.masked_fill_(
            generated_tokens == self.tokenizer.pad_token_id, 0.0
        )
        transition_scores = transition_scores.sum(dim=1)
        # print(transition_scores)
        return transition_scores

    def reply(self, reply, history) -> str:
        chat_history_ids = torch.LongTensor(history).unsqueeze(0)
        # encode the new user input, add the eos_token and return a tensor in Pytorch
        new_user_input_ids = self.tokenizer.encode(
            reply + self.tokenizer.eos_token, return_tensors="pt"
        )

        # append the new user input tokens to the chat history
        bot_input_ids = (
            torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
            if chat_history_ids is not None
            else new_user_input_ids
        )

        # generated a response while limiting the total chat history to 1000 tokens,
        with torch.inference_mode():
            output = model.generate(
                bot_input_ids,
                pad_token_id=self.tokenizer.pad_token_id,
                do_sample=True,
                top_p=0.93,
                temperature=0.5,
                repetition_penalty=1.17,
                max_time=10,
                num_return_sequences=self.n_candidate,
                max_length=512,
                min_length=4,
                forced_eos_token_id=self.tokenizer.pad_token_id,
                return_dict_in_generate=True,
                output_scores=True,
                min_new_tokens=2,
            )

        # score of each candidate
        scores_condition_s2t = self._calc_scores(
            sequences=output.sequences, scores=output.scores, input_ids=bot_input_ids
        )
        new_token_ids = output.sequences[:, bot_input_ids.shape[-1] :]
        single_scores = self._calc_single_scores(new_token_ids) * self.param_lambda

        total_scores = scores_condition_s2t - single_scores
        id_selected = torch.argmax(total_scores)

        chat_history_ids = output.sequences[id_selected].unsqueeze(
            0
        )  # update chat history
        # remove pad token
        chat_history_ids = chat_history_ids[
            :, chat_history_ids[0] != self.tokenizer.pad_token_id
        ]
        replay_string = tokenizer.decode(
            chat_history_ids[:, :][0], skip_special_tokens=False
        )
        return replay_string, chat_history_ids[0].tolist()


bot = DialogGPT(
    tokenizer,
    model,
)


def predict(input, history=[]):
    replay_string, history = bot.reply(input, history)
    response = replay_string.split(tokenizer.eos_token)
    response = [
        (response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)
    ]  # convert to tuples of list
    return response, history


with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    state = gr.State([])

    with gr.Row():
        txt = gr.Textbox(
            show_label=False, placeholder="Enter text and press enter"
        ).style(container=False)

    txt.submit(predict, [txt, state], [chatbot, state])

demo.launch()