PAYADOR-experiments / models.py
sgongora27's picture
Update models.py
150ae64 verified
"""Load models to use them as a narrator and a common-sense oracle in the PAYADOR pipeline."""
import google.generativeai as genai
import requests
import replicate
import os
def get_llm(model_name: str = "gemini-1.5-flash") -> object:
google_models = ["gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash"]
replicate_models = ["meta/meta-llama-3-70b", "meta/meta-llama-3-70b-instruct"]
model = None
if model_name in google_models:
model = GeminiModel(API_key="GOOGLE_API_KEY", model_name=model_name)
elif model_name in replicate_models:
model = ReplicateModel(API_key="REPLICATE_API_TOKEN", model_name=model_name)
return model
class ReplicateModel():
def __init__ (self, API_key:str, model_name:str = "meta/meta-llama-3-70b-instruct") -> None:
self.temperature = 0.7
self.model_name = model_name
def prompt_model(self,system_msg: str, user_msg:str) -> str:
"""Prompt the Replicate model."""
system_instructions = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_msg}<|eot_id|><|start_header_id|>user<|end_header_id|>"
input = {
"top_p": 0.1,
"min_tokens": 0,
"temperature": self.temperature,
"prompt": user_msg,
"prompt_template": system_instructions + "\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
}
output = replicate.run(self.model_name,input=input)
return "".join(output)
class GeminiModel():
def __init__ (self, API_key:str, model_name:str = "gemini-1.0-pro") -> None:
""""Initialize the Gemini model using an API key."""
self.safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
]
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
self.model = genai.GenerativeModel(model_name)
def prompt_model(self,system_msg: str, user_msg:str) -> str:
"""Prompt the Gemini model."""
return self.model.generate_content(system_msg + "\n\n" + user_msg, safety_settings=self.safety_settings).text