Spaces:
Sleeping
Sleeping
import os | |
import requests | |
import time | |
# from google.cloud import storage | |
from sentence_transformers import SentenceTransformer | |
from .config import max_new_tokens, streaming_url, job_url, default_payload, headers, embedding_path | |
from .db.db_utilities import query_db | |
class Model: | |
'''Client class for holding Llama2 model and tokenizer. Models are loaded according to | |
ENVIRONMENT environment variable | |
''' | |
def __init__(self, | |
max_new_tokens:int=max_new_tokens): | |
self.max_new_tokens = max_new_tokens | |
# self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1') | |
self.embedding_model = SentenceTransformer(embedding_path) | |
def inference(self, query:str, table:str): | |
'''Inference function for gradio text streaming''' | |
# set in the case that None | |
if table == 'Court' or table is None: | |
table= 'court_opinion' | |
output= [[None, '']] | |
for i in self.query_model(query, table): | |
output[0][1] += i | |
yield output | |
def get_context(self, query:str, table:str='court_opinion')-> str: | |
"""Query vectordb for additional context and compiles a new query string with added context""" | |
matches = query_db(query, self.embedding_model, table=table) | |
if len(matches) > 0: | |
match = '"""' + matches[0][0] + '"""' | |
context = "You are the United States Supreme Court. Use the following historical opinion to give your ruling on a court case description. Historical opinion: " + match | |
else: | |
context = 'You are the United States Supreme Court. Give your ruling on a court case description.' | |
return context + " Answer in less than 400 words in the format Opinion: <opinion> " | |
def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str: | |
"""Query the model api on runpod. Runs for 60s by default. Generator response until job is complete""" | |
context = self.get_context(query=query, table=table) | |
for k,v in kwargs: | |
default_payload['input']['sampling_params'][k] = v | |
augmented_prompt_template = [ | |
{ | |
"role": "system", | |
"content": context, | |
}, | |
{ | |
"role": "user", | |
"content": query, | |
} | |
] | |
default_payload["input"]["prompt"] = augmented_prompt_template | |
job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id'] | |
for i in range(timeout): | |
time.sleep(1) | |
stream_response = requests.get(streaming_url+ job_id, headers=headers).json() | |
if stream_response['status'] == 'COMPLETED': | |
break | |
for i in stream_response['stream']: | |
for j in i['output']['text']: | |
yield j | |
# def download_checkpoints(self, bucket_name: str = bucket_name): | |
# """Downloads model files from gcp storage if running in gcp.""" | |
# if not(os.path.exists('model/')): | |
# os.mkdir('model') | |
# storage_client = storage.Client() | |
# bucket = storage_client.bucket(bucket_name) | |
# # get tokenizer | |
# blob = bucket.blob(self.tokenizer_path) | |
# blob.download_to_filename(self.tokenizer_path) | |
# # get model files to models/ | |
# model_file_paths = [self.model_path + i for i in model_files] | |
# for object_name in model_file_paths: | |
# blob = bucket.blob(object_name) | |
# blob.download_to_filename(object_name) | |