gorkemgoknar commited on
Commit
9136a7f
1 Parent(s): 9aa52b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -23,19 +23,23 @@ SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"]
23
 
24
  #See document for experiment https://www.linkedin.com/pulse/ai-goes-job-interview-g%C3%B6rkem-g%C3%B6knar/
25
 
26
-
27
  def get_chat_response(name,history=[], input_txt = "Hello , what is your name?"):
28
 
29
  ai_history = history.copy()
30
 
31
- ai_history.append(input_txt)
32
  ai_history_e = [tokenizer.encode(e) for e in ai_history]
33
 
34
  personality = "My name is " + name
35
 
36
  bos, eos, speaker1, speaker2 = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1])
37
- sequence = [[bos] + tokenizer.encode(personality)] + ai_history_e
 
 
 
 
38
  sequence = [sequence[0]] + [[speaker2 if (len(sequence)-i) % 2 else speaker1] + s for i, s in enumerate(sequence[1:])]
 
39
  sequence = list(chain(*sequence))
40
 
41
  #bot_input_ids = tokenizer.encode(personality + tokenizer.eos_token + input_txt + tokenizer.eos_token , return_tensors='pt')
 
23
 
24
  #See document for experiment https://www.linkedin.com/pulse/ai-goes-job-interview-g%C3%B6rkem-g%C3%B6knar/
25
 
 
26
  def get_chat_response(name,history=[], input_txt = "Hello , what is your name?"):
27
 
28
  ai_history = history.copy()
29
 
30
+ #ai_history.append(input_txt)
31
  ai_history_e = [tokenizer.encode(e) for e in ai_history]
32
 
33
  personality = "My name is " + name
34
 
35
  bos, eos, speaker1, speaker2 = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1])
36
+
37
+ #persona first, history next, input text must be at the end
38
+ #[[bos, persona] , [history] , [input]]
39
+ sequence = [[bos] + tokenizer.encode(personality)] + ai_history_e + [tokenizer.encode(input_txt)]
40
+ ##[[bos, persona] , [speaker1 .., speakser2 .., speaker1 ... speaker2 ... , [input]]
41
  sequence = [sequence[0]] + [[speaker2 if (len(sequence)-i) % 2 else speaker1] + s for i, s in enumerate(sequence[1:])]
42
+
43
  sequence = list(chain(*sequence))
44
 
45
  #bot_input_ids = tokenizer.encode(personality + tokenizer.eos_token + input_txt + tokenizer.eos_token , return_tensors='pt')