Spectra-Backend / main.py
ShadowGard3n's picture
backend changes
9308201
import os
import json
from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import PlainTextResponse
from fastapi.middleware.cors import CORSMiddleware
import time
# Import your custom PyPI library
from graphvision import GraphExtractor
app = FastAPI(title="STEM Sight Backend")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows any browser extension to connect
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize your custom PyPI library
print("Initializing STEM Sight Vision Engine...")
vision_engine = GraphExtractor()
@app.get("/")
async def root():
return {"message": "STEM Sight API is online and ready."}
def generate_audio_summary(extraction_result: dict) -> str:
"""
Hardcoded logic to generate a conversational summary from graph data
without relying on an external LLM.
"""
chart_type = extraction_result.get("chart_type", "unknown").lower()
title = extraction_result.get("title")
title_text = f"titled {title}" if title else "without a specific title"
# --- 1. PIE CHART LOGIC ---
if chart_type == "pie":
data = extraction_result.get("data", {})
if not data:
return f"This is a pie chart {title_text}, but no data could be extracted."
max_cat = max(data, key=data.get)
min_cat = min(data, key=data.get)
summary = (
f"This is a pie chart {title_text}. "
f"The largest portion is {max_cat} at {data[max_cat]}. "
f"The smallest portion is {min_cat} at {data[min_cat]}."
)
return summary
# --- 2. BAR CHART LOGIC (HBAR & VBAR) ---
elif chart_type in ["hbar_categorical", "vbar_categorical", "hbar", "vbar"]:
data = extraction_result.get("data", [])
x_label = extraction_result.get("x_axis_label", "the X axis")
y_label = extraction_result.get("y_axis_label", "the Y axis")
if not data:
return f"This is a bar chart {title_text}, but no data could be extracted."
max_item = max(data, key=lambda d: d.get("value", 0))
min_item = min(data, key=lambda d: d.get("value", 0))
summary = (
f"This is a bar chart {title_text}, showing {y_label} against {x_label}. "
f"The highest value is {max_item.get('category')} at {max_item.get('value')}. "
f"The lowest value is {min_item.get('category')} at {min_item.get('value')}. "
)
# Filter out the max and min items so we don't repeat them
other_items = [item for item in data if item != max_item and item != min_item]
if other_items:
# Join the remaining items with a comma so the text-to-speech engine adds a slight pause
other_points_text = ", ".join([f"{item.get('category')} at {item.get('value')}" for item in other_items])
summary += f"The other values are: {other_points_text}."
return summary
# --- 3. DOT / LINE CHART LOGIC ---
elif chart_type == "dot_line":
data = extraction_result.get("data", [])
x_label = extraction_result.get("x_axis_label", "the X axis")
y_label = extraction_result.get("y_axis_label", "the Y axis")
total_points = extraction_result.get("total_points", len(data))
if not data:
return f"This is a line chart {title_text}, but no data could be extracted."
# Group data points by their category (class)
categories = {}
for item in data:
cat_name = item.get("class", "unknown")
if cat_name not in categories:
categories[cat_name] = []
categories[cat_name].append(item)
classes = list(categories.keys())
# Format the classes cleanly for the introductory sentence
classes_text = ", ".join(classes[:3])
if len(classes) > 3:
classes_text += f", and {len(classes) - 3} other categories"
# Introductory overview
summary = (
f"This is a scatter plot {title_text}, with {x_label} on the X axis and {y_label} on the Y axis. "
f"It shows {total_points} data points across categories like {classes_text}. "
)
# Calculate and append max/min for each category
category_summaries = []
for cat_name, points in categories.items():
max_item = max(points, key=lambda d: d.get("y", 0))
min_item = min(points, key=lambda d: d.get("y", 0))
# Using a predictable sentence structure that includes the X coordinate
category_summaries.append(
f"For {cat_name}, the highest value is {max_item.get('y')} when X is {max_item.get('x')}, "
f"and the lowest value is {min_item.get('y')} when X is {min_item.get('x')}."
)
# Join the category breakdowns with spaces so they read as separate sentences
summary += " ".join(category_summaries)
return summary
elif chart_type == "line":
# The Donut model already generated the perfect text summary!
summary = extraction_result.get("summary", "")
if not summary:
return f"This is a line chart {title_text}, but the Vision Engine could not generate a summary."
return summary
# --- FALLBACK ---
else:
return f"Data has been extracted for a {chart_type} chart, but the summary feature for this specific format is not available."
@app.post("/analyze-graph", response_class=PlainTextResponse)
async def analyze_graph(file: UploadFile = File(...)):
try:
start_time = time.time()
# 1. Save the uploaded image temporarily
temp_image_path = f"temp_{file.filename}"
with open(temp_image_path, "wb") as buffer:
buffer.write(await file.read())
print(f"⏱️ Image received and saved in: {time.time() - start_time:.2f} seconds")
# 2. Extract structured data
extract_start = time.time()
print(f"Extracting data from {file.filename}...")
extraction_json_string = vision_engine.extract(temp_image_path)
print(f"⏱️ AI Extraction finished in: {time.time() - extract_start:.2f} seconds")
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
extraction_result = json.loads(extraction_json_string)
print(f"Extracted data: {extraction_result}")
if "error" in extraction_result:
return f"I am sorry, I could not clearly identify the data in this graph. Reason: {extraction_result['error']}"
# 3. Generate summary using hardcoded logic instead of Groq
audio_script = generate_audio_summary(extraction_result)
print(f"✅ TOTAL TIME: {time.time() - start_time:.2f} seconds")
return audio_script
except Exception as e:
return f"An error occurred while analyzing the graph: {str(e)}"
@app.post("/ask", response_class=PlainTextResponse)
async def ask_chart_rule_based(
file: UploadFile = File(...),
question: str = Form(...)
):
# 1. Extract JSON using GraphVision
temp_image_path = f"temp_qa_{file.filename}"
with open(temp_image_path, "wb") as buffer:
buffer.write(await file.read())
extraction_json_string = vision_engine.extract(temp_image_path)
os.remove(temp_image_path)
extraction_result = json.loads(extraction_json_string)
chart_type = extraction_result.get("chart_type", "unknown").lower()
if chart_type == "line":
summary = extraction_result.get("summary")
if summary:
return summary
else:
return "I couldn't extract a summary from this line chart."
data = extraction_result.get("data")
if not data:
return "I couldn't extract data from this chart. Please ensure the image is clear."
# 2. Pre-process Data and Question
question_lower = question.lower()
# Dynamically find all categories available in this specific chart
available_categories = []
if isinstance(data, dict): # It's a Pie Chart
available_categories = list(data.keys())
elif isinstance(data, list): # It's a Bar, Line, or Scatter plot
for item in data:
cat = item.get("category", item.get("class"))
if cat and cat not in available_categories:
available_categories.append(cat)
# Check if the user is asking about a SPECIFIC category
target_category = None
for cat in available_categories:
if cat.lower() in question_lower:
target_category = cat
break # Found the category they are asking about
# If they asked for a specific category in a list-based chart, filter the data!
filtered_data = data
if target_category and isinstance(data, list):
filtered_data = [item for item in data if item.get("category", item.get("class")) == target_category]
# 3. Rule-Based Intent Routing
# Intent 1: Asking for the highest/maximum
if any(word in question_lower for word in ["highest", "maximum", "most", "largest", "top"]):
if isinstance(filtered_data, dict): # Pie Charts
max_cat = max(filtered_data, key=filtered_data.get)
val = filtered_data[max_cat]
return f"Based on the extracted data, the highest is {max_cat} with a value of {val}."
elif isinstance(filtered_data, list): # Bar/Line/Scatter
max_item = max(filtered_data, key=lambda d: d.get("value", d.get("y", 0)))
cat = max_item.get("category", max_item.get("class", "unknown"))
val = max_item.get("value", max_item.get("y"))
# If they asked for a specific category in a scatter plot, include the X coordinate
if target_category:
x_val = max_item.get("x")
if x_val is not None:
return f"For the {target_category} category, the highest value is {val} when X is {x_val}."
return f"For the {target_category} category, the highest value is {val}."
else:
return f"Based on the extracted data, the overall highest is {cat} with a value of {val}."
# Intent 2: Asking for the lowest/minimum
elif any(word in question_lower for word in ["lowest", "minimum", "least", "smallest", "bottom"]):
if isinstance(filtered_data, dict): # Pie Charts
min_cat = min(filtered_data, key=filtered_data.get)
val = filtered_data[min_cat]
return f"Based on the extracted data, the lowest is {min_cat} with a value of {val}."
elif isinstance(filtered_data, list): # Bar/Line/Scatter
min_item = min(filtered_data, key=lambda d: d.get("value", d.get("y", 0)))
cat = min_item.get("category", min_item.get("class", "unknown"))
val = min_item.get("value", min_item.get("y"))
if target_category:
x_val = min_item.get("x")
if x_val is not None:
return f"For the {target_category} category, the lowest value is {val} when X is {x_val}."
return f"For the {target_category} category, the lowest value is {val}."
else:
return f"Based on the extracted data, the overall lowest is {cat} with a value of {val}."
# Intent 3: Asking for a specific category's value (General Lookup)
elif target_category:
if isinstance(data, dict): # Pie Charts
val = data[target_category]
return f"Based on the extracted data, the value for {target_category} is {val}."
elif isinstance(filtered_data, list): # Bar charts
if len(filtered_data) == 1:
val = filtered_data[0].get("value", filtered_data[0].get("y"))
return f"Based on the extracted data, the value for {target_category} is {val}."
else:
# If there are multiple values (like a line chart), tell them to be more specific
return f"The category {target_category} has {len(filtered_data)} different data points. Please ask for the highest or lowest value for this category."
# Intent 4: Fallback
return "I am sorry, I do not understand the question. Please ask for the highest value, the lowest value, or ask about a specific category."
# import os
# import json
# from fastapi import FastAPI, File, UploadFile
# from fastapi.responses import PlainTextResponse
# from fastapi.middleware.cors import CORSMiddleware
# from groq import Groq
# import time
# # Import your newly updated PyPI library!
# from graphvision import GraphExtractor
# app = FastAPI(title="STEM Sight Backend")
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"], # Allows any browser extension to connect
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# # Initialize the Groq Client (Looks for the GROQ_API_KEY environment variable)
# groq_client = Groq()
# # Initialize your custom PyPI library
# print("Initializing STEM Sight Vision Engine...")
# vision_engine = GraphExtractor()
# @app.get("/")
# async def root():
# return {"message": "STEM Sight API is online and ready."}
# @app.post("/analyze-graph", response_class=PlainTextResponse)
# async def analyze_graph(file: UploadFile = File(...)):
# try:
# start_time = time.time()
# # 1. Save the uploaded image temporarily
# temp_image_path = f"temp_{file.filename}"
# with open(temp_image_path, "wb") as buffer:
# buffer.write(await file.read())
# print(f"⏱️ Image received and saved in: {time.time() - start_time:.2f} seconds")
# # 2. Extract structured data
# extract_start = time.time()
# print(f"Extracting data from {file.filename}...")
# extraction_json_string = vision_engine.extract(temp_image_path)
# print(f"⏱️ AI Extraction finished in: {time.time() - extract_start:.2f} seconds")
# if os.path.exists(temp_image_path):
# os.remove(temp_image_path)
# extraction_result = json.loads(extraction_json_string)
# print(f"Extracted data: {extraction_result}")
# if "error" in extraction_result:
# return f"I'm sorry, I couldn't clearly identify the data in this graph. Reason: {extraction_result['error']}"
# graph_type = extraction_result.get("chart_type", "unknown")
# graph_data = extraction_result.get("data", [])
# x_label = extraction_result.get("x_axis_label", "Unknown X-Axis")
# y_label = extraction_result.get("y_axis_label", "Unknown Y-Axis")
# title = extraction_result.get("title", "Untitled Graph")
# prompt = f"""
# You are an accessibility assistant for visually impaired students.
# I am giving you extracted data from a {graph_type} chart.
# Title: {title}
# X-Axis Label: {x_label}
# Y-Axis Label: {y_label}
# Please summarize this data in one short, conversational, and easy-to-understand paragraph.
# Point out the largest and smallest values if relevant.
# Do not use markdown, bold text, or asterisks. Write it exactly as it should be spoken out loud by a text-to-speech engine.
# Data:
# {graph_data}
# """
# # 3. Send to Groq
# groq_start = time.time()
# print("Generating audio script with Groq Llama 3...")
# chat_completion = groq_client.chat.completions.create(
# messages=[{"role": "user", "content": prompt}],
# model="llama-3.1-8b-instant",
# temperature=0.4,
# )
# print(f"⏱️ Groq Llama 3 finished in: {time.time() - groq_start:.2f} seconds")
# print(f"✅ TOTAL TIME: {time.time() - start_time:.2f} seconds")
# return chat_completion.choices[0].message.content.strip()
# except Exception as e:
# return f"An error occurred while analyzing the graph: {str(e)}"