import re from test import delete_extra_zero import transformers DEFAULT_PAD_TOKEN = "" DEFAULT_BOS_TOKEN = "" DEFAULT_EOS_TOKEN = "" DEFAULT_UNK_TOKEN = "" import json import tqdm import multiprocessing from functools import partial def load_tokenizer(model_name_or_path): print(f"+ [Model] Initializing Tokenizer: {model_name_or_path}") tokenizer = transformers.AutoTokenizer.from_pretrained( model_name_or_path, padding_side="right", use_fast=False, ) if 'phi' in model_name_or_path: tokenizer.pad_token = tokenizer.unk_token else: if tokenizer.pad_token is None: tokenizer.add_special_tokens({ "eos_token": DEFAULT_EOS_TOKEN, "bos_token": DEFAULT_BOS_TOKEN, "unk_token": DEFAULT_UNK_TOKEN, }) return tokenizer def evaluate_expression(string): correct_count = 0 # Find the expression inside <<>> match = re.search(r'<<(.+?)>>', string) if match: expression = match.group(1) # Separate the expressions before and after the '=' try: before_equal, after_equal = expression.split('=') # Evaluate the expression before the '=' computed_value = float(eval(before_equal.strip())) # Convert the after_equal to an integer for comparison actual_value = float(delete_extra_zero(after_equal.strip().replace(",", ""))) # Compare the computed value with the actual value if abs(computed_value - actual_value) <= 1e-3: correct_count = 1 except Exception as e: print(f"Error evaluating expression: {expression}. Error: {e}") # Calculate accuracy return correct_count def process_line(tokenizer, line): acc = [] line = json.loads(line) for idx in range(len(line['outputs'])): item = line['outputs'][idx] v_scores = item['vscores'] solution_tokens = item['tokens'] if item['label']: split_token_id = 13 split_indices = [0] split_indices.extend([i + 1 for i, token_id in enumerate(solution_tokens) if token_id == split_token_id and solution_tokens[i - 1] != split_token_id]) split_indices.append(len(solution_tokens)) segment_v_scores = [v_scores[split_indices[i]] for i in range(1, len(split_indices))] score_changes = [(segment_v_scores[i] - segment_v_scores[i - 1]) for i in range(1, len(segment_v_scores))] # split_token_id = 13 # split_indices = [0] # split_indices.extend([i + 1 for i, token_id in enumerate(solution_tokens) if token_id == split_token_id and solution_tokens[i - 1] != split_token_id]) # split_indices.append(len(solution_tokens)) # segment_v_scores = [v_scores[split_indices[i]] for i in range(len(split_indices) - 1)] # score_changes = [(segment_v_scores[i] - segment_v_scores[i - 1]) for i in range(1, len(segment_v_scores))] if len(score_changes) and min(score_changes) < 0: max_change_index = score_changes.index(min(score_changes)) + 1 highlighted_solution = [] for i in range(len(split_indices) - 1): segment = solution_tokens[split_indices[i]:split_indices[i + 1]] if i == max_change_index: detect_string = tokenizer.decode(segment[:-1]) highlighted_solution.append("" + detect_string + "\n") matches = re.findall(r'<<([^>>]+)>>', detect_string) if not matches: continue is_false = not evaluate_expression(detect_string) if is_false: acc.append(1) else: acc.append(0) else: highlighted_solution.append(tokenizer.decode(segment)) return acc def process_line2(tokenizer, line): filter_list= [] line = json.loads(line) for idx in range(len(line['outputs'])): item = line['outputs'][idx] v_scores = item['vscores'] solution_tokens = item['tokens'] def load_vescores(): file_path = "eval_results/gsm8k/verifier/test/responses_v(mistral7b-ep2-n100-scahead-mse-lm-token)_g(llama2chatfinetuned).jsonl" model_dir = "/data/OVM-Mistral-7b/mistral7b-ep2-n100-scahead-mse-lm-token" tokenizer = load_tokenizer(model_dir) number = multiprocessing.cpu_count() pool = multiprocessing.Pool(1) acc = [] with open(file_path, 'r', encoding='utf-8') as fp: lines = fp.readlines() with tqdm.tqdm(total=len(lines)) as pbar: # Correct usage of tqdm func = partial(process_line, tokenizer) for result in pool.imap(func, lines): acc.extend(result) pbar.update() print(f"acc : {sum(acc)/len(acc)}") # Ensure that the main module is being run if __name__ == '__main__': pass # load_vescores() # Example usage # strings = [ # "In the fourth game, Clayton scored the average of his points from the first three games. This is 24+14+10 = <<24+14+10=40>>40 points." # "In the fourth game, Clayton scored the average of his points from the first three games. This is 24+14+10 = <<24+14+10=48>>", # "Another example where 2*5 = <<2*5=10>>" # ] # # accuracy = evaluate_expression(strings) # print(f"Accuracy: {accuracy}")