“Transcendental-Programmer” commited on
Commit
4ac0bf8
·
1 Parent(s): b0741bf

fix: Update chart generation and LLM agent functionality

Browse files
Files changed (3) hide show
  1. app.py +9 -2
  2. chart_generator.py +26 -4
  3. llm_agent.py +100 -70
app.py CHANGED
@@ -16,7 +16,9 @@ logging.getLogger('PIL').setLevel(logging.WARNING)
16
 
17
  app = Flask(__name__, static_folder=os.path.join(os.path.dirname(__file__), '..', 'static'))
18
 
19
- CORS(app)
 
 
20
  agent = LLM_Agent()
21
 
22
  UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '..', 'data', 'uploads')
@@ -53,7 +55,12 @@ def plot():
53
  @app.route('/static/<path:filename>')
54
  def serve_static(filename):
55
  logging.info(f"Serving static file: {filename}")
56
- return send_from_directory(app.static_folder, filename)
 
 
 
 
 
57
 
58
  @app.route('/upload', methods=['POST'])
59
  def upload_file():
 
16
 
17
  app = Flask(__name__, static_folder=os.path.join(os.path.dirname(__file__), '..', 'static'))
18
 
19
+ # Configure CORS to allow all origins for development
20
+ CORS(app, origins=["*"], supports_credentials=True)
21
+
22
  agent = LLM_Agent()
23
 
24
  UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '..', 'data', 'uploads')
 
55
  @app.route('/static/<path:filename>')
56
  def serve_static(filename):
57
  logging.info(f"Serving static file: {filename}")
58
+ response = send_from_directory(app.static_folder, filename)
59
+ # Add CORS headers for images
60
+ response.headers.add('Access-Control-Allow-Origin', '*')
61
+ response.headers.add('Access-Control-Allow-Headers', 'Content-Type')
62
+ response.headers.add('Access-Control-Allow-Methods', 'GET')
63
+ return response
64
 
65
  @app.route('/upload', methods=['POST'])
66
  def upload_file():
chart_generator.py CHANGED
@@ -27,32 +27,54 @@ class ChartGenerator:
27
  missing_cols.append(y)
28
  if missing_cols:
29
  logging.error(f"Missing columns in data: {missing_cols}")
 
30
  raise ValueError(f"Missing columns in data: {missing_cols}")
31
 
32
- fig, ax = plt.subplots()
 
 
 
 
 
33
  for y in y_cols:
34
  color = plot_args.get('color', None)
35
  if plot_args.get('chart_type', 'line') == 'bar':
36
  ax.bar(self.data[x_col], self.data[y], label=y, color=color)
37
  else:
38
- ax.plot(self.data[x_col], self.data[y], label=y, color=color)
39
 
40
  ax.set_xlabel(x_col)
 
 
41
  ax.legend()
 
42
 
 
 
 
43
 
44
  chart_filename = 'chart.png'
45
  output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static', 'images')
46
  if not os.path.exists(output_dir):
47
  os.makedirs(output_dir)
 
48
 
49
  full_path = os.path.join(output_dir, chart_filename)
50
 
51
  if os.path.exists(full_path):
52
  os.remove(full_path)
 
53
 
54
- plt.savefig(full_path)
 
 
55
 
56
- logging.info(f"Chart generated and saved to {full_path}")
 
 
 
 
 
 
57
 
58
  return os.path.join('static', 'images', chart_filename)
 
27
  missing_cols.append(y)
28
  if missing_cols:
29
  logging.error(f"Missing columns in data: {missing_cols}")
30
+ logging.info(f"Available columns: {list(self.data.columns)}")
31
  raise ValueError(f"Missing columns in data: {missing_cols}")
32
 
33
+ # Clear any existing plots
34
+ plt.clf()
35
+ plt.close('all')
36
+
37
+ fig, ax = plt.subplots(figsize=(10, 6))
38
+
39
  for y in y_cols:
40
  color = plot_args.get('color', None)
41
  if plot_args.get('chart_type', 'line') == 'bar':
42
  ax.bar(self.data[x_col], self.data[y], label=y, color=color)
43
  else:
44
+ ax.plot(self.data[x_col], self.data[y], label=y, color=color, marker='o')
45
 
46
  ax.set_xlabel(x_col)
47
+ ax.set_ylabel('Value')
48
+ ax.set_title(f'{plot_args.get("chart_type", "line").title()} Chart')
49
  ax.legend()
50
+ ax.grid(True, alpha=0.3)
51
 
52
+ # Rotate x-axis labels if needed
53
+ if len(self.data[x_col]) > 5:
54
+ plt.xticks(rotation=45)
55
 
56
  chart_filename = 'chart.png'
