AndrewChar commited on
Commit
ac1b258
1 Parent(s): 0c57adf

Update App.py

Browse files
Files changed (1) hide show
  1. App.py +124 -4
App.py CHANGED
@@ -1,7 +1,127 @@
 
 
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import random
3
+ import torch
4
  import gradio as gr
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from util_funcs import getLengthParam, calcAnswerLengthByProbability, cropContext
7
 
8
+ def chat_function(Message, History): # model, tokenizer
 
9
 
10
+ input_user = Message
11
+
12
+ history = History or []
13
+
14
+ chat_history_ids = torch.zeros((1, 0), dtype=torch.int) if history == [] else torch.tensor(history[-1][2], dtype=torch.long)
15
+
16
+ # encode the new user input, add parameters and return a tensor in Pytorch
17
+ lengthId = getLengthParam(input_user, tokenizer)
18
+ new_user_input_ids = tokenizer.encode(f"|0|{lengthId}|" \
19
+ + input_user + tokenizer.eos_token, return_tensors="pt")
20
+ # append the new user input tokens to the chat history
21
+ chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
22
+
23
+ # Длину ожидаемой фразы мы рассчитаем на основании последнего инпута
24
+ # Например, я не люблю когда на мой длинный ответ отвечают короткой фразой
25
+ # Но пойдем через вероятности:
26
+ # при длинном инпуте 60% что будет длинный ответ (3), 30% что средний (2), 10% что короткий (1)
27
+ # при среднем инпуте 50% что ответ будет средний (2), и по 25% на оба остальных случая
28
+ # при коротком инпуте 50% что ответ будет короткий (1), 30% что средний (2) и 20% что длинный (3)
29
+ # см. функцию calcAnswerLengthByProbability()
30
+
31
+ next_len = calcAnswerLengthByProbability(lengthId)
32
+
33
+ # encode the new user input, add parameters and return a tensor in Pytorch
34
+ new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
35
+
36
+ # append the new user input tokens to the chat history
37
+ chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
38
+
39
+ chat_history_ids = cropContext(chat_history_ids, 10)
40
+
41
+ print(tokenizer.decode(chat_history_ids[-1]))# uncomment for debug
42
+
43
+ # save previous len
44
+ input_len = chat_history_ids.shape[-1]
45
+ # generated a response; PS you can read about the parameters at hf.co/blog/how-to-generate
46
+
47
+ temperature = 0.6
48
+
49
+ # Обрезаем контекст до нужной длины с конца
50
+
51
+ # Создадим копию изначальных данных на случай если придется перегенерировать ответ
52
+ chat_history_ids_initial = chat_history_ids
53
+
54
+ while True:
55
+ chat_history_ids = model.generate(
56
+ chat_history_ids,
57
+ num_return_sequences=1,
58
+ min_length = 2,
59
+ max_length=512,
60
+ no_repeat_ngram_size=3,
61
+ do_sample=True,
62
+ top_k=50,
63
+ top_p=0.9,
64
+ temperature = temperature,
65
+ mask_token_id=tokenizer.mask_token_id,
66
+ eos_token_id=tokenizer.eos_token_id,
67
+ unk_token_id=tokenizer.unk_token_id,
68
+ pad_token_id=tokenizer.pad_token_id,
69
+ device='cpu'
70
+ )
71
+
72
+ answer = tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)
73
+
74
+ if (len(answer) > 0 and answer[-1] != ',' and answer[-1] != ':'):
75
+ break
76
+ else:
77
+ if (temperature <= 0.1):
78
+ temperature -= 0.1
79
+
80
+ # Случай когда надо перегенерировать ответ наступил, берем изначальный тензор
81
+ chat_history_ids = chat_history_ids_initial
82
+
83
+ history.append((input_user, answer, chat_history_ids.tolist()))
84
+ html = "<div class='chatbot'>"
85
+ for user_msg, resp_msg, _ in history:
86
+ if user_msg != '-':
87
+ html += f"<div class='user_msg'>{user_msg}</div>"
88
+ if resp_msg != '-':
89
+ html += f"<div class='resp_msg'>{resp_msg}</div>"
90
+ html += "</div>"
91
+ return html, history
92
+
93
+ # Download checkpoint:
94
+
95
+ checkpoint = "avorozhko/ruDialoGpt3-medium-finetuned-context"
96
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
97
+ model = AutoModelForCausalLM.from_pretrained(checkpoint)
98
+ model = model.eval()
99
+
100
+ # Gradio
101
+ title = "Чат-бот для поднятия настроения"
102
+ description = """
103
+ Данный бот постарается поднять вам настроение, так как он знает 26700 анекдотов.
104
+ Но чувство юмора у него весьма специфичное.
105
+ Бот не знает матерных слов и откровенных пошлостей, но кто такой Вовочка и Поручик Ржевский знает )
106
+ """
107
+ article = "<p style='text-align: center'><a href='https://huggingface.co/avorozhko/ruDialoGpt3-medium-finetuned-context'>Бот на основе дообученной GPT-3</a></p>"
108
+
109
+ iface = gr.Interface(fn=chat_function,
110
+ inputs=[gr.inputs.Textbox(lines=3, placeholder="Что вы хотите сказать боту..."), "state"],
111
+ outputs=["html", "state"],
112
+ title=title, description=description, article=article,
113
+ theme='dark-grass',
114
+ css= """
115
+ .chatbox {display:flex;flex-direction:column}
116
+ .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
117
+ .user_msg {background-color:#1e4282;color:white;align-self:start}
118
+ .resp_msg {background-color:#552a2a;align-self:self-end}
119
+ .panels.unaligned {flex-direction: column !important;align-items: initial!important;}
120
+ .panels.unaligned :last-child {order: -1 !important;}
121
+ """,
122
+ allow_screenshot=False,
123
+ allow_flagging='never'
124
+ )
125
+
126
+ if __name__ == "__main__":
127
+ iface.launch(debug=True, share=True)