|
|
|
import time |
|
import os |
|
import json |
|
from werkzeug.utils import secure_filename |
|
import re |
|
import ast |
|
import sqlite3 |
|
import random |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
from llmware.models import ModelCatalog |
|
from llmware.prompts import Prompt |
|
|
|
def model_test_run_general(): |
|
|
|
t0 = time.time() |
|
|
|
model_name = "llmware/slim-sql-1b-v0" |
|
|
|
print("update: model_name - ", model_name) |
|
|
|
custom_hf_model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True) |
|
|
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
model = ModelCatalog().load_hf_generative_model(custom_hf_model, hf_tokenizer, instruction_following=False, |
|
prompt_wrapper="human_bot") |
|
|
|
model.temperature = 0.3 |
|
|
|
print("\nupdate: Starting Generative Instruct Custom Fine-tuned Test") |
|
|
|
t1 = time.time() |
|
|
|
print("update: time loading model - ", t1 - t0) |
|
|
|
fp = "" |
|
fn = "sql_test_100_simple_s.jsonl" |
|
|
|
opened_file = open(os.path.join(fp, fn), "r") |
|
|
|
prompt_list = [] |
|
|
|
for i, rows in enumerate(opened_file): |
|
|
|
rows = json.loads(rows) |
|
new_entry = {"question": rows["question"], |
|
"answer": rows["answer"], |
|
"context": rows["context"]} |
|
|
|
prompt_list.append(new_entry) |
|
|
|
random.shuffle(prompt_list) |
|
|
|
total_response_output = [] |
|
perfect_match = 0 |
|
|
|
for i, entries in enumerate(prompt_list): |
|
prompt = entries["question"] |
|
context = re.sub("[\n\r]","", entries["context"]) |
|
context = re.sub("\s+", " ", context) |
|
context = re.sub("\"", "", context) |
|
|
|
answer = "" |
|
|
|
if "answer" in entries: |
|
answer = entries["answer"] |
|
|
|
output = model.inference(prompt, add_context=context, add_prompt_engineering=True) |
|
|
|
print("\nupdate: model question - ", prompt) |
|
|
|
llm_response = re.sub("['\"]", "", output["llm_response"]) |
|
answer = re.sub("['\"]", "", answer) |
|
|
|
print("update: model response - ", i, llm_response) |
|
print("update: model gold answer - ", answer) |
|
|
|
if llm_response.strip().lower() == answer.strip().lower(): |
|
perfect_match += 1 |
|
print("update: 100% MATCH") |
|
|
|
print("update: perfect match accuracy - ", perfect_match / (i+1)) |
|
|
|
core_output = {"number": i, |
|
"llm_response": output["llm_response"], |
|
"gold_answer": answer, |
|
"prompt": prompt, |
|
"usage": output["usage"]} |
|
|
|
total_response_output.append(core_output) |
|
|
|
t2 = time.time() |
|
|
|
print("update: total processing time: ", t2-t1) |
|
|
|
return total_response_output |
|
|
|
output = model_test_run_general() |
|
|