File size: 5,503 Bytes
ae84b44
 
 
21a5dba
ae84b44
9283ef4
ae84b44
7f3bae1
ae84b44
7f3bae1
ae84b44
7f3bae1
ae84b44
7f3bae1
ae84b44
 
 
 
7f3bae1
ae84b44
7f3bae1
ae84b44
 
3e462aa
ae84b44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f3bae1
ae84b44
 
 
 
 
 
 
 
 
 
 
 
9283ef4
ae84b44
09f0075
ae84b44
 
 
 
 
 
 
741f4b4
ae84b44
 
 
 
 
 
 
 
ddc2c87
21a5dba
 
 
ae84b44
 
 
6966380
ae84b44
 
 
 
21a5dba
e4ef938
ae84b44
9283ef4
7f3bae1
ae84b44
7f3bae1
 
 
ae84b44
 
ac61535
21a5dba
 
e4ef938
 
9283ef4
3e462aa
21a5dba
ac61535
ae84b44
 
 
 
 
 
21a5dba
 
1673dcd
 
ae84b44
 
21a5dba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from util_funcs import get_length_param

def chat_function(Message, Length_of_the_answer, Who_is_next, Base_to_On_subject_temperature, history):   # model, tokenizer
    
    input_user = Message
    
    if Length_of_the_answer == 'short':
        next_len = '1'
    elif Length_of_the_answer == 'medium':
        next_len = '2'
    elif Length_of_the_answer == 'long':
        next_len = '3'
    else:
        next_len = '-'
        
    if Who_is_next == 'Kirill':
        next_who = 'G'
    elif Who_is_next == 'Me':
        next_who = 'H'
        
    history = history or []
    chat_history_ids = torch.zeros((1, 0), dtype=torch.int) if history == [] else torch.tensor(history[-1][2], dtype=torch.long)

    # encode the new user input, add parameters and return a tensor in Pytorch
    if len(input_user) != 0:

        new_user_input_ids = tokenizer.encode(f"|0|{get_length_param(input_user, tokenizer)}|" \
                                              + input_user + tokenizer.eos_token, return_tensors="pt")
        # append the new user input tokens to the chat history
        chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
    else:
        input_user = '-'
        
    if next_who == "G":

        # encode the new user input, add parameters and return a tensor in Pytorch
        new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
        # append the new user input tokens to the chat history
        chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)

        # print(tokenizer.decode(chat_history_ids[-1])) # uncomment to see full gpt input

        # save previous len
        input_len = chat_history_ids.shape[-1]
        # generated a response; PS you can read about the parameters at hf.co/blog/how-to-generate
        chat_history_ids = model.generate(
            chat_history_ids,
            num_return_sequences=1,                     # use for more variants, but have to print [i]
            max_length=512,
            no_repeat_ngram_size=3,
            do_sample=True,
            top_k=50,
            top_p=0.9,
            temperature = float(Base_to_On_subject_temperature),                          # 0 for greedy
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )

        response = tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)
    else:
        response = '-'
        
    history.append((input_user, response, chat_history_ids.tolist()))        
    # depricated -- gr.set_state(history)

    html = "<div class='chatbot'>"
    for user_msg, resp_msg, _ in history:
        if user_msg != '-':
            html += f"<div class='user_msg'>{user_msg}</div>"
        if resp_msg != '-':
            html += f"<div class='resp_msg'>{resp_msg}</div>"
    html += "</div>"
    return html, history





# Download checkpoint:
checkpoint = "Kirili4ik/ruDialoGpt3-medium-finetuned-telegram-6ep"
tokenizer =  AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)
model = model.eval()

# Gradio
checkbox_group = gr.inputs.CheckboxGroup(['Kirill', 'Me'], default=['Kirill'], type="value", label=None)
title = "Chat with Kirill (in Russian)"
description = "Тут можно поболтать со мной. Но вместо меня бот. Оставь сообщение пустым, чтобы Кирилл продолжил говорить - он очень любит писать подряд несколько сообщений в чате. Используй слайдер, чтобы ответы были более общими или более конкретными (ближе к теме). Подробнее о технике по ссылке внизу."
article = "<p style='text-align: center'><a href='https://github.com/Kirili4ik/ruDialoGpt3-finetune-colab'>Github with fine-tuning GPT-3 on your chat</a></p>"
examples = [
            ["В чем смысл жизни?", 'medium', 'Kirill', 0.95],
            ["Когда у тебя ближайший собес?", 'medium', 'Kirill', 0.85],
            ["Сколько тебе лет, Кирилл?", 'medium', 'Kirill', 0.85] 
]

iface = gr.Interface(chat_function,
                     [
                         "text",
                         gr.inputs.Radio(["short", "medium", "long"], default='medium'),
                         gr.inputs.Radio(["Kirill", "Me"], default='Kirill'),
                         gr.inputs.Slider(0, 1.5, default=0.5),
                         "state"
                     ],
                     ["html", "state"],
                     title=title, description=description, article=article, examples=examples,
                     css= """
                            .chatbox {display:flex;flex-direction:column}
                            .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
                            .user_msg {background-color:cornflowerblue;color:white;align-self:start}
                            .resp_msg {background-color:lightgray;align-self:self-end}
                          """,
                     allow_screenshot=True,
                     allow_flagging=False,
                     api_mode=True
                    )

if __name__ == "__main__":
    iface.launch()