|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
def load_rag_benchmark_tester_ds(): |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
ds_name = "llmware/rag_instruct_benchmark_tester" |
|
|
|
dataset = load_dataset(ds_name) |
|
|
|
print("update: loading RAG Benchmark test dataset - ", dataset) |
|
|
|
test_set = [] |
|
for i, samples in enumerate(dataset["train"]): |
|
test_set.append(samples) |
|
|
|
|
|
|
|
|
|
return test_set |
|
|
|
|
|
def run_test(model_name, test_ds): |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
print("\nRAG Performance Test - 200 questions") |
|
print("update: model - ", model_name) |
|
print("update: device - ", device) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) |
|
model.to(device) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
for i, entries in enumerate(test_ds): |
|
|
|
|
|
new_prompt = "<human>: " + entries["context"] + "\n" + entries["query"] + "\n" + "<bot>:" |
|
|
|
inputs = tokenizer(new_prompt, return_tensors="pt") |
|
start_of_output = len(inputs.input_ids[0]) |
|
|
|
|
|
|
|
|
|
outputs = model.generate( |
|
inputs.input_ids.to(device), |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.eos_token_id, |
|
do_sample=True, |
|
temperature=0.3, |
|
max_new_tokens=100, |
|
) |
|
|
|
output_only = tokenizer.decode(outputs[0][start_of_output:],skip_special_tokens=True) |
|
|
|
|
|
|
|
eot = output_only.find("<|endoftext|>") |
|
if eot > -1: |
|
output_only = output_only[:eot] |
|
|
|
bot = output_only.find("<bot>:") |
|
if bot > -1: |
|
output_only = output_only[bot+len("<bot>:"):] |
|
|
|
|
|
|
|
print("\n") |
|
print(i, "llm_response - ", output_only) |
|
print(i, "gold_answer - ", entries["answer"]) |
|
|
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_ds = load_rag_benchmark_tester_ds() |
|
|
|
model_name = "llmware/dragon-stablelm-7b-v0" |
|
output = run_test(model_name,test_ds) |
|
|
|
|
|
|