dh-mc commited on
Commit
79eed96
1 Parent(s): 5983ad7

fixed bug for Mistral RAG/chat template

Browse files
Files changed (2) hide show
  1. app.py +14 -3
  2. app_modules/llm_inference.py +1 -1
app.py CHANGED
@@ -86,7 +86,14 @@ else:
86
  model = model.to(device)
87
 
88
 
89
- def chat(message, history, temperature, repetition_penalty, do_sample, max_tokens):
 
 
 
 
 
 
 
90
  print("repetition_penalty:", repetition_penalty)
91
  chat = []
92
  for item in history:
@@ -136,9 +143,12 @@ def chat(message, history, temperature, repetition_penalty, do_sample, max_token
136
  partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
137
 
138
  if index >= 0: # RAG
139
- scores = calc_bleu_rouge_scores(
140
- [answer], [questions[index]["wellFormedAnswers"]], debug=True
 
 
141
  )
 
142
 
143
  partial_text += "\n\n Performance Metrics:\n"
144
  partial_text += f'1. BLEU: {scores["bleu_scores"]["bleu"]:.3f}\n'
@@ -150,6 +160,7 @@ def chat(message, history, temperature, repetition_penalty, do_sample, max_token
150
  demo = gr.ChatInterface(
151
  fn=chat,
152
  examples=examples,
 
153
  additional_inputs_accordion=gr.Accordion(
154
  label="⚙️ Parameters", open=False, render=False
155
  ),
 
86
  model = model.to(device)
87
 
88
 
89
+ def chat(
90
+ message,
91
+ history,
92
+ temperature=0,
93
+ repetition_penalty=1.1,
94
+ do_sample=True,
95
+ max_tokens=1024,
96
+ ):
97
  print("repetition_penalty:", repetition_penalty)
98
  chat = []
99
  for item in history:
 
143
  partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
144
 
145
  if index >= 0: # RAG
146
+ key = (
147
+ "wellFormedAnswers"
148
+ if "wellFormedAnswers" in questions[index]
149
+ else "answers"
150
  )
151
+ scores = calc_bleu_rouge_scores([answer], [questions[index][key]], debug=True)
152
 
153
  partial_text += "\n\n Performance Metrics:\n"
154
  partial_text += f'1. BLEU: {scores["bleu_scores"]["bleu"]:.3f}\n'
 
160
  demo = gr.ChatInterface(
161
  fn=chat,
162
  examples=examples,
163
+ cache_examples=False,
164
  additional_inputs_accordion=gr.Accordion(
165
  label="⚙️ Parameters", open=False, render=False
166
  ),
app_modules/llm_inference.py CHANGED
@@ -166,7 +166,7 @@ class LLMInference(metaclass=abc.ABCMeta):
166
  def apply_chat_template(self, user_message):
167
  result = (
168
  []
169
- if self.llm_loader.model_name.lower().startswith("gemma")
170
  else [
171
  {
172
  "role": "system",
 
166
  def apply_chat_template(self, user_message):
167
  result = (
168
  []
169
+ if re.search(r"gemma|mistral", self.llm_loader.model_name, re.IGNORECASE)
170
  else [
171
  {
172
  "role": "system",