dh-mc commited on
Commit
5983ad7
1 Parent(s): 0e8d94e

show metrics in graido app

Browse files
Files changed (2) hide show
  1. app.py +23 -4
  2. app_modules/utils.py +17 -4
app.py CHANGED
@@ -9,6 +9,7 @@ from transformers import (
9
  import os
10
  from threading import Thread
11
  import subprocess
 
12
 
13
  from dotenv import find_dotenv, load_dotenv
14
 
@@ -93,10 +94,11 @@ def chat(message, history, temperature, repetition_penalty, do_sample, max_token
93
  if item[1] is not None:
94
  chat.append({"role": "assistant", "content": item[1]})
95
 
 
96
  if [message] in examples:
97
  index = examples.index([message])
98
  message = f"{qa_system_prompt}\n\n{questions[index]['context']}\n\nQuestion: {message}"
99
- print(message)
100
 
101
  chat.append({"role": "user", "content": message})
102
 
@@ -105,6 +107,10 @@ def chat(message, history, temperature, repetition_penalty, do_sample, max_token
105
  streamer = TextIteratorStreamer(
106
  tok, timeout=200.0, skip_prompt=True, skip_special_tokens=True
107
  )
 
 
 
 
108
  generate_kwargs = dict(
109
  model_inputs,
110
  streamer=streamer,
@@ -114,9 +120,6 @@ def chat(message, history, temperature, repetition_penalty, do_sample, max_token
114
  eos_token_id=terminators,
115
  )
116
 
117
- if temperature == 0:
118
- generate_kwargs["do_sample"] = False
119
-
120
  t = Thread(target=model.generate, kwargs=generate_kwargs)
121
  t.start()
122
 
@@ -125,6 +128,22 @@ def chat(message, history, temperature, repetition_penalty, do_sample, max_token
125
  partial_text += new_text
126
  yield partial_text
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  yield partial_text
129
 
130
 
 
9
  import os
10
  from threading import Thread
11
  import subprocess
12
+ from app_modules.utils import calc_bleu_rouge_scores, detect_repetitions
13
 
14
  from dotenv import find_dotenv, load_dotenv
15
 
 
94
  if item[1] is not None:
95
  chat.append({"role": "assistant", "content": item[1]})
96
 
97
+ index = -1
98
  if [message] in examples:
99
  index = examples.index([message])
100
  message = f"{qa_system_prompt}\n\n{questions[index]['context']}\n\nQuestion: {message}"
101
+ print("RAG prompt:", message)
102
 
103
  chat.append({"role": "user", "content": message})
104
 
 
107
  streamer = TextIteratorStreamer(
108
  tok, timeout=200.0, skip_prompt=True, skip_special_tokens=True
109
  )
110
+
111
+ if temperature == 0:
112
+ temperature = 0.01
113
+
114
  generate_kwargs = dict(
115
  model_inputs,
116
  streamer=streamer,
 
120
  eos_token_id=terminators,
121
  )
122
 
 
 
 
123
  t = Thread(target=model.generate, kwargs=generate_kwargs)
124
  t.start()
125
 
 
128
  partial_text += new_text
129
  yield partial_text
130
 
131
+ answer = partial_text
132
+ (newline_score, repetition_score, total_repetitions) = detect_repetitions(answer)
133
+ partial_text += "\n\nRepetition Metrics:\n"
134
+ partial_text += f"1. Newline Score: {newline_score:.3f}\n"
135
+ partial_text += f"1. Repetition Score: {repetition_score:.3f}\n"
136
+ partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
137
+
138
+ if index >= 0: # RAG
139
+ scores = calc_bleu_rouge_scores(
140
+ [answer], [questions[index]["wellFormedAnswers"]], debug=True
141
+ )
142
+
143
+ partial_text += "\n\n Performance Metrics:\n"
144
+ partial_text += f'1. BLEU: {scores["bleu_scores"]["bleu"]:.3f}\n'
145
+ partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n'
146
+
147
  yield partial_text
148
 
149
 
app_modules/utils.py CHANGED
@@ -191,15 +191,28 @@ bleu = evaluate.load("bleu")
191
  rouge = evaluate.load("rouge")
192
 
193
 
194
- def calc_metrics(df):
195
- predictions = [df["answer"][i] for i in range(len(df))]
196
- references = [df["ground_truth"][i] for i in range(len(df))]
 
197
 
198
  bleu_scores = bleu.compute(
199
  predictions=predictions, references=references, max_order=1
200
  )
201
  rouge_scores = rouge.compute(predictions=predictions, references=references)
202
- return {"bleu_scores": bleu_scores, "rouge_scores": rouge_scores}
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
 
205
  pattern_abnormal_newlines = re.compile(r"\n{5,}")
 
191
  rouge = evaluate.load("rouge")
192
 
193
 
194
+ def calc_bleu_rouge_scores(predictions, references, debug=False):
195
+ if debug:
196
+ print("predictions:", predictions)
197
+ print("references:", references)
198
 
199
  bleu_scores = bleu.compute(
200
  predictions=predictions, references=references, max_order=1
201
  )
202
  rouge_scores = rouge.compute(predictions=predictions, references=references)
203
+ result = {"bleu_scores": bleu_scores, "rouge_scores": rouge_scores}
204
+
205
+ if debug:
206
+ print("result:", result)
207
+
208
+ return result
209
+
210
+
211
+ def calc_metrics(df):
212
+ predictions = [df["answer"][i] for i in range(len(df))]
213
+ references = [df["ground_truth"][i] for i in range(len(df))]
214
+
215
+ return calc_bleu_rouge_scores(predictions, references)
216
 
217
 
218
  pattern_abnormal_newlines = re.compile(r"\n{5,}")