codelion commited on
Commit
9e6170e
·
verified ·
1 Parent(s): 74ad440

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -113
app.py CHANGED
@@ -226,67 +226,33 @@ def evaluate_prompt(prompt: str, dataset_name: str, split: str, num_samples: int
226
  max_tokens=500,
227
  )
228
 
229
- prediction = response.choices[0].message.content.strip()
230
-
231
- # Smart evaluation - handle both math and text answers
232
- target_str = str(target).strip()
233
- pred_str = prediction.strip()
234
-
235
- def extract_answer(text):
236
- """Extract answer from text - handles GSM8K format and general text"""
237
- # GSM8K format: "#### NUMBER" at the end
238
- if "####" in text:
239
- parts = text.split("####")
240
- if len(parts) > 1:
241
- answer_part = parts[-1].strip()
242
- # Remove comma separators (1,000 -> 1000)
243
- answer_part = answer_part.replace(',', '')
244
- return answer_part
245
-
246
- # Try to extract last number from free-form text
247
- numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
248
- if numbers:
249
- # Return the last number found (usually the final answer)
250
- return numbers[-1].replace(',', '')
251
-
252
- return text
253
-
254
- def is_mathematically_equal(str1, str2):
255
- """Check if two strings represent the same mathematical value"""
256
- try:
257
- # Try to convert both to floats and compare
258
- num1 = float(str1.replace(',', ''))
259
- num2 = float(str2.replace(',', ''))
260
- # Use small epsilon for float comparison
261
- return abs(num1 - num2) < 1e-6
262
- except (ValueError, AttributeError):
263
- # If conversion fails, do string comparison
264
- return str1.lower().strip() == str2.lower().strip()
265
-
266
- # Extract answers
267
- target_answer = extract_answer(target_str)
268
- pred_answer = extract_answer(pred_str)
269
-
270
- # Check if answers match mathematically or textually
271
- is_correct = is_mathematically_equal(target_answer, pred_answer)
272
-
273
- # Fallback: check for semantic equivalents for sentiment analysis
274
- if not is_correct:
275
- target_lower = target_answer.lower()
276
- pred_lower = pred_answer.lower()
277
-
278
- # Sentiment mappings with expanded synonyms
279
- positive_words = ["positive", "good", "great", "excellent", "wonderful", "fantastic",
280
- "amazing", "love", "best", "1", "pos", "admiration", "appreciation",
281
- "praise", "favorable", "approve"]
282
- negative_words = ["negative", "bad", "poor", "terrible", "awful", "worst", "hate",
283
- "0", "neg", "criticism", "disdain", "disapproval", "unfavorable",
284
- "critique", "condemn", "sarcasm"]
285
-
286
- if target_lower in ["1", "positive", "pos"]:
287
- is_correct = any(word in pred_lower for word in positive_words)
288
- elif target_lower in ["0", "negative", "neg"]:
289
- is_correct = any(word in pred_lower for word in negative_words)
290
 
291
  if is_correct:
292
  correct += 1
@@ -606,64 +572,33 @@ def evaluate(prompt: str) -> dict:
606
  max_tokens=500,
607
  )
608
 
609
- prediction = response.choices[0].message.content.strip()
610
 
611
- # Smart evaluation - handle both math and text answers
612
- target_str = str(target).strip()
613
- pred_str = prediction.strip()
614
 
615
- def extract_answer(text):
616
- """Extract answer from text - handles GSM8K format and general text"""
617
- import re
618
 
619
- # GSM8K format: "#### NUMBER" at the end
620
- if "####" in text:
621
- parts = text.split("####")
622
- if len(parts) > 1:
623
- answer_part = parts[-1].strip()
624
- answer_part = answer_part.replace(',', '')
625
- return answer_part
626
 
627
- # Try to extract last number from free-form text
628
- numbers = re.findall(r'-?\\d+(?:,\\d{{3}})*(?:\\.\\d+)?', text)
629
- if numbers:
630
- return numbers[-1].replace(',', '')
631
 
632
- return text
 
 
 
 
 
 
 
633
 
