llongpre commited on
Commit
21a1749
1 Parent(s): 8520ac4
Files changed (1) hide show
  1. app.py +3 -51
app.py CHANGED
@@ -1,57 +1,10 @@
1
- # import streamlit as st
2
- # from streamlit_chat import message as st_message
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
  MAX_HISTORY = 7
7
  MODEL_PATH = 'llongpre/DialoGPT-small-miles'
8
 
9
- def get_models():
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
11
- model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
12
- return tokenizer, model
13
-
14
- # if "history" not in st.session_state:
15
- # st.session_state.history = []
16
- #
17
- # if "history_ids" not in st.session_state:
18
- # st.session_state.history_ids = []
19
- #
20
- # st.title("Chat with me")
21
-
22
- # def generate_answer():
23
- # tokenizer, model = get_models()
24
- # user_message = st.session_state.input_text
25
- # new_user_input_ids = tokenizer.encode(st.session_state.input_text + tokenizer.eos_token, return_tensors='pt')
26
- # st.session_state.history_ids.append(new_user_input_ids)
27
- # if len(st.session_state.history_ids) > MAX_HISTORY:
28
- # st.session_state.history_ids = st.session_state.history_ids[-MAX_HISTORY:]
29
- # bot_input_ids = torch.cat(st.session_state.history_ids, dim=-1)
30
- # chat_history_ids = model.generate(
31
- # bot_input_ids,
32
- # pad_token_id=tokenizer.pad_token_id,
33
- # max_length=1000,
34
- # do_sample=True,
35
- # # top_k=150, # sample from the top k words sorted descending by probability
36
- # top_p=0.7, # choose smallest possible words whose cumulative probability exceeds p
37
- # temperature = 0.95, # 0 greedy, inf is random
38
- # no_repeat_ngram_size=3,
39
- # )
40
- # response = chat_history_ids[:, bot_input_ids.shape[-1]:]
41
- # st.session_state.history_ids.append(response)
42
- # output = tokenizer.decode(response[0], skip_special_tokens=True)
43
- #
44
- # st.session_state.history.append({"message": user_message, "is_user": True})
45
- # st.session_state.history.append({"message": output, "is_user": False})
46
-
47
- # st.text_input("Your text message", key="input_text", on_change=generate_answer, placeholder='')
48
-
49
- # for chat in st.session_state.history:
50
- # st_message(**chat) # unpacking
51
-
52
-
53
- import gradio as gr
54
-
55
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
56
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
57
 
@@ -80,8 +33,7 @@ def predict(input, history=[]):
80
 
81
  def generate_answer(input, history=[]):
82
  new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
83
- # st.session_state.history_ids.append(new_user_input_ids)
84
- history = history.append(new_user_input_ids)
85
  if len(history) > MAX_HISTORY:
86
  history = history[-MAX_HISTORY:]
87
  bot_input_ids = torch.cat(history, dim=-1)
@@ -96,8 +48,8 @@ def generate_answer(input, history=[]):
96
  no_repeat_ngram_size=3,
97
  )
98
  response = chat_history_ids[:, bot_input_ids.shape[-1]:]
99
- history.append(response)
100
  output = tokenizer.decode(response[0], skip_special_tokens=True)
 
101
 
102
  return output, history
103
 
 
1
+ import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
  MAX_HISTORY = 7
6
  MODEL_PATH = 'llongpre/DialoGPT-small-miles'
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
9
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
10
 
 
33
 
34
  def generate_answer(input, history=[]):
35
  new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
36
+ history = history.append(input)
 
37
  if len(history) > MAX_HISTORY:
38
  history = history[-MAX_HISTORY:]
39
  bot_input_ids = torch.cat(history, dim=-1)
 
48
  no_repeat_ngram_size=3,
49
  )
50
  response = chat_history_ids[:, bot_input_ids.shape[-1]:]
 
51
  output = tokenizer.decode(response[0], skip_special_tokens=True)
52
+ history.append(output)
53
 
54
  return output, history
55