Spaces:
Sleeping
Sleeping
File size: 22,154 Bytes
16d8457 8e2fec9 16d8457 8e2fec9 16d8457 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 |
import random
import gradio as gr
import pandas as pd
import numpy as np
from Src.Processing import load_data
from Src.Processing import process_data
from Src.Inference import load_model
import torchvision.models as models
from torchvision import transforms
from Src.Processing_img import (
get_style_model_and_losses,
image_loader,
run_style_transfer,
save_image,
gram_matrix
)
import torch
import tempfile
import time # For simulating delays
from PIL import Image, ImageDraw, ImageFont # Ensure ImageDraw and ImageFont are imported
import os # For file operations
import mne
import matplotlib.pyplot as plt
import io
import pyvista as pv
import matplotlib.cm as cm
import gradio as gr
pv.set_plot_theme("document") # A simple theme
pv.set_jupyter_backend('html')
# --- Data for demonstration ---
# Dummy data for Emotion Distribution Bar Chart
# In a real app, this would come from your PSD analysis
dummy_emotion_data = pd.DataFrame({
'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'],
'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3]
})
int_to_emotion = {
0: 'sad',
1: 'dis',
2: 'fear',
3: 'neu',
4: 'joy',
5: 'ten',
6: 'ins'
}
abr_to_emotion = {
'sad': "sadness",
'dis': "disgust",
'fear': "fear",
'neu': "neutral",
'joy': "joy",
'ten': 'Tenderness',
'ins': "inspiration"
}
# --- Local Image Paths Setup for Dynamic Loading ---
# Define a base directory for all painters' images
# In Hugging Face Spaces, this would be a folder like 'Painters/' in your repository
PAINTERS_BASE_DIR = "Painters"
EMOTION_BASE_DIR = "Emotions"
model_path = "models\lstm_emotion_model_state.pth"
input_size = 320
hidden_size=50
output_size = 7
num_layers=1
# Define painters and some example "filenames" to create placeholders for
painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador Dalí"]
Base_Dir = "Datasets"
# This dictionary defines what placeholder files to create and their captions.
# The actual gallery content will be read from the file system.
PAINTER_PLACEHOLDER_DATA = {
"Pablo Picasso": [
("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"),
("The Weeping Woman (1937).png", "The Weeping Woman (1937)"),
("Three Musicians (1921).png", "Three Musicians (1921)"),
],
"Vincent van Gogh": [
("Sunflowers (1888).png", "Sunflowers (1888)"),
("The Starry Night (1889).png", "The Starry Night (1889)"),
("The Potato Eaters (1885).png", "The Potato Eaters (1885)"),
],
"Salvador Dalí": [
("Persistence of Memory (1931).png", "Persistence of Memory (1931)"),
("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"),
("Sleep (1937).png", "Sleep (1937)"),
],
}
# --- Define the specific PSD files to choose from ---
predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"] # You can put full paths here if they are actual files
# --- Core Functions (Simulated) ---
def upload_psd_file(selected_file_name):
"""
Processes a selected PSD file, performs inference, and prepares emotion distribution data.
"""
if selected_file_name is None:
# If no file is selected, return a dummy plot hidden
# Return the dummy DataFrame and an empty DataFrame for the state
return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame()
# --- Load and Process PSD Data ---
psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/')
# In a real scenario, you'd handle file existence check and errors for load_data
try:
global np_data
np_data = load_data(psd_file_path)
print(f"np data orig {np_data.shape}")
except FileNotFoundError:
print(f"Error: PSD file not found at {psd_file_path}")
# Return a plot with error message or just hide it
return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame()
final_data = process_data(np_data)
# Ensure data is suitable for LSTM (e.g., (batch, sequence_length, input_size))
# If final_data is (sequence_length, input_size), add a batch dimension
torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0)
print(f"Processed data shape for model: {torch_data.shape}")
# --- Inference ---
# Ensure model_path is correct relative to where app.py is run
# If 'models' is at your_project_root, adjust path if needed
# Assuming 'models' directory is at 'your_project_root' level
absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth")
loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers)
loaded_model.eval() # Set model to evaluation mode
# Pass the prepared torch_data to the model
with torch.no_grad(): # Disable gradient calculation for inference
predicted_logits, _ = loaded_model(torch_data) # LSTM returns (output, (h_n, c_n))
# Get the most probable emotion index for each time step in the sequence
final_output_indices = torch.argmax(predicted_logits, dim=2) # Shape: (batch_size, sequence_length)
# Flatten the sequence to count overall emotion frequencies
# If batch size is 1, and sequence is long, this view(-1) works for counting all predictions
all_predicted_indices = final_output_indices.view(-1)
print(f"All predicted indices (flattened): {all_predicted_indices}")
# Count occurrences of each predicted emotion index
values_count = torch.bincount(all_predicted_indices, minlength=output_size) # Use minlength to ensure all 7 indices are considered
print(f"Raw bincount: {values_count}")
# --- Create Emotion Distribution DataFrame ---
# Initialize emotions_count with all emotions set to 0 frequency
emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)} # Use .strip() to remove trailing space from 'sad '
# Update counts only for emotions that were actually predicted
for idx, count in enumerate(values_count):
if idx < output_size: # Ensure index is within the expected range
emotions_count[int_to_emotion[idx].strip()] = count.item() # Use .strip() here too
# Convert dictionary to DataFrame
dom_emotion = max(emotions_count, key=emotions_count.get)
# Ensure column names match what gr.BarPlot expects: "Emotion" and "Frequency"
emotion_data = pd.DataFrame({
"Emotion": list(emotions_count.keys()),
"Frequency": list(emotions_count.values())
})
# Optional: Sort DataFrame by emotion name or frequency if desired for consistent plotting
emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True)
print(f"Final emotion_data DataFrame:\n{emotion_data}")
# CORRECTED: Return the DataFrame, NOT a gr.BarPlot object
# Return both the DataFrame for the plot and the DataFrame itself for the state
return gr.BarPlot(
emotion_data,
x="Emotion",
y="Frequency",
label="Emotion Distribution",
visible=True,
y_title="Frequency"
), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True)
def update_paintings(painter_name):
"""
Updates the gallery with paintings specific to the selected painter by
dynamically listing files in the painter's directory.
"""
painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/')
print(painter_dir)
artist_paintings_for_gallery = []
if os.path.isdir(painter_dir):
for filename in sorted(os.listdir(painter_dir)): # Sort for consistent order
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
file_path = os.path.join(painter_dir, filename).replace(os.sep, '/')
print(file_path)
# Use filename without extension as title, or create more sophisticated parsing
title_with_ext = os.path.splitext(filename)[0]
artist_paintings_for_gallery.append((file_path, title_with_ext))
print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}")
return artist_paintings_for_gallery # Return the list directly for the gallery
def generate_my_art(painter, chosen_painting, dom_emotion):
"""
Simulates the image generation process.
'chosen_painting_info' will be the single selected item from gr.Gallery.select(),
e.g., ['Painters/Pablo Picasso/Dora Maar with Cat (1941).png', 'Dora Maar with Cat (1941)']
We need to extract the path.
"""
print("generating started")
print(f"painter: {painter}")
print(f"choosen painting: {chosen_painting}")
if not painter or not chosen_painting:
# Provide default outputs to ensure Gradio components are updated correctly
return "Please select a painter and a painting.", None, None
##style image_path
img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting)
print(f"img_stype_path: {img_style_pth}")
# Display initial status and disable button
# --- Simulate your NST or Diffusion Process Here ---
# In a real scenario, this would involve your actual NST code.
# It would use `selected_painting_path` as the style image.
# A content image would be dynamically generated (e.g., a simple colored canvas or
# abstract representation based on the PSD analysis's dominant emotion).
time.sleep(3) # Simulate processing time
# --- Simulate saving a generated image locally ---
# This PIL Image would be the actual result of your NST.
# We save it to the 'generated_art' directory.
"""generated_img_pil = Image.new('RGB', (400, 400), color=(np.random.randint(0,255), np.random.randint(0,255), np.random.randint(0,255)))
generated_image_local_path = f"generated_art/generated_output_{int(time.time())}.png"
generated_img_pil.save(generated_image_local_path)
# For the blended image, let's just return the selected style image path for now.
# In a real app, this might be a version of the 'generated_img_pil' with a final blend.
blended_image_local_path = selected_painting_path """
##original image
emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion)
image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)]
original_image_pth = os.path.join(emotion_pth, image_name)
print(f"original img _path: {original_image_pth}")
final_message = f"Art generated based on {painter}'s {chosen_painting} style!"
## Neural Style Transfer added here
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize = 512 if torch.cuda.is_available() else 256
loader = transforms.Compose([
transforms.Resize((imsize, imsize)),
transforms.ToTensor()
])
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
style_img = image_loader(img_style_pth, loader, device)
content_img = image_loader(original_image_pth, loader, device)
input_img = content_img.clone()
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
content_img, style_img, input_img, content_layers, style_layers, device)
save_image(output, "stylized_output.jpg")
print("Stylized image saved as 'stylized_output.jpg'")
stylized_img_path = 'stylized_output.jpg'
# Return final results and re-enable button
yield gr.Textbox(final_message), original_image_pth, stylized_img_path
def generate_topomap(n_channels, n_time):
n_sensors = 64
if n_channels is None or n_time is None:
print("they are None")
n_channels = 4
n_time = 500
# ----------------------------
# 2. Load standard 10-20 montage
# ----------------------------
montage = mne.channels.make_standard_montage('standard_1020')
# Filter only the standard 64 EEG electrodes
standard_64_chs = [
'Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8',
'FC5', 'FC1', 'FC2', 'FC6', 'T7', 'C3', 'Cz', 'C4', 'T8',
'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8',
'POz', 'O1', 'Oz', 'O2', 'Fpz', 'AF7', 'AF3', 'AF4', 'AF8',
'F5', 'F1', 'F2', 'F6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2',
'C6', 'CP3', 'CPz', 'CP4', 'P5', 'P1', 'P2', 'P6', 'PO3', 'PO4',
'PO7', 'PO8', 'PO9', 'PO10', 'O1', 'O2', 'FT7', 'FT8', 'TP7', 'TP8'
]
# exactly 64 channels
ch_pos_dict = montage.get_positions()['ch_pos']
ch_pos_dict_filtered = {ch: ch_pos_dict[ch] for ch in standard_64_chs}
channel_names = list(ch_pos_dict_filtered.keys())
ch_pos_array = np.array([ch_pos_dict_filtered[ch] for ch in standard_64_chs]) # Nx3
ch_pos_2d = ch_pos_array[:, :2] # For 2D topomap
# ----------------------------
# 3. Choose a time index and frequency index
# ----------------------------
new_data = np_data.reshape(64, 630, 5)
print(f"shape: {new_data.shape}")
print(f"n channels: {n_channels}")
psd_snapshot = new_data[:, n_time - 1, n_channels - 1]
# Normalize PSD for coloring
psd_norm = (psd_snapshot - psd_snapshot.min()) / (psd_snapshot.max() - psd_snapshot.min())
# ----------------------------
# 4. Plot 2D topomap using MNE
# ----------------------------
print(f"shape psd :{psd_snapshot.shape}")
fig, ax = plt.subplots()
mne.viz.plot_topomap(
psd_snapshot,
ch_pos_2d,
names=channel_names,
show=False,
axes=ax
)
# Save the generated topomap to a temp file
tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
fig.savefig(tmpfile.name, dpi=150, bbox_inches="tight")
plt.close(fig)
return tmpfile.name
predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"]
# --- Gradio Interface Definition ---
with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo:
# Define the gr.State component here, accessible throughout the Blocks
# This will hold the information of the SINGLE selected painting from the gallery
# This will hold the DataFrame of the emotion distribution (to be passed to generate_my_art)
current_emotion_df_state = gr.State(value=pd.DataFrame())
# Header Section
gr.Markdown(
"""
<h1 style="text-align: center;font-size: 5em; padding: 20px; font-weight: bold;">Brain Emotion Decoder 🧠🎨</h1>
<p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;">
Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity,
generating a personalized artwork that comes to life within an interactive 3D brain model. Discover the art of your inner self.
</p>
"""
)
with gr.Row():
# Left Column: Input and Emotion Distribution
with gr.Column(scale=1):
gr.Markdown("<h2 font-size: 2em;>1. Choose a PSD file<h2>")
# Radio buttons to select from predefined files
psd_file_selection = gr.Radio(
choices=predefined_psd_files,
label="Select a PSD file for analysis",
value=predefined_psd_files[0], # Default selection
interactive=True
)
# Button to trigger PSD analysis
analyze_psd_button = gr.Button("Analyze PSD File", variant="secondary")
gr.Markdown("<h2 font-size: 2em;>2. Emotion Distribution<h2>")
# Bar plot for emotion distribution
emotion_distribution_plot = gr.BarPlot(
dummy_emotion_data,
x="Emotion",
y="Value",
label="Emotion Distribution",
height=300,
x_title="Emotion Type",
y_title="Frequency",
visible=False # Hidden until analysis is triggered
)
dom_emotion = gr.Textbox(label = "dominant emotion", visible=False)
# Right Column: Art Museum and Generation
with gr.Column(scale=1):
gr.Markdown("<h3>Your Art Mesum</h3>") # Kept original heading
gr.Markdown("<h3>3. Choose your favourite painter</h3>")
painter_dropdown = gr.Dropdown(
choices=painters,
value="Pablo Picasso", # Default selection
label="Select a Painter"
)
gr.Markdown("<h3>4. Choose your favourite painting</h3>")
# Gallery to display paintings for selection
painting_gallery = gr.Gallery(
# Correct initial value and visibility
value=update_paintings("Pablo Picasso"), # Initial load for Picasso's paintings
label="Select a Painting",
height=300,
columns=3,
rows=1,
object_fit="contain",
preview=True, # Allows clicking to see larger image
interactive=True, # Make it selectable
elem_id="painting_gallery",
visible=True, # Should be visible by default
)
# Button to trigger art generation
selected_painting_name = gr.Textbox(visible=False)
generate_button = gr.Button("Generate My Art", variant="primary", scale=0)
# Status message for image generation
status_message = gr.Textbox(
value="Click 'Generate My Art' to begin.",
label="Generation Status",
interactive=False,
show_label=False,
lines=1
)
# Output section on a separate "page" or revealed dynamically
gr.Markdown(
"""
<h1 style="text-align: center;">Your Generated Artwork</h1>
<p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;>
Once your brain's emotional data is processed, we pinpoint the <b>dominant emotion</b>. This single feeling inspires a <b>personalized artwork</b>, generated using <b>diffusion techniques</b> and blended with <b>my AI painting style</b>. You can then <b>download</b> this unique visual representation of your inner self.
</p>
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("<h3>Generated Image</h3>")
generated_image_output = gr.Image(label="Generated Image", show_label=False, height=300)
gr.Markdown("<h3>Blended Style Image</h3>")
blended_image_output = gr.Image(label="Blended Style Image", show_label=False, height=300)
with gr.Column(scale=1):
gr.Markdown("<h3>Brain Topomap</h3>")
channels_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Channels", interactive=True)
timestamp_slider = gr.Slider(minimum=1, maximum=630, value=1, step=1, label="Timestamp", interactive=True)
mne_2d_img = gr.Image(visible=True)
generate_button.click(generate_topomap, outputs=mne_2d_img)
# --- Event Listeners ---
analyze_psd_button.click(
upload_psd_file,
inputs=[psd_file_selection], # Input is the selected radio button value (file name)
outputs=[emotion_distribution_plot, current_emotion_df_state, dom_emotion] # CORRECTED: Added current_emotion_df_state to outputs
)
# When painter dropdown changes, update the gallery content and reset selected_painting_state
painter_dropdown.change(
update_paintings, # This updates the gallery
inputs=[painter_dropdown],
outputs=[painting_gallery] # Only output the gallery content directly
)
# IMPORTANT: Use the .select() method of gr.Gallery to capture the specific clicked item.
# The 'select' event passes the selected value directly as the argument to the function.
# We use a lambda to simply return that selected value and store it in our state.
def on_select(evt: gr.SelectData):
print("this function started")
print(f"Image index: {evt.index}\nImage value: {evt.value['image']['orig_name']}")
return evt.value['image']['orig_name']
painting_gallery.select(
on_select, # This lambda receives the selected image info (path, title)
outputs=[selected_painting_name] # The output updates our gr.State component
)
# The generate_button now correctly uses the value from selected_painting_state
generate_button.click(
generate_my_art,
inputs=[painter_dropdown, selected_painting_name, dom_emotion], # Pass painter and the SELECTED painting
outputs=[status_message, generated_image_output, blended_image_output]
)
## sliders event listener
channels_slider.change(fn=generate_topomap, inputs=[channels_slider, timestamp_slider], outputs=mne_2d_img)
timestamp_slider.change(fn=generate_topomap, inputs=[channels_slider, timestamp_slider], outputs=mne_2d_img)
# Launch the demo
if __name__ == "__main__":
# Ensure project_root_dir is defined for this block if you uncomment these lines
# project_root_dir = os.path.dirname(os.path.abspath(__file__))
# project_root_dir = os.path.dirname(project_root_dir)
# print(f"Loading LSTM model from: {os.path.join(project_root_dir, model_path)}")
# _ = load_model(os.path.join(project_root_dir, model_path), input_size, hidden_size, output_size, num_layers)
demo.launch()
|