import os import requests import time # from google.cloud import storage from sentence_transformers import SentenceTransformer from .config import max_new_tokens, streaming_url, job_url, default_payload, headers, embedding_path from .db.db_utilities import query_db class Model: '''Client class for holding Llama2 model and tokenizer. Models are loaded according to ENVIRONMENT environment variable ''' def __init__(self, max_new_tokens:int=max_new_tokens): self.max_new_tokens = max_new_tokens # self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1') self.embedding_model = SentenceTransformer(embedding_path) def inference(self, query:str, table:str): '''Inference function for gradio text streaming''' # set in the case that None if table == 'Court' or table is None: table= 'court_opinion' output= [[None, '']] for i in self.query_model(query, table): output[0][1] += i yield output def get_context(self, query:str, table:str='court_opinion')-> str: """Query vectordb for additional context and compiles a new query string with added context""" matches = query_db(query, self.embedding_model, table=table) if len(matches) > 0: match = '"""' + matches[0][0] + '"""' context = "You are the United States Supreme Court. Use the following historical opinion to give your ruling on a court case description. Historical opinion: " + match else: context = 'You are the United States Supreme Court. Give your ruling on a court case description.' return context + " Answer in less than 400 words in the format Opinion: " def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str: """Query the model api on runpod. Runs for 60s by default. Generator response until job is complete""" context = self.get_context(query=query, table=table) for k,v in kwargs: default_payload['input']['sampling_params'][k] = v augmented_prompt_template = [ { "role": "system", "content": context, }, { "role": "user", "content": query, } ] default_payload["input"]["prompt"] = augmented_prompt_template job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id'] for i in range(timeout): time.sleep(1) stream_response = requests.get(streaming_url+ job_id, headers=headers).json() if stream_response['status'] == 'COMPLETED': break for i in stream_response['stream']: for j in i['output']['text']: yield j # def download_checkpoints(self, bucket_name: str = bucket_name): # """Downloads model files from gcp storage if running in gcp.""" # if not(os.path.exists('model/')): # os.mkdir('model') # storage_client = storage.Client() # bucket = storage_client.bucket(bucket_name) # # get tokenizer # blob = bucket.blob(self.tokenizer_path) # blob.download_to_filename(self.tokenizer_path) # # get model files to models/ # model_file_paths = [self.model_path + i for i in model_files] # for object_name in model_file_paths: # blob = bucket.blob(object_name) # blob.download_to_filename(object_name)