|
|
import os |
|
|
import random |
|
|
import asyncio |
|
|
import ssl |
|
|
from dotenv import load_dotenv |
|
|
from llama_index.core.agent.workflow import AgentWorkflow |
|
|
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI |
|
|
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor |
|
|
|
|
|
from opentelemetry.sdk.trace.export import SimpleSpanProcessor |
|
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter |
|
|
from opentelemetry import trace |
|
|
from opentelemetry.sdk.trace import TracerProvider |
|
|
from langfuse import get_client |
|
|
from rich.pretty import pprint |
|
|
import aiohttp |
|
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import base64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from tools import ( |
|
|
get_tavily_tool, |
|
|
get_arxiv_reader, |
|
|
get_wikipedia_reader, |
|
|
get_wikipedia_tool, |
|
|
get_arxiv_tool, |
|
|
get_search_tool, |
|
|
get_calculator_tool, |
|
|
get_hub_stats_tool, |
|
|
get_hub_stats, |
|
|
) |
|
|
|
|
|
load_dotenv("env.local") |
|
|
|
|
|
class LlamaIndexAgent: |
|
|
def __init__(self): |
|
|
|
|
|
self.tavily_tool = get_tavily_tool() |
|
|
self.arxiv_reader = get_arxiv_reader() |
|
|
self.wikipedia_reader = get_wikipedia_reader() |
|
|
self.wikipedia_tool = get_wikipedia_tool(self.wikipedia_reader) |
|
|
self.arxiv_tool = get_arxiv_tool(self.arxiv_reader) |
|
|
self.search_tool = get_search_tool() |
|
|
self.calculator_tool = get_calculator_tool() |
|
|
self.hub_stats_tool = get_hub_stats_tool() |
|
|
with open("system_prompt.txt", "r") as f: |
|
|
self.system_prompt = f.read() |
|
|
|
|
|
print("system_prompt loaded:", self.system_prompt[:80], "...") |
|
|
print("DEBUG: search_tool:", self.search_tool, type(self.search_tool)) |
|
|
print("DEBUG: calculator_tool:", self.calculator_tool, type(self.calculator_tool)) |
|
|
print("DEBUG: wikipedia_tool:", self.wikipedia_tool, type(self.wikipedia_tool)) |
|
|
print("DEBUG: arxiv_tool:", self.arxiv_tool, type(self.arxiv_tool)) |
|
|
print("DEBUG: hub_stats_tool:", self.hub_stats_tool, type(self.hub_stats_tool)) |
|
|
all_tools = [*self.search_tool, *self.calculator_tool, self.wikipedia_tool, self.arxiv_tool, self.hub_stats_tool] |
|
|
print("DEBUG: All tools list:", all_tools) |
|
|
print("DEBUG: Types in all_tools:", [type(t) for t in all_tools]) |
|
|
|
|
|
|
|
|
self.llm = HuggingFaceInferenceAPI(model_name="Qwen/Qwen2.5-Coder-32B-Instruct", streaming=False, client_kwargs={"timeout": 60}) |
|
|
self.alfred = AgentWorkflow.from_tools_or_functions( |
|
|
all_tools, |
|
|
llm=self.llm, |
|
|
system_prompt=self.system_prompt |
|
|
|
|
|
) |
|
|
|
|
|
LANGFUSE_AUTH=base64.b64encode(f"{os.getenv('LANGFUSE_PUBLIC_KEY')}:{os.getenv('LANGFUSE_SECRET_KEY')}".encode()).decode() |
|
|
os.environ['OTEL_EXPORTER_OTLP_ENDPOINT'] = os.environ.get("LANGFUSE_HOST") + "/api/public/otel" |
|
|
os.environ['OTEL_EXPORTER_OTLP_HEADERS'] = f"Authorization=Basic {LANGFUSE_AUTH}" |
|
|
|
|
|
|
|
|
self.tracer_provider = TracerProvider() |
|
|
self.tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter())) |
|
|
trace.set_tracer_provider(self.tracer_provider) |
|
|
|
|
|
|
|
|
self.instrumentor = LlamaIndexInstrumentor( |
|
|
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), |
|
|
secret_key=os.getenv("LANGFUSE_SECRET_KEY"), |
|
|
host=os.environ.get("LANGFUSE_HOST") |
|
|
) |
|
|
|
|
|
|
|
|
@retry( |
|
|
stop=stop_after_attempt(3), |
|
|
wait=wait_exponential(multiplier=1, min=4, max=10), |
|
|
retry=retry_if_exception_type(( |
|
|
aiohttp.client_exceptions.ClientConnectionError, |
|
|
aiohttp.client_exceptions.ClientOSError, |
|
|
ssl.SSLError, |
|
|
KeyError, |
|
|
ConnectionError |
|
|
)) |
|
|
) |
|
|
async def run_query(self, query: str): |
|
|
|
|
|
self.instrumentor.instrument() |
|
|
|
|
|
langfuse = get_client() |
|
|
|
|
|
|
|
|
with langfuse.start_as_current_span(name="llamaindex-query") as span: |
|
|
|
|
|
span.update_trace(user_id="user_123", input={"query": query}) |
|
|
|
|
|
try: |
|
|
response = await self.alfred.run(query) |
|
|
except aiohttp.client_exceptions.ClientConnectionError as e: |
|
|
span.update_trace(output={"response": f"Connection error: {e}"}) |
|
|
raise |
|
|
except aiohttp.client_exceptions.ClientOSError as e: |
|
|
span.update_trace(output={"response": f"Client OS error: {e}"}) |
|
|
raise |
|
|
except ssl.SSLError as e: |
|
|
span.update_trace(output={"response": f"SSL error: {e}"}) |
|
|
raise |
|
|
except (KeyError, ConnectionError) as e: |
|
|
span.update_trace(output={"response": f"Session/Connection error: {e}"}) |
|
|
raise |
|
|
except Exception as e: |
|
|
span.update_trace(output={"response": f"General error: {e}"}) |
|
|
return f"AGENT ERROR: {e}" |
|
|
|
|
|
|
|
|
span.update_trace(output={"response": str(response)}) |
|
|
|
|
|
|
|
|
langfuse.flush() |
|
|
self.tracer_provider.shutdown() |
|
|
return response |
|
|
|
|
|
def main(): |
|
|
|
|
|
agent = LlamaIndexAgent() |
|
|
query = "what is the capital of maharashtra?" |
|
|
print(f"Running query: {query}") |
|
|
response = asyncio.run(agent.run_query(query)) |
|
|
print("\n🎩 Agents's Response:") |
|
|
print(response) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |