phoenix-byte / src /model.py
akfung's picture
hf formatting and prompting
dff38c9
import os
import requests
import time
# from google.cloud import storage
from sentence_transformers import SentenceTransformer
from .config import max_new_tokens, streaming_url, job_url, default_payload, headers, embedding_path
from .db.db_utilities import query_db
class Model:
'''Client class for holding Llama2 model and tokenizer. Models are loaded according to
ENVIRONMENT environment variable
'''
def __init__(self,
max_new_tokens:int=max_new_tokens):
self.max_new_tokens = max_new_tokens
# self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
self.embedding_model = SentenceTransformer(embedding_path)
def inference(self, query:str, table:str):
'''Inference function for gradio text streaming'''
# set in the case that None
if table == 'Court' or table is None:
table= 'court_opinion'
output= [[None, '']]
for i in self.query_model(query, table):
output[0][1] += i
yield output
def get_context(self, query:str, table:str='court_opinion')-> str:
"""Query vectordb for additional context and compiles a new query string with added context"""
matches = query_db(query, self.embedding_model, table=table)
if len(matches) > 0:
match = '"""' + matches[0][0] + '"""'
context = "You are the United States Supreme Court. Use the following historical opinion to give your ruling on a court case description. Historical opinion: " + match
else:
context = 'You are the United States Supreme Court. Give your ruling on a court case description.'
return context + " Answer in less than 400 words in the format Opinion: <opinion> "
def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str:
"""Query the model api on runpod. Runs for 60s by default. Generator response until job is complete"""
context = self.get_context(query=query, table=table)
for k,v in kwargs:
default_payload['input']['sampling_params'][k] = v
augmented_prompt_template = [
{
"role": "system",
"content": context,
},
{
"role": "user",
"content": query,
}
]
default_payload["input"]["prompt"] = augmented_prompt_template
job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id']
for i in range(timeout):
time.sleep(1)
stream_response = requests.get(streaming_url+ job_id, headers=headers).json()
if stream_response['status'] == 'COMPLETED':
break
for i in stream_response['stream']:
for j in i['output']['text']:
yield j
# def download_checkpoints(self, bucket_name: str = bucket_name):
# """Downloads model files from gcp storage if running in gcp."""
# if not(os.path.exists('model/')):
# os.mkdir('model')
# storage_client = storage.Client()
# bucket = storage_client.bucket(bucket_name)
# # get tokenizer
# blob = bucket.blob(self.tokenizer_path)
# blob.download_to_filename(self.tokenizer_path)
# # get model files to models/
# model_file_paths = [self.model_path + i for i in model_files]
# for object_name in model_file_paths:
# blob = bucket.blob(object_name)
# blob.download_to_filename(object_name)