File size: 2,597 Bytes
a70311a
 
1be71e1
a70311a
1be71e1
 
a70311a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1be71e1
 
 
 
dd6a7c6
1be71e1
afae51e
e72e55e
a70311a
 
1be71e1
a70311a
 
 
 
 
 
 
 
1be71e1
a70311a
 
 
 
afae51e
a70311a
 
 
1be71e1
 
 
a70311a
afae51e
d7c49f6
a70311a
1be71e1
 
 
 
 
 
 
 
 
 
 
a70311a
 
 
 
1be71e1
 
a70311a
 
1be71e1
a70311a
1be71e1
 
 
 
 
 
 
 
d7c49f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import ctranslate2
from transformers import AutoTokenizer

import threading
import gradio as gr

from typing import Optional
from queue import Queue




class TokenIteratorStreamer:
    def __init__(self, end_token_id: int, timeout: Optional[float] = None):
        self.end_token_id = end_token_id
        self.queue = Queue()
        self.timeout = timeout

    def put(self, token: int):
        self.queue.put(token, timeout=self.timeout)

    def __iter__(self):
        return self

    def __next__(self):
        token = self.queue.get(timeout=self.timeout)
        if token == self.end_token_id:
            raise StopIteration()
        else:
            return token



def generate_prompt(history):
    prompt = ""
    for chain in history[:-1]:
        prompt += f"<human>: {chain[0]}\n<bot>: {chain[1]}\n"
    prompt += f"<human>: {history[-1][0]}\n<bot>:"
    tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
    return tokens

def generate(streamer, history):
    def stepResultCallback(result):
        streamer.put(result.token_id)
        if result.is_last and (result.token_id != end_token_id):
            streamer.put(end_token_id)
        print(f"step={result.step}, batch_id={result.batch_id}, token={result.token}")
    
    tokens = generate_prompt(history)

    results = translator.translate_batch(
        [tokens],
        beam_size=1,
        max_decoding_length = 256,
        repetition_penalty = 1.8,
        callback = stepResultCallback
    )
    return results



translator = ctranslate2.Translator("model", intra_threads=2)
tokenizer = AutoTokenizer.from_pretrained("DKYoon/mt5-xl-lm-adapt")
end_token = "</s>"
end_token_id = tokenizer.encode(end_token)[0]


with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, ""]]

    def bot(history):
        bot_message_tokens = []
        streamer = TokenIteratorStreamer(end_token_id = end_token_id)
        generation_thread = threading.Thread(target=generate, args=(streamer, history))
        generation_thread.start()
        
        for token in streamer:
            bot_message_tokens.append(token)
            history[-1][1] = tokenizer.decode(bot_message_tokens)
            yield history
        generation_thread.join()

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)
    
demo.queue()
if __name__ == "__main__":
    demo.launch()