|
"""Load models to use them as a narrator and a common-sense oracle in the PAYADOR pipeline.""" |
|
import google.generativeai as genai |
|
import requests |
|
import replicate |
|
import os |
|
|
|
def get_llm(model_name: str = "gemini-1.5-flash") -> object: |
|
|
|
google_models = ["gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash"] |
|
replicate_models = ["meta/meta-llama-3-70b", "meta/meta-llama-3-70b-instruct"] |
|
|
|
model = None |
|
|
|
if model_name in google_models: |
|
model = GeminiModel(API_key="GOOGLE_API_KEY", model_name=model_name) |
|
elif model_name in replicate_models: |
|
model = ReplicateModel(API_key="REPLICATE_API_TOKEN", model_name=model_name) |
|
|
|
return model |
|
|
|
|
|
class ReplicateModel(): |
|
def __init__ (self, API_key:str, model_name:str = "meta/meta-llama-3-70b-instruct") -> None: |
|
self.temperature = 0.7 |
|
self.model_name = model_name |
|
|
|
def prompt_model(self,system_msg: str, user_msg:str) -> str: |
|
"""Prompt the Replicate model.""" |
|
|
|
system_instructions = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_msg}<|eot_id|><|start_header_id|>user<|end_header_id|>" |
|
|
|
input = { |
|
"top_p": 0.1, |
|
"min_tokens": 0, |
|
"temperature": self.temperature, |
|
"prompt": user_msg, |
|
"prompt_template": system_instructions + "\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", |
|
} |
|
|
|
output = replicate.run(self.model_name,input=input) |
|
|
|
return "".join(output) |
|
|
|
class GeminiModel(): |
|
def __init__ (self, API_key:str, model_name:str = "gemini-1.0-pro") -> None: |
|
""""Initialize the Gemini model using an API key.""" |
|
self.safety_settings = [ |
|
{ |
|
"category": "HARM_CATEGORY_DANGEROUS", |
|
"threshold": "BLOCK_NONE", |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_HARASSMENT", |
|
"threshold": "BLOCK_NONE", |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_HATE_SPEECH", |
|
"threshold": "BLOCK_NONE", |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
|
"threshold": "BLOCK_NONE", |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", |
|
"threshold": "BLOCK_NONE", |
|
}, |
|
] |
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
|
self.model = genai.GenerativeModel(model_name) |
|
|
|
def prompt_model(self,system_msg: str, user_msg:str) -> str: |
|
"""Prompt the Gemini model.""" |
|
return self.model.generate_content(system_msg + "\n\n" + user_msg, safety_settings=self.safety_settings).text |
|
|