|
from typing import Any, List, Optional |
|
|
|
from smolagents import CodeAgent, InferenceClientModel, HfApiModel, TransformersModel, LiteLLMModel |
|
|
|
|
|
|
|
import os |
|
hf_token = os.getenv("HF_TOKEN") |
|
from utils.logger import get_logger |
|
|
|
logger = get_logger(__name__) |
|
|
|
""" |
|
self.model = InferenceClientModel( |
|
model_id="deepseek-ai/DeepSeek-R1", |
|
api_provider="together" # Changed parameter to api_provider |
|
) |
|
""" |
|
|
|
|
|
class Agent: |
|
""" |
|
Agent class that wraps a CodeAgent and provides a callable interface for answering questions. |
|
Args: |
|
tools (Optional[List[Any]]): List of tools to provide to the agent. |
|
prompt (Optional[str]): Custom prompt template for the agent. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tools: Optional[List[Any]] = None, |
|
prompt: Optional[str] = None, |
|
): |
|
logger.info("Initializing Agent") |
|
|
|
|
|
|
|
self.model = LiteLLMModel( |
|
model_id="groq/deepseek-r1-distill-llama-70b", |
|
temperature=0.2, |
|
api_key=os.environ["GROQ_API_KEY"]) |
|
|
|
self.tools = tools |
|
self.imports = [ |
|
"pandas", |
|
"numpy", |
|
"os", |
|
"requests", |
|
"tempfile", |
|
"datetime", |
|
"json", |
|
"time", |
|
"re", |
|
"openpyxl", |
|
"pathlib", |
|
] |
|
self.agent = CodeAgent( |
|
model=self.model, |
|
tools=self.tools, |
|
add_base_tools=True, |
|
additional_authorized_imports=self.imports, |
|
) |
|
|
|
self.prompt = prompt or ( |
|
"""You are an advanced AI assistant specialized in solving complex tasks using tools. |
|
|
|
Key Instructions: |
|
1. ALWAYS use tools for file paths in the context: |
|
- Use read_file for text files |
|
- Use extract_text_from_image for images (.png, .jpg) |
|
2. Final answer must ONLY contain the requested result in the EXACT format needed. |
|
|
|
QUESTION: {question} |
|
{context} |
|
ANSWER:""" |
|
) |
|
logger.info("Agent initialized") |
|
|
|
def __call__(self, question: str, file_path: Optional[str] = None) -> str: |
|
|
|
context = f"CONTEXT: File available at: {file_path}" if file_path else "" |
|
answer = self.agent.run( |
|
self.prompt.format(question=question, context=context) |
|
) |
|
return str(answer).strip("'").strip('"').strip() |