634
- def is_mathematically_equal(str1, str2):
635
- """Check if two strings represent the same mathematical value"""
636
- try:
637
- num1 = float(str1.replace(',', ''))
638
- num2 = float(str2.replace(',', ''))
639
- return abs(num1 - num2) < 1e-6
640
- except (ValueError, AttributeError):
641
- return str1.lower().strip() == str2.lower().strip()
642
-
643
- # Extract answers
644
- target_answer = extract_answer(target_str)
645
- pred_answer = extract_answer(pred_str)
646
-
647
- # Check if answers match mathematically or textually
648
- is_correct = is_mathematically_equal(target_answer, pred_answer)
649
-
650
- # Fallback: check for semantic equivalents for sentiment analysis
651
- if not is_correct:
652
- target_lower = target_answer.lower()
653
- pred_lower = pred_answer.lower()
654
-
655
- # Sentiment mappings with expanded synonyms
656
- positive_words = ["positive", "good", "great", "excellent", "wonderful", "fantastic",
657
- "amazing", "love", "best", "1", "pos", "admiration", "appreciation",
658
- "praise", "favorable", "approve"]
659
- negative_words = ["negative", "bad", "poor", "terrible", "awful", "worst", "hate",
660
- "0", "neg", "criticism", "disdain", "disapproval", "unfavorable",
661
- "critique", "condemn", "sarcasm"]
662
-
663
- if target_lower in ["1", "positive", "pos"]:
664
- is_correct = any(word in pred_lower for word in positive_words)
665
- elif target_lower in ["0", "negative", "neg"]:
666
- is_correct = any(word in pred_lower for word in negative_words)
667
 
668
  if is_correct:
669
  correct += 1
 
226
  max_tokens=500,
227
  )
228
 
229
+ prediction = response.choices[0].message.content.strip().lower()
230
+
231
+ # IMDB labels: 0 = negative, 1 = positive
232
+ true_label = int(target) # 0 or 1
233
+
234
+ # Check for sentiment classification in first 100 chars (to avoid long explanations)
235
+ pred_start = prediction[:100]
236
+
237
+ # Look for clear positive/negative indicators
238
+ has_positive = ("positive" in pred_start and "sentiment" in pred_start) or \
239
+ ("this is positive" in pred_start) or \
240
+ ("sentiment: positive" in pred_start)
241
+
242
+ has_negative = ("negative" in pred_start and "sentiment" in pred_start) or \
243
+ ("this is negative" in pred_start) or \
244
+ ("sentiment: negative" in pred_start)
245
+
246
+ # Prediction must be unambiguous
247
+ if has_positive and not has_negative:
248
+ predicted_label = 1
249
+ elif has_negative and not has_positive:
250
+ predicted_label = 0
251
+ else:
252
+ # Ambiguous or no clear signal = wrong
253
+ predicted_label = -1
254
+
255
+ is_correct = (predicted_label == true_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  if is_correct:
258
  correct += 1
 
572
  max_tokens=500,
573
  )
574
 
575
+ prediction = response.choices[0].message.content.strip().lower()
576
 
577
+ # IMDB labels: 0 = negative, 1 = positive
578
+ true_label = int(target) # 0 or 1
 
579
 
580
+ # Check for sentiment classification in first 100 chars (to avoid long explanations)
581
+ pred_start = prediction[:100]
 
582
 
583
+ # Look for clear positive/negative indicators
584
+ has_positive = ("positive" in pred_start and "sentiment" in pred_start) or \
585
+ ("this is positive" in pred_start) or \
586
+ ("sentiment: positive" in pred_start)
 
 
 
587
 
588
+ has_negative = ("negative" in pred_start and "sentiment" in pred_start) or \
589
+ ("this is negative" in pred_start) or \
590
+ ("sentiment: negative" in pred_start)
 
591
 
592
+ # Prediction must be unambiguous
593
+ if has_positive and not has_negative:
594
+ predicted_label = 1
595
+ elif has_negative and not has_positive:
596
+ predicted_label = 0
597
+ else:
598
+ # Ambiguous or no clear signal = wrong
599
+ predicted_label = -1
600
 
601
+ is_correct = (predicted_label == true_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
 
603
  if is_correct:
604
  correct += 1