MathAssistant / app.py
jujutechnology's picture
Update app.py
87188de verified
raw
history blame
32.4 kB
import streamlit as st
import google.generativeai as genai
import os
import json
import base64
from dotenv import load_dotenv
from streamlit_local_storage import LocalStorage
import re
import streamlit.components.v1 as components
import math # Needed for trigonometry in dynamic visuals
# --- PAGE CONFIGURATION ---
st.set_page_config(
page_title="Math Jegna - Your AI Math Tutor",
page_icon="🧠",
layout="wide"
)
# Create an instance of the LocalStorage class
localS = LocalStorage()
# --- HELPER FUNCTIONS ---
def format_chat_for_download(chat_history):
"""Formats the chat history into a human-readable string for download."""
formatted_text = f"# Math Mentor Chat\n\n"
for message in chat_history:
role = "You" if message["role"] == "user" else "Math Mentor"
formatted_text += f"**{role}:**\n{message['content']}\n\n---\n\n"
return formatted_text
def convert_role_for_gemini(role):
"""Convert Streamlit chat roles to Gemini API roles"""
if role == "assistant":
return "model"
return role # "user" stays the same
def should_generate_visual(user_prompt, ai_response):
"""Determine if a visual aid would be helpful based on the content"""
# Expanded keywords to trigger new dynamic visuals
k12_visual_keywords = [
'add', 'subtract', 'multiply', 'times', 'divide', 'counting', 'numbers',
'fraction', 'half', 'quarter', 'third', 'parts', 'whole',
'shape', 'triangle', 'circle', 'square', 'rectangle',
'money', 'coins', 'dollars', 'cents', 'change',
'time', 'clock', 'hours', 'minutes', 'o\'clock',
'measurement', 'length', 'height', 'weight',
'place value', 'tens', 'ones', 'hundreds',
'pattern', 'sequence', 'skip counting',
'greater than', 'less than', 'equal', 'compare',
'number line', 'array', 'grid'
]
combined_text = (user_prompt + " " + ai_response).lower()
return any(keyword in combined_text for keyword in k12_visual_keywords)
def create_visual_manipulative(user_prompt, ai_response):
"""-- SMART VISUAL ROUTER --
Parses the user prompt and calls the appropriate dynamic visual function."""
try:
user_lower = user_prompt.lower()
# Priority 1: Time / Clock (e.g., "7:30", "4 o'clock")
time_match = re.search(r'(\d{1,2}):(\d{2})', user_lower) or re.search(r'(\d{1,2})\s*o\'clock', user_lower)
if time_match:
groups = time_match.groups()
hour = int(groups[0])
minute = int(groups[1]) if len(groups) > 1 and groups[1] else 0
if 1 <= hour <= 12 and 0 <= minute <= 59:
return create_clock_visual(hour, minute)
# Priority 2: Fractions (e.g., "2/5", "fraction 3/8")
fraction_match = re.search(r'(\d+)/(\d+)', user_lower)
if fraction_match:
num, den = int(fraction_match.group(1)), int(fraction_match.group(2))
if 0 < num <= den and den <= 16: # Keep it visually clean
return create_dynamic_fraction_circle(num, den)
# Priority 3: Multiplication Arrays (e.g., "3 times 5", "4 x 6")
mult_match = re.search(r'(\d+)\s*(?:x|times)\s*(\d+)', user_lower)
if mult_match:
rows, cols = int(mult_match.group(1)), int(mult_match.group(2))
if rows <= 10 and cols <= 10: # Keep arrays reasonable
return create_multiplication_array(rows, cols)
# Priority 4: Addition/Subtraction Blocks
if any(word in user_lower for word in ['add', 'plus', '+', 'subtract', 'minus', 'take away', '-']):
numbers = re.findall(r'\d+', user_prompt)
if len(numbers) >= 2:
num1, num2 = int(numbers[0]), int(numbers[1])
operation = 'add' if any(w in user_lower for w in ['add', 'plus', '+']) else 'subtract'
if num1 <= 20 and num2 <= 20:
return create_counting_blocks(num1, num2, operation)
# Priority 5: Number Lines
if 'number line' in user_lower:
numbers = [int(n) for n in re.findall(r'\d+', user_prompt)]
if numbers:
start = min(numbers) - 2
end = max(numbers) + 2
return create_number_line(start, end, numbers, "Your Numbers on the Line")
# Priority 6: Place Value
if 'place value' in user_lower:
numbers = re.findall(r'\d+', user_prompt)
if numbers:
num = int(numbers[0])
if num <= 999:
return create_place_value_blocks(num)
# Fallback to static, general visuals
if any(word in user_lower for word in ['fraction', 'part']): return create_dynamic_fraction_circle(1, 2) # Show a default example
if any(word in user_lower for word in ['shape']): return create_shape_explorer()
if any(word in user_lower for word in ['money', 'coin']): return create_money_counter()
if any(word in user_lower for word in ['time', 'clock']): return create_clock_visual(10, 10) # Show a default example
return None # No relevant visual found
except Exception as e:
st.error(f"Could not create visual: {e}")
return None
# --- VISUAL TOOLBOX FUNCTIONS ---
def create_counting_blocks(num1, num2, operation):
"""(Dynamic) Create colorful counting blocks for addition/subtraction."""
html = f"""
<div style="padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; margin: 10px 0;">
<h3 style="color: white; text-align: center; margin-bottom: 20px;">🧮 Counting Blocks: {num1} {'+' if operation == 'add' else '−'} {num2}</h3>
<div style="display: flex; justify-content: center; align-items: center; gap: 20px; flex-wrap: wrap;">
<!-- Blocks for Num1 -->
<div style="display: flex; flex-wrap: wrap; gap: 5px; border: 2px dashed #FFE066; padding: 5px; border-radius: 5px; align-items: center; justify-content: center; min-width: 100px;"><div style="width: 100%; text-align:center; color: white; font-weight: bold;">{num1}</div>{''.join([f'<div style="width: 25px; height: 25px; background: #FF6B6B; border-radius: 5px;"></div>' for _ in range(num1)])}</div>
<div style="font-size: 40px; color: #FFE066;">{'+' if operation == 'add' else '−'}</div>
<!-- Blocks for Num2 -->
<div style="display: flex; flex-wrap: wrap; gap: 5px; border: 2px dashed #FFE066; padding: 5px; border-radius: 5px; align-items: center; justify-content: center; min-width: 100px;"><div style="width: 100%; text-align:center; color: white; font-weight: bold;">{num2}</div>{''.join([f'<div style="width: 25px; height: 25px; background: #4ECDC4; border-radius: 5px;"></div>' for _ in range(num2)])}</div>
<div style="font-size: 40px; color: #FFE066;">=</div>
<!-- Blocks for Answer -->
<div style="display: flex; flex-wrap: wrap; gap: 5px; border: 2px solid white; background: rgba(255,255,255,0.2); padding: 5px; border-radius: 5px; align-items: center; justify-content: center; min-width: 100px;"><div style="width: 100%; text-align:center; color: white; font-weight: bold;">{num1 + num2 if operation == 'add' else max(0, num1 - num2)}</div>{''.join([f'<div style="width: 25px; height: 25px; background: #95E1D3; border-radius: 5px;"></div>' for _ in range(num1 + num2 if operation == 'add' else max(0, num1 - num2))])}</div>
</div>
</div>"""
return html
def create_dynamic_fraction_circle(numerator, denominator):
"""(Dynamic) Generates an SVG of a pizza/pie to represent a fraction."""
if not (0 < numerator <= denominator): return "<p>I can only show proper fractions!</p>"
width, height, radius = 150, 150, 60
cx, cy = width / 2, height / 2
slices_html = ''
angle_step = 360 / denominator
for i in range(denominator):
start_angle, end_angle = i * angle_step, (i + 1) * angle_step
fill_color = "#FF6B6B" if i < numerator else "#DDDDDD"
start_rad, end_rad = math.radians(start_angle - 90), math.radians(end_angle - 90)
x1, y1 = cx + radius * math.cos(start_rad), cy + radius * math.sin(start_rad)
x2, y2 = cx + radius * math.cos(end_rad), cy + radius * math.sin(end_rad)
large_arc_flag = 1 if angle_step > 180 else 0
path_d = f"M {cx},{cy} L {x1},{y1} A {radius},{radius} 0 {large_arc_flag},1 {x2},{y2} Z"
slices_html += f'<path d="{path_d}" fill="{fill_color}" stroke="#333" stroke-width="2"/>'
html = f"""<div style="padding: 20px; background: linear-gradient(135deg, #A8EDEA 0%, #FED6E3 100%); border-radius: 15px; margin: 10px 0;"><h3 style="color: #333; text-align: center;">Fraction Pizza: {numerator}/{denominator}</h3><div style="display: flex; justify-content: center;"><svg width="{width}" height="{height}">{slices_html}</svg></div><p style="color: #333; text-align: center; margin-top: 15px; font-size: 18px;">The pizza is cut into <b>{denominator}</b> equal slices, and we are showing <b>{numerator}</b> of them! 🍕</p></div>"""
return html
def create_clock_visual(hours, minutes):
"""(Dynamic) Create a clock showing a specific time."""
min_angle = minutes * 6
hour_angle = (hours % 12 + minutes / 60) * 30
html = f"""<div style="padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; margin: 10px 0;"><h3 style="color: white; text-align: center; margin-bottom: 20px;">🕐 Learning Time!</h3><div style="display: flex; justify-content: center;"><svg width="250" height="250" viewBox="0 0 250 250" style="background: white; border-radius: 50%; border: 8px solid #FFE066;"><circle cx="125" cy="125" r="110" fill="white" stroke="#333" stroke-width="2"/><text x="125" y="45" text-anchor="middle" font-size="20" font-weight="bold" fill="#333">12</text><text x="205" y="130" text-anchor="middle" font-size="20" font-weight="bold" fill="#333">3</text><text x="125" y="215" text-anchor="middle" font-size="20" font-weight="bold" fill="#333">6</text><text x="45" y="130" text-anchor="middle" font-size="20" font-weight="bold" fill="#333">9</text><line x1="125" y1="125" x2="125" y2="40" stroke="#FF6B6B" stroke-width="6" stroke-linecap="round" transform="rotate({hour_angle}, 125, 125)"/><line x1="125" y1="125" x2="125" y2="25" stroke="#4ECDC4" stroke-width="4" stroke-linecap="round" transform="rotate({min_angle}, 125, 125)"/><circle cx="125" cy="125" r="8" fill="#333"/></svg></div><div style="text-align: center; margin-top: 20px;"><p style="color: #FFE066; font-size: 24px; font-weight: bold;">This clock shows {hours:02d}:{minutes:02d}</p><p style="color: white; font-size: 16px;">The short <span style="color:#FF6B6B">red</span> hand points to the hour. The long <span style="color:#4ECDC4">blue</span> hand points to the minutes.</p></div></div>"""
return html
def create_multiplication_array(rows, cols):
"""(NEW & Dynamic) Generates an SVG grid of dots to show multiplication."""
cell_size, gap = 25, 5
svg_width = cols * (cell_size + gap)
svg_height = rows * (cell_size + gap)
dots_html = "".join([f'<circle cx="{c * (cell_size + gap) + cell_size/2}" cy="{r * (cell_size + gap) + cell_size/2}" r="{cell_size/2 - 2}" fill="#FF6B6B"/>' for r in range(rows) for c in range(cols)])
html = f"""<div style="padding: 20px; background: linear-gradient(135deg, #FF9A9E 0%, #FECFEF 100%); border-radius: 15px; margin: 10px 0;"><h3 style="color:#333; text-align: center;">Multiplication Array: {rows} × {cols} = {rows * cols}</h3><div style="display: flex; justify-content: center; padding: 10px;"><svg width="{svg_width}" height="{svg_height}">{dots_html}</svg></div><p style="color: #333; text-align: center; font-size: 18px;">See? There are <b>{rows}</b> rows of <b>{cols}</b> dots. That's <b>{rows*cols}</b> dots in total!</p></div>"""
return html
def create_number_line(start, end, points, title="Number Line"):
"""(NEW & Dynamic) Creates a simple number line SVG."""
width = 600
padding = 30
# Handle the edge case where start equals end
if start >= end:
end = start + 1
scale = (width - 2 * padding) / (end - start)
def to_x(n): return padding + (n - start) * scale
ticks_html = "".join([f'<g transform="translate({to_x(i)}, 50)"><line y2="10" stroke="#aaa"/><text y="30" text-anchor="middle" fill="#555">{i}</text></g>' for i in range(start, end + 1)])
points_html = "".join([f'<g transform="translate({to_x(p)}, 50)"><circle r="8" fill="#FF6B6B" stroke="white" stroke-width="2"/><text y="-15" text-anchor="middle" font-weight="bold" fill="#D63031">{p}</text></g>' for p in points])
html = f"""<div style="padding: 20px; background: #f7f1e3; border-radius: 15px; margin: 10px 0;"><h3 style="text-align: center; color: #333;">{title}</h3><svg width="{width}" height="100"><line x1="{padding}" y1="50" x2="{width-padding}" y2="50" stroke="#333" stroke-width="2"/>{ticks_html}{points_html}</svg></div>"""
return html
def create_place_value_blocks(number):
"""(FIXED & Dynamic) Create place value blocks for understanding numbers."""
hundreds, tens, ones = number // 100, (number % 100) // 10, number % 10
# --- Hundreds Block HTML ---
h_block_html = ""
if hundreds > 0:
hundreds_grid = "".join(["<div style='background:#F5A6A6'></div>"] * 100)
hundreds_squares = "".join([f"""
<div style="width: 100px; height: 100px; background: #FF6B6B; border: 2px solid #D63031; display: grid; grid-template-columns: repeat(10, 1fr); gap: 2px; padding: 2px;">
{hundreds_grid}
</div>
""" for _ in range(hundreds)])
h_block_html = f"""
<div style="text-align: center;">
<h4>Hundreds: {hundreds}</h4>
<div style="display: flex; gap: 5px;">{hundreds_squares}</div>
</div>
"""
# --- Tens Block HTML ---
t_block_html = ""
if tens > 0:
tens_grid = "".join(["<div style='background:#A2E8E4'></div>"] * 10)
tens_sticks = "".join([f"""
<div style="width: 10px; height: 100px; background: #4ECDC4; border: 2px solid #00B894; display: grid; grid-template-rows: repeat(10, 1fr); gap: 2px; padding: 2px;">
{tens_grid}
</div>
""" for _ in range(tens)])
t_block_html = f"""
<div style="text-align: center;">
<h4>Tens: {tens}</h4>
<div style="display: flex; gap: 5px; align-items: flex-end;">{tens_sticks}</div>
</div>
"""
# --- Ones Block HTML ---
o_block_html = ""
if ones > 0:
ones_cubes = "".join(['<div style="width: 10px; height: 10px; background: #FFE066; border: 2px solid #FDCB6E;"></div>' for _ in range(ones)])
o_block_html = f"""
<div style="text-align: center;">
<h4>Ones: {ones}</h4>
<div style="display: flex; gap: 5px; align-items: flex-end; flex-wrap: wrap; width: 50px; justify-content: center;">{ones_cubes}</div>
</div>
"""
# --- Final Assembly ---
html = f"""
<div style="padding: 20px; background: linear-gradient(135deg, #dfe6e9 0%, #b2bec3 100%); border-radius: 15px; margin: 10px 0;">
<h3 style="color: #333; text-align: center;">Place Value Blocks for {number}</h3>
<div style="display: flex; justify-content: center; align-items: flex-end; gap: 20px; flex-wrap: wrap; padding: 20px 0; min-height: 150px;">
{h_block_html}
{t_block_html}
{o_block_html}
</div>
<div style="text-align: center; margin-top: 15px; padding: 10px; background: rgba(0,0,0,0.1); border-radius: 10px;">
<h4 style="color: #333; margin:0;">
{hundreds} Hundreds + {tens} Tens + {ones} Ones = {number}
</h4>
</div>
</div>
"""
return html
def create_shape_explorer():
"""(Static) Create colorful shape recognition tool."""
html = """<div style="padding: 20px; background: linear-gradient(135deg, #A8EDEA 0%, #FED6E3 100%); border-radius: 15px; margin: 10px 0;"><h3 style="color: #333; text-align: center; margin-bottom: 20px;">🔷 Shape Explorer!</h3><div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 20px; max-width: 600px; margin: 0 auto;"><div style="text-align: center; padding: 15px; background: white; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);"><h4 style="color: #333; margin-bottom: 10px;">Circle</h4><svg width="80" height="80"><circle cx="40" cy="40" r="35" fill="#FF6B6B" stroke="#333" stroke-width="3"/></svg><p style="color: #666; font-size: 12px; margin-top: 10px;">Round and smooth!</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);"><h4 style="color: #333; margin-bottom: 10px;">Square</h4><svg width="80" height="80"><rect x="12.5" y="12.5" width="55" height="55" fill="#4ECDC4" stroke="#333" stroke-width="3"/></svg><p style="color: #666; font-size: 12px; margin-top: 10px;">4 equal sides!</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);"><h4 style="color: #333; margin-bottom: 10px;">Triangle</h4><svg width="80" height="80"><polygon points="40,15 15,65 65,65" fill="#FFD93D" stroke="#333" stroke-width="3"/></svg><p style="color: #666; font-size: 12px; margin-top: 10px;">3 sides and corners!</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);"><h4 style="color: #333; margin-bottom: 10px;">Rectangle</h4><svg width="80" height="80"><rect x="10" y="25" width="60" height="30" fill="#95E1D3" stroke="#333" stroke-width="3"/></svg><p style="color: #666; font-size: 12px; margin-top: 10px;">4 sides, opposite sides equal!</p></div></div><p style="color: #333; text-align: center; margin-top: 20px; font-size: 18px;">Can you find these shapes around you? 🔍✨</p></div>"""
return html
def create_money_counter():
"""(Static) Create coin counting visual."""
html = """<div style="padding: 20px; background: linear-gradient(135deg, #FFE259 0%, #FFA751 100%); border-radius: 15px; margin: 10px 0;"><h3 style="color: #333; text-align: center; margin-bottom: 20px;">💰 Money Counter!</h3><div style="display: flex; justify-content: center; gap: 30px; flex-wrap: wrap;"><div style="text-align: center; padding: 15px; background: white; border-radius: 10px;"><h4 style="color: #333;">Penny</h4><div style="width: 50px; height: 50px; background: #CD7F32; border-radius: 50%; margin: 10px auto; display: flex; align-items: center; justify-content: center; border: 3px solid #8B4513;"><span style="color: white; font-weight: bold;">1¢</span></div><p style="color: #666; font-size: 12px;">1 cent</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px;"><h4 style="color: #333;">Nickel</h4><div style="width: 55px; height: 55px; background: #C0C0C0; border-radius: 50%; margin: 10px auto; display: flex; align-items: center; justify-content: center; border: 3px solid #808080;"><span style="color: #333; font-weight: bold;">5¢</span></div><p style="color: #666; font-size: 12px;">5 cents</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px;"><h4 style="color: #333;">Dime</h4><div style="width: 45px; height: 45px; background: #C0C0C0; border-radius: 50%; margin: 10px auto; display: flex; align-items: center; justify-content: center; border: 3px solid #808080;"><span style="color: #333; font-weight: bold;">10¢</span></div><p style="color: #666; font-size: 12px;">10 cents</p></div><div style="text-align: center; padding: 15px; background: white; border-radius: 10px;"><h4 style="color: #333;">Quarter</h4><div style="width: 60px; height: 60px; background: #C0C0C0; border-radius: 50%; margin: 10px auto; display: flex; align-items: center; justify-content: center; border: 3px solid #808080;"><span style="color: #333; font-weight: bold;">25¢</span></div><p style="color: #666; font-size: 12px;">25 cents</p></div></div><p style="color: #333; text-align: center; margin-top: 20px; font-size: 18px;">Practice counting coins to make different amounts! 🪙✨</p></div>"""
return html
# --- [The rest of your application code remains the same] ---
# --- API KEY & MODEL CONFIGURATION, SESSION STATE, DIALOGS, etc. ---
# --- API KEY & MODEL CONFIGURATION ---
load_dotenv()
api_key = None
try:
api_key = st.secrets["GOOGLE_API_KEY"]
except (KeyError, FileNotFoundError):
api_key = os.getenv("GOOGLE_API_KEY")
if api_key:
genai.configure(api_key=api_key)
# Main text model
model = genai.GenerativeModel(
model_name="gemini-1.5-flash",
system_instruction="""
You are "Math Jegna", an AI specializing exclusively in K-12 mathematics.
Your one and only function is to solve and explain math problems for children.
You are an AI math tutor that uses the Professor B methodology developed by Everard Barrett. This methodology is designed to activate children's natural learning capacities and present mathematics as a contextual, developmental story that makes sense.
IMPORTANT: When explaining mathematical concepts to young learners, mention that colorful visual aids will be provided to help illustrate the concept. Use phrases like:
- "Let me show you this with some colorful blocks..."
- "A fun visual will help you see how this works..."
- "I'll create a picture to help you understand this fraction..."
Focus on concepts appropriate for K-12 students:
- Basic counting and number recognition
- Simple addition and subtraction (using manipulatives)
- Multiplication as arrays or groups
- Basic shapes and geometry
- Place value with hundreds, tens, ones
- Money counting and coin recognition
- Time telling with analog clocks
- Simple patterns and sequences
- Basic measurement concepts
Always use age-appropriate language and relate math to real-world examples children understand.
Core Philosophy and Principles
1. Contextual Learning Approach
Present math as a story: Every mathematical concept should be taught as part of a continuing narrative that builds connections between ideas
Use concrete manipulatives: Always relate abstract concepts to physical, visual representations
Truth-telling: Present arithmetic computations simply and truthfully without confusing steps
2. Natural Learning Activation
Leverage natural capacities: Recognize that each child has mental capabilities designed to learn naturally
Story-based retention: Use stories and visual representations that children can easily remember
Reduced anxiety: Make math fun and engaging, not scary or confusing
3. Hands-on Learning
Mental gymnastics: Use finger counting, visual blocks, and interactive elements
No rote memorization: Focus on understanding through play and exploration
Build confidence: Celebrate small victories and progress
You are strictly forbidden from answering any question that is not mathematical in nature.
If you receive a non-mathematical question, you MUST decline with: "I can only answer math questions for students. Please ask me about numbers, shapes, counting, or other math topics!"
Keep explanations simple, encouraging, and fun for young learners.
"""
)
else:
st.error("🚨 Google API Key not found! Please add it to your secrets or a local .env file.")
st.stop()
# --- SESSION STATE & LOCAL STORAGE INITIALIZATION ---
if "chats" not in st.session_state:
try:
shared_chat_b64 = st.query_params.get("shared_chat")
if shared_chat_b64:
decoded_chat_json = base64.urlsafe_b64decode(shared_chat_b64).decode()
st.session_state.chats = {"Shared Chat": json.loads(decoded_chat_json)}
st.session_state.active_chat_key = "Shared Chat"
st.query_params.clear()
else:
raise ValueError("No shared chat")
except (TypeError, ValueError, Exception):
saved_data_json = localS.getItem("math_mentor_chats")
if saved_data_json:
saved_data = json.loads(saved_data_json)
st.session_state.chats = saved_data.get("chats", {})
st.session_state.active_chat_key = saved_data.get("active_chat_key", "New Chat")
else:
st.session_state.chats = {
"New Chat": [
{"role": "assistant", "content": "Hello! I'm Math Jegna, your friendly math helper! 🧠✨ I love helping students learn math with colorful pictures and fun activities. What would you like to learn about today? Maybe counting, shapes, or solving a math problem? 🌟"}
]
}
st.session_state.active_chat_key = "New Chat"
# --- RENAME DIALOG ---
@st.dialog("Rename Chat")
def rename_chat(chat_key):
st.write(f"Enter a new name for '{chat_key}':")
new_name = st.text_input("New Name", key=f"rename_input_{chat_key}")
if st.button("Save", key=f"save_rename_{chat_key}"):
if new_name and new_name not in st.session_state.chats:
st.session_state.chats[new_name] = st.session_state.chats.pop(chat_key)
st.session_state.active_chat_key = new_name
st.rerun()
elif not new_name:
st.error("Name cannot be empty.")
else:
st.error("A chat with this name already exists.")
# --- DELETE CONFIRMATION DIALOG ---
@st.dialog("Delete Chat")
def delete_chat(chat_key):
st.warning(f"Are you sure you want to delete '{chat_key}'? This cannot be undone.")
if st.button("Yes, Delete", type="primary", key=f"confirm_delete_{chat_key}"):
st.session_state.chats.pop(chat_key)
# Add the logic to switch to a new or different chat after deletion
if st.session_state.active_chat_key == chat_key:
# Simple fallback to the first available chat or a new one
if st.session_state.chats:
st.session_state.active_chat_key = next(iter(st.session_state.chats))
else:
# Create a new chat if none are left
st.session_state.chats["New Chat"] = [
{"role": "assistant", "content": "Hello! Let's start a new math adventure! 🚀"}
]
st.session_state.active_chat_key = "New Chat"
st.rerun()
# --- MAIN APP LAYOUT ---
with st.sidebar:
st.title("🧮 Math Jegna")
st.write("Your K-8 AI Math Tutor")
st.divider()
# Chat history list
for chat_key in list(st.session_state.chats.keys()):
col1, col2, col3 = st.columns([0.6, 0.2, 0.2])
with col1:
if st.button(chat_key, key=f"switch_{chat_key}", use_container_width=True, type="primary" if st.session_state.active_chat_key == chat_key else "secondary"):
st.session_state.active_chat_key = chat_key
st.rerun()
with col2:
if st.button("✏️", key=f"rename_{chat_key}", help="Rename Chat"):
rename_chat(chat_key)
with col3:
if st.button("🗑️", key=f"delete_{chat_key}", help="Delete Chat"):
delete_chat(chat_key)
if st.button("➕ New Chat", use_container_width=True):
new_chat_name = f"Chat {len(st.session_state.chats) + 1}"
# Ensure the name is unique
while new_chat_name in st.session_state.chats:
new_chat_name += "*"
st.session_state.chats[new_chat_name] = [
{"role": "assistant", "content": "Ready for a new math problem! What's on your mind? 😃"}
]
st.session_state.active_chat_key = new_chat_name
st.rerun()
st.divider()
# Save chats to local storage
if st.button("💾 Save Chats", use_container_width=True):
data_to_save = {
"chats": st.session_state.chats,
"active_chat_key": st.session_state.active_chat_key
}
localS.setItem("math_mentor_chats", json.dumps(data_to_save))
st.toast("Chats saved to your browser!", icon="✅")
# Download chat button
active_chat_history = st.session_state.chats[st.session_state.active_chat_key]
download_str = format_chat_for_download(active_chat_history)
st.download_button(
label="📥 Download Chat",
data=download_str,
file_name=f"{st.session_state.active_chat_key.replace(' ', '_')}_history.md",
mime="text/markdown",
use_container_width=True
)
# Share chat button
if st.button("🔗 Share Chat", use_container_width=True):
chat_json = json.dumps(st.session_state.chats[st.session_state.active_chat_key])
chat_b64 = base64.urlsafe_b64encode(chat_json.encode()).decode()
# This part might need adjustment depending on how Streamlit Community Cloud handles base URLs
share_url = f"https://huggingface.co/spaces/YOUR_SPACE_HERE?shared_chat={chat_b64}" # Placeholder
st.code(share_url)
st.info("Copy the URL above to share this specific chat! (You might need to update the base URL)")
st.header(f"Chatting with Math Jegna: _{st.session_state.active_chat_key}_")
# Display chat messages
for message in st.session_state.chats[st.session_state.active_chat_key]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# If a visual was generated and saved with the message, display it
if "visual_html" in message and message["visual_html"]:
components.html(message["visual_html"], height=400, scrolling=True)
# User input
if prompt := st.chat_input("Ask a K-8 math question..."):
# Add user message to chat history
st.session_state.chats[st.session_state.active_chat_key].append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Prepare chat for Gemini API
gemini_chat_history = [
{"role": convert_role_for_gemini(m["role"]), "parts": [m["content"]]}
for m in st.session_state.chats[st.session_state.active_chat_key]
]
# Generate response
with st.chat_message("assistant"):
with st.spinner("Math Jegna is thinking..."):
try:
chat_session = model.start_chat(history=gemini_chat_history)
response = chat_session.send_message(prompt, stream=True)
full_response = ""
response_container = st.empty()
for chunk in response:
full_response += chunk.text
response_container.markdown(full_response + " ▌")
response_container.markdown(full_response)
# After generating text, decide if a visual is needed and generate it
visual_html_content = None
if should_generate_visual(prompt, full_response):
visual_html_content = create_visual_manipulative(prompt, full_response)
if visual_html_content:
components.html(visual_html_content, height=400, scrolling=True)
# Add AI response and visual to session state
st.session_state.chats[st.session_state.active_chat_key].append({
"role": "assistant",
"content": full_response,
"visual_html": visual_html_content # Store the visual with the message
})
except genai.types.generation_types.BlockedPromptException as e:
error_message = "I can only answer math questions for students. Please ask me about numbers, shapes, or other math topics!"
st.error(error_message)
st.session_state.chats[st.session_state.active_chat_key].append({"role": "assistant", "content": error_message, "visual_html": None})
except Exception as e:
st.error(f"An error occurred: {e}")