|
|
|
|
|
from concurrent.futures import ProcessPoolExecutor |
|
|
import logging |
|
|
import os |
|
|
import asyncio |
|
|
import threading |
|
|
import traceback |
|
|
import uuid |
|
|
from fastapi import FastAPI, HTTPException, Header |
|
|
from fastapi.encoders import jsonable_encoder |
|
|
from typing import Dict, List, Optional |
|
|
from fastapi.responses import FileResponse |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from pandasai import SmartDataframe |
|
|
from langchain_groq.chat_models import ChatGroq |
|
|
from dotenv import load_dotenv |
|
|
from pydantic import BaseModel, Field |
|
|
from cerebras_report_generator import generate_csv_report_cerebras |
|
|
from csv_service import clean_data, extract_chart_filenames, generate_csv_data, get_csv_basic_info |
|
|
from urllib.parse import unquote |
|
|
from langchain_groq import ChatGroq |
|
|
import pandas as pd |
|
|
from langchain_experimental.tools import PythonAstREPLTool |
|
|
from langchain_experimental.agents import create_pandas_dataframe_agent |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib |
|
|
import seaborn as sns |
|
|
from gemini_report_generator import generate_csv_report_gemini |
|
|
from groq_report_generator import generate_csv_report_groq |
|
|
from intitial_q_handler import if_initial_chart_question, if_initial_chat_question |
|
|
from orc_agent_main_cerebras import csv_orchestrator_chat_cerebras |
|
|
from orchestrator_agent import csv_orchestrator_chat_gemini |
|
|
from python_code_executor_service import CsvChatResult, PythonExecutor |
|
|
from supabase_service import upload_file_to_supabase |
|
|
from cerebras_csv_agent import query_csv_agent_cerebras |
|
|
from util_service import _prompt_generator, process_answer |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
max_cpus = os.cpu_count() |
|
|
logger.info(f"Max CPUs: {max_cpus}") |
|
|
|
|
|
|
|
|
os.makedirs("/app/cache", exist_ok=True) |
|
|
|
|
|
os.makedirs("/app", exist_ok=True) |
|
|
open("/app/pandasai.log", "a").close() |
|
|
|
|
|
|
|
|
os.makedirs("/app/generated_charts", exist_ok=True) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
image_file_path = os.getenv("IMAGE_FILE_PATH") |
|
|
image_not_found = os.getenv("IMAGE_NOT_FOUND") |
|
|
allowed_hosts = os.getenv("ALLOWED_HOSTS", "").split(",") |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=allowed_hosts, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
groq_api_keys = os.getenv("GROQ_API_KEYS").split(",") |
|
|
model_name = os.getenv("GROQ_LLAMA_MODEL") |
|
|
|
|
|
class CsvUrlRequest(BaseModel): |
|
|
csv_url: str |
|
|
|
|
|
class ImageRequest(BaseModel): |
|
|
image_path: str |
|
|
chat_id: str |
|
|
|
|
|
class FileProps(BaseModel): |
|
|
fileName: str |
|
|
filePath: str |
|
|
fileType: str |
|
|
|
|
|
class Files(BaseModel): |
|
|
csv_files: List[FileProps] |
|
|
image_files: List[FileProps] |
|
|
|
|
|
class FileBoxProps(BaseModel): |
|
|
files: Files |
|
|
|
|
|
|
|
|
current_groq_key_index = 0 |
|
|
current_groq_key_lock = threading.Lock() |
|
|
|
|
|
|
|
|
current_langchain_key_index = 0 |
|
|
current_langchain_key_lock = threading.Lock() |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "CSV Chat Service-1 server is running"} |
|
|
|
|
|
|
|
|
@app.get("/ping") |
|
|
async def root(): |
|
|
return {"message": "Pong !!"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/basic_csv_data") |
|
|
async def basic_csv_data(request: CsvUrlRequest): |
|
|
try: |
|
|
decoded_url = unquote(request.csv_url) |
|
|
logger.info(f"Fetching CSV data from URL: {decoded_url}") |
|
|
|
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
csv_data = await loop.run_in_executor( |
|
|
process_executor, get_csv_basic_info, decoded_url |
|
|
) |
|
|
logger.info(f"CSV data fetched successfully: {csv_data}") |
|
|
return {"data": csv_data} |
|
|
except Exception as e: |
|
|
logger.error(f"Error while fetching CSV data: {e}") |
|
|
raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/get-chart") |
|
|
async def get_image(request: ImageRequest, authorization: str = Header(None)): |
|
|
if not authorization: |
|
|
raise HTTPException(status_code=401, detail="Authorization header missing") |
|
|
|
|
|
if not authorization.startswith("Bearer "): |
|
|
raise HTTPException(status_code=401, detail="Invalid authorization header format") |
|
|
|
|
|
token = authorization.split(" ")[1] |
|
|
if not token: |
|
|
raise HTTPException(status_code=401, detail="Token missing") |
|
|
if token != os.getenv("AUTH_TOKEN"): |
|
|
raise HTTPException(status_code=403, detail="Invalid token") |
|
|
|
|
|
try: |
|
|
logger.info("Groq Chat created a chat for the user query...") |
|
|
image_file_path = request.image_path |
|
|
unique_file_name =f'{str(uuid.uuid4())}.png' |
|
|
logger.info("Uploading the chart to supabase...") |
|
|
image_public_url = await upload_file_to_supabase(f"{image_file_path}", unique_file_name, chat_id=request.chat_id) |
|
|
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url}) |
|
|
os.remove(image_file_path) |
|
|
return {"image_url": image_public_url} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error: {e}") |
|
|
return {"answer": "error"} |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/csv_data") |
|
|
async def get_csv_data(request: CsvUrlRequest): |
|
|
try: |
|
|
decoded_url = unquote(request.csv_url) |
|
|
logger.info(f"Fetching CSV data from URL: {decoded_url}") |
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
csv_data = await loop.run_in_executor( |
|
|
process_executor, generate_csv_data, decoded_url |
|
|
) |
|
|
return csv_data |
|
|
except Exception as e: |
|
|
logger.error(f"Error while fetching CSV data: {e}") |
|
|
raise HTTPException(status_code=400, detail=f"Failed to retrieve CSV data: {str(e)}") |
|
|
|
|
|
|
|
|
class ExecutionRequest(BaseModel): |
|
|
chat_id: str = Field(..., alias="chat_id") |
|
|
csv_url: str = Field(..., alias="csv_url") |
|
|
codeExecutionPayload: CsvChatResult |
|
|
|
|
|
|
|
|
@app.post("/api/code_execution_csv") |
|
|
async def code_execution_csv( |
|
|
request_data: ExecutionRequest, |
|
|
authorization: Optional[str] = Header(None) |
|
|
): |
|
|
|
|
|
expected_token = os.environ.get("AUTH_TOKEN") |
|
|
if not authorization or not expected_token or authorization.replace("Bearer ", "") != expected_token: |
|
|
raise HTTPException(status_code=401, detail="Unauthorized") |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("Incoming request data:", request_data) |
|
|
|
|
|
|
|
|
decoded_url = unquote(request_data.csv_url) |
|
|
df = clean_data(decoded_url) |
|
|
executor = PythonExecutor(df) |
|
|
formatted_output = await executor.process_response(request_data.codeExecutionPayload, request_data.chat_id) |
|
|
return {"answer": formatted_output} |
|
|
|
|
|
except Exception as e: |
|
|
logger.info("Processing error:", str(e)) |
|
|
return {"error": "Failed to execute request", "message": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def groq_chat(csv_url: str, question: str): |
|
|
global current_groq_key_index, current_groq_key_lock |
|
|
|
|
|
while True: |
|
|
with current_groq_key_lock: |
|
|
if current_groq_key_index >= len(groq_api_keys): |
|
|
return {"error": "All API keys exhausted."} |
|
|
current_api_key = groq_api_keys[current_groq_key_index] |
|
|
|
|
|
try: |
|
|
|
|
|
data = clean_data(csv_url) |
|
|
llm = ChatGroq(model=model_name, api_key=current_api_key) |
|
|
|
|
|
chart_filename = f"chart_{uuid.uuid4()}.png" |
|
|
chart_path = os.path.join("generated_charts", chart_filename) |
|
|
|
|
|
|
|
|
df = SmartDataframe( |
|
|
data, |
|
|
config={ |
|
|
'llm': llm, |
|
|
'save_charts': True, |
|
|
'open_charts': False, |
|
|
'save_charts_path': os.path.dirname(chart_path), |
|
|
'custom_chart_filename': chart_filename, |
|
|
'enable_cache': False |
|
|
} |
|
|
) |
|
|
|
|
|
answer = df.chat(question) |
|
|
|
|
|
|
|
|
if isinstance(answer, pd.DataFrame): |
|
|
processed = answer.apply(handle_out_of_range_float).to_dict(orient="records") |
|
|
elif isinstance(answer, pd.Series): |
|
|
processed = answer.apply(handle_out_of_range_float).to_dict() |
|
|
elif isinstance(answer, list): |
|
|
processed = [handle_out_of_range_float(item) for item in answer] |
|
|
elif isinstance(answer, dict): |
|
|
processed = {k: handle_out_of_range_float(v) for k, v in answer.items()} |
|
|
else: |
|
|
processed = {"answer": str(handle_out_of_range_float(answer))} |
|
|
|
|
|
return processed |
|
|
|
|
|
except Exception as e: |
|
|
error_message = str(e) |
|
|
if error_message != "": |
|
|
logger.warning("Rate limit exceeded. Switching to next API key.") |
|
|
with current_groq_key_lock: |
|
|
current_groq_key_index += 1 |
|
|
if current_groq_key_index >= len(groq_api_keys): |
|
|
return {"error": "All API keys exhausted."} |
|
|
else: |
|
|
logger.error("Error in groq_chat: %s", e) |
|
|
return {"error": error_message} |
|
|
|
|
|
|
|
|
def langchain_csv_chat(csv_url: str, question: str, chart_required: bool): |
|
|
global current_langchain_key_index, current_langchain_key_lock, current_langchain_chart_key_index, current_langchain_chart_lock |
|
|
|
|
|
data = clean_data(csv_url) |
|
|
attempts = 0 |
|
|
|
|
|
while attempts < len(groq_api_keys): |
|
|
with current_langchain_key_lock: |
|
|
if current_langchain_key_index >= len(groq_api_keys): |
|
|
current_langchain_key_index = 0 |
|
|
api_key = groq_api_keys[current_langchain_key_index] |
|
|
current_langchain_key_index += 1 |
|
|
attempts += 1 |
|
|
|
|
|
try: |
|
|
llm = ChatGroq(model=model_name, api_key=api_key) |
|
|
tool = PythonAstREPLTool(locals={ |
|
|
"df": data, |
|
|
"pd": pd, |
|
|
"np": np, |
|
|
"plt": plt, |
|
|
"sns": sns, |
|
|
"matplotlib": matplotlib |
|
|
}) |
|
|
|
|
|
agent = create_pandas_dataframe_agent( |
|
|
llm, |
|
|
data, |
|
|
agent_type="tool-calling", |
|
|
verbose=True, |
|
|
allow_dangerous_code=True, |
|
|
extra_tools=[tool], |
|
|
return_intermediate_steps=True |
|
|
) |
|
|
|
|
|
prompt = _prompt_generator(question, chart_required, csv_url) |
|
|
result = agent.invoke({"input": prompt}) |
|
|
return result.get("output") |
|
|
|
|
|
except Exception as e: |
|
|
error_message = str(e) |
|
|
if error_message != "": |
|
|
with current_langchain_chart_lock: |
|
|
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) |
|
|
logger.warning(f"Rate limit exceeded. Switching to next API key: {groq_api_keys[current_langchain_chart_key_index]}") |
|
|
else: |
|
|
logger.error(f"Error with API key {api_key}: {error_message}") |
|
|
return {"error": error_message} |
|
|
|
|
|
return {"error": "All API keys exhausted"} |
|
|
|
|
|
|
|
|
async def handle_detailed_answer(decoded_url, query, conversation_history, chat_id): |
|
|
""" |
|
|
Try CSV processing first with Cerebras orchestrator, then fallback to Gemini if needed. |
|
|
""" |
|
|
orchestrator_answer = None |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Processing detailed answer with Cerebras orchestrator...") |
|
|
orchestrator_answer = await asyncio.to_thread( |
|
|
csv_orchestrator_chat_cerebras, decoded_url, query, conversation_history, chat_id |
|
|
) |
|
|
if orchestrator_answer is not None: |
|
|
logger.info(f"Cerebras answer successful: {str(orchestrator_answer)[:200]}...") |
|
|
return {"answer": jsonable_encoder(orchestrator_answer)} |
|
|
else: |
|
|
logger.warning("Cerebras orchestrator returned None") |
|
|
except Exception as e: |
|
|
logger.error(f"Cerebras orchestrator failed: {str(e)}") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Falling back to Gemini orchestrator...") |
|
|
orchestrator_answer = await asyncio.to_thread( |
|
|
csv_orchestrator_chat_gemini, decoded_url, query, conversation_history, chat_id |
|
|
) |
|
|
if orchestrator_answer is not None: |
|
|
logger.info(f"Gemini answer successful: {str(orchestrator_answer)[:200]}...") |
|
|
return {"answer": jsonable_encoder(orchestrator_answer)} |
|
|
else: |
|
|
logger.warning("Gemini orchestrator returned None") |
|
|
except Exception as e: |
|
|
logger.error(f"Gemini orchestrator failed: {str(e)}") |
|
|
|
|
|
|
|
|
logger.error("Both Cerebras and Gemini orchestrators failed or returned None") |
|
|
return {"answer": None} |
|
|
|
|
|
|
|
|
@app.post("/api/csv-chat") |
|
|
async def csv_chat(request: Dict, authorization: str = Header(None)): |
|
|
|
|
|
if not authorization or not authorization.startswith("Bearer "): |
|
|
logger.error("Authorization failed: Missing or invalid authorization header") |
|
|
raise HTTPException(status_code=401, detail="Invalid authorization") |
|
|
|
|
|
token = authorization.split(" ")[1] |
|
|
if token != os.getenv("AUTH_TOKEN"): |
|
|
logger.error("Authorization failed: Invalid token") |
|
|
raise HTTPException(status_code=403, detail="Invalid token") |
|
|
|
|
|
logger.info("Authorization successful") |
|
|
|
|
|
try: |
|
|
|
|
|
query = request.get("query") |
|
|
csv_url = request.get("csv_url") |
|
|
decoded_url = unquote(csv_url) |
|
|
detailed_answer = request.get("detailed_answer") |
|
|
conversation_history = request.get("conversation_history", []) |
|
|
generate_report = request.get("generate_report") |
|
|
chat_id = request.get("chat_id") |
|
|
|
|
|
logger.info(f"Request parameters: query='{query[:100]}...', csv_url='{csv_url}', detailed_answer={detailed_answer}, generate_report={generate_report}, chat_id={chat_id}") |
|
|
|
|
|
|
|
|
if generate_report is True: |
|
|
logger.info("Starting report generation process...") |
|
|
|
|
|
|
|
|
logger.info("Attempting report generation with Cerebras...") |
|
|
try: |
|
|
report_files = await generate_csv_report_cerebras(csv_url, query, chat_id, conversation_history) |
|
|
if report_files is not None and (report_files.files.csv_files or report_files.files.image_files): |
|
|
logger.info(f"Cerebras report generation successful: {len(report_files.files.csv_files)} CSV files, {len(report_files.files.image_files)} image files") |
|
|
return {"answer": jsonable_encoder(report_files)} |
|
|
else: |
|
|
logger.warning("Cerebras report generation returned empty or None result") |
|
|
except Exception as cerebras_error: |
|
|
logger.error(f"Cerebras report generation failed: {str(cerebras_error)}") |
|
|
|
|
|
|
|
|
logger.info("Falling back to Gemini for report generation...") |
|
|
try: |
|
|
report_files = await generate_csv_report_gemini(csv_url, query, chat_id, conversation_history) |
|
|
if report_files is not None and (report_files.files.csv_files or report_files.files.image_files): |
|
|
logger.info(f"Gemini report generation successful: {len(report_files.files.csv_files)} CSV files, {len(report_files.files.image_files)} image files") |
|
|
return {"answer": jsonable_encoder(report_files)} |
|
|
else: |
|
|
logger.warning("Gemini report generation returned empty or None result") |
|
|
except Exception as gemini_error: |
|
|
logger.error(f"Gemini report generation failed: {str(gemini_error)}") |
|
|
|
|
|
logger.error("Both Cerebras and Gemini report generation failed") |
|
|
|
|
|
|
|
|
logger.info("Attempting report generation with Groq as last resort...") |
|
|
try: |
|
|
report_files = await generate_csv_report_groq(csv_url, query, chat_id, conversation_history) |
|
|
if report_files is not None and (report_files.files.csv_files or report_files.files.image_files): |
|
|
logger.info(f"Groq report generation successful: {len(report_files.files.csv_files)} CSV files, {len(report_files.files.image_files)} image files") |
|
|
return {"answer": jsonable_encoder(report_files)} |
|
|
else: |
|
|
logger.warning("Groq report generation returned empty or None result") |
|
|
except Exception as groq_error: |
|
|
logger.error(f"Groq report generation failed: {str(groq_error)}") |
|
|
|
|
|
logger.error("All report generation methods failed") |
|
|
|
|
|
|
|
|
if if_initial_chat_question(query): |
|
|
logger.info("Processing as initial chat question with langchain...") |
|
|
try: |
|
|
answer = await asyncio.to_thread( |
|
|
langchain_csv_chat, decoded_url, query, False |
|
|
) |
|
|
logger.info(f"Langchain initial chat answer: {str(answer)[:200]}...") |
|
|
return {"answer": jsonable_encoder(answer)} |
|
|
except Exception as e: |
|
|
logger.error(f"Langchain initial chat failed: {str(e)}") |
|
|
|
|
|
|
|
|
if detailed_answer is True: |
|
|
logger.info("Processing detailed answer with orchestrator...") |
|
|
return await handle_detailed_answer(decoded_url, query, conversation_history, chat_id) |
|
|
|
|
|
|
|
|
logger.info("Processing with standard CSV agent (Cerebras)...") |
|
|
try: |
|
|
result = await query_csv_agent_cerebras(decoded_url, query, chat_id) |
|
|
logger.info(f"Standard CSV agent (Cerebras) result: {str(result)[:200]}...") |
|
|
if result is not None and result != "": |
|
|
return {"answer": result} |
|
|
else: |
|
|
logger.warning("Standard CSV agent (Cerebras) returned empty or None result") |
|
|
except Exception as e: |
|
|
logger.error(f"Standard CSV agent (Cerebras) failed: {str(e)}") |
|
|
|
|
|
|
|
|
logger.info("Falling back to langchain CSV chat...") |
|
|
try: |
|
|
lang_answer = await asyncio.to_thread( |
|
|
langchain_csv_chat, decoded_url, query, False |
|
|
) |
|
|
logger.info(f"Langchain fallback result: {str(lang_answer)[:200]}...") |
|
|
|
|
|
if process_answer(lang_answer): |
|
|
logger.error("Langchain fallback produced error response") |
|
|
return {"answer": "error"} |
|
|
|
|
|
logger.info("Langchain fallback successful") |
|
|
return {"answer": jsonable_encoder(lang_answer)} |
|
|
except Exception as e: |
|
|
logger.error(f"Langchain fallback failed: {str(e)}") |
|
|
|
|
|
|
|
|
logger.error("All processing methods failed") |
|
|
return {"answer": "error"} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Critical error processing request: {str(e)}") |
|
|
logger.error(f"Error traceback: {traceback.format_exc()}") |
|
|
return {"answer": "error"} |
|
|
|
|
|
def handle_out_of_range_float(value): |
|
|
"""Handle out of range float values for JSON serialization""" |
|
|
if isinstance(value, float): |
|
|
if np.isnan(value): |
|
|
logger.debug("Converting NaN to None") |
|
|
return None |
|
|
elif np.isinf(value): |
|
|
logger.debug("Converting Infinity to string") |
|
|
return "Infinity" |
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
instructions = """ |
|
|
|
|
|
- Please ensure that each value is clearly visible, You may need to adjust the font size, rotate the labels, or use truncation to improve readability (if needed). |
|
|
- For multiple charts, put all of them in a single file. |
|
|
- Use colorblind-friendly palette |
|
|
- Read above instructions and follow them. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
current_groq_chart_key_index = 0 |
|
|
current_groq_chart_lock = threading.Lock() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def model(): |
|
|
global current_groq_chart_key_index, current_groq_chart_lock |
|
|
with current_groq_chart_lock: |
|
|
if current_groq_chart_key_index >= len(groq_api_keys): |
|
|
raise Exception("All API keys exhausted for chart generation") |
|
|
api_key = groq_api_keys[current_groq_chart_key_index] |
|
|
return ChatGroq(model=model_name, api_key=api_key) |
|
|
|
|
|
def groq_chart(csv_url: str, question: str): |
|
|
global current_groq_chart_key_index, current_groq_chart_lock |
|
|
|
|
|
for attempt in range(len(groq_api_keys)): |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = clean_data(csv_url) |
|
|
with current_groq_chart_lock: |
|
|
current_api_key = groq_api_keys[current_groq_chart_key_index] |
|
|
|
|
|
llm = ChatGroq(model=model_name, api_key=current_api_key) |
|
|
|
|
|
|
|
|
chart_filename = f"chart_{uuid.uuid4()}.png" |
|
|
chart_path = os.path.join("generated_charts", chart_filename) |
|
|
|
|
|
|
|
|
df = SmartDataframe( |
|
|
data, |
|
|
config={ |
|
|
'llm': llm, |
|
|
'save_charts': True, |
|
|
'open_charts': False, |
|
|
'save_charts_path': os.path.dirname(chart_path), |
|
|
'custom_chart_filename': chart_filename, |
|
|
'enable_cache': False |
|
|
} |
|
|
) |
|
|
|
|
|
answer = df.chat(question + instructions) |
|
|
|
|
|
if process_answer(answer): |
|
|
return "Chart not generated" |
|
|
return answer |
|
|
|
|
|
except Exception as e: |
|
|
error = str(e) |
|
|
|
|
|
if error != "": |
|
|
with current_groq_chart_lock: |
|
|
current_groq_chart_key_index = (current_groq_chart_key_index + 1) |
|
|
else: |
|
|
logger.error(f"Chart generation error: {error}") |
|
|
return {"error": error} |
|
|
|
|
|
return {"error": "All API keys exhausted for chart generation"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_langchain_chart_key_index = 0 |
|
|
current_langchain_chart_lock = threading.Lock() |
|
|
|
|
|
|
|
|
|
|
|
process_executor = ProcessPoolExecutor(max_workers=max_cpus-2) |
|
|
|
|
|
|
|
|
def langchain_csv_chart(csv_url: str, question: str, chart_required: bool): |
|
|
""" |
|
|
Generate a chart using the langchain-based method. |
|
|
Modifications: |
|
|
• No shared deletion of cache. |
|
|
• Close matplotlib figures after generation. |
|
|
• Return a list of full chart file paths. |
|
|
""" |
|
|
global current_langchain_chart_key_index, current_langchain_chart_lock |
|
|
|
|
|
data = clean_data(csv_url) |
|
|
|
|
|
for attempt in range(len(groq_api_keys)): |
|
|
try: |
|
|
with current_langchain_chart_lock: |
|
|
api_key = groq_api_keys[current_langchain_chart_key_index] |
|
|
current_key = current_langchain_chart_key_index |
|
|
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) % len(groq_api_keys) |
|
|
|
|
|
llm = ChatGroq(model=model_name, api_key=api_key) |
|
|
tool = PythonAstREPLTool(locals={ |
|
|
"df": data, |
|
|
"pd": pd, |
|
|
"np": np, |
|
|
"plt": plt, |
|
|
"sns": sns, |
|
|
"matplotlib": matplotlib, |
|
|
"uuid": uuid |
|
|
}) |
|
|
|
|
|
agent = create_pandas_dataframe_agent( |
|
|
llm, |
|
|
data, |
|
|
agent_type="tool-calling", |
|
|
verbose=True, |
|
|
allow_dangerous_code=True, |
|
|
extra_tools=[tool], |
|
|
return_intermediate_steps=True |
|
|
) |
|
|
|
|
|
result = agent.invoke({"input": _prompt_generator(question, True, csv_url)}) |
|
|
output = result.get("output", "") |
|
|
|
|
|
|
|
|
plt.close('all') |
|
|
|
|
|
|
|
|
chart_files = extract_chart_filenames(output) |
|
|
if len(chart_files) > 0: |
|
|
|
|
|
return [os.path.join(image_file_path, f) for f in chart_files] |
|
|
|
|
|
if attempt < len(groq_api_keys) - 1: |
|
|
logger.info(f"Langchain chart error (key {current_key}): {output}") |
|
|
|
|
|
except Exception as e: |
|
|
error_message = str(e) |
|
|
if error_message != "": |
|
|
with current_langchain_chart_lock: |
|
|
current_langchain_chart_key_index = (current_langchain_chart_key_index + 1) |
|
|
logger.warning(f"Rate limit exceeded. Switching to next API key: {groq_api_keys[current_langchain_chart_key_index]}") |
|
|
else: |
|
|
logger.error(f"Error with API key {api_key}: {error_message}") |
|
|
return {"error": error_message} |
|
|
|
|
|
logger.error("All API keys exhausted for chart generation") |
|
|
return "Chart generation failed after all retries" |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/csv-chart") |
|
|
async def csv_chart(request: dict, authorization: str = Header(None)): |
|
|
""" |
|
|
Endpoint for generating a chart from CSV data. |
|
|
This endpoint uses a ProcessPoolExecutor to run the (CPU-bound) chart generation |
|
|
functions in separate processes so that multiple requests can run in parallel. |
|
|
""" |
|
|
|
|
|
if not authorization or not authorization.startswith("Bearer "): |
|
|
raise HTTPException(status_code=401, detail="Authorization required") |
|
|
|
|
|
token = authorization.split(" ")[1] |
|
|
if token != os.getenv("AUTH_TOKEN"): |
|
|
raise HTTPException(status_code=403, detail="Invalid credentials") |
|
|
|
|
|
try: |
|
|
query = request.get("query", "") |
|
|
csv_url = unquote(request.get("csv_url", "")) |
|
|
detailed_answer = request.get("detailed_answer", False) |
|
|
conversation_history = request.get("conversation_history", []) |
|
|
generate_report = request.get("generate_report", False) |
|
|
chat_id = request.get("chat_id", "") |
|
|
|
|
|
if generate_report is True: |
|
|
report_files = await generate_csv_report_gemini(csv_url, query, chat_id, conversation_history) |
|
|
if report_files is not None: |
|
|
return {"orchestrator_response": jsonable_encoder(report_files)} |
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
|
|
|
if if_initial_chart_question(query): |
|
|
langchain_result = await loop.run_in_executor( |
|
|
process_executor, langchain_csv_chart, csv_url, query, True |
|
|
) |
|
|
logger.info("Langchain chart result:", langchain_result) |
|
|
if isinstance(langchain_result, list) and len(langchain_result) > 0: |
|
|
unique_file_name =f'{str(uuid.uuid4())}.png' |
|
|
logger.info("Uploading the chart to supabase...") |
|
|
image_public_url = await upload_file_to_supabase(f"{langchain_result[0]}", unique_file_name, chat_id=chat_id) |
|
|
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url}) |
|
|
os.remove(langchain_result[0]) |
|
|
return {"image_url": image_public_url} |
|
|
|
|
|
|
|
|
|
|
|
if detailed_answer is True: |
|
|
orchestrator_answer = await asyncio.to_thread( |
|
|
csv_orchestrator_chat_gemini, csv_url, query, conversation_history, chat_id |
|
|
) |
|
|
|
|
|
if orchestrator_answer is not None: |
|
|
return {"orchestrator_response": jsonable_encoder(orchestrator_answer)} |
|
|
|
|
|
logger.info("Trying cerebras ai llama...") |
|
|
result = await query_csv_agent_cerebras(csv_url, query, chat_id) |
|
|
logger.info("cerebras ai result ==>", result) |
|
|
if result is not None and result != "": |
|
|
return {"orchestrator_response": jsonable_encoder(result)} |
|
|
|
|
|
|
|
|
logger.error("Cerebras ai llama response failed, trying langchain groq....") |
|
|
langchain_paths = await loop.run_in_executor( |
|
|
process_executor, langchain_csv_chart, csv_url, query, True |
|
|
) |
|
|
logger.info("Fallback langchain chart result:", langchain_paths) |
|
|
if isinstance(langchain_paths, list) and len(langchain_paths) > 0: |
|
|
unique_file_name =f'{str(uuid.uuid4())}.png' |
|
|
logger.info("Uploading the chart to supabase...") |
|
|
image_public_url = await upload_file_to_supabase(f"{langchain_paths[0]}", unique_file_name, chat_id=chat_id) |
|
|
logger.info("Image uploaded to Supabase and Image URL is... ", {image_public_url}) |
|
|
os.remove(langchain_paths[0]) |
|
|
return {"image_url": image_public_url} |
|
|
else: |
|
|
logger.error("All chart generation methods failed") |
|
|
return {"answer": "error"} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Critical chart error: {str(e)}") |
|
|
return {"answer": "error"} |
|
|
|
|
|
|