57
  output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static', 'images')
58
  if not os.path.exists(output_dir):
59
  os.makedirs(output_dir)
60
+ logging.info(f"Created output directory: {output_dir}")
61
 
62
  full_path = os.path.join(output_dir, chart_filename)
63
 
64
  if os.path.exists(full_path):
65
  os.remove(full_path)
66
+ logging.info(f"Removed existing chart file: {full_path}")
67
 
68
+ # Save with high DPI for better quality
69
+ plt.savefig(full_path, dpi=300, bbox_inches='tight', facecolor='white')
70
+ plt.close(fig)
71
 
72
+ # Verify file was created
73
+ if os.path.exists(full_path):
74
+ file_size = os.path.getsize(full_path)
75
+ logging.info(f"Chart generated and saved to {full_path} (size: {file_size} bytes)")
76
+ else:
77
+ logging.error(f"Failed to create chart file at {full_path}")
78
+ raise FileNotFoundError(f"Chart file was not created at {full_path}")
79
 
80
  return os.path.join('static', 'images', chart_filename)
llm_agent.py CHANGED
@@ -10,6 +10,7 @@ import os
10
  from dotenv import load_dotenv
11
  import ast
12
  import requests
 
13
 
14
  load_dotenv()
15
 
@@ -36,7 +37,7 @@ class LLM_Agent:
36
  def process_request(self, data):
37
  start_time = time.time()
38
  logging.info(f"Processing request data: {data}")
39
- query = data['query']
40
  data_path = data.get('file_path')
41
  model_choice = data.get('model', 'bart')
42
 
@@ -49,8 +50,16 @@ class LLM_Agent:
49
  else:
50
  logging.info(f"File exists at path: {data_path}")
51
 
52
- # Few-shot + persona prompt for Flan-UL2 (best model)
53
- flan_prompt = (
 
 
 
 
 
 
 
 
54
  "You are VizBot, an expert data visualization assistant. "
55
  "Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). "
56
  "Do not include any explanation or extra text.\n\n"
@@ -66,73 +75,94 @@ class LLM_Agent:
66
  f"User: {query}\nOutput:"
67
  )
68
 
69
- # Re-initialize data processor and chart generator if a file is specified
70
- if data_path:
71
- self.data_processor = DataProcessor(data_path)
72
- # Log loaded columns
73
- loaded_columns = self.data_processor.get_columns()
74
- logging.info(f"Loaded columns from data: {loaded_columns}")
75
- self.chart_generator = ChartGenerator(self.data_processor.data)
76
-
77
- if model_choice == 'bart':
78
- # Use local fine-tuned BART model
79
- inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
80
- outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
81
- response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
82
- elif model_choice == 'flan-t5-base':
83
- # Use Hugging Face Inference API with Flan-T5-Base model
84
- api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base"
85
- headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}", "Content-Type": "application/json"}
86
- response = requests.post(api_url, headers=headers, json={"inputs": flan_prompt})
87
- if response.status_code != 200:
88
- logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
89
- response_text = "Error: Unable to get response from Flan-T5-Base API. Please try again later."
90
- else:
91
- try:
92
- resp_json = response.json()
93
- response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
94
- except Exception as e:
95
- logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
96
- response_text = f"Error: Unexpected response from Flan-T5-Base API."
97
- elif model_choice == 'flan-ul2':
98
- # Use Hugging Face Inference API with Flan-UL2 model
99
- api_url = "https://api-inference.huggingface.co/models/google/flan-ul2"
100
- # Corrected model name to "google/flan-ul2" does not exist, use "google/flan-t5-xxl" as best available
101
- api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
102
- headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}", "Content-Type": "application/json"}
103
- response = requests.post(api_url, headers=headers, json={"inputs": flan_prompt})
104
- if response.status_code != 200:
105
- logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
106
- response_text = "Error: Unable to get response from Flan-T5-XXL API. Please try again later."
 
 
 
 
 
 
107
  else:
108
- try:
109
- resp_json = response.json()
110
- response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
111
- except Exception as e:
112
- logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
113
- response_text = f"Error: Unexpected response from Flan-T5-XXL API."
114
- else:
115
- # Default fallback to local fine-tuned BART model
116
- inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
117
- outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
118
- response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
119
 
120
- logging.info(f"LLM response text: {response_text}")
121
- try:
122
- plot_args = ast.literal_eval(response_text)
123
- except (SyntaxError, ValueError):
124
- plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
125
- logging.warning(f"Invalid LLM response. Using default plot args: {plot_args}")
126
- if LLM_Agent.validate_plot_args(plot_args):
 
 
 
 
 
 
 
 
 
 
 
 
127
  chart_path = self.chart_generator.generate_chart(plot_args)
