llongpre commited on
Commit
1b29d23
1 Parent(s): 99b0010
Files changed (1) hide show
  1. app.py +28 -22
app.py CHANGED
@@ -8,33 +8,39 @@ MODEL_PATH = 'llongpre/DialoGPT-small-miles'
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
9
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
10
 
11
- def predict(input, history=[]):
12
- # tokenize the new input sentence
13
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
14
-
15
- # append the new user input tokens to the chat history
16
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
17
-
18
- # generate a response
19
- history = model.generate(
20
- bot_input_ids,
21
- max_length=1000,
22
- pad_token_id=tokenizer.eos_token_id,
23
- no_repeat_ngram_size=3,
24
- top_p = 0.92,
25
- top_k = 50
26
- ).tolist()
27
-
28
- # convert the tokens to text, and then split the responses into lines
29
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
30
- response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list
 
 
 
 
31
 
32
- return response, history
 
 
33
 
34
  def generate_answer(input, history=[]):
35
  new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
36
  history = history.append(input)
37
- print(history)
38
  if len(history) > MAX_HISTORY:
39
  history = history[-MAX_HISTORY:]
40
  bot_input_ids = torch.cat(history, dim=-1)
 
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
9
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
10
 
11
+ # def predict(input, history=[]):
12
+ # # tokenize the new input sentence
13
+ # new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
14
+ #
15
+ # # append the new user input tokens to the chat history
16
+ # bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
17
+ #
18
+ # # generate a response
19
+ # history = model.generate(
20
+ # bot_input_ids,
21
+ # max_length=1000,
22
+ # pad_token_id=tokenizer.eos_token_id,
23
+ # no_repeat_ngram_size=3,
24
+ # top_p = 0.92,
25
+ # top_k = 50
26
+ # ).tolist()
27
+ #
28
+ # # convert the tokens to text, and then split the responses into lines
29
+ # response = tokenizer.decode(history[0]).split("<|endoftext|>")
30
+ # response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list
31
+ #
32
+ # return response, history
33
+ #
34
+ # from transformers.utils import logging
35
 
36
+ logging.set_verbosity_info()
37
+ logger = logging.get_logger("transformers")
38
+ logger.info("INFO")
39
 
40
  def generate_answer(input, history=[]):
41
  new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
42
  history = history.append(input)
43
+ logger.info(history)
44
  if len(history) > MAX_HISTORY:
45
  history = history[-MAX_HISTORY:]
46
  bot_input_ids = torch.cat(history, dim=-1)