LlemmaFT / mse_ollama_run_ft.py
Zenos5's picture
Upload 24 files
766ea9e verified
import argparse
import jsonlines
import json
# from deepeval.scorer import Scorer
from deepeval.models import OllamaModel
from deepeval.metrics import (
ContextualRelevancyMetric,
ContextualRecallMetric,
ContextualPrecisionMetric,
AnswerRelevancyMetric,
FaithfulnessMetric
)
# import docx
from deepeval.test_case import LLMTestCase
from deepeval.dataset import EvaluationDataset, Golden
from deepeval import evaluate
from deepeval.models import OllamaModel
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from Llemma_Finetuned import Llemma_Finetuned
import ollama
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from deepeval.models import DeepEvalBaseLLM
class CustomLlemma(DeepEvalBaseLLM):
def __init__(self):
self.torch_device = "cuda" if torch.cuda.is_available() else "cpu"
# finetuned
model = AutoModelForCausalLM.from_pretrained("./merged_models/llemma_lora_merged").to(self.torch_device)
tokenizer = AutoTokenizer.from_pretrained("./merged_models/llemma_lora_merged")
self.model = model
self.tokenizer = tokenizer
def load_model(self):
return self.model
def generate(self, prompt: str) -> str:
model = self.load_model()
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=self.tokenizer,
framework="pt",
device=0,
max_length=4096,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id,
)
return pipeline(prompt)
# inputs = self.tokenizer(prompt, return_tensors='pt').to(self.torch_device)
# output = self.model.generate(**inputs)
# a_output = self.tokenizer.decode(output[0])
# return json.dumps(a_output)
async def a_generate(self, prompt: str) -> str:
return self.generate(prompt)
def get_model_name(self):
return "Llemma Fine-tuned"
#ollama run Hudson/llemma:7b
#deepeval set-ollama Hudson/llemma:7b
def is_invalid_length(text, length=4096):
if len(text) <= length:
return False
else:
return True
if __name__=="__main__":
# Initialize parser
parser = argparse.ArgumentParser()
# Adding optional argument
parser.add_argument("-t", "--test", help = "Test to run (ar, cp, crec, crel, f)")
parser.add_argument("-d", "--dataset", help = "Path to test case dataset")
# Read arguments from command line
args = parser.parse_args()
test_type = str(args.test)
test_data = str(args.dataset)
dataset = EvaluationDataset()
# Add as test cases
dataset.add_test_cases_from_json_file(
# file_path is the absolute path to you .json file
file_path=test_data,
input_key_name="input",
actual_output_key_name="actual_output",
expected_output_key_name="expected_output",
context_key_name="context",
retrieval_context_key_name="retrieval_context",
)
# orig
# model = ollama.pull(model="Hudson/llemma:7b")
#OllamaModel(model="Hudson/llemma:7b")
custom_llm = CustomLlemma()
# finetuned
# llemma_model = AutoModelForCausalLM.from_pretrained("./train_llemma/merged_models/llemma_lora_merged")
# tokenizer = AutoTokenizer.from_pretrained("./train_llemma/merged_models/llemma_lora_merged")
# model = Llemma_Finetuned(model=llemma_model, tokenizer=tokenizer)
# sorted_rows = []
# with open('dataset_row_stl.txt', 'r') as file:
# sorted_rows = file.readlines()
# # print(sorted_rows)
# sorted_rows = sorted_rows[0:num_shot]
# sorted_rows = [int(x) for x in sorted_rows]
# print("Read in sorted rows.")
# examples = "Here are " + str(num_shot) + " examples of math questions (Q) with given answers (A).\n"
# with jsonlines.open("mse_text_img_QA_ds_test.jsonl", mode='r') as fp:
# #with open("mse_text_img_QA_ds_test.jsonl", mode='r') as fp:
# n = 0
# for j, data in enumerate(fp):
# if j + 1 in sorted_rows:
# print("Num shot row " + str(j + 1))
# # data = json.loads(line)
# examples += "Q: " + data["body"] + "\n\n"
# is_accepted = False
# best_score = float('-inf')
# output_text = ""
# for i in range(len(data["answers"])):
# if bool(data["answers"][i]["accepted"]) == True:
# if is_accepted == False:
# is_accepted = True
# best_score = int(data["answers"][i]["score"])
# output_text = data["answers"][i]["body"]
# elif int(data["answers"][i]["score"]) > best_score:
# best_score = int(data["answers"][i]["score"])
# output_text = data["answers"][i]["body"]
# elif int(data["answers"][i]["score"]) > best_score:
# best_score = int(data["answers"][i]["score"])
# output_text = data["answers"][i]["body"]
# examples += "A: " + output_text + "\n\n"
# if n == (num_shot - 1):
# examples += "Provide an answer (A) to the following math question (Q) in a similar manner to the previous example(s) given.\n\nQ: "
# # 26th line
# n += 1
# elif n >= num_shot:
# break
# else:
# continue
# print("Generated examples for", str(num_shot), "shot.")
# mse_dataset = []
# with jsonlines.open("mse_text_img_QA_ds_test.jsonl", mode='r') as reader:
# count = 0
# curr_row = 0
# for row in reader.iter(type=dict, skip_invalid=True):
# curr_row += 1
# if curr_row <= skip_to:
# continue
# elif curr_row == 33 or curr_row == 36 or curr_row == 69 \
# or curr_row == 24 or curr_row == 76 \
# or curr_row == 66 or curr_row == 9 \
# or curr_row == 26 or curr_row == 27 \
# or curr_row == 37 or curr_row == 55 \
# or curr_row == 54 or curr_row == 138 \
# or curr_row == 77 or curr_row == 84 or curr_row == 87 \
# or curr_row == 80 or curr_row == 81 or curr_row == 97 \
# or curr_row == 115 or curr_row == 106:
# print("Skipped row " + str(curr_row))
# continue
# elif curr_row in sorted_rows:
# print("Skipped row " + str(curr_row) + " because it is a shorter example")
# continue
# # question_path = "output/" + row["id"]
# # if count ual<= 0:
# # print(obj)
# if count >= test_case_num:
# break
# else:
# input_text = row["body"]
# # response = ollama.generate(model='Hudson/llemma:7b', prompt=input_text)
# # actual_response = response['response']
# is_accepted = False
# best_score = float('-inf')
# output_text = ""
# # context = []
# next_best_answer = ""
# for i in range(len(row["answers"])):
# if bool(row["answers"][i]["accepted"]) == True:
# if is_accepted == False:
# is_accepted = True
# next_best_answer = output_text
# best_score = int(row["answers"][i]["score"])
# output_text = row["answers"][i]["body"]
# elif int(row["answers"][i]["score"]) > best_score:
# next_best_answer = output_text
# best_score = int(row["answers"][i]["score"])
# output_text = row["answers"][i]["body"]
# # else:
# # context.append(row["answers"][i]["body"])
# elif int(row["answers"][i]["score"]) > best_score:
# next_best_answer = output_text
# best_score = int(row["answers"][i]["score"])
# output_text = row["answers"][i]["body"]
# # else:
# # context.append(row["answers"][i]["body"])
# if next_best_answer == "" or next_best_answer is None:
# next_best_answer = row["title"]
# # test_case_dataset.append(LLMTestCase(input=input_text, actual_output=actual_response, expected_output=output_text, retrieval_context=None))
# # test_case_dataset.append(LLMTestCase(input=input_text, actual_output=model.generate(input_text), expected_output=output_text, retrieval_context=context))
# if num_shot == 0:
# i_text = json.dumps(input_text)
# e_output = json.dumps(output_text)
# r_context = json.dumps(next_best_answer)
# gen_answer = ollama.generate(model="Hudson/llemma:7b", prompt=i_text)
# a_output = json.dumps(gen_answer.response)
# # print("i_text = ", i_text)
# # print("a_output = ", a_output)
# # print("e_output = ", e_output)
# # print("r_context = ", r_context)
# # r_context = gen_answer.context
# # if is_invalid_length(i_text) or is_invalid_length(e_output) or is_invalid_length(r_context):
# # continue
# mse_dataset.append(LLMTestCase(input=i_text, actual_output=a_output, expected_output=e_output, retrieval_context=[r_context]))
# else:
# i_text = json.dumps(examples + input_text)
# e_output = json.dumps(output_text)
# r_context = json.dumps(next_best_answer)
# gen_answer = ollama.generate(model="Hudson/llemma:7b", prompt=i_text)
# a_output = json.dumps(gen_answer.response)
# # r_context = gen_answer.context
# # print("i_text = ", i_text)
# # print("a_output = ", a_output)
# # print("e_output = ", e_output)
# # print("r_context = ", r_context)
# # if is_invalid_length(i_text) or is_invalid_length(e_output) or is_invalid_length(r_context):
# # continue
# mse_dataset.append(LLMTestCase(input=i_text, actual_output=a_output, expected_output=e_output, retrieval_context=[r_context]))
# count = count + 1
# # if curr_row % 1 == 0:
# print("At", str(count), "out of", str(test_case_num), " current row =", str(curr_row))
# first_test_case = LLMTestCase(input="...", actual_output="...", context=["..."])
# second_test_case = LLMTestCase(input="...", actual_output="...", context=["..."])
# dataset = EvaluationDataset(test_cases=mse_dataset)
pass_threshold = 0.7
# eval_output = ""
if test_type == "ar":
# answer_relevancy = AnswerRelevancyMetric(model=model, threshold=pass_threshold, async_mode=False)
answer_relevancy = AnswerRelevancyMetric(model=custom_llm, threshold=pass_threshold)
# evaluate(dataset, metrics=[answer_relevancy], out_file=out_path, run_async=True)
# with open(out_path, "a") as f:
# # f.write(dataset.evaluate([answer_relevancy]))
# eval_output = dataset.evaluate([answer_relevancy])
# evaluate(goldens=dataset.goldens, metrics=[answer_relevancy])
evaluate(dataset, metrics=[answer_relevancy])
elif test_type == "cp":
contextual_precision = ContextualPrecisionMetric(model=custom_llm, threshold=pass_threshold)
# evaluate(dataset, metrics=[contextual_precision], out_file=out_path, run_async=True)
# evaluate(dataset, metrics=[contextual_precision])
# eval_output = dataset.evaluate([contextual_precision])
# evaluate(goldens=dataset.goldens, metrics=[contextual_precision])
evaluate(dataset, metrics=[contextual_precision])
elif test_type == "crec":
contextual_recall = ContextualRecallMetric(model=custom_llm, threshold=pass_threshold)
# evaluate(dataset, metrics=[contextual_recall], out_file=out_path, run_async=True)
# evaluate(dataset, metrics=[contextual_recall])
# eval_output = dataset.evaluate([contextual_recall])
# evaluate(goldens=dataset.goldens, metrics=[contextual_recall])
evaluate(dataset, metrics=[contextual_recall])
else:
print("Test case (" + test_type + ") not covered")
# with open(out_path, "a") as f:
# f.write(str(eval_output))
# Create a document
# doc = docx.Document()
# # Add a paragraph to the document
# p = doc.add_paragraph()
# # Add some formatting to the paragraph
# p.paragraph_format.line_spacing = 1
# p.paragraph_format.space_after = 0
# # Add a run to the paragraph
# run = p.add_run(eval_output)
# # Add some formatting to the run
# run.bold = False
# run.italic = False
# run.font.name = 'Arial'
# run.font.size = docx.shared.Pt(12)
# # Save the document
# doc.save(out_path)