File size: 22,154 Bytes
16d8457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e2fec9
16d8457
8e2fec9
16d8457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
import random
import gradio as gr
import pandas as pd
import numpy as np
from Src.Processing import load_data
from Src.Processing import process_data
from Src.Inference import load_model
import torchvision.models as models
from torchvision import transforms
from Src.Processing_img import (
    get_style_model_and_losses,
    image_loader, 
    run_style_transfer,
    save_image, 
    gram_matrix
    )
import torch
import tempfile
import time # For simulating delays
from PIL import Image, ImageDraw, ImageFont # Ensure ImageDraw and ImageFont are imported
import os # For file operations
import mne
import matplotlib.pyplot as plt
import io
import pyvista as pv
import matplotlib.cm as cm
import gradio as gr
pv.set_plot_theme("document") # A simple theme
pv.set_jupyter_backend('html') 

# --- Data for demonstration ---
# Dummy data for Emotion Distribution Bar Chart
# In a real app, this would come from your PSD analysis
dummy_emotion_data = pd.DataFrame({
    'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'],
    'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3]
})

int_to_emotion = {
    0: 'sad',
    1: 'dis',
    2: 'fear',
    3: 'neu',
    4: 'joy',
    5: 'ten',
    6: 'ins'
}

abr_to_emotion = {
    'sad': "sadness",
    'dis': "disgust",
    'fear': "fear",
    'neu': "neutral",
    'joy': "joy",
    'ten': 'Tenderness',
    'ins': "inspiration"
}

# --- Local Image Paths Setup for Dynamic Loading ---
# Define a base directory for all painters' images
# In Hugging Face Spaces, this would be a folder like 'Painters/' in your repository
PAINTERS_BASE_DIR = "Painters"
EMOTION_BASE_DIR = "Emotions"
model_path = "models\lstm_emotion_model_state.pth"
input_size = 320
hidden_size=50
output_size = 7
num_layers=1

# Define painters and some example "filenames" to create placeholders for
painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador Dalí"]
Base_Dir = "Datasets"

# This dictionary defines what placeholder files to create and their captions.
# The actual gallery content will be read from the file system.
PAINTER_PLACEHOLDER_DATA = {
    "Pablo Picasso": [
        ("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"),
        ("The Weeping Woman (1937).png", "The Weeping Woman (1937)"),
        ("Three Musicians (1921).png", "Three Musicians (1921)"),
    ],
    "Vincent van Gogh": [
        ("Sunflowers (1888).png", "Sunflowers (1888)"),
        ("The Starry Night (1889).png", "The Starry Night (1889)"),
        ("The Potato Eaters (1885).png", "The Potato Eaters (1885)"),
    ],
    "Salvador Dalí": [
        ("Persistence of Memory (1931).png", "Persistence of Memory (1931)"),
        ("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"),
        ("Sleep (1937).png", "Sleep (1937)"),
    ],
}




# --- Define the specific PSD files to choose from ---
predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"] # You can put full paths here if they are actual files

# --- Core Functions (Simulated) ---

