llongpre commited on
Commit
8520ac4
1 Parent(s): 439de6d
Files changed (2) hide show
  1. app.py +109 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # from streamlit_chat import message as st_message
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+
6
+ MAX_HISTORY = 7
7
+ MODEL_PATH = 'llongpre/DialoGPT-small-miles'
8
+
9
+ def get_models():
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
11
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
12
+ return tokenizer, model
13
+
14
+ # if "history" not in st.session_state:
15
+ # st.session_state.history = []
16
+ #
17
+ # if "history_ids" not in st.session_state:
18
+ # st.session_state.history_ids = []
19
+ #
20
+ # st.title("Chat with me")
21
+
22
+ # def generate_answer():
23
+ # tokenizer, model = get_models()
24
+ # user_message = st.session_state.input_text
25
+ # new_user_input_ids = tokenizer.encode(st.session_state.input_text + tokenizer.eos_token, return_tensors='pt')
26
+ # st.session_state.history_ids.append(new_user_input_ids)
27
+ # if len(st.session_state.history_ids) > MAX_HISTORY:
28
+ # st.session_state.history_ids = st.session_state.history_ids[-MAX_HISTORY:]
29
+ # bot_input_ids = torch.cat(st.session_state.history_ids, dim=-1)
30
+ # chat_history_ids = model.generate(
31
+ # bot_input_ids,
32
+ # pad_token_id=tokenizer.pad_token_id,
33
+ # max_length=1000,
34
+ # do_sample=True,
35
+ # # top_k=150, # sample from the top k words sorted descending by probability
36
+ # top_p=0.7, # choose smallest possible words whose cumulative probability exceeds p
37
+ # temperature = 0.95, # 0 greedy, inf is random
38
+ # no_repeat_ngram_size=3,
39
+ # )
40
+ # response = chat_history_ids[:, bot_input_ids.shape[-1]:]
41
+ # st.session_state.history_ids.append(response)
42
+ # output = tokenizer.decode(response[0], skip_special_tokens=True)
43
+ #
44
+ # st.session_state.history.append({"message": user_message, "is_user": True})
45
+ # st.session_state.history.append({"message": output, "is_user": False})
46
+
47
+ # st.text_input("Your text message", key="input_text", on_change=generate_answer, placeholder='')
48
+
49
+ # for chat in st.session_state.history:
50
+ # st_message(**chat) # unpacking
51
+
52
+
53
+ import gradio as gr
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
56
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
57
+
58
+ def predict(input, history=[]):
59
+ # tokenize the new input sentence
60
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
61
+
62
+ # append the new user input tokens to the chat history
63
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
64
+
65
+ # generate a response
66
+ history = model.generate(
67
+ bot_input_ids,
68
+ max_length=1000,
69
+ pad_token_id=tokenizer.eos_token_id,
70
+ no_repeat_ngram_size=3,
71
+ top_p = 0.92,
72
+ top_k = 50
73
+ ).tolist()
74
+
75
+ # convert the tokens to text, and then split the responses into lines
76
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
77
+ response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list
78
+
79
+ return response, history
80
+
81
+ def generate_answer(input, history=[]):
82
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
83
+ # st.session_state.history_ids.append(new_user_input_ids)
84
+ history = history.append(new_user_input_ids)
85
+ if len(history) > MAX_HISTORY:
86
+ history = history[-MAX_HISTORY:]
87
+ bot_input_ids = torch.cat(history, dim=-1)
88
+ chat_history_ids = model.generate(
89
+ bot_input_ids,
90
+ pad_token_id=tokenizer.pad_token_id,
91
+ max_length=1000,
92
+ do_sample=True,
93
+ # top_k=150, # sample from the top k words sorted descending by probability
94
+ top_p=0.7, # choose smallest possible words whose cumulative probability exceeds p
95
+ temperature = 0.95, # 0 greedy, inf is random
96
+ no_repeat_ngram_size=3,
97
+ )
98
+ response = chat_history_ids[:, bot_input_ids.shape[-1]:]
99
+ history.append(response)
100
+ output = tokenizer.decode(response[0], skip_special_tokens=True)
101
+
102
+ return output, history
103
+
104
+
105
+ gr.Interface(fn=generate_answer,
106
+ title="DialoGPT-large",
107
+ inputs=["text", "state"],
108
+ outputs=["chatbot", "state"],
109
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ torch