Ihssane123's picture
Update app.py
8e2fec9 verified
import random
import gradio as gr
import pandas as pd
import numpy as np
from Src.Processing import load_data
from Src.Processing import process_data
from Src.Inference import load_model
import torchvision.models as models
from torchvision import transforms
from Src.Processing_img import (
get_style_model_and_losses,
image_loader,
run_style_transfer,
save_image,
gram_matrix
)
import torch
import tempfile
import time # For simulating delays
from PIL import Image, ImageDraw, ImageFont # Ensure ImageDraw and ImageFont are imported
import os # For file operations
import mne
import matplotlib.pyplot as plt
import io
import pyvista as pv
import matplotlib.cm as cm
import gradio as gr
pv.set_plot_theme("document") # A simple theme
pv.set_jupyter_backend('html')
# --- Data for demonstration ---
# Dummy data for Emotion Distribution Bar Chart
# In a real app, this would come from your PSD analysis
dummy_emotion_data = pd.DataFrame({
'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'],
'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3]
})
int_to_emotion = {
0: 'sad',
1: 'dis',
2: 'fear',
3: 'neu',
4: 'joy',
5: 'ten',
6: 'ins'
}
abr_to_emotion = {
'sad': "sadness",
'dis': "disgust",
'fear': "fear",
'neu': "neutral",
'joy': "joy",
'ten': 'Tenderness',
'ins': "inspiration"
}
# --- Local Image Paths Setup for Dynamic Loading ---
# Define a base directory for all painters' images
# In Hugging Face Spaces, this would be a folder like 'Painters/' in your repository
PAINTERS_BASE_DIR = "Painters"
EMOTION_BASE_DIR = "Emotions"
model_path = "models\lstm_emotion_model_state.pth"
input_size = 320
hidden_size=50
output_size = 7
num_layers=1
# Define painters and some example "filenames" to create placeholders for
painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador Dalí"]
Base_Dir = "Datasets"
# This dictionary defines what placeholder files to create and their captions.
# The actual gallery content will be read from the file system.
PAINTER_PLACEHOLDER_DATA = {
"Pablo Picasso": [
("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"),
("The Weeping Woman (1937).png", "The Weeping Woman (1937)"),
("Three Musicians (1921).png", "Three Musicians (1921)"),
],
"Vincent van Gogh": [
("Sunflowers (1888).png", "Sunflowers (1888)"),
("The Starry Night (1889).png", "The Starry Night (1889)"),
("The Potato Eaters (1885).png", "The Potato Eaters (1885)"),
],
"Salvador Dalí": [
("Persistence of Memory (1931).png", "Persistence of Memory (1931)"),
("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"),
("Sleep (1937).png", "Sleep (1937)"),
],
}
# --- Define the specific PSD files to choose from ---
predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"] # You can put full paths here if they are actual files
# --- Core Functions (Simulated) ---
def upload_psd_file(selected_file_name):
"""
Processes a selected PSD file, performs inference, and prepares emotion distribution data.
"""
if selected_file_name is None:
# If no file is selected, return a dummy plot hidden
# Return the dummy DataFrame and an empty DataFrame for the state
return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame()
# --- Load and Process PSD Data ---
psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/')
# In a real scenario, you'd handle file existence check and errors for load_data
try:
global np_data
np_data = load_data(psd_file_path)
print(f"np data orig {np_data.shape}")
except FileNotFoundError:
print(f"Error: PSD file not found at {psd_file_path}")
# Return a plot with error message or just hide it
return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame()
final_data = process_data(np_data)
# Ensure data is suitable for LSTM (e.g., (batch, sequence_length, input_size))
# If final_data is (sequence_length, input_size), add a batch dimension
torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0)
print(f"Processed data shape for model: {torch_data.shape}")
# --- Inference ---
# Ensure model_path is correct relative to where app.py is run
# If 'models' is at your_project_root, adjust path if needed
# Assuming 'models' directory is at 'your_project_root' level
absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth")
loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers)
loaded_model.eval() # Set model to evaluation mode
# Pass the prepared torch_data to the model
with torch.no_grad(): # Disable gradient calculation for inference
predicted_logits, _ = loaded_model(torch_data) # LSTM returns (output, (h_n, c_n))
# Get the most probable emotion index for each time step in the sequence
final_output_indices = torch.argmax(predicted_logits, dim=2) # Shape: (batch_size, sequence_length)
# Flatten the sequence to count overall emotion frequencies
# If batch size is 1, and sequence is long, this view(-1) works for counting all predictions
all_predicted_indices = final_output_indices.view(-1)
print(f"All predicted indices (flattened): {all_predicted_indices}")
# Count occurrences of each predicted emotion index
values_count = torch.bincount(all_predicted_indices, minlength=output_size) # Use minlength to ensure all 7 indices are considered
print(f"Raw bincount: {values_count}")
# --- Create Emotion Distribution DataFrame ---
# Initialize emotions_count with all emotions set to 0 frequency
emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)} # Use .strip() to remove trailing space from 'sad '
# Update counts only for emotions that were actually predicted
for idx, count in enumerate(values_count):
if idx < output_size: # Ensure index is within the expected range
emotions_count[int_to_emotion[idx].strip()] = count.item() # Use .strip() here too
# Convert dictionary to DataFrame
dom_emotion = max(emotions_count, key=emotions_count.get)
# Ensure column names match what gr.BarPlot expects: "Emotion" and "Frequency"
emotion_data = pd.DataFrame({
"Emotion": list(emotions_count.keys()),
"Frequency": list(emotions_count.values())
})
# Optional: Sort DataFrame by emotion name or frequency if desired for consistent plotting
emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True)
print(f"Final emotion_data DataFrame:\n{emotion_data}")
# CORRECTED: Return the DataFrame, NOT a gr.BarPlot object
# Return both the DataFrame for the plot and the DataFrame itself for the state
return gr.BarPlot(
emotion_data,
x="Emotion",
y="Frequency",
label="Emotion Distribution",
visible=True,
y_title="Frequency"
), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True)
def update_paintings(painter_name):
"""
Updates the gallery with paintings specific to the selected painter by
dynamically listing files in the painter's directory.
"""
painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/')
print(painter_dir)
artist_paintings_for_gallery = []
if os.path.isdir(painter_dir):
for filename in sorted(os.listdir(painter_dir)): # Sort for consistent order
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
file_path = os.path.join(painter_dir, filename).replace(os.sep, '/')
print(file_path)
# Use filename without extension as title, or create more sophisticated parsing
title_with_ext = os.path.splitext(filename)[0]
artist_paintings_for_gallery.append((file_path, title_with_ext))
print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}")
return artist_paintings_for_gallery # Return the list directly for the gallery
def generate_my_art(painter, chosen_painting, dom_emotion):
"""
Simulates the image generation process.
'chosen_painting_info' will be the single selected item from gr.Gallery.select(),
e.g., ['Painters/Pablo Picasso/Dora Maar with Cat (1941).png', 'Dora Maar with Cat (1941)']
We need to extract the path.
"""
print("generating started")
print(f"painter: {painter}")
print(f"choosen painting: {chosen_painting}")
if not painter or not chosen_painting:
# Provide default outputs to ensure Gradio components are updated correctly
return "Please select a painter and a painting.", None, None
##style image_path
img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting)
print(f"img_stype_path: {img_style_pth}")
# Display initial status and disable button
# --- Simulate your NST or Diffusion Process Here ---
# In a real scenario, this would involve your actual NST code.
# It would use `selected_painting_path` as the style image.
# A content image would be dynamically generated (e.g., a simple colored canvas or
# abstract representation based on the PSD analysis's dominant emotion).
time.sleep(3) # Simulate processing time
# --- Simulate saving a generated image locally ---
# This PIL Image would be the actual result of your NST.
# We save it to the 'generated_art' directory.
"""generated_img_pil = Image.new('RGB', (400, 400), color=(np.random.randint(0,255), np.random.randint(0,255), np.random.randint(0,255)))
generated_image_local_path = f"generated_art/generated_output_{int(time.time())}.png"
generated_img_pil.save(generated_image_local_path)
# For the blended image, let's just return the selected style image path for now.
# In a real app, this might be a version of the 'generated_img_pil' with a final blend.
blended_image_local_path = selected_painting_path """
##original image
emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion)
image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)]
original_image_pth = os.path.join(emotion_pth, image_name)
print(f"original img _path: {original_image_pth}")
final_message = f"Art generated based on {painter}'s {chosen_painting} style!"
## Neural Style Transfer added here
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize = 512 if torch.cuda.is_available() else 256
loader = transforms.Compose([
transforms.Resize((imsize, imsize)),
transforms.ToTensor()
])
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
style_img = image_loader(img_style_pth, loader, device)
content_img = image_loader(original_image_pth, loader, device)
input_img = content_img.clone()
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
content_img, style_img, input_img, content_layers, style_layers, device)
save_image(output, "stylized_output.jpg")
print("Stylized image saved as 'stylized_output.jpg'")
stylized_img_path = 'stylized_output.jpg'
# Return final results and re-enable button
yield gr.Textbox(final_message), original_image_pth, stylized_img_path
def generate_topomap(n_channels, n_time):
n_sensors = 64
if n_channels is None or n_time is None:
print("they are None")
n_channels = 4
n_time = 500
# ----------------------------
# 2. Load standard 10-20 montage
# ----------------------------
montage = mne.channels.make_standard_montage('standard_1020')
# Filter only the standard 64 EEG electrodes
standard_64_chs = [
'Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8',
'FC5', 'FC1', 'FC2', 'FC6', 'T7', 'C3', 'Cz', 'C4', 'T8',
'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8',
'POz', 'O1', 'Oz', 'O2', 'Fpz', 'AF7', 'AF3', 'AF4', 'AF8',
'F5', 'F1', 'F2', 'F6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2',
'C6', 'CP3', 'CPz', 'CP4', 'P5', 'P1', 'P2', 'P6', 'PO3', 'PO4',
'PO7', 'PO8', 'PO9', 'PO10', 'O1', 'O2', 'FT7', 'FT8', 'TP7', 'TP8'
]
# exactly 64 channels
ch_pos_dict = montage.get_positions()['ch_pos']
ch_pos_dict_filtered = {ch: ch_pos_dict[ch] for ch in standard_64_chs}
channel_names = list(ch_pos_dict_filtered.keys())
ch_pos_array = np.array([ch_pos_dict_filtered[ch] for ch in standard_64_chs]) # Nx3
ch_pos_2d = ch_pos_array[:, :2] # For 2D topomap
# ----------------------------
# 3. Choose a time index and frequency index
# ----------------------------
new_data = np_data.reshape(64, 630, 5)
print(f"shape: {new_data.shape}")
print(f"n channels: {n_channels}")
psd_snapshot = new_data[:, n_time - 1, n_channels - 1]
# Normalize PSD for coloring
psd_norm = (psd_snapshot - psd_snapshot.min()) / (psd_snapshot.max() - psd_snapshot.min())
# ----------------------------
# 4. Plot 2D topomap using MNE
# ----------------------------
print(f"shape psd :{psd_snapshot.shape}")
fig, ax = plt.subplots()
mne.viz.plot_topomap(
psd_snapshot,
ch_pos_2d,
names=channel_names,
show=False,
axes=ax
)
# Save the generated topomap to a temp file
tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
fig.savefig(tmpfile.name, dpi=150, bbox_inches="tight")
plt.close(fig)
return tmpfile.name
predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"]
# --- Gradio Interface Definition ---
with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo:
# Define the gr.State component here, accessible throughout the Blocks
# This will hold the information of the SINGLE selected painting from the gallery
# This will hold the DataFrame of the emotion distribution (to be passed to generate_my_art)
current_emotion_df_state = gr.State(value=pd.DataFrame())
# Header Section
gr.Markdown(
"""
<h1 style="text-align: center;font-size: 5em; padding: 20px; font-weight: bold;">Brain Emotion Decoder 🧠🎨</h1>
<p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;">
Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity,
generating a personalized artwork that comes to life within an interactive 3D brain model. Discover the art of your inner self.
</p>
"""
)
with gr.Row():
# Left Column: Input and Emotion Distribution
with gr.Column(scale=1):
gr.Markdown("<h2 font-size: 2em;>1. Choose a PSD file<h2>")
# Radio buttons to select from predefined files
psd_file_selection = gr.Radio(
choices=predefined_psd_files,
label="Select a PSD file for analysis",
value=predefined_psd_files[0], # Default selection
interactive=True
)
# Button to trigger PSD analysis
analyze_psd_button = gr.Button("Analyze PSD File", variant="secondary")
gr.Markdown("<h2 font-size: 2em;>2. Emotion Distribution<h2>")
# Bar plot for emotion distribution
emotion_distribution_plot = gr.BarPlot(
dummy_emotion_data,
x="Emotion",
y="Value",
label="Emotion Distribution",
height=300,
x_title="Emotion Type",
y_title="Frequency",
visible=False # Hidden until analysis is triggered
)
dom_emotion = gr.Textbox(label = "dominant emotion", visible=False)
# Right Column: Art Museum and Generation
with gr.Column(scale=1):
gr.Markdown("<h3>Your Art Mesum</h3>") # Kept original heading
gr.Markdown("<h3>3. Choose your favourite painter</h3>")
painter_dropdown = gr.Dropdown(
choices=painters,
value="Pablo Picasso", # Default selection
label="Select a Painter"
)
gr.Markdown("<h3>4. Choose your favourite painting</h3>")
# Gallery to display paintings for selection
painting_gallery = gr.Gallery(
# Correct initial value and visibility
value=update_paintings("Pablo Picasso"), # Initial load for Picasso's paintings
label="Select a Painting",
height=300,
columns=3,
rows=1,
object_fit="contain",
preview=True, # Allows clicking to see larger image
interactive=True, # Make it selectable
elem_id="painting_gallery",
visible=True, # Should be visible by default
)
# Button to trigger art generation
selected_painting_name = gr.Textbox(visible=False)
generate_button = gr.Button("Generate My Art", variant="primary", scale=0)
# Status message for image generation
status_message = gr.Textbox(
value="Click 'Generate My Art' to begin.",
label="Generation Status",
interactive=False,
show_label=False,
lines=1
)
# Output section on a separate "page" or revealed dynamically
gr.Markdown(
"""
<h1 style="text-align: center;">Your Generated Artwork</h1>
<p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;>
Once your brain's emotional data is processed, we pinpoint the <b>dominant emotion</b>. This single feeling inspires a <b>personalized artwork</b>, generated using <b>diffusion techniques</b> and blended with <b>my AI painting style</b>. You can then <b>download</b> this unique visual representation of your inner self.
</p>
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("<h3>Generated Image</h3>")
generated_image_output = gr.Image(label="Generated Image", show_label=False, height=300)
gr.Markdown("<h3>Blended Style Image</h3>")
blended_image_output = gr.Image(label="Blended Style Image", show_label=False, height=300)
with gr.Column(scale=1):
gr.Markdown("<h3>Brain Topomap</h3>")
channels_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Channels", interactive=True)
timestamp_slider = gr.Slider(minimum=1, maximum=630, value=1, step=1, label="Timestamp", interactive=True)
mne_2d_img = gr.Image(visible=True)
generate_button.click(generate_topomap, outputs=mne_2d_img)
# --- Event Listeners ---
analyze_psd_button.click(
upload_psd_file,
inputs=[psd_file_selection], # Input is the selected radio button value (file name)
outputs=[emotion_distribution_plot, current_emotion_df_state, dom_emotion] # CORRECTED: Added current_emotion_df_state to outputs
)
# When painter dropdown changes, update the gallery content and reset selected_painting_state
painter_dropdown.change(
update_paintings, # This updates the gallery
inputs=[painter_dropdown],
outputs=[painting_gallery] # Only output the gallery content directly
)
# IMPORTANT: Use the .select() method of gr.Gallery to capture the specific clicked item.
# The 'select' event passes the selected value directly as the argument to the function.
# We use a lambda to simply return that selected value and store it in our state.
def on_select(evt: gr.SelectData):
print("this function started")
print(f"Image index: {evt.index}\nImage value: {evt.value['image']['orig_name']}")
return evt.value['image']['orig_name']
painting_gallery.select(
on_select, # This lambda receives the selected image info (path, title)
outputs=[selected_painting_name] # The output updates our gr.State component
)
# The generate_button now correctly uses the value from selected_painting_state
generate_button.click(
generate_my_art,
inputs=[painter_dropdown, selected_painting_name, dom_emotion], # Pass painter and the SELECTED painting
outputs=[status_message, generated_image_output, blended_image_output]
)
## sliders event listener
channels_slider.change(fn=generate_topomap, inputs=[channels_slider, timestamp_slider], outputs=mne_2d_img)
timestamp_slider.change(fn=generate_topomap, inputs=[channels_slider, timestamp_slider], outputs=mne_2d_img)
# Launch the demo
if __name__ == "__main__":
# Ensure project_root_dir is defined for this block if you uncomment these lines
# project_root_dir = os.path.dirname(os.path.abspath(__file__))
# project_root_dir = os.path.dirname(project_root_dir)
# print(f"Loading LSTM model from: {os.path.join(project_root_dir, model_path)}")
# _ = load_model(os.path.join(project_root_dir, model_path), input_size, hidden_size, output_size, num_layers)
demo.launch()