128
- else:
129
- logging.warning("Invalid plot arguments. Using default.")
130
- chart_path = self.chart_generator.generate_chart({'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'})
131
- verified = self.image_verifier.verify(chart_path, query)
132
- end_time = time.time()
133
- logging.info(f"Processed request in {end_time - start_time} seconds")
134
- return {
135
- "response": response_text,
136
- "chart_path": chart_path,
137
- "verified": verified
138
- }
 
 
 
 
 
 
 
 
 
 
 
10
  from dotenv import load_dotenv
11
  import ast
12
  import requests
13
+ import json
14
 
15
  load_dotenv()
16
 
 
37
  def process_request(self, data):
38
  start_time = time.time()
39
  logging.info(f"Processing request data: {data}")
40
+ query = data.get('query', '')
41
  data_path = data.get('file_path')
42
  model_choice = data.get('model', 'bart')
43
 
 
50
  else:
51
  logging.info(f"File exists at path: {data_path}")
52
 
53
+ # Re-initialize data processor and chart generator if a file is specified
54
+ if data_path:
55
+ self.data_processor = DataProcessor(data_path)
56
+ # Log loaded columns
57
+ loaded_columns = self.data_processor.get_columns()
58
+ logging.info(f"Loaded columns from data: {loaded_columns}")
59
+ self.chart_generator = ChartGenerator(self.data_processor.data)
60
+
61
+ # Enhanced prompt for better model responses
62
+ enhanced_prompt = (
63
  "You are VizBot, an expert data visualization assistant. "
64
  "Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). "
65
  "Do not include any explanation or extra text.\n\n"
 
75
  f"User: {query}\nOutput:"
76
  )
77
 
78
+ try:
79
+ if model_choice == 'bart':
80
+ # Use local fine-tuned BART model
81
+ inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
82
+ outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
83
+ response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+ elif model_choice == 'flan-t5-base':
85
+ # Use Hugging Face Inference API with Flan-T5-Base model
86
+ api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base"
87
+ headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
88
+ payload = {"inputs": enhanced_prompt}
89
+
90
+ response = requests.post(api_url, headers=headers, json=payload, timeout=30)
91
+ if response.status_code != 200:
92
+ logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
93
+ response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
94
+ else:
95
+ try:
96
+ resp_json = response.json()
97
+ response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
98
+ if not response_text:
99
+ response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
100
+ except Exception as e:
101
+ logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
102
+ response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
103
+ elif model_choice == 'flan-ul2':
104
+ # Use Hugging Face Inference API with Flan-T5-XXL model (best available)
105
+ api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
106
+ headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
107
+ payload = {"inputs": enhanced_prompt}
108
+
109
+ response = requests.post(api_url, headers=headers, json=payload, timeout=30)
110
+ if response.status_code != 200:
111
+ logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
112
+ response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
113
+ else:
114
+ try:
115
+ resp_json = response.json()
116
+ response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
117
+ if not response_text:
118
+ response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
119
+ except Exception as e:
120
+ logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
121
+ response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
122
  else:
123
+ # Default fallback to local fine-tuned BART model
124
+ inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
125
+ outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
126
+ response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
127
 
128
+ logging.info(f"LLM response text: {response_text}")
129
+
130
+ # Clean and parse the response
131
+ response_text = response_text.strip()
132
+ if response_text.startswith("```") and response_text.endswith("```"):
133
+ response_text = response_text[3:-3].strip()
134
+ if response_text.startswith("python"):
135
+ response_text = response_text[6:].strip()
136
+
137
+ try:
138
+ plot_args = ast.literal_eval(response_text)
139
+ except (SyntaxError, ValueError) as e:
140
+ logging.warning(f"Invalid LLM response: {e}. Response: {response_text}")
141
+ plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
142
+
143
+ if not LLM_Agent.validate_plot_args(plot_args):
144
+ logging.warning("Invalid plot arguments. Using default.")
145
+ plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
146
+
147
  chart_path = self.chart_generator.generate_chart(plot_args)
148
+ verified = self.image_verifier.verify(chart_path, query)
149
+
150
+ end_time = time.time()
151
+ logging.info(f"Processed request in {end_time - start_time} seconds")
152
+
153
+ return {
154
+ "response": response_text,
155
+ "chart_path": chart_path,
156
+ "verified": verified
157
+ }
158
+
159
+ except Exception as e:
160
+ logging.error(f"Error processing request: {e}")
161
+ end_time = time.time()
162
+ logging.info(f"Processed request in {end_time - start_time} seconds")
163
+
164
+ return {
165
+ "response": f"Error: {str(e)}",
166
+ "chart_path": "",
167
+ "verified": False
168
+ }