Елена Fomina commited on
Commit
a602c25
1 Parent(s): dbd9204

Add application file

Browse files
Files changed (1) hide show
  1. app.py +58 -4
app.py CHANGED
@@ -1,7 +1,61 @@
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
  import gradio as gr
3
+ import torch
4
 
5
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
6
+ tokenizer = GPT2Tokenizer.from_pretrained('microsoft/DialoGPT-small')
7
+ model = GPT2LMHeadModel.from_pretrained('microsoft/DialoGPT-small')
8
+ model.eval()
9
 
10
+ def chat(message, history):
11
+ history = history or []
12
+ new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
13
+
14
+ if len(history) > 0 and len(history) < 2:
15
+ for i in range(0,len(history)):
16
+ encoded_message = tokenizer.encode(history[i][0] + tokenizer.eos_token, return_tensors='pt')
17
+ encoded_response = tokenizer.encode(history[i][1] + tokenizer.eos_token, return_tensors='pt')
18
+ if i == 0:
19
+ chat_history_ids = encoded_message
20
+ chat_history_ids = torch.cat([chat_history_ids,encoded_response], dim=-1)
21
+ else:
22
+ chat_history_ids = torch.cat([chat_history_ids,encoded_message], dim=-1)
23
+ chat_history_ids = torch.cat([chat_history_ids,encoded_response], dim=-1)
24
+
25
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
26
+
27
+ elif len(history) >= 2:
28
+ for i in range(len(history)-1, len(history)):
29
+ encoded_message = tokenizer.encode(history[i][0] + tokenizer.eos_token, return_tensors='pt')
30
+ encoded_response = tokenizer.encode(history[i][1] + tokenizer.eos_token, return_tensors='pt')
31
+ if i == (len(history)-1):
32
+ chat_history_ids = encoded_message
33
+ chat_history_ids = torch.cat([chat_history_ids,encoded_response], dim=-1)
34
+ else:
35
+ chat_history_ids = torch.cat([chat_history_ids,encoded_message], dim=-1)
36
+ chat_history_ids = torch.cat([chat_history_ids,encoded_response], dim=-1)
37
+
38
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
39
+
40
+ elif len(history) == 0:
41
+ bot_input_ids = new_user_input_ids
42
+
43
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, do_sample=True, top_p=0.9, temperature=0.8, pad_token_id=tokenizer.eos_token_id)
44
+ response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
45
+
46
+ history.append((message, response))
47
+
48
+ return history, history
49
+
50
+ title = "DialoGPT-small"
51
+ description = "Gradio demo for dialog using DialoGPT"
52
+ iface = gr.Interface(
53
+ chat,
54
+ ["text", "state"],
55
+ ["chatbot", "state"],
56
+ allow_screenshot=False,
57
+ allow_flagging="never",
58
+ title=title,
59
+ description=description
60
+ )
61
+ iface.launch(debug=True)