def upload_psd_file(selected_file_name):
    """
    Processes a selected PSD file, performs inference, and prepares emotion distribution data.
    """
    if selected_file_name is None:
        # If no file is selected, return a dummy plot hidden
        # Return the dummy DataFrame and an empty DataFrame for the state
        return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame()
    
    # --- Load and Process PSD Data ---
    psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/')
    
    # In a real scenario, you'd handle file existence check and errors for load_data
    try:
        global np_data
        np_data = load_data(psd_file_path)
        print(f"np data orig {np_data.shape}")
    except FileNotFoundError:
        print(f"Error: PSD file not found at {psd_file_path}")
        # Return a plot with error message or just hide it
        return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame()
    
    
    final_data = process_data(np_data)
    # Ensure data is suitable for LSTM (e.g., (batch, sequence_length, input_size))
    # If final_data is (sequence_length, input_size), add a batch dimension
    torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0) 
    
    print(f"Processed data shape for model: {torch_data.shape}")

    # --- Inference ---
    # Ensure model_path is correct relative to where app.py is run
    # If 'models' is at your_project_root, adjust path if needed
    
    
    # Assuming 'models' directory is at 'your_project_root' level
    absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth")


    loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers)
    loaded_model.eval() # Set model to evaluation mode
    
    # Pass the prepared torch_data to the model
    with torch.no_grad(): # Disable gradient calculation for inference
        predicted_logits, _ = loaded_model(torch_data) # LSTM returns (output, (h_n, c_n))
    
    # Get the most probable emotion index for each time step in the sequence
    final_output_indices = torch.argmax(predicted_logits, dim=2) # Shape: (batch_size, sequence_length)
    
    # Flatten the sequence to count overall emotion frequencies
    # If batch size is 1, and sequence is long, this view(-1) works for counting all predictions
    all_predicted_indices = final_output_indices.view(-1) 
    
    print(f"All predicted indices (flattened): {all_predicted_indices}")
    
    # Count occurrences of each predicted emotion index
    values_count = torch.bincount(all_predicted_indices, minlength=output_size) # Use minlength to ensure all 7 indices are considered
    print(f"Raw bincount: {values_count}")

    # --- Create Emotion Distribution DataFrame ---
    # Initialize emotions_count with all emotions set to 0 frequency
    emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)} # Use .strip() to remove trailing space from 'sad '
    
    # Update counts only for emotions that were actually predicted
    for idx, count in enumerate(values_count):
        if idx < output_size: # Ensure index is within the expected range
            emotions_count[int_to_emotion[idx].strip()] = count.item() # Use .strip() here too
            
    # Convert dictionary to DataFrame
    dom_emotion = max(emotions_count, key=emotions_count.get)
    # Ensure column names match what gr.BarPlot expects: "Emotion" and "Frequency"
    emotion_data = pd.DataFrame({
        "Emotion": list(emotions_count.keys()),
        "Frequency": list(emotions_count.values())
    })
    
    # Optional: Sort DataFrame by emotion name or frequency if desired for consistent plotting
    emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True)

    print(f"Final emotion_data DataFrame:\n{emotion_data}")
    
    # CORRECTED: Return the DataFrame, NOT a gr.BarPlot object
    # Return both the DataFrame for the plot and the DataFrame itself for the state
    return gr.BarPlot(
        emotion_data, 
        x="Emotion", 
        y="Frequency", 
        label="Emotion Distribution", 
        visible=True,
        y_title="Frequency"
    ), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True)


def update_paintings(painter_name):
    """
    Updates the gallery with paintings specific to the selected painter by
    dynamically listing files in the painter's directory.
    """
    painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/')
    print(painter_dir)
    
    artist_paintings_for_gallery = []
    if os.path.isdir(painter_dir):
        for filename in sorted(os.listdir(painter_dir)): # Sort for consistent order
            
            if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
                file_path = os.path.join(painter_dir, filename).replace(os.sep, '/')
                print(file_path)
                # Use filename without extension as title, or create more sophisticated parsing
                title_with_ext = os.path.splitext(filename)[0]
                artist_paintings_for_gallery.append((file_path, title_with_ext)) 
    print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}")
    return artist_paintings_for_gallery # Return the list directly for the gallery


def generate_my_art(painter, chosen_painting, dom_emotion):
    """
    Simulates the image generation process.
    'chosen_painting_info' will be the single selected item from gr.Gallery.select(),
    e.g., ['Painters/Pablo Picasso/Dora Maar with Cat (1941).png', 'Dora Maar with Cat (1941)']
    We need to extract the path.
    """
    print("generating started")
    print(f"painter: {painter}")
    print(f"choosen painting: {chosen_painting}")
    if not painter or not chosen_painting:
        # Provide default outputs to ensure Gradio components are updated correctly
        return "Please select a painter and a painting.", None, None
    
    ##style image_path
    img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting)
    print(f"img_stype_path: {img_style_pth}")

    # Display initial status and disable button
    # --- Simulate your NST or Diffusion Process Here ---
    # In a real scenario, this would involve your actual NST code.
    # It would use `selected_painting_path` as the style image.
    # A content image would be dynamically generated (e.g., a simple colored canvas or
    # abstract representation based on the PSD analysis's dominant emotion).

    time.sleep(3) # Simulate processing time

    # --- Simulate saving a generated image locally ---
    # This PIL Image would be the actual result of your NST.
    # We save it to the 'generated_art' directory.
    """generated_img_pil = Image.new('RGB', (400, 400), color=(np.random.randint(0,255), np.random.randint(0,255), np.random.randint(0,255)))
    generated_image_local_path = f"generated_art/generated_output_{int(time.time())}.png"
    generated_img_pil.save(generated_image_local_path)

    # For the blended image, let's just return the selected style image path for now.
    # In a real app, this might be a version of the 'generated_img_pil' with a final blend.
    blended_image_local_path = selected_painting_path """

    ##original image
    emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion)
    image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)]
    original_image_pth = os.path.join(emotion_pth, image_name)
    print(f"original img _path: {original_image_pth}")
    final_message = f"Art generated based on {painter}'s {chosen_painting} style!"

    ## Neural Style Transfer added here
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    imsize = 512 if torch.cuda.is_available() else 256
    loader = transforms.Compose([
    transforms.Resize((imsize, imsize)),
    transforms.ToTensor()
    ])
    content_layers = ['conv_4']
    style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
    cnn = models.vgg19(pretrained=True).features.to(device).eval()
    cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
    cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
    style_img = image_loader(img_style_pth, loader, device)
    content_img = image_loader(original_image_pth, loader, device)
    input_img = content_img.clone()
    output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
                            content_img, style_img, input_img, content_layers, style_layers, device)
    save_image(output, "stylized_output.jpg")
    print("Stylized image saved as 'stylized_output.jpg'")

    stylized_img_path = 'stylized_output.jpg'
    # Return final results and re-enable button
    yield gr.Textbox(final_message), original_image_pth, stylized_img_path


