|
import google.generativeai as genai |
|
import os |
|
from dotenv import load_dotenv |
|
from typing import Optional |
|
|
|
generation_config=genai.types.GenerationConfig( |
|
|
|
|
|
|
|
max_output_tokens=4096, |
|
temperature=0.1 |
|
) |
|
class GeminiModel: |
|
""" |
|
This class is used to interact with the Google LLM models for text generation. |
|
|
|
Args: |
|
model: The name of the model to be used. Defaults to 'gemini-pro'. |
|
max_output_tokens: The maximum number of tokens to generate. Defaults to 1024. |
|
top_p: The probability of generating the next token. Defaults to 1.0. |
|
temperature: The temperature of the model. Defaults to 0.0. |
|
top_k: The number of top tokens to consider. Defaults to 5. |
|
""" |
|
|
|
def __init__(self, |
|
model_name: Optional[str] = 'gemini-pro', |
|
): |
|
|
|
|
|
load_dotenv() |
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
|
self.model = genai.GenerativeModel(model_name) |
|
|
|
|
|
def execute(self, prompt: str) -> str: |
|
|
|
try: |
|
prompt_tokens = self.model.count_tokens(prompt).total_tokens |
|
print(f"Input tokens: {total_tokens}") |
|
response = self.model.generate_content(prompt, generation_config=generation_config) |
|
output_tokens = self.model.count_tokens(response.text).total_tokens |
|
print(f"Output tokens: {output_tokens}") |
|
|
|
return response.text,{'prompt_tokens':prompt_tokens,"total_tokens":output_tokens} |
|
except Exception as e: |
|
return f"An error occurred: {e}" |