5to9 commited on
Commit
c6e8ef5
1 Parent(s): 7977c5d

0.6 defining chat template for pharia

Browse files
Files changed (1) hide show
  1. app.py +53 -16
app.py CHANGED
@@ -25,8 +25,34 @@ tokenizer_b, model_b = None, None
25
  torch_dtype = torch.bfloat16
26
  attn_implementation = "flash_attention_2"
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def load_model_a(model_id):
29
- global tokenizer_a, model_a
 
30
  tokenizer_a = AutoTokenizer.from_pretrained(model_id)
31
  logging.debug(f"model A: {tokenizer_a.eos_token}")
32
  try:
@@ -50,7 +76,8 @@ def load_model_a(model_id):
50
  return gr.update(label=model_id)
51
 
52
  def load_model_b(model_id):
53
- global tokenizer_b, model_b
 
54
  tokenizer_b = AutoTokenizer.from_pretrained(model_id)
55
  logging.debug(f"model B: {tokenizer_b.eos_token}")
56
  try:
@@ -92,20 +119,30 @@ def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_token
92
  chat_history_b.append({"role": "user", "content": user})
93
  chat_history_b.append({"role": "assistant", "content": assistant})
94
 
95
- base_messages = system_prompt_list + chat_history_a + input_text_list
96
- new_messages = system_prompt_list + chat_history_b + input_text_list
97
-
98
- input_ids_a = tokenizer_a.apply_chat_template(
99
- base_messages,
100
- add_generation_prompt=True,
101
- return_tensors="pt"
102
- ).to(model_a.device)
103
-
104
- input_ids_b = tokenizer_b.apply_chat_template(
105
- new_messages,
106
- add_generation_prompt=True,
107
- return_tensors="pt"
108
- ).to(model_b.device)
 
 
 
 
 
 
 
 
 
 
109
 
110
  generation_kwargs_a = dict(
111
  input_ids=input_ids_a,
 
25
  torch_dtype = torch.bfloat16
26
  attn_implementation = "flash_attention_2"
27
 
28
+ def apply_chat_template(messages, add_generation_prompt=False):
29
+ """
30
+ Function to apply the chat template manually for each message in a list.
31
+ messages: List of dictionaries, each containing a 'role' and 'content'.
32
+ """
33
+ pharia_template = """<|begin_of_text|>"""
34
+ role_map = {
35
+ "system": "<|start_header_id|>system<|end_header_id|>\n",
36
+ "user": "<|start_header_id|>user<|end_header_id|>\n",
37
+ "assistant": "<|start_header_id|>assistant<|end_header_id|>\n",
38
+ }
39
+
40
+ # Iterate through the messages and apply the template for each role
41
+ for message in messages:
42
+ role = message["role"]
43
+ content = message["content"]
44
+ pharia_template += role_map.get(role, "") + content + "<|eot_id|>\n"
45
+
46
+ # Add the assistant generation prompt if required
47
+ if add_generation_prompt:
48
+ pharia_template += "<|start_header_id|>assistant<|end_header_id|>\n"
49
+
50
+ return pharia_template
51
+
52
+
53
  def load_model_a(model_id):
54
+ global tokenizer_a, model_a, model_id_a
55
+ model_id_a = model_id # need to access model_id with tokenizer
56
  tokenizer_a = AutoTokenizer.from_pretrained(model_id)
57
  logging.debug(f"model A: {tokenizer_a.eos_token}")
58
  try:
 
76
  return gr.update(label=model_id)
77
 
78
  def load_model_b(model_id):
79
+ global tokenizer_b, model_b, model_id_b
80
+ model_id_b = model_id
81
  tokenizer_b = AutoTokenizer.from_pretrained(model_id)
82
  logging.debug(f"model B: {tokenizer_b.eos_token}")
83
  try:
 
119
  chat_history_b.append({"role": "user", "content": user})
120
  chat_history_b.append({"role": "assistant", "content": assistant})
121
 
122
+ new_messages_a = system_prompt_list + chat_history_a + input_text_list
123
+ new_messages_b = system_prompt_list + chat_history_b + input_text_list
124
+
125
+ if "pharia" in model_id_a:
126
+ logging.debug("model a is pharia based, applying own template")
127
+ formatted_message_a = apply_chat_template(new_messages_a, add_generation_prompt=True)
128
+ input_ids_a = tokenizer_b(formatted_message_a, return_tensors="pt").input_ids.to(model_a.device)
129
+ else:
130
+ input_ids_a = tokenizer_a.apply_chat_template(
131
+ new_messages_a,
132
+ add_generation_prompt=True,
133
+ return_tensors="pt"
134
+ ).to(model_a.device)
135
+
136
+ if "pharia" in model_id_b:
137
+ logging.debug("model b is pharia based, applying own template")
138
+ formatted_message_b = apply_chat_template(new_messages_a, add_generation_prompt=True)
139
+ input_ids_b = tokenizer_b(formatted_message_b, return_tensors="pt").input_ids.to(model_a.device)
140
+ else:
141
+ input_ids_b = tokenizer_b.apply_chat_template(
142
+ new_messages_b,
143
+ add_generation_prompt=True,
144
+ return_tensors="pt"
145
+ ).to(model_b.device)
146
 
147
  generation_kwargs_a = dict(
148
  input_ids=input_ids_a,