def generate_topomap(n_channels, n_time):
    n_sensors = 64

    if n_channels is None or n_time is None:
        print("they are None")
        n_channels = 4
        n_time = 500
    # ----------------------------
    # 2. Load standard 10-20 montage
    # ----------------------------
    montage = mne.channels.make_standard_montage('standard_1020')
    # Filter only the standard 64 EEG electrodes
    standard_64_chs = [
        'Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8',
        'FC5', 'FC1', 'FC2', 'FC6', 'T7', 'C3', 'Cz', 'C4', 'T8',
        'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8',
        'POz', 'O1', 'Oz', 'O2', 'Fpz', 'AF7', 'AF3', 'AF4', 'AF8',
        'F5', 'F1', 'F2', 'F6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2',
        'C6', 'CP3', 'CPz', 'CP4', 'P5', 'P1', 'P2', 'P6', 'PO3', 'PO4',
        'PO7', 'PO8', 'PO9', 'PO10', 'O1', 'O2', 'FT7', 'FT8', 'TP7', 'TP8'
    ]
    # exactly 64 channels
    ch_pos_dict = montage.get_positions()['ch_pos']

    ch_pos_dict_filtered = {ch: ch_pos_dict[ch] for ch in standard_64_chs}
    channel_names = list(ch_pos_dict_filtered.keys())
    ch_pos_array = np.array([ch_pos_dict_filtered[ch] for ch in standard_64_chs])  # Nx3
    ch_pos_2d = ch_pos_array[:, :2]  # For 2D topomap

    # ----------------------------
    # 3. Choose a time index and frequency index
    # ----------------------------

    new_data = np_data.reshape(64, 630, 5)
    print(f"shape: {new_data.shape}")
    print(f"n channels: {n_channels}")
    psd_snapshot = new_data[:, n_time - 1, n_channels - 1]

    # Normalize PSD for coloring
    psd_norm = (psd_snapshot - psd_snapshot.min()) / (psd_snapshot.max() - psd_snapshot.min())

    # ----------------------------
    # 4. Plot 2D topomap using MNE
    # ----------------------------

    print(f"shape psd :{psd_snapshot.shape}")
    fig, ax = plt.subplots()
    mne.viz.plot_topomap(
        psd_snapshot,
        ch_pos_2d,
        names=channel_names,
        show=False,
        axes=ax
    )

    # Save the generated topomap to a temp file
    tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    fig.savefig(tmpfile.name, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return tmpfile.name
    
predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"]

# --- Gradio Interface Definition ---

with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo:
    # Define the gr.State component here, accessible throughout the Blocks
    # This will hold the information of the SINGLE selected painting from the gallery
    # This will hold the DataFrame of the emotion distribution (to be passed to generate_my_art)
    current_emotion_df_state = gr.State(value=pd.DataFrame())


    # Header Section
    gr.Markdown(
        """
        <h1 style="text-align: center;font-size: 5em; padding: 20px;  font-weight: bold;">Brain Emotion Decoder 🧠🎨</h1>
        <p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;">
        Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity,
        generating a personalized artwork that comes to life within an interactive 3D brain model. Discover the art of your inner self.
        </p>
        """
    )

    with gr.Row():
        # Left Column: Input and Emotion Distribution
        with gr.Column(scale=1):
            gr.Markdown("<h2 font-size: 2em;>1. Choose a PSD file<h2>")
            # Radio buttons to select from predefined files
            psd_file_selection = gr.Radio(
                choices=predefined_psd_files,
                label="Select a PSD file for analysis",
                value=predefined_psd_files[0], # Default selection
                interactive=True
            )
            
            # Button to trigger PSD analysis
            analyze_psd_button = gr.Button("Analyze PSD File", variant="secondary")

            gr.Markdown("<h2 font-size: 2em;>2. Emotion Distribution<h2>")

            # Bar plot for emotion distribution
            emotion_distribution_plot = gr.BarPlot(
                dummy_emotion_data,
                x="Emotion",
                y="Value",
                label="Emotion Distribution",
                height=300,
                x_title="Emotion Type",
                y_title="Frequency",
                visible=False # Hidden until analysis is triggered
            )


            dom_emotion = gr.Textbox(label = "dominant emotion", visible=False)
            


        # Right Column: Art Museum and Generation
        with gr.Column(scale=1):
            gr.Markdown("<h3>Your Art Mesum</h3>") # Kept original heading
            
            gr.Markdown("<h3>3. Choose your favourite painter</h3>")
            painter_dropdown = gr.Dropdown(
                choices=painters,
                value="Pablo Picasso", # Default selection
                label="Select a Painter"
            )

            gr.Markdown("<h3>4. Choose your favourite painting</h3>")
            # Gallery to display paintings for selection
            painting_gallery = gr.Gallery(
                # Correct initial value and visibility
                value=update_paintings("Pablo Picasso"), # Initial load for Picasso's paintings
                label="Select a Painting",
                height=300,
                columns=3,
                rows=1,
                object_fit="contain",
                preview=True, # Allows clicking to see larger image
                interactive=True, # Make it selectable
                elem_id="painting_gallery",
                visible=True, # Should be visible by default
            )


            # Button to trigger art generation
            selected_painting_name = gr.Textbox(visible=False)
            generate_button = gr.Button("Generate My Art", variant="primary", scale=0) 
            # Status message for image generation
            status_message = gr.Textbox(
                value="Click 'Generate My Art' to begin.", 
                label="Generation Status",
                interactive=False,
                show_label=False,
                lines=1 
            )

    # Output section on a separate "page" or revealed dynamically
    gr.Markdown(
        """
        <h1 style="text-align: center;">Your Generated Artwork</h1>
        <p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;>
        Once your brain's emotional data is processed, we pinpoint the <b>dominant emotion</b>. This single feeling inspires a <b>personalized artwork</b>, generated using <b>diffusion techniques</b> and blended with <b>my AI painting style</b>. You can then <b>download</b> this unique visual representation of your inner self.
        </p>
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("<h3>Generated Image</h3>")
            generated_image_output = gr.Image(label="Generated Image", show_label=False, height=300)
            gr.Markdown("<h3>Blended Style Image</h3>")
            blended_image_output = gr.Image(label="Blended Style Image", show_label=False, height=300)
        
        with gr.Column(scale=1):
            gr.Markdown("<h3>Brain Topomap</h3>")
            channels_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Channels", interactive=True)
            timestamp_slider = gr.Slider(minimum=1, maximum=630, value=1, step=1, label="Timestamp", interactive=True)
            mne_2d_img = gr.Image(visible=True)
            generate_button.click(generate_topomap, outputs=mne_2d_img)
            
            

    # --- Event Listeners ---
    analyze_psd_button.click(
        upload_psd_file,
        inputs=[psd_file_selection], # Input is the selected radio button value (file name)
        outputs=[emotion_distribution_plot, current_emotion_df_state, dom_emotion] # CORRECTED: Added current_emotion_df_state to outputs
    )

    # When painter dropdown changes, update the gallery content and reset selected_painting_state
    painter_dropdown.change(
        update_paintings, # This updates the gallery
        inputs=[painter_dropdown],
        outputs=[painting_gallery] # Only output the gallery content directly
    )

    # IMPORTANT: Use the .select() method of gr.Gallery to capture the specific clicked item.
    # The 'select' event passes the selected value directly as the argument to the function.
    # We use a lambda to simply return that selected value and store it in our state.
    def on_select(evt: gr.SelectData):
        print("this function started")
        print(f"Image index: {evt.index}\nImage value: {evt.value['image']['orig_name']}")
        return evt.value['image']['orig_name']
    painting_gallery.select(
        on_select, # This lambda receives the selected image info (path, title)
        outputs=[selected_painting_name] # The output updates our gr.State component
    )

    
    
    

    # The generate_button now correctly uses the value from selected_painting_state
    generate_button.click(
        generate_my_art,
        inputs=[painter_dropdown, selected_painting_name, dom_emotion], # Pass painter and the SELECTED painting
        outputs=[status_message, generated_image_output, blended_image_output]
    )


    ## sliders event listener
    channels_slider.change(fn=generate_topomap, inputs=[channels_slider, timestamp_slider], outputs=mne_2d_img)
    timestamp_slider.change(fn=generate_topomap, inputs=[channels_slider, timestamp_slider], outputs=mne_2d_img)
# Launch the demo
if __name__ == "__main__":
    # Ensure project_root_dir is defined for this block if you uncomment these lines
    # project_root_dir = os.path.dirname(os.path.abspath(__file__))
    # project_root_dir = os.path.dirname(project_root_dir)
    # print(f"Loading LSTM model from: {os.path.join(project_root_dir, model_path)}")
    # _ = load_model(os.path.join(project_root_dir, model_path), input_size, hidden_size, output_size, num_layers)
    
    demo.launch()