Ashishkr commited on
Commit
bde65d6
1 Parent(s): 91e30ca

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +26 -12
model.py CHANGED
@@ -42,19 +42,33 @@ tokenizer = transformers.AutoTokenizer.from_pretrained(
42
  )
43
 
44
 
45
- def get_prompt(message: str, chat_history: list[tuple[str, str]],
46
- system_prompt: str) -> str:
47
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
48
- # The first user input is _not_ stripped
49
- do_strip = False
50
- for user_input, response in chat_history:
51
- user_input = user_input.strip() if do_strip else user_input
52
- do_strip = True
53
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
54
- message = message.strip() if do_strip else message
55
- texts.append(f'{message} [/INST]')
56
- return ''.join(texts)
 
 
 
 
 
57
 
 
 
 
 
 
 
 
 
 
58
 
59
  def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
60
  prompt = get_prompt(message, chat_history, system_prompt)
 
42
  )
43
 
44
 
45
+ # def get_prompt(message: str, chat_history: list[tuple[str, str]],
46
+ # system_prompt: str) -> str:
47
+ # texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
48
+ # # The first user input is _not_ stripped
49
+ # do_strip = False
50
+ # for user_input, response in chat_history:
51
+ # user_input = user_input.strip() if do_strip else user_input
52
+ # do_strip = True
53
+ # texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
54
+ # message = message.strip() if do_strip else message
55
+ # texts.append(f'{message} [/INST]')
56
+ # return ''.join(texts)
57
+
58
+
59
+
60
+ def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
61
+ texts = [f'{system_prompt}\n']
62
 
63
+ for user_input, response in chat_history[:-1]:
64
+ texts.append(f'{user_input} {response}\n')
65
+
66
+ # Getting the user input and response from the last tuple in the chat history
67
+ last_user_input, last_response = chat_history[-1]
68
+ texts.append(f' input: {last_user_input} {last_response} {message} response: ')
69
+
70
+ return ''.join(texts)
71
+
72
 
73
  def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
74
  prompt = get_prompt(message, chat_history, system_prompt)