|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from data_processor import DataProcessor |
|
from chart_generator import ChartGenerator |
|
from image_verifier import ImageVerifier |
|
from huggingface_hub import login |
|
import logging |
|
import time |
|
import os |
|
from dotenv import load_dotenv |
|
import ast |
|
import requests |
|
import json |
|
|
|
load_dotenv() |
|
|
|
class LLM_Agent: |
|
def __init__(self, data_path=None): |
|
logging.info("Initializing LLM_Agent") |
|
self.data_processor = DataProcessor(data_path) |
|
self.chart_generator = ChartGenerator(self.data_processor.data) |
|
self.image_verifier = ImageVerifier() |
|
|
|
|
|
model_path = "ArchCoder/fine-tuned-bart-large" |
|
self.query_tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
self.query_model = AutoModelForSeq2SeqLM.from_pretrained(model_path) |
|
|
|
def validate_plot_args(plot_args): |
|
required_keys = ['x', 'y', 'chart_type'] |
|
if not all(key in plot_args for key in required_keys): |
|
return False |
|
if not isinstance(plot_args['y'], list): |
|
plot_args['y'] = [plot_args['y']] |
|
return True |
|
|
|
def process_request(self, data): |
|
start_time = time.time() |
|
logging.info(f"Processing request data: {data}") |
|
query = data.get('query', '') |
|
data_path = data.get('file_path') |
|
model_choice = data.get('model', 'bart') |
|
|
|
|
|
if data_path: |
|
logging.info(f"Data path received: {data_path}") |
|
import os |
|
if not os.path.exists(data_path): |
|
logging.error(f"File does not exist at path: {data_path}") |
|
else: |
|
logging.info(f"File exists at path: {data_path}") |
|
|
|
|
|
if data_path: |
|
self.data_processor = DataProcessor(data_path) |
|
|
|
loaded_columns = self.data_processor.get_columns() |
|
logging.info(f"Loaded columns from data: {loaded_columns}") |
|
self.chart_generator = ChartGenerator(self.data_processor.data) |
|
|
|
|
|
enhanced_prompt = ( |
|
"You are VizBot, an expert data visualization assistant. " |
|
"Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). " |
|
"Do not include any explanation or extra text.\n\n" |
|
"Example 1:\n" |
|
"User: plot the sales in the years with red line\n" |
|
"Output: {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line', 'color': 'red'}\n\n" |
|
"Example 2:\n" |
|
"User: show employee expenses and net profit over the years\n" |
|
"Output: {'x': 'Year', 'y': ['Employee expense', 'Net profit'], 'chart_type': 'line'}\n\n" |
|
"Example 3:\n" |
|
"User: display the EBITDA for each year with a blue bar\n" |
|
"Output: {'x': 'Year', 'y': ['EBITDA'], 'chart_type': 'bar', 'color': 'blue'}\n\n" |
|
f"User: {query}\nOutput:" |
|
) |
|
|
|
try: |
|
if model_choice == 'bart': |
|
|
|
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1) |
|
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
elif model_choice == 'flan-t5-base': |
|
|
|
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base" |
|
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"} |
|
payload = {"inputs": enhanced_prompt} |
|
|
|
response = requests.post(api_url, headers=headers, json=payload, timeout=30) |
|
if response.status_code != 200: |
|
logging.error(f"Hugging Face API error: {response.status_code} {response.text}") |
|
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" |
|
else: |
|
try: |
|
resp_json = response.json() |
|
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '') |
|
if not response_text: |
|
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" |
|
except Exception as e: |
|
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}") |
|
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" |
|
elif model_choice == 'flan-ul2': |
|
|
|
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl" |
|
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"} |
|
payload = {"inputs": enhanced_prompt} |
|
|
|
response = requests.post(api_url, headers=headers, json=payload, timeout=30) |
|
if response.status_code != 200: |
|
logging.error(f"Hugging Face API error: {response.status_code} {response.text}") |
|
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" |
|
else: |
|
try: |
|
resp_json = response.json() |
|
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '') |
|
if not response_text: |
|
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" |
|
except Exception as e: |
|
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}") |
|
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}" |
|
else: |
|
|
|
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1) |
|
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
logging.info(f"LLM response text: {response_text}") |
|
|
|
|
|
response_text = response_text.strip() |
|
if response_text.startswith("```") and response_text.endswith("```"): |
|
response_text = response_text[3:-3].strip() |
|
if response_text.startswith("python"): |
|
response_text = response_text[6:].strip() |
|
|
|
try: |
|
plot_args = ast.literal_eval(response_text) |
|
except (SyntaxError, ValueError) as e: |
|
logging.warning(f"Invalid LLM response: {e}. Response: {response_text}") |
|
plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'} |
|
|
|
if not LLM_Agent.validate_plot_args(plot_args): |
|
logging.warning("Invalid plot arguments. Using default.") |
|
plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'} |
|
|
|
chart_path = self.chart_generator.generate_chart(plot_args) |
|
verified = self.image_verifier.verify(chart_path, query) |
|
|
|
end_time = time.time() |
|
logging.info(f"Processed request in {end_time - start_time} seconds") |
|
|
|
return { |
|
"response": response_text, |
|
"chart_path": chart_path, |
|
"verified": verified |
|
} |
|
|
|
except Exception as e: |
|
logging.error(f"Error processing request: {e}") |
|
end_time = time.time() |
|
logging.info(f"Processed request in {end_time - start_time} seconds") |
|
|
|
return { |
|
"response": f"Error: {str(e)}", |
|
"chart_path": "", |
|
"verified": False |
|
} |
|
|