graph_generator / app.py
angelicaporto's picture
Enhance explanation via prompt engineering
a5ad035 verified
raw
history blame
8.06 kB
# ---------------------------------------------------------------------------------
# Aplicación principal para cargar el modelo, generar prompts y explicar los datos
# ---------------------------------------------------------------------------------
import streamlit as st # type: ignore
import os
import re
import pandas as pd # type: ignore
from dotenv import load_dotenv # type: ignore # Para cambios locales
from supabase import create_client, Client # type: ignore
# from pandasai import SmartDataframe # type: ignore
from pandasai import SmartDatalake # type: ignore # Porque ya usamos más de un df (más de una tabla de nuestra db)
from pandasai.llm.local_llm import LocalLLM # type: ignore
from pandasai import Agent
import matplotlib.pyplot as plt
import time
# ---------------------------------------------------------------------------------
# Funciones auxiliares
# ---------------------------------------------------------------------------------
def generate_graph_prompt(user_query):
prompt = f"""
You are a senior data scientist analyzing European labor force data.
Given the user's request: "{user_query}"
1. Plot the relevant data using matplotlib:
- Use `df.query("geo == 'X'")` to filter the country, instead of chained comparisons.
- Avoid using filters like `df[df['geo'] == 'Germany']`.
- Include clear axis labels and a descriptive title.
- Save the plot as an image file (e.g., temp_chart.png).
2. After plotting, write a **concise analytical summary** of the trend based on those 5 years. The summary should:
- Identify the **year with the largest increase** and the percent change.
- Identify the **year with the largest decrease** and the percent change.
- Provide a **brief overall trend interpretation** (e.g., steady growth, fluctuating, recovery, etc.).
- Avoid listing every year individually, summarize intelligently.
3. Store the summary in a variable named `explanation`.
4. Return a result dictionary structured as follows:
result = {{
"type": "plot",
"value": "temp_chart.png",
"explanation": explanation
}}
IMPORTANT: Use only the data available in the input DataFrame.
"""
return prompt
#TODO: Continuar mejorando el prompt
# ---------------------------------------------------------------------------------
# Configuración de conexión a Supabase
# ---------------------------------------------------------------------------------
# Cargar variables de entorno desde archivo .env
load_dotenv()
# Conectar las credenciales de Supabase (ubicadas en "Secrets" en Streamlit)
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
# Crear cliente Supabase
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
# Función para cargar datos de una tabla de Supabase
# Tablas posibles: fertility, geo data, labor, population, predictions
def load_data(table):
try:
if supabase:
response = supabase.from_(table).select("*").execute()
print(f"Response object: {response}") # Inspeccionar objeto completo
print(f"Response type: {type(response)}") # Verificar tipo de objeto
# Acceder a atributos relacionados a error o data
if hasattr(response, 'data'):
print(f"Response data: {response.data}")
return pd.DataFrame(response.data)
elif hasattr(response, 'status_code'):
print(f"Response status code: {response.status_code}")
elif hasattr(response, '_error'): # Versiones antiguas
print(f"Older error attribute: {response._error}")
st.error(f"Error fetching data: {response._error}")
return pd.DataFrame()
else:
st.info("Response object does not have 'data' or known error attributes. Check the logs.")
return pd.DataFrame()
else:
st.error("Supabase client not initialized. Check environment variables.")
return pd.DataFrame()
except Exception as e:
st.error(f"An error occurred during data loading: {e}")
return pd.DataFrame()
# ---------------------------------------------------------------------------------
# Cargar datos iniciales
# ---------------------------------------------------------------------------------
# TODO: La idea es luego usar todas las tablas, cuando ya funcione.
# Se puede si el modelo funciona con las gráficas, sino que toca mejorarlo porque serían consultas más complejas.
labor_data = load_data("labor")
fertility_data = load_data("fertility")
# population_data = load_data("population")
# predictions_data = load_data("predictions")
# TODO: Buscar la forma de disminuir la latencia (muchos datos = mucha latencia)
# ---------------------------------------------------------------------------------
# Inicializar LLM desde Ollama con PandasAI
# ---------------------------------------------------------------------------------
# ollama_llm = LocalLLM(api_base="http://localhost:11434/v1",
# model="gemma3:12b",
# temperature=0.1,
# max_tokens=8000)
lm_studio_llm = LocalLLM(api_base="http://localhost:1234/v1") # el modelo es gemma-3-12b-it-qat
# sdl = SmartDatalake([labor_data, fertility_data, population_data, predictions_data], config={"llm": ollama_llm}) # DataFrame PandasAI-ready.
# sdl = SmartDatalake([labor_data, fertility_data], config={"llm": ollama_llm})
# agent = Agent([labor_data], config={"llm": lm_studio_llm}) # TODO: Probar Agent con multiples dfs
agent = Agent(
[
labor_data,
fertility_data
],
config={
"llm": lm_studio_llm,
"enable_cache": False,
"enable_filter_extraction": False # evita errores de parseo
}
)
# ---------------------------------------------------------------------------------
# Configuración de la app en Streamlit
# ---------------------------------------------------------------------------------
# Título de la app
st.title("Europe GraphGen :blue[Graph generator] :flag-eu:")
# TODO: Poner instrucciones al usuario sobre cómo hacer un muy buen prompt (sin tecnisismos, pensando en el usuario final)
# Entrada de usuario para describir el gráfico
user_input = st.text_input("What graphics do you have in mind")
generate_button = st.button("Generate")
if generate_button and user_input:
with st.spinner('Generating answer...'):
try:
print(f"\nGenerating prompt...\n")
prompt = generate_graph_prompt(user_input)
print(f"\nPrompt generated\n")
start_time = time.time()
answer = agent.chat(prompt)
print(f"\nAnswer type: {type(answer)}\n") # Verificar tipo de objeto
print(f"\nAnswer content: {answer}\n") # Inspeccionar contenido de la respuesta
print(f"\nFull result: {agent.last_result}\n")
full_result = agent.last_result
explanation = full_result.get("explanation", "")
elapsed_time = time.time() - start_time
print(f"\nExecution time: {elapsed_time:.2f} seconds\n")
if isinstance(answer, str) and os.path.isfile(answer):
# Si el output es una ruta válida a imagen
im = plt.imread(answer)
st.image(im)
os.remove(answer) # Limpiar archivo temporal
if explanation:
st.markdown(f"**Explanation:** {explanation}")
else:
# Si no es una ruta válida, mostrar como texto
st.markdown(str(answer))
except Exception as e:
st.error(f"Error generating answer: {e}")
# TODO: Output estructurado si vemos que es necesario.