mohannad-tazi's picture
Update agent.py
b634382 verified
from typing import Any, List, Optional
from smolagents import CodeAgent, InferenceClientModel, HfApiModel, TransformersModel, LiteLLMModel
import os
hf_token = os.getenv("HF_TOKEN")
from utils.logger import get_logger
logger = get_logger(__name__)
"""
self.model = InferenceClientModel(
model_id="deepseek-ai/DeepSeek-R1",
api_provider="together" # Changed parameter to api_provider
)
"""
class Agent:
"""
Agent class that wraps a CodeAgent and provides a callable interface for answering questions.
Args:
tools (Optional[List[Any]]): List of tools to provide to the agent.
prompt (Optional[str]): Custom prompt template for the agent.
"""
def __init__(
self,
tools: Optional[List[Any]] = None,
prompt: Optional[str] = None,
):
logger.info("Initializing Agent")
#model_id = "meta-llama/Llama-3.3-70B-Instruct"
#self.model = HfApiModel(model_id=model_id, token=hf_token)
self.model = LiteLLMModel(
model_id="groq/deepseek-r1-distill-llama-70b",
temperature=0.2,
api_key=os.environ["GROQ_API_KEY"])
self.tools = tools
self.imports = [
"pandas",
"numpy",
"os",
"requests",
"tempfile",
"datetime",
"json",
"time",
"re",
"openpyxl",
"pathlib",
]
self.agent = CodeAgent(
model=self.model,
tools=self.tools,
add_base_tools=True,
additional_authorized_imports=self.imports,
)
# Improved prompt with clearer context handling instructions
self.prompt = prompt or (
"""You are an advanced AI assistant specialized in solving complex tasks using tools.
Key Instructions:
1. ALWAYS use tools for file paths in the context:
- Use read_file for text files
- Use extract_text_from_image for images (.png, .jpg)
2. Final answer must ONLY contain the requested result in the EXACT format needed.
QUESTION: {question}
{context}
ANSWER:"""
)
logger.info("Agent initialized")
def __call__(self, question: str, file_path: Optional[str] = None) -> str:
# Handle context more cleanly
context = f"CONTEXT: File available at: {file_path}" if file_path else ""
answer = self.agent.run(
self.prompt.format(question=question, context=context)
)
return str(answer).strip("'").strip('"').strip()