wi-lab commited on
Commit
2b9ed28
·
verified ·
1 Parent(s): 73a79ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -6
app.py CHANGED
@@ -233,13 +233,13 @@ def display_confusion_matrices_los(percentage):
233
  return raw_img, embeddings_img
234
 
235
  # Main function to handle user choice
236
- def handle_user_choice(choice, percentage=None, uploaded_file=None):
237
  if choice == "Use Default Dataset":
238
  raw_img, embeddings_img = display_confusion_matrices_los(percentage)
239
  return raw_img, embeddings_img, "" # Return empty string for console output
240
  elif choice == "Upload Dataset":
241
  if uploaded_file is not None:
242
- raw_img, embeddings_img, console_output = process_hdf5_file(uploaded_file, percentage)
243
  return raw_img, embeddings_img, console_output
244
  else:
245
  return "Please upload a dataset", "Please upload a dataset", "" # Return empty string for console output
@@ -409,7 +409,7 @@ def identical_train_test_split(output_emb, output_raw, labels, train_percentage)
409
  # Store the original working directory when the app starts
410
  original_dir = os.getcwd()
411
 
412
- def process_hdf5_file(uploaded_file, percentage):
413
  capture = PrintCapture()
414
  sys.stdout = capture # Redirect print statements to capture
415
 
@@ -462,7 +462,12 @@ def process_hdf5_file(uploaded_file, percentage):
462
  preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
463
 
464
  # Step 7: Perform inference using the functions from inference.py
465
- output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model, device)
 
 
 
 
 
466
  output_raw = inference.create_raw_dataset(preprocessed_chs, device)
467
 
468
  print(f"Output Embeddings Shape: {output_emb.shape}")
@@ -648,6 +653,10 @@ with gr.Blocks(css="""
648
  embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
649
  output_textbox = gr.Textbox(label="Console Output", lines=10, elem_id="console-output")
650
 
 
 
 
 
651
  # Update the file uploader visibility based on user choice
652
  def toggle_file_input(choice):
653
  return gr.update(visible=(choice == "Upload Dataset"))
@@ -655,13 +664,17 @@ with gr.Blocks(css="""
655
  choice_radio.change(fn=toggle_file_input, inputs=[choice_radio], outputs=file_input)
656
 
657
  # When user makes a choice, update the display
658
- choice_radio.change(fn=handle_user_choice, inputs=[choice_radio, percentage_slider_los, file_input],
659
  outputs=[raw_img_los, embeddings_img_los, output_textbox])
660
 
661
  # When percentage slider changes (for predefined data)
662
- percentage_slider_los.change(fn=handle_user_choice, inputs=[choice_radio, percentage_slider_los, file_input],
663
  outputs=[raw_img_los, embeddings_img_los, output_textbox])
664
 
 
 
 
 
665
  # Add a conclusion section at the bottom
666
  gr.Markdown("""
667
  <div class="explanation-box">
 
233
  return raw_img, embeddings_img
234
 
235
  # Main function to handle user choice
236
+ def handle_user_choice(choice, percentage=None, uploaded_file=None, emb_type='CLS Embedding'):
237
  if choice == "Use Default Dataset":
238
  raw_img, embeddings_img = display_confusion_matrices_los(percentage)
239
  return raw_img, embeddings_img, "" # Return empty string for console output
240
  elif choice == "Upload Dataset":
241
  if uploaded_file is not None:
242
+ raw_img, embeddings_img, console_output = process_hdf5_file(uploaded_file, percentage, emb_type)
243
  return raw_img, embeddings_img, console_output
244
  else:
245
  return "Please upload a dataset", "Please upload a dataset", "" # Return empty string for console output
 
409
  # Store the original working directory when the app starts
410
  original_dir = os.getcwd()
411
 
412
+ def process_hdf5_file(uploaded_file, percentage, emb_type='CLS Embedding'):
413
  capture = PrintCapture()
414
  sys.stdout = capture # Redirect print statements to capture
415
 
 
462
  preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
463
 
464
  # Step 7: Perform inference using the functions from inference.py
465
+ if emb_type == 'Channel Embedding':
466
+ embedding_type = 'channel_emb'
467
+ elif emb_type == 'CLS Embedding':
468
+ embedding_type = 'cls_emb'
469
+
470
+ output_emb = inference.lwm_inference(preprocessed_chs, embedding_type, model, device)
471
  output_raw = inference.create_raw_dataset(preprocessed_chs, device)
472
 
473
  print(f"Output Embeddings Shape: {output_emb.shape}")
 
653
  embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
654
  output_textbox = gr.Textbox(label="Console Output", lines=10, elem_id="console-output")
655
 
656
+ emb_type = gr.Dropdown(choices=["Channel Embedding", "CLS Embedding"],
657
+ value="CLS Embedding",
658
+ label="Embedding Type", interactive=True)
659
+
660
  # Update the file uploader visibility based on user choice
661
  def toggle_file_input(choice):
662
  return gr.update(visible=(choice == "Upload Dataset"))
 
664
  choice_radio.change(fn=toggle_file_input, inputs=[choice_radio], outputs=file_input)
665
 
666
  # When user makes a choice, update the display
667
+ choice_radio.change(fn=handle_user_choice, inputs=[choice_radio, percentage_slider_los, file_input, emb_type],
668
  outputs=[raw_img_los, embeddings_img_los, output_textbox])
669
 
670
  # When percentage slider changes (for predefined data)
671
+ percentage_slider_los.change(fn=handle_user_choice, inputs=[choice_radio, percentage_slider_los, file_input, emb_type],
672
  outputs=[raw_img_los, embeddings_img_los, output_textbox])
673
 
674
+ # When embedding type changes
675
+ emb_type.change(fn=handle_user_choice, inputs=[choice_radio, percentage_slider_los, file_input, emb_type],
676
+ outputs=[raw_img_los, embeddings_img_los, output_textbox])
677
+
678
  # Add a conclusion section at the bottom
679
  gr.Markdown("""
680
  <div class="explanation-box">