slim-sql-1b-v0 / generation_test_sql_slim_hf.py
doberst113080's picture
Upload 2 files
9becace
import time
import os
import json
from werkzeug.utils import secure_filename
import re
import ast
import sqlite3
import random
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmware.models import ModelCatalog
from llmware.prompts import Prompt
def model_test_run_general():
t0 = time.time()
model_name = "llmware/slim-sql-1b-v0"
print("update: model_name - ", model_name)
custom_hf_model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True)
hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
# now, we have 'imported' our own custom 'instruct' model into llmware
model = ModelCatalog().load_hf_generative_model(custom_hf_model, hf_tokenizer, instruction_following=False,
prompt_wrapper="human_bot")
model.temperature = 0.3
# run direct inference on model
print("\nupdate: Starting Generative Instruct Custom Fine-tuned Test")
t1 = time.time()
print("update: time loading model - ", t1 - t0)
fp = ""
fn = "sql_test_100_simple_s.jsonl"
opened_file = open(os.path.join(fp, fn), "r")
prompt_list = []
for i, rows in enumerate(opened_file):
# print("update: ", i, rows)
rows = json.loads(rows)
new_entry = {"question": rows["question"],
"answer": rows["answer"],
"context": rows["context"]}
prompt_list.append(new_entry)
random.shuffle(prompt_list)
total_response_output = []
perfect_match = 0
for i, entries in enumerate(prompt_list):
prompt = entries["question"]
context = re.sub("[\n\r]","", entries["context"])
context = re.sub("\s+", " ", context)
context = re.sub("\"", "", context)
answer = ""
if "answer" in entries:
answer = entries["answer"]
output = model.inference(prompt, add_context=context, add_prompt_engineering=True)
print("\nupdate: model question - ", prompt)
llm_response = re.sub("['\"]", "", output["llm_response"])
answer = re.sub("['\"]", "", answer)
print("update: model response - ", i, llm_response)
print("update: model gold answer - ", answer)
if llm_response.strip().lower() == answer.strip().lower():
perfect_match += 1
print("update: 100% MATCH")
print("update: perfect match accuracy - ", perfect_match / (i+1))
core_output = {"number": i,
"llm_response": output["llm_response"],
"gold_answer": answer,
"prompt": prompt,
"usage": output["usage"]}
total_response_output.append(core_output)
t2 = time.time()
print("update: total processing time: ", t2-t1)
return total_response_output
output = model_test_run_general()