Spaces:
Sleeping
Sleeping
import random | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from Src.Processing import load_data | |
from Src.Processing import process_data | |
from Src.Inference import load_model | |
import torchvision.models as models | |
from torchvision import transforms | |
from Src.Processing_img import ( | |
get_style_model_and_losses, | |
image_loader, | |
run_style_transfer, | |
save_image, | |
gram_matrix | |
) | |
import torch | |
import tempfile | |
import time # For simulating delays | |
from PIL import Image, ImageDraw, ImageFont # Ensure ImageDraw and ImageFont are imported | |
import os # For file operations | |
import mne | |
import matplotlib.pyplot as plt | |
import io | |
import pyvista as pv | |
import matplotlib.cm as cm | |
import gradio as gr | |
pv.set_plot_theme("document") # A simple theme | |
pv.set_jupyter_backend('html') | |
# --- Data for demonstration --- | |
# Dummy data for Emotion Distribution Bar Chart | |
# In a real app, this would come from your PSD analysis | |
dummy_emotion_data = pd.DataFrame({ | |
'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'], | |
'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3] | |
}) | |
int_to_emotion = { | |
0: 'sad', | |
1: 'dis', | |
2: 'fear', | |
3: 'neu', | |
4: 'joy', | |
5: 'ten', | |
6: 'ins' | |
} | |
abr_to_emotion = { | |
'sad': "sadness", | |
'dis': "disgust", | |
'fear': "fear", | |
'neu': "neutral", | |
'joy': "joy", | |
'ten': 'Tenderness', | |
'ins': "inspiration" | |
} | |
# --- Local Image Paths Setup for Dynamic Loading --- | |
# Define a base directory for all painters' images | |
# In Hugging Face Spaces, this would be a folder like 'Painters/' in your repository | |
PAINTERS_BASE_DIR = "Painters" | |
EMOTION_BASE_DIR = "Emotions" | |
model_path = "models\lstm_emotion_model_state.pth" | |
input_size = 320 | |
hidden_size=50 | |
output_size = 7 | |
num_layers=1 | |
# Define painters and some example "filenames" to create placeholders for | |
painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador Dalí"] | |
Base_Dir = "Datasets" | |
# This dictionary defines what placeholder files to create and their captions. | |
# The actual gallery content will be read from the file system. | |
PAINTER_PLACEHOLDER_DATA = { | |
"Pablo Picasso": [ | |
("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"), | |
("The Weeping Woman (1937).png", "The Weeping Woman (1937)"), | |
("Three Musicians (1921).png", "Three Musicians (1921)"), | |
], | |
"Vincent van Gogh": [ | |
("Sunflowers (1888).png", "Sunflowers (1888)"), | |
("The Starry Night (1889).png", "The Starry Night (1889)"), | |
("The Potato Eaters (1885).png", "The Potato Eaters (1885)"), | |
], | |
"Salvador Dalí": [ | |
("Persistence of Memory (1931).png", "Persistence of Memory (1931)"), | |
("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"), | |
("Sleep (1937).png", "Sleep (1937)"), | |
], | |
} | |
# --- Define the specific PSD files to choose from --- | |
predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"] # You can put full paths here if they are actual files | |
# --- Core Functions (Simulated) --- | |
def upload_psd_file(selected_file_name): | |
""" | |
Processes a selected PSD file, performs inference, and prepares emotion distribution data. | |
""" | |
if selected_file_name is None: | |
# If no file is selected, return a dummy plot hidden | |
# Return the dummy DataFrame and an empty DataFrame for the state | |
return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame() | |
# --- Load and Process PSD Data --- | |
psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/') | |
# In a real scenario, you'd handle file existence check and errors for load_data | |
try: | |
global np_data | |
np_data = load_data(psd_file_path) | |
print(f"np data orig {np_data.shape}") | |
except FileNotFoundError: | |
print(f"Error: PSD file not found at {psd_file_path}") | |
# Return a plot with error message or just hide it | |
return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame() | |
final_data = process_data(np_data) | |
# Ensure data is suitable for LSTM (e.g., (batch, sequence_length, input_size)) | |
# If final_data is (sequence_length, input_size), add a batch dimension | |
torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0) | |
print(f"Processed data shape for model: {torch_data.shape}") | |
# --- Inference --- | |
# Ensure model_path is correct relative to where app.py is run | |
# If 'models' is at your_project_root, adjust path if needed | |
# Assuming 'models' directory is at 'your_project_root' level | |
absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth") | |
loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers) | |
loaded_model.eval() # Set model to evaluation mode | |
# Pass the prepared torch_data to the model | |
with torch.no_grad(): # Disable gradient calculation for inference | |
predicted_logits, _ = loaded_model(torch_data) # LSTM returns (output, (h_n, c_n)) | |
# Get the most probable emotion index for each time step in the sequence | |
final_output_indices = torch.argmax(predicted_logits, dim=2) # Shape: (batch_size, sequence_length) | |
# Flatten the sequence to count overall emotion frequencies | |
# If batch size is 1, and sequence is long, this view(-1) works for counting all predictions | |
all_predicted_indices = final_output_indices.view(-1) | |
print(f"All predicted indices (flattened): {all_predicted_indices}") | |
# Count occurrences of each predicted emotion index | |
values_count = torch.bincount(all_predicted_indices, minlength=output_size) # Use minlength to ensure all 7 indices are considered | |
print(f"Raw bincount: {values_count}") | |
# --- Create Emotion Distribution DataFrame --- | |
# Initialize emotions_count with all emotions set to 0 frequency | |
emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)} # Use .strip() to remove trailing space from 'sad ' | |
# Update counts only for emotions that were actually predicted | |
for idx, count in enumerate(values_count): | |
if idx < output_size: # Ensure index is within the expected range | |
emotions_count[int_to_emotion[idx].strip()] = count.item() # Use .strip() here too | |
# Convert dictionary to DataFrame | |
dom_emotion = max(emotions_count, key=emotions_count.get) | |
# Ensure column names match what gr.BarPlot expects: "Emotion" and "Frequency" | |
emotion_data = pd.DataFrame({ | |
"Emotion": list(emotions_count.keys()), | |
"Frequency": list(emotions_count.values()) | |
}) | |
# Optional: Sort DataFrame by emotion name or frequency if desired for consistent plotting | |
emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True) | |
print(f"Final emotion_data DataFrame:\n{emotion_data}") | |
# CORRECTED: Return the DataFrame, NOT a gr.BarPlot object | |
# Return both the DataFrame for the plot and the DataFrame itself for the state | |
return gr.BarPlot( | |
emotion_data, | |
x="Emotion", | |
y="Frequency", | |
label="Emotion Distribution", | |
visible=True, | |
y_title="Frequency" | |
), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True) | |
def update_paintings(painter_name): | |
""" | |
Updates the gallery with paintings specific to the selected painter by | |
dynamically listing files in the painter's directory. | |
""" | |
painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/') | |
print(painter_dir) | |
artist_paintings_for_gallery = [] | |
if os.path.isdir(painter_dir): | |
for filename in sorted(os.listdir(painter_dir)): # Sort for consistent order | |
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')): | |
file_path = os.path.join(painter_dir, filename).replace(os.sep, '/') | |
print(file_path) | |
# Use filename without extension as title, or create more sophisticated parsing | |
title_with_ext = os.path.splitext(filename)[0] | |
artist_paintings_for_gallery.append((file_path, title_with_ext)) | |
print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}") | |
return artist_paintings_for_gallery # Return the list directly for the gallery | |
def generate_my_art(painter, chosen_painting, dom_emotion): | |
""" | |
Simulates the image generation process. | |
'chosen_painting_info' will be the single selected item from gr.Gallery.select(), | |
e.g., ['Painters/Pablo Picasso/Dora Maar with Cat (1941).png', 'Dora Maar with Cat (1941)'] | |
We need to extract the path. | |
""" | |
print("generating started") | |
print(f"painter: {painter}") | |
print(f"choosen painting: {chosen_painting}") | |
if not painter or not chosen_painting: | |
# Provide default outputs to ensure Gradio components are updated correctly | |
return "Please select a painter and a painting.", None, None | |
##style image_path | |
img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting) | |
print(f"img_stype_path: {img_style_pth}") | |
# Display initial status and disable button | |
# --- Simulate your NST or Diffusion Process Here --- | |
# In a real scenario, this would involve your actual NST code. | |
# It would use `selected_painting_path` as the style image. | |
# A content image would be dynamically generated (e.g., a simple colored canvas or | |
# abstract representation based on the PSD analysis's dominant emotion). | |
time.sleep(3) # Simulate processing time | |
# --- Simulate saving a generated image locally --- | |
# This PIL Image would be the actual result of your NST. | |
# We save it to the 'generated_art' directory. | |
"""generated_img_pil = Image.new('RGB', (400, 400), color=(np.random.randint(0,255), np.random.randint(0,255), np.random.randint(0,255))) | |
generated_image_local_path = f"generated_art/generated_output_{int(time.time())}.png" | |
generated_img_pil.save(generated_image_local_path) | |
# For the blended image, let's just return the selected style image path for now. | |
# In a real app, this might be a version of the 'generated_img_pil' with a final blend. | |
blended_image_local_path = selected_painting_path """ | |
##original image | |
emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion) | |
image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)] | |
original_image_pth = os.path.join(emotion_pth, image_name) | |
print(f"original img _path: {original_image_pth}") | |
final_message = f"Art generated based on {painter}'s {chosen_painting} style!" | |
## Neural Style Transfer added here | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
imsize = 512 if torch.cuda.is_available() else 256 | |
loader = transforms.Compose([ | |
transforms.Resize((imsize, imsize)), | |
transforms.ToTensor() | |
]) | |
content_layers = ['conv_4'] | |
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] | |
cnn = models.vgg19(pretrained=True).features.to(device).eval() | |
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
style_img = image_loader(img_style_pth, loader, device) | |
content_img = image_loader(original_image_pth, loader, device) | |
input_img = content_img.clone() | |
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, | |
content_img, style_img, input_img, content_layers, style_layers, device) | |
save_image(output, "stylized_output.jpg") | |
print("Stylized image saved as 'stylized_output.jpg'") | |
stylized_img_path = 'stylized_output.jpg' | |
# Return final results and re-enable button | |
yield gr.Textbox(final_message), original_image_pth, stylized_img_path | |
def generate_topomap(n_channels, n_time): | |
n_sensors = 64 | |
if n_channels is None or n_time is None: | |
print("they are None") | |
n_channels = 4 | |
n_time = 500 | |
# ---------------------------- | |
# 2. Load standard 10-20 montage | |
# ---------------------------- | |
montage = mne.channels.make_standard_montage('standard_1020') | |
# Filter only the standard 64 EEG electrodes | |
standard_64_chs = [ | |
'Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', | |
'FC5', 'FC1', 'FC2', 'FC6', 'T7', 'C3', 'Cz', 'C4', 'T8', | |
'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8', | |
'POz', 'O1', 'Oz', 'O2', 'Fpz', 'AF7', 'AF3', 'AF4', 'AF8', | |
'F5', 'F1', 'F2', 'F6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', | |
'C6', 'CP3', 'CPz', 'CP4', 'P5', 'P1', 'P2', 'P6', 'PO3', 'PO4', | |
'PO7', 'PO8', 'PO9', 'PO10', 'O1', 'O2', 'FT7', 'FT8', 'TP7', 'TP8' | |
] | |
# exactly 64 channels | |
ch_pos_dict = montage.get_positions()['ch_pos'] | |
ch_pos_dict_filtered = {ch: ch_pos_dict[ch] for ch in standard_64_chs} | |
channel_names = list(ch_pos_dict_filtered.keys()) | |
ch_pos_array = np.array([ch_pos_dict_filtered[ch] for ch in standard_64_chs]) # Nx3 | |
ch_pos_2d = ch_pos_array[:, :2] # For 2D topomap | |
# ---------------------------- | |
# 3. Choose a time index and frequency index | |
# ---------------------------- | |
new_data = np_data.reshape(64, 630, 5) | |
print(f"shape: {new_data.shape}") | |
print(f"n channels: {n_channels}") | |
psd_snapshot = new_data[:, n_time - 1, n_channels - 1] | |
# Normalize PSD for coloring | |
psd_norm = (psd_snapshot - psd_snapshot.min()) / (psd_snapshot.max() - psd_snapshot.min()) | |
# ---------------------------- | |
# 4. Plot 2D topomap using MNE | |
# ---------------------------- | |
print(f"shape psd :{psd_snapshot.shape}") | |
fig, ax = plt.subplots() | |
mne.viz.plot_topomap( | |
psd_snapshot, | |
ch_pos_2d, | |
names=channel_names, | |
show=False, | |
axes=ax | |
) | |
# Save the generated topomap to a temp file | |
tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
fig.savefig(tmpfile.name, dpi=150, bbox_inches="tight") | |
plt.close(fig) | |
return tmpfile.name | |
predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"] | |
# --- Gradio Interface Definition --- | |
with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo: | |
# Define the gr.State component here, accessible throughout the Blocks | |
# This will hold the information of the SINGLE selected painting from the gallery | |
# This will hold the DataFrame of the emotion distribution (to be passed to generate_my_art) | |
current_emotion_df_state = gr.State(value=pd.DataFrame()) | |
# Header Section | |
gr.Markdown( | |
""" | |
<h1 style="text-align: center;font-size: 5em; padding: 20px; font-weight: bold;">Brain Emotion Decoder 🧠🎨</h1> | |
<p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;"> | |
Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity, | |
generating a personalized artwork that comes to life within an interactive 3D brain model. Discover the art of your inner self. | |
</p> | |
""" | |
) | |
with gr.Row(): | |
# Left Column: Input and Emotion Distribution | |
with gr.Column(scale=1): | |
gr.Markdown("<h2 font-size: 2em;>1. Choose a PSD file<h2>") | |
# Radio buttons to select from predefined files | |
psd_file_selection = gr.Radio( | |
choices=predefined_psd_files, | |
label="Select a PSD file for analysis", | |
value=predefined_psd_files[0], # Default selection | |
interactive=True | |
) | |
# Button to trigger PSD analysis | |
analyze_psd_button = gr.Button("Analyze PSD File", variant="secondary") | |
gr.Markdown("<h2 font-size: 2em;>2. Emotion Distribution<h2>") | |
# Bar plot for emotion distribution | |
emotion_distribution_plot = gr.BarPlot( | |
dummy_emotion_data, | |
x="Emotion", | |
y="Value", | |
label="Emotion Distribution", | |
height=300, | |
x_title="Emotion Type", | |
y_title="Frequency", | |
visible=False # Hidden until analysis is triggered | |
) | |
dom_emotion = gr.Textbox(label = "dominant emotion", visible=False) | |
# Right Column: Art Museum and Generation | |
with gr.Column(scale=1): | |
gr.Markdown("<h3>Your Art Mesum</h3>") # Kept original heading | |
gr.Markdown("<h3>3. Choose your favourite painter</h3>") | |
painter_dropdown = gr.Dropdown( | |
choices=painters, | |
value="Pablo Picasso", # Default selection | |
label="Select a Painter" | |
) | |
gr.Markdown("<h3>4. Choose your favourite painting</h3>") | |
# Gallery to display paintings for selection | |
painting_gallery = gr.Gallery( | |
# Correct initial value and visibility | |
value=update_paintings("Pablo Picasso"), # Initial load for Picasso's paintings | |
label="Select a Painting", | |
height=300, | |
columns=3, | |
rows=1, | |
object_fit="contain", | |
preview=True, # Allows clicking to see larger image | |
interactive=True, # Make it selectable | |
elem_id="painting_gallery", | |
visible=True, # Should be visible by default | |
) | |
# Button to trigger art generation | |
selected_painting_name = gr.Textbox(visible=False) | |
generate_button = gr.Button("Generate My Art", variant="primary", scale=0) | |
# Status message for image generation | |
status_message = gr.Textbox( | |
value="Click 'Generate My Art' to begin.", | |
label="Generation Status", | |
interactive=False, | |
show_label=False, | |
lines=1 | |
) | |
# Output section on a separate "page" or revealed dynamically | |
gr.Markdown( | |
""" | |
<h1 style="text-align: center;">Your Generated Artwork</h1> | |
<p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;> | |
Once your brain's emotional data is processed, we pinpoint the <b>dominant emotion</b>. This single feeling inspires a <b>personalized artwork</b>, generated using <b>diffusion techniques</b> and blended with <b>my AI painting style</b>. You can then <b>download</b> this unique visual representation of your inner self. | |
</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("<h3>Generated Image</h3>") | |
generated_image_output = gr.Image(label="Generated Image", show_label=False, height=300) | |
gr.Markdown("<h3>Blended Style Image</h3>") | |
blended_image_output = gr.Image(label="Blended Style Image", show_label=False, height=300) | |
with gr.Column(scale=1): | |
gr.Markdown("<h3>Brain Topomap</h3>") | |
channels_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Channels", interactive=True) | |
timestamp_slider = gr.Slider(minimum=1, maximum=630, value=1, step=1, label="Timestamp", interactive=True) | |
mne_2d_img = gr.Image(visible=True) | |
generate_button.click(generate_topomap, outputs=mne_2d_img) | |
# --- Event Listeners --- | |
analyze_psd_button.click( | |
upload_psd_file, | |
inputs=[psd_file_selection], # Input is the selected radio button value (file name) | |
outputs=[emotion_distribution_plot, current_emotion_df_state, dom_emotion] # CORRECTED: Added current_emotion_df_state to outputs | |
) | |
# When painter dropdown changes, update the gallery content and reset selected_painting_state | |
painter_dropdown.change( | |
update_paintings, # This updates the gallery | |
inputs=[painter_dropdown], | |
outputs=[painting_gallery] # Only output the gallery content directly | |
) | |
# IMPORTANT: Use the .select() method of gr.Gallery to capture the specific clicked item. | |
# The 'select' event passes the selected value directly as the argument to the function. | |
# We use a lambda to simply return that selected value and store it in our state. | |
def on_select(evt: gr.SelectData): | |
print("this function started") | |
print(f"Image index: {evt.index}\nImage value: {evt.value['image']['orig_name']}") | |
return evt.value['image']['orig_name'] | |
painting_gallery.select( | |
on_select, # This lambda receives the selected image info (path, title) | |
outputs=[selected_painting_name] # The output updates our gr.State component | |
) | |
# The generate_button now correctly uses the value from selected_painting_state | |
generate_button.click( | |
generate_my_art, | |
inputs=[painter_dropdown, selected_painting_name, dom_emotion], # Pass painter and the SELECTED painting | |
outputs=[status_message, generated_image_output, blended_image_output] | |
) | |
## sliders event listener | |
channels_slider.change(fn=generate_topomap, inputs=[channels_slider, timestamp_slider], outputs=mne_2d_img) | |
timestamp_slider.change(fn=generate_topomap, inputs=[channels_slider, timestamp_slider], outputs=mne_2d_img) | |
# Launch the demo | |
if __name__ == "__main__": | |
# Ensure project_root_dir is defined for this block if you uncomment these lines | |
# project_root_dir = os.path.dirname(os.path.abspath(__file__)) | |
# project_root_dir = os.path.dirname(project_root_dir) | |
# print(f"Loading LSTM model from: {os.path.join(project_root_dir, model_path)}") | |
# _ = load_model(os.path.join(project_root_dir, model_path), input_size, hidden_size, output_size, num_layers) | |
demo.launch() | |