| | import entrypoint_setup |
| |
|
| | import os |
| | import tkinter as tk |
| | import argparse |
| | import base64 |
| | import json |
| | import queue |
| | import subprocess |
| | import sys |
| | import traceback |
| | import webbrowser |
| | from types import SimpleNamespace |
| | from tkinter import ttk, messagebox, filedialog |
| | from concurrent.futures import ThreadPoolExecutor |
| |
|
| | from base_models.get_base_models import BaseModelArguments, standard_models |
| | from data.supported_datasets import supported_datasets, standard_data_benchmark, internal_datasets |
| | from embedder import EmbeddingArguments |
| | from probes.get_probe import ProbeArguments |
| | from probes.trainers import TrainerArguments |
| | from main import MainProcess |
| | from data.data_mixin import DataArguments |
| | from modal_utils import parse_modal_api_key |
| | from utils import print_message, print_done, print_title, expand_dms_ids_all |
| | from visualization.plot_result import create_plots |
| | from benchmarks.proteingym.compare_scoring_methods import compare_scoring_methods |
| | from hyperopt_utils import HyperoptModule |
| |
|
| |
|
| | class BackgroundTask: |
| | def __init__(self, target, *args, **kwargs): |
| | self.target = target |
| | self.args = args |
| | self.kwargs = kwargs |
| | self.result = None |
| | self.error = None |
| | self._complete = False |
| | |
| | def run(self): |
| | try: |
| | self.result = self.target(*self.args, **self.kwargs) |
| | except Exception as e: |
| | self.error = e |
| | print_message(f"Error in background task: {str(e)}") |
| | traceback.print_exc() |
| | finally: |
| | self._complete = True |
| | |
| | @property |
| | def complete(self): |
| | return self._complete |
| |
|
| |
|
| | class GUI(MainProcess): |
| | def __init__(self, master): |
| | super().__init__(argparse.Namespace(), GUI=True) |
| | self.master = master |
| | self.master.title("Settings GUI") |
| | self.master.geometry("600x800") |
| |
|
| | icon = tk.PhotoImage(file="protify_logo.png") |
| | |
| | self.master.iconphoto(True, icon) |
| |
|
| | |
| | self.settings_vars = {} |
| |
|
| | |
| | self.notebook = ttk.Notebook(master) |
| | self.notebook.pack(fill='both', expand=True) |
| |
|
| | |
| | self.info_tab = ttk.Frame(self.notebook) |
| | self.data_tab = ttk.Frame(self.notebook) |
| | self.embed_tab = ttk.Frame(self.notebook) |
| | self.model_tab = ttk.Frame(self.notebook) |
| | self.probe_tab = ttk.Frame(self.notebook) |
| | self.trainer_tab = ttk.Frame(self.notebook) |
| | self.wandb_tab = ttk.Frame(self.notebook) |
| | self.modal_tab = ttk.Frame(self.notebook) |
| | self.scikit_tab = ttk.Frame(self.notebook) |
| | self.replay_tab = ttk.Frame(self.notebook) |
| | self.viz_tab = ttk.Frame(self.notebook) |
| | self.proteingym_tab = ttk.Frame(self.notebook) |
| |
|
| | |
| | self.notebook.add(self.info_tab, text="Info") |
| | self.notebook.add(self.model_tab, text="Model") |
| | self.notebook.add(self.data_tab, text="Data") |
| | self.notebook.add(self.embed_tab, text="Embedding") |
| | self.notebook.add(self.probe_tab, text="Probe") |
| | self.notebook.add(self.trainer_tab, text="Trainer") |
| | self.notebook.add(self.wandb_tab, text="W&B Sweep") |
| | self.notebook.add(self.modal_tab, text="Modal") |
| | self.notebook.add(self.proteingym_tab, text="ProteinGym") |
| | self.notebook.add(self.scikit_tab, text="Scikit") |
| | self.notebook.add(self.replay_tab, text="Replay") |
| | self.notebook.add(self.viz_tab, text="Visualization") |
| |
|
| | |
| | self.task_queue = queue.Queue() |
| | self.thread_pool = ThreadPoolExecutor(max_workers=1) |
| | self.current_task = None |
| | self.modal_polling_active = False |
| | |
| | |
| | self.check_task_queue() |
| |
|
| | |
| | self.build_info_tab() |
| | self.build_model_tab() |
| | self.build_data_tab() |
| | self.build_embed_tab() |
| | self.build_probe_tab() |
| | self.build_trainer_tab() |
| | self.build_wandb_tab() |
| | self.build_modal_tab() |
| | self.build_proteingym_tab() |
| | self.build_scikit_tab() |
| | self.build_replay_tab() |
| | self.build_viz_tab() |
| |
|
| | def check_task_queue(self): |
| | """Periodically check for completed background tasks""" |
| | if self.current_task and self.current_task.complete: |
| | if self.current_task.error: |
| | print_message(f"Task failed: {self.current_task.error}") |
| | self.current_task = None |
| | |
| | if not self.current_task and not self.task_queue.empty(): |
| | self.current_task = self.task_queue.get() |
| | self.thread_pool.submit(self.current_task.run) |
| | |
| | |
| | self.master.after(100, self.check_task_queue) |
| | |
| | def run_in_background(self, target, *args, **kwargs): |
| | """Queue a task to run in background""" |
| | task = BackgroundTask(target, *args, **kwargs) |
| | self.task_queue.put(task) |
| | return task |
| |
|
| | def _open_url(self, url): |
| | """Open a URL in the default web browser""" |
| | webbrowser.open_new_tab(url) |
| | |
| | def build_info_tab(self): |
| | |
| | id_frame = ttk.LabelFrame(self.info_tab, text="Identification") |
| | id_frame.pack(fill="x", padx=10, pady=5) |
| |
|
| | |
| | ttk.Label(id_frame, text="Huggingface Username:").grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["huggingface_username"] = tk.StringVar(value="Synthyra") |
| | entry_huggingface_username = ttk.Entry(id_frame, textvariable=self.settings_vars["huggingface_username"], width=30) |
| | entry_huggingface_username.grid(row=0, column=1, padx=10, pady=5) |
| | self.add_help_button(id_frame, 0, 2, "Your Hugging Face username for model downloads and uploads.") |
| |
|
| | |
| | ttk.Label(id_frame, text="Huggingface Token:").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["huggingface_token"] = tk.StringVar(value="") |
| | entry_huggingface_token = ttk.Entry(id_frame, textvariable=self.settings_vars["huggingface_token"], width=30) |
| | entry_huggingface_token.grid(row=1, column=1, padx=10, pady=5) |
| | self.add_help_button(id_frame, 1, 2, "Your Hugging Face API token for accessing gated or private models.") |
| |
|
| | |
| | ttk.Label(id_frame, text="Wandb API Key:").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["wandb_api_key"] = tk.StringVar(value="") |
| | entry_wandb_api_key = ttk.Entry(id_frame, textvariable=self.settings_vars["wandb_api_key"], width=30) |
| | entry_wandb_api_key.grid(row=2, column=1, padx=10, pady=5) |
| | self.add_help_button(id_frame, 2, 2, "Your Weights & Biases API key for experiment tracking.") |
| |
|
| | |
| | ttk.Label(id_frame, text="Synthyra API Key:").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["synthyra_api_key"] = tk.StringVar(value="") |
| | entry_synthyra_api_key = ttk.Entry(id_frame, textvariable=self.settings_vars["synthyra_api_key"], width=30) |
| | entry_synthyra_api_key.grid(row=3, column=1, padx=10, pady=5) |
| | self.add_help_button(id_frame, 3, 2, "Your Synthyra API key for accessing premium features.") |
| |
|
| | |
| | ttk.Label(id_frame, text="Modal API Key (legacy):").grid(row=4, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_api_key"] = tk.StringVar(value="") |
| | entry_modal_api_key = ttk.Entry(id_frame, textvariable=self.settings_vars["modal_api_key"], width=30, show="*") |
| | entry_modal_api_key.grid(row=4, column=1, padx=10, pady=5) |
| | self.add_help_button(id_frame, 4, 2, "Legacy format '<modal_token_id>:<modal_token_secret>'.") |
| |
|
| | |
| | ttk.Label(id_frame, text="Modal Token ID:").grid(row=5, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_token_id"] = tk.StringVar(value="") |
| | entry_modal_token_id = ttk.Entry(id_frame, textvariable=self.settings_vars["modal_token_id"], width=30) |
| | entry_modal_token_id.grid(row=5, column=1, padx=10, pady=5) |
| | self.add_help_button(id_frame, 5, 2, "Modal token ID used for CLI/SDK authentication.") |
| |
|
| | |
| | ttk.Label(id_frame, text="Modal Token Secret:").grid(row=6, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_token_secret"] = tk.StringVar(value="") |
| | entry_modal_token_secret = ttk.Entry(id_frame, textvariable=self.settings_vars["modal_token_secret"], width=30, show="*") |
| | entry_modal_token_secret.grid(row=6, column=1, padx=10, pady=5) |
| | self.add_help_button(id_frame, 6, 2, "Modal token secret used for CLI/SDK authentication.") |
| |
|
| | |
| | paths_frame = ttk.LabelFrame(self.info_tab, text="Paths") |
| | paths_frame.pack(fill="x", padx=10, pady=5) |
| |
|
| | ttk.Label(paths_frame, text='Home Directory:').grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["home_dir"] = tk.StringVar(value=os.getcwd()) |
| | entry_home_dir = ttk.Entry(paths_frame, textvariable=self.settings_vars["home_dir"], width=30) |
| | entry_home_dir.grid(row=0, column=1, padx=10, pady=5) |
| | self.add_help_button(paths_frame, 0, 2, "Home directory for Protify.") |
| |
|
| | |
| | ttk.Label(paths_frame, text="HF Home Directory:").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["hf_home"] = tk.StringVar(value="") |
| | entry_hf_home = ttk.Entry(paths_frame, textvariable=self.settings_vars["hf_home"], width=30) |
| | entry_hf_home.grid(row=1, column=1, padx=10, pady=5) |
| | self.add_help_button(paths_frame, 1, 2, "Customize the HuggingFace cache directory. Leave empty to use default.") |
| |
|
| | |
| | ttk.Label(paths_frame, text="Log Directory:").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["log_dir"] = tk.StringVar(value="logs") |
| | entry_log_dir = ttk.Entry(paths_frame, textvariable=self.settings_vars["log_dir"], width=30) |
| | entry_log_dir.grid(row=2, column=1, padx=10, pady=5) |
| | self.add_help_button(paths_frame, 2, 2, "Directory where log files will be stored.") |
| |
|
| | |
| | ttk.Label(paths_frame, text="Results Directory:").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["results_dir"] = tk.StringVar(value="results") |
| | entry_results_dir = ttk.Entry(paths_frame, textvariable=self.settings_vars["results_dir"], width=30) |
| | entry_results_dir.grid(row=3, column=1, padx=10, pady=5) |
| | self.add_help_button(paths_frame, 3, 2, "Directory where results data will be stored.") |
| |
|
| | |
| | ttk.Label(paths_frame, text="Model Save Directory:").grid(row=4, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["model_save_dir"] = tk.StringVar(value="weights") |
| | entry_model_save = ttk.Entry(paths_frame, textvariable=self.settings_vars["model_save_dir"], width=30) |
| | entry_model_save.grid(row=4, column=1, padx=10, pady=5) |
| | self.add_help_button(paths_frame, 4, 2, "Directory where trained models will be saved.") |
| |
|
| | ttk.Label(paths_frame, text="Plots Directory:").grid(row=5, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["plots_dir"] = tk.StringVar(value="plots") |
| | entry_plots_dir = ttk.Entry(paths_frame, textvariable=self.settings_vars["plots_dir"], width=30) |
| | entry_plots_dir.grid(row=5, column=1, padx=10, pady=5) |
| | self.add_help_button(paths_frame, 5, 2, "Directory where plots and visualizations will be saved.") |
| |
|
| | |
| | ttk.Label(paths_frame, text="Embedding Save Directory:").grid(row=6, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["embedding_save_dir"] = tk.StringVar(value="embeddings") |
| | entry_embed_save = ttk.Entry(paths_frame, textvariable=self.settings_vars["embedding_save_dir"], width=30) |
| | entry_embed_save.grid(row=6, column=1, padx=10, pady=5) |
| | self.add_help_button(paths_frame, 6, 2, "Directory where computed embeddings will be saved.") |
| |
|
| | |
| | ttk.Label(paths_frame, text="Download Directory:").grid(row=7, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["download_dir"] = tk.StringVar(value="Synthyra/vector_embeddings") |
| | entry_download = ttk.Entry(paths_frame, textvariable=self.settings_vars["download_dir"], width=30) |
| | entry_download.grid(row=7, column=1, padx=10, pady=5) |
| | self.add_help_button(paths_frame, 7, 2, "HuggingFace repository path for downloading pre-computed embeddings.") |
| |
|
| | |
| | start_logging_button = ttk.Button(self.info_tab, text="Start session", command=self._session_start) |
| | start_logging_button.pack(pady=10) |
| | |
| | |
| | try: |
| | original_logo = tk.PhotoImage(file="synthyra_logo.png") |
| | |
| | logo = original_logo.subsample(3, 3) |
| | |
| | |
| | bottom_frame = ttk.Frame(self.info_tab) |
| | bottom_frame.pack(pady=(10, 20), fill="x") |
| | |
| | |
| | logo_label = ttk.Label(bottom_frame, image=logo, cursor="hand2") |
| | logo_label.image = logo |
| | logo_label.pack(side=tk.LEFT, padx=(20, 10)) |
| | |
| | logo_label.bind("<Button-1>", lambda e: self._open_url("https://synthyra.com")) |
| | |
| | |
| | visit_btn = ttk.Button( |
| | bottom_frame, |
| | text="Visit Synthyra.com", |
| | command=lambda: self._open_url("https://synthyra.com"), |
| | style="Link.TButton" |
| | ) |
| | |
| | |
| | style = ttk.Style() |
| | style.configure("Link.TButton", font=("Helvetica", 12), foreground="blue") |
| | |
| | visit_btn.pack(side=tk.LEFT, padx=(10, 20), pady=10) |
| | |
| | except Exception as e: |
| | print_message(f"Error setting up logo and link: {str(e)}") |
| |
|
| | def build_model_tab(self): |
| | ttk.Label(self.model_tab, text="Model Names:").grid(row=0, column=0, padx=10, pady=5, sticky="nw") |
| |
|
| | self.model_listbox = tk.Listbox(self.model_tab, selectmode="extended", height=24) |
| | for model_name in standard_models: |
| | self.model_listbox.insert(tk.END, model_name) |
| | self.model_listbox.grid(row=0, column=1, padx=10, pady=5, sticky="nw") |
| | self.add_help_button(self.model_tab, 0, 2, "Select the language models to use for embedding. Multiple models can be selected.") |
| |
|
| | ttk.Label(self.model_tab, text="Model DType:").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["model_dtype"] = tk.StringVar(value="bf16") |
| | combo_model_dtype = ttk.Combobox( |
| | self.model_tab, |
| | textvariable=self.settings_vars["model_dtype"], |
| | values=["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"], |
| | state="readonly", |
| | ) |
| | combo_model_dtype.grid(row=1, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.model_tab, 1, 2, "Data type used when loading base models.") |
| |
|
| | ttk.Label(self.model_tab, text="Use xformers:").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["use_xformers"] = tk.BooleanVar(value=False) |
| | check_use_xformers = ttk.Checkbutton(self.model_tab, variable=self.settings_vars["use_xformers"]) |
| | check_use_xformers.grid(row=2, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.model_tab, 2, 2, "Enable memory-efficient xformers attention where supported.") |
| |
|
| | run_button = ttk.Button(self.model_tab, text="Select Models", command=self._select_models) |
| | run_button.grid(row=99, column=0, columnspan=2, pady=(10, 10)) |
| |
|
| | def build_data_tab(self): |
| | ttk.Label(self.data_tab, text="Max Sequence Length:").grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["max_length"] = tk.IntVar(value=2048) |
| | spin_max_length = ttk.Spinbox(self.data_tab, from_=1, to=32768, textvariable=self.settings_vars["max_length"]) |
| | spin_max_length.grid(row=0, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.data_tab, 0, 2, "Maximum length of sequences (in tokens) to process.") |
| |
|
| | ttk.Label(self.data_tab, text="Trim Sequences:").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["trim"] = tk.BooleanVar(value=False) |
| | check_trim = ttk.Checkbutton(self.data_tab, variable=self.settings_vars["trim"]) |
| | check_trim.grid(row=1, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.data_tab, 1, 2, "Whether to trim sequences to the specified max length.") |
| |
|
| | ttk.Label(self.data_tab, text="Delimiter:").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["delimiter"] = tk.StringVar(value=",") |
| | entry_delimiter = ttk.Entry(self.data_tab, textvariable=self.settings_vars["delimiter"], width=5) |
| | entry_delimiter.grid(row=2, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.data_tab, 2, 2, "Character used to separate columns in CSV data files.") |
| |
|
| | ttk.Label(self.data_tab, text="Column Names (comma-separated):").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["col_names"] = tk.StringVar(value="seqs,labels") |
| | entry_col_names = ttk.Entry(self.data_tab, textvariable=self.settings_vars["col_names"], width=20) |
| | entry_col_names.grid(row=3, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.data_tab, 3, 2, "Names of columns in data files, separate with commas.") |
| |
|
| | ttk.Label(self.data_tab, text="Multi-Column Sequences (space-separated):").grid(row=4, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["multi_column"] = tk.StringVar(value="") |
| | entry_multi_column = ttk.Entry(self.data_tab, textvariable=self.settings_vars["multi_column"], width=20) |
| | entry_multi_column.grid(row=4, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.data_tab, 4, 2, "If set, list of sequence column names to combine per sample (space-separated). Leave empty if not using multi-column sequences.") |
| |
|
| | ttk.Label(self.data_tab, text="Local Data Directories (comma-separated):").grid(row=5, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["data_dirs"] = tk.StringVar(value="") |
| | entry_data_dirs = ttk.Entry(self.data_tab, textvariable=self.settings_vars["data_dirs"], width=30) |
| | entry_data_dirs.grid(row=5, column=1, padx=10, pady=5, sticky="w") |
| | browse_data_dir_button = ttk.Button(self.data_tab, text="Browse", command=self._browse_data_dir) |
| | browse_data_dir_button.grid(row=5, column=2, padx=5, pady=5) |
| | self.add_help_button(self.data_tab, 5, 3, "Optional local dataset directories. Multiple paths can be comma-separated.") |
| |
|
| | ttk.Label(self.data_tab, text="AA -> DNA:").grid(row=6, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["aa_to_dna"] = tk.BooleanVar(value=False) |
| | ttk.Checkbutton(self.data_tab, variable=self.settings_vars["aa_to_dna"]).grid(row=6, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.data_tab, text="AA -> RNA:").grid(row=7, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["aa_to_rna"] = tk.BooleanVar(value=False) |
| | ttk.Checkbutton(self.data_tab, variable=self.settings_vars["aa_to_rna"]).grid(row=7, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.data_tab, text="DNA -> AA:").grid(row=8, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["dna_to_aa"] = tk.BooleanVar(value=False) |
| | ttk.Checkbutton(self.data_tab, variable=self.settings_vars["dna_to_aa"]).grid(row=8, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.data_tab, text="RNA -> AA:").grid(row=9, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["rna_to_aa"] = tk.BooleanVar(value=False) |
| | ttk.Checkbutton(self.data_tab, variable=self.settings_vars["rna_to_aa"]).grid(row=9, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.data_tab, text="Codon -> AA:").grid(row=10, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["codon_to_aa"] = tk.BooleanVar(value=False) |
| | ttk.Checkbutton(self.data_tab, variable=self.settings_vars["codon_to_aa"]).grid(row=10, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.data_tab, text="AA -> Codon:").grid(row=11, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["aa_to_codon"] = tk.BooleanVar(value=False) |
| | ttk.Checkbutton(self.data_tab, variable=self.settings_vars["aa_to_codon"]).grid(row=11, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.data_tab, text="Random Pair Flipping:").grid(row=12, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["random_pair_flipping"] = tk.BooleanVar(value=False) |
| | ttk.Checkbutton(self.data_tab, variable=self.settings_vars["random_pair_flipping"]).grid(row=12, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.data_tab, 12, 2, "Randomly flip paired inputs during training for pair datasets.") |
| |
|
| | ttk.Label(self.data_tab, text="Dataset Names:").grid(row=13, column=0, padx=10, pady=5, sticky="nw") |
| | self.data_listbox = tk.Listbox(self.data_tab, selectmode="extended", height=20, width=25) |
| | for dataset_name in supported_datasets: |
| | if dataset_name not in internal_datasets: |
| | self.data_listbox.insert(tk.END, dataset_name) |
| | self.data_listbox.grid(row=13, column=1, padx=10, pady=5, sticky="nw") |
| | self.add_help_button(self.data_tab, 13, 2, "Select datasets to use. Multiple datasets can be selected.") |
| |
|
| | run_button = ttk.Button(self.data_tab, text="Get Data", command=self._get_data) |
| | run_button.grid(row=99, column=0, columnspan=2, pady=(10, 10)) |
| |
|
| | def build_embed_tab(self): |
| | |
| | ttk.Label(self.embed_tab, text="Batch Size:").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["batch_size"] = tk.IntVar(value=4) |
| | spin_batch_size = ttk.Spinbox(self.embed_tab, from_=1, to=1024, textvariable=self.settings_vars["batch_size"]) |
| | spin_batch_size.grid(row=1, column=1, padx=10, pady=5) |
| | self.add_help_button(self.embed_tab, 1, 2, "Number of sequences to process at once during embedding.") |
| |
|
| | |
| | ttk.Label(self.embed_tab, text="Num Workers:").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["num_workers"] = tk.IntVar(value=0) |
| | spin_num_workers = ttk.Spinbox(self.embed_tab, from_=0, to=64, textvariable=self.settings_vars["num_workers"]) |
| | spin_num_workers.grid(row=2, column=1, padx=10, pady=5) |
| | self.add_help_button(self.embed_tab, 2, 2, "Number of worker processes for data loading. 0 means main process only.") |
| |
|
| | |
| | ttk.Label(self.embed_tab, text="Download Embeddings:").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["download_embeddings"] = tk.BooleanVar(value=False) |
| | check_download = ttk.Checkbutton(self.embed_tab, variable=self.settings_vars["download_embeddings"]) |
| | check_download.grid(row=3, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.embed_tab, 3, 2, "Whether to download pre-computed embeddings from HuggingFace instead of computing them.") |
| |
|
| | |
| | ttk.Label(self.embed_tab, text="Matrix Embedding:").grid(row=4, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["matrix_embed"] = tk.BooleanVar(value=False) |
| | check_matrix = ttk.Checkbutton(self.embed_tab, variable=self.settings_vars["matrix_embed"]) |
| | check_matrix.grid(row=4, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.embed_tab, 4, 2, "Whether to use matrix embedding (full embedding matrices) instead of pooled embeddings.") |
| |
|
| | |
| | ttk.Label(self.embed_tab, text="Pooling Types (comma-separated):").grid(row=5, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["embedding_pooling_types"] = tk.StringVar(value="mean, var") |
| | entry_pooling = ttk.Entry(self.embed_tab, textvariable=self.settings_vars["embedding_pooling_types"], width=20) |
| | entry_pooling.grid(row=5, column=1, padx=10, pady=5) |
| | self.add_help_button(self.embed_tab, 5, 2, "Types of pooling to apply to embeddings, separate with commas.") |
| | |
| | ttk.Label(self.embed_tab, text="Options: mean, max, min, norm, prod, median, std, var, cls, parti").grid(row=6, column=0, columnspan=2, padx=10, pady=2, sticky="w") |
| |
|
| | |
| | ttk.Label(self.embed_tab, text="Embedding DType:").grid(row=7, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["embed_dtype"] = tk.StringVar(value="float32") |
| | combo_dtype = ttk.Combobox( |
| | self.embed_tab, |
| | textvariable=self.settings_vars["embed_dtype"], |
| | values=["float32", "float16", "bfloat16", "float8_e4m3fn", "float8_e5m2"] |
| | ) |
| | combo_dtype.grid(row=7, column=1, padx=10, pady=5) |
| | self.add_help_button(self.embed_tab, 7, 2, "Data type to use for storing embeddings (affects precision and size).") |
| |
|
| | |
| | ttk.Label(self.embed_tab, text="Use SQL:").grid(row=8, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["sql"] = tk.BooleanVar(value=False) |
| | check_sql = ttk.Checkbutton(self.embed_tab, variable=self.settings_vars["sql"]) |
| | check_sql.grid(row=8, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.embed_tab, 8, 2, "Whether to use SQL database for storing embeddings instead of files.") |
| |
|
| | run_button = ttk.Button(self.embed_tab, text="Embed sequences to disk", command=self._get_embeddings) |
| | run_button.grid(row=99, column=0, columnspan=2, pady=(10, 10)) |
| |
|
| | def build_probe_tab(self): |
| | |
| | ttk.Label(self.probe_tab, text="Probe Type:").grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["probe_type"] = tk.StringVar(value="linear") |
| | combo_probe = ttk.Combobox( |
| | self.probe_tab, |
| | textvariable=self.settings_vars["probe_type"], |
| | values=["linear", "transformer", "retrievalnet", "lyra"] |
| | ) |
| | combo_probe.grid(row=0, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 0, 2, "Type of probe architecture to use (linear, transformer, or retrievalnet).") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Tokenwise:").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["tokenwise"] = tk.BooleanVar(value=False) |
| | check_tokenwise = ttk.Checkbutton(self.probe_tab, variable=self.settings_vars["tokenwise"]) |
| | check_tokenwise.grid(row=1, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.probe_tab, 1, 2, "Whether to use token-wise prediction (operate on each token) instead of sequence-level.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Pre Layer Norm:").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["pre_ln"] = tk.BooleanVar(value=True) |
| | check_pre_ln = ttk.Checkbutton(self.probe_tab, variable=self.settings_vars["pre_ln"]) |
| | check_pre_ln.grid(row=2, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.probe_tab, 2, 2, "Whether to use pre-layer normalization in transformer architecture.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Number of Layers:").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["n_layers"] = tk.IntVar(value=1) |
| | spin_n_layers = ttk.Spinbox(self.probe_tab, from_=1, to=100, textvariable=self.settings_vars["n_layers"]) |
| | spin_n_layers.grid(row=3, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 3, 2, "Number of layers in the probe architecture.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Hidden Dimension:").grid(row=4, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["hidden_size"] = tk.IntVar(value=8192) |
| | spin_hidden_size = ttk.Spinbox(self.probe_tab, from_=1, to=10000, textvariable=self.settings_vars["hidden_size"]) |
| | spin_hidden_size.grid(row=4, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 4, 2, "Size of hidden dimension in the probe model.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Dropout:").grid(row=5, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["dropout"] = tk.DoubleVar(value=0.2) |
| | spin_dropout = ttk.Spinbox(self.probe_tab, from_=0.0, to=1.0, increment=0.1, textvariable=self.settings_vars["dropout"]) |
| | spin_dropout.grid(row=5, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 5, 2, "Dropout probability for regularization (0.0-1.0).") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="=== Transformer Probe Settings ===").grid(row=6, column=0, columnspan=2, pady=10) |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Transformer Hidden Dimension:").grid(row=7, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["transformer_hidden_size"] = tk.IntVar(value=512) |
| | spin_transformer_hidden_size = ttk.Spinbox(self.probe_tab, from_=64, to=4096, textvariable=self.settings_vars["transformer_hidden_size"]) |
| | spin_transformer_hidden_size.grid(row=7, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 7, 2, "Internal hidden dimension for transformer probe (512 recommended).") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Classifier Dimension:").grid(row=8, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["classifier_size"] = tk.IntVar(value=4096) |
| | spin_classifier_size = ttk.Spinbox(self.probe_tab, from_=1, to=10000, textvariable=self.settings_vars["classifier_size"]) |
| | spin_classifier_size.grid(row=8, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 8, 2, "Dimension of the classifier/feedforward layer in transformer probe.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Classifier Dropout:").grid(row=9, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["classifier_dropout"] = tk.DoubleVar(value=0.2) |
| | spin_class_dropout = ttk.Spinbox(self.probe_tab, from_=0.0, to=1.0, increment=0.1, textvariable=self.settings_vars["classifier_dropout"]) |
| | spin_class_dropout.grid(row=9, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 9, 2, "Dropout probability in the classifier layer (0.0-1.0).") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Number of Heads:").grid(row=10, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["n_heads"] = tk.IntVar(value=4) |
| | spin_n_heads = ttk.Spinbox(self.probe_tab, from_=1, to=32, textvariable=self.settings_vars["n_heads"]) |
| | spin_n_heads.grid(row=10, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 10, 2, "Number of attention heads in transformer probe.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Rotary:").grid(row=11, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["rotary"] = tk.BooleanVar(value=True) |
| | check_rotary = ttk.Checkbutton(self.probe_tab, variable=self.settings_vars["rotary"]) |
| | check_rotary.grid(row=11, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.probe_tab, 11, 2, "Whether to use rotary position embeddings in transformer.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Pooling Types (comma-separated):").grid(row=12, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["probe_pooling_types"] = tk.StringVar(value="mean, var") |
| | entry_pooling = ttk.Entry(self.probe_tab, textvariable=self.settings_vars["probe_pooling_types"], width=20) |
| | entry_pooling.grid(row=12, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 12, 2, "Types of pooling to use in the probe model, separate with commas.") |
| | |
| | |
| | ttk.Label(self.probe_tab, text="Transformer Dropout:").grid(row=13, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["transformer_dropout"] = tk.DoubleVar(value=0.1) |
| | spin_transformer_dropout = ttk.Spinbox(self.probe_tab, from_=0.0, to=1.0, increment=0.1, textvariable=self.settings_vars["transformer_dropout"]) |
| | spin_transformer_dropout.grid(row=13, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.probe_tab, 13, 2, "Dropout probability in the transformer layers (0.0-1.0).") |
| | |
| | |
| | ttk.Label(self.probe_tab, text="Token Attention:").grid(row=14, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["token_attention"] = tk.BooleanVar(value=False) |
| | check_token_attention = ttk.Checkbutton(self.probe_tab, variable=self.settings_vars["token_attention"]) |
| | check_token_attention.grid(row=14, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.probe_tab, 14, 2, "If true, use TokenFormer instead of Transformer blocks.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Use Bias:").grid(row=15, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["use_bias"] = tk.BooleanVar(value=False) |
| | check_use_bias = ttk.Checkbutton(self.probe_tab, variable=self.settings_vars["use_bias"]) |
| | check_use_bias.grid(row=15, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.probe_tab, 15, 2, "Use bias terms in probe linear layers.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Add Token IDs:").grid(row=16, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["add_token_ids"] = tk.BooleanVar(value=False) |
| | check_add_token_ids = ttk.Checkbutton(self.probe_tab, variable=self.settings_vars["add_token_ids"]) |
| | check_add_token_ids.grid(row=16, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.probe_tab, 16, 2, "Add learned token type IDs for pair tasks.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="=== RetrievalNet Settings ===").grid(row=17, column=0, columnspan=2, pady=10) |
| | |
| | |
| | ttk.Label(self.probe_tab, text="Similarity Type:").grid(row=18, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["sim_type"] = tk.StringVar(value="dot") |
| | combo_sim_type = ttk.Combobox( |
| | self.probe_tab, |
| | textvariable=self.settings_vars["sim_type"], |
| | values=["dot", "euclidean", "cosine"] |
| | ) |
| | combo_sim_type.grid(row=18, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 18, 2, "Cross-attention mechanism for token-parameter-attention (dot, euclidean, or cosine).") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Save Model:").grid(row=19, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["save_model"] = tk.BooleanVar(value=False) |
| | check_save_model = ttk.Checkbutton(self.probe_tab, variable=self.settings_vars["save_model"]) |
| | check_save_model.grid(row=19, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.probe_tab, 19, 2, "Whether to save the trained probe model to disk.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="Production Model:").grid(row=20, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["production_model"] = tk.BooleanVar(value=False) |
| | check_prod_model = ttk.Checkbutton(self.probe_tab, variable=self.settings_vars["production_model"]) |
| | check_prod_model.grid(row=20, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.probe_tab, 20, 2, "Whether to prepare the model for production deployment.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="=== LoRA Settings ===").grid(row=21, column=0, columnspan=2, pady=10) |
| | |
| | |
| | ttk.Label(self.probe_tab, text="Use LoRA:").grid(row=22, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["lora"] = tk.BooleanVar(value=False) |
| | check_lora = ttk.Checkbutton(self.probe_tab, variable=self.settings_vars["lora"]) |
| | check_lora.grid(row=22, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.probe_tab, 22, 2, "Whether to use Low-Rank Adaptation (LoRA) for fine-tuning.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="LoRA r:").grid(row=23, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["lora_r"] = tk.IntVar(value=8) |
| | spin_lora_r = ttk.Spinbox(self.probe_tab, from_=1, to=128, textvariable=self.settings_vars["lora_r"]) |
| | spin_lora_r.grid(row=23, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 23, 2, "Rank parameter r for LoRA (lower = more efficient, higher = more expressive).") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="LoRA alpha:").grid(row=24, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["lora_alpha"] = tk.DoubleVar(value=32.0) |
| | spin_lora_alpha = ttk.Spinbox(self.probe_tab, from_=1.0, to=128.0, increment=1.0, textvariable=self.settings_vars["lora_alpha"]) |
| | spin_lora_alpha.grid(row=24, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 24, 2, "Alpha parameter for LoRA, controls update scale.") |
| |
|
| | |
| | ttk.Label(self.probe_tab, text="LoRA dropout:").grid(row=25, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["lora_dropout"] = tk.DoubleVar(value=0.01) |
| | spin_lora_dropout = ttk.Spinbox(self.probe_tab, from_=0.0, to=0.5, increment=0.01, textvariable=self.settings_vars["lora_dropout"]) |
| | spin_lora_dropout.grid(row=25, column=1, padx=10, pady=5) |
| | self.add_help_button(self.probe_tab, 25, 2, "Dropout probability for LoRA layers (0.0-0.5).") |
| | |
| | |
| | run_button = ttk.Button(self.probe_tab, text="Save Probe Arguments", command=self._create_probe_args) |
| | run_button.grid(row=99, column=0, columnspan=2, pady=(10, 10)) |
| |
|
| | def build_trainer_tab(self): |
| | |
| | ttk.Label(self.trainer_tab, text="Hybrid Probe:").grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["hybrid_probe"] = tk.BooleanVar(value=False) |
| | check_hybrid_probe = ttk.Checkbutton(self.trainer_tab, variable=self.settings_vars["hybrid_probe"]) |
| | check_hybrid_probe.grid(row=0, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.trainer_tab, 0, 2, "Whether to use hybrid probe (combines neural and linear probes).") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Full Finetuning:").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["full_finetuning"] = tk.BooleanVar(value=False) |
| | check_full_ft = ttk.Checkbutton(self.trainer_tab, variable=self.settings_vars["full_finetuning"]) |
| | check_full_ft.grid(row=1, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.trainer_tab, 1, 2, "Whether to perform full finetuning of the entire model.") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Number of Epochs:").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["num_epochs"] = tk.IntVar(value=200) |
| | spin_num_epochs = ttk.Spinbox(self.trainer_tab, from_=1, to=1000, textvariable=self.settings_vars["num_epochs"]) |
| | spin_num_epochs.grid(row=2, column=1, padx=10, pady=5) |
| | self.add_help_button(self.trainer_tab, 2, 2, "Number of training epochs (complete passes through the dataset).") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Probe Batch Size:").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["probe_batch_size"] = tk.IntVar(value=64) |
| | spin_probe_batch_size = ttk.Spinbox(self.trainer_tab, from_=1, to=1000, textvariable=self.settings_vars["probe_batch_size"]) |
| | spin_probe_batch_size.grid(row=3, column=1, padx=10, pady=5) |
| | self.add_help_button(self.trainer_tab, 3, 2, "Batch size for probe training.") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Base Batch Size:").grid(row=4, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["base_batch_size"] = tk.IntVar(value=4) |
| | spin_base_batch_size = ttk.Spinbox(self.trainer_tab, from_=1, to=1000, textvariable=self.settings_vars["base_batch_size"]) |
| | spin_base_batch_size.grid(row=4, column=1, padx=10, pady=5) |
| | self.add_help_button(self.trainer_tab, 4, 2, "Batch size for base model training.") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Probe Grad Accum:").grid(row=5, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["probe_grad_accum"] = tk.IntVar(value=1) |
| | spin_probe_grad_accum = ttk.Spinbox(self.trainer_tab, from_=1, to=100, textvariable=self.settings_vars["probe_grad_accum"]) |
| | spin_probe_grad_accum.grid(row=5, column=1, padx=10, pady=5) |
| | self.add_help_button(self.trainer_tab, 5, 2, "Gradient accumulation steps for probe training.") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Base Grad Accum:").grid(row=6, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["base_grad_accum"] = tk.IntVar(value=8) |
| | spin_base_grad_accum = ttk.Spinbox(self.trainer_tab, from_=1, to=100, textvariable=self.settings_vars["base_grad_accum"]) |
| | spin_base_grad_accum.grid(row=6, column=1, padx=10, pady=5) |
| | self.add_help_button(self.trainer_tab, 6, 2, "Gradient accumulation steps for base model training.") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Learning Rate:").grid(row=7, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["lr"] = tk.DoubleVar(value=1e-4) |
| | spin_lr = ttk.Spinbox(self.trainer_tab, from_=1e-6, to=1e-2, increment=1e-5, textvariable=self.settings_vars["lr"]) |
| | spin_lr.grid(row=7, column=1, padx=10, pady=5) |
| | self.add_help_button(self.trainer_tab, 7, 2, "Learning rate for optimizer. Controls step size during training.") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Weight Decay:").grid(row=8, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["weight_decay"] = tk.DoubleVar(value=0.00) |
| | spin_weight_decay = ttk.Spinbox(self.trainer_tab, from_=0.0, to=1.0, increment=0.01, textvariable=self.settings_vars["weight_decay"]) |
| | spin_weight_decay.grid(row=8, column=1, padx=10, pady=5) |
| | self.add_help_button(self.trainer_tab, 8, 2, "L2 regularization factor to prevent overfitting (0.0-1.0).") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Patience:").grid(row=9, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["patience"] = tk.IntVar(value=1) |
| | spin_patience = ttk.Spinbox(self.trainer_tab, from_=1, to=100, textvariable=self.settings_vars["patience"]) |
| | spin_patience.grid(row=9, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.trainer_tab, 9, 2, "Number of epochs with no improvement after which training will stop.") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Random Seed:").grid(row=10, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["seed"] = tk.IntVar(value=42) |
| | spin_seed = ttk.Spinbox(self.trainer_tab, from_=0, to=10000, textvariable=self.settings_vars["seed"]) |
| | spin_seed.grid(row=10, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.trainer_tab, 10, 2, "Random seed for reproducibility of experiments.") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Read Scaler:").grid(row=11, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["read_scaler"] = tk.IntVar(value=100) |
| | spin_read_scaler = ttk.Spinbox(self.trainer_tab, from_=1, to=1000, textvariable=self.settings_vars["read_scaler"]) |
| | spin_read_scaler.grid(row=11, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.trainer_tab, 11, 2, "Read scaler for SQL storage (multiplier for batch size when reading from SQL database).") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Deterministic:").grid(row=12, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["deterministic"] = tk.BooleanVar(value=False) |
| | check_deterministic = ttk.Checkbutton(self.trainer_tab, variable=self.settings_vars["deterministic"]) |
| | check_deterministic.grid(row=12, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.trainer_tab, 12, 2, "Enable deterministic behavior for reproducibility (will slow down training).") |
| |
|
| | |
| | ttk.Label(self.trainer_tab, text="Number of Runs:").grid(row=13, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["num_runs"] = tk.IntVar(value=1) |
| | spin_num_runs = ttk.Spinbox(self.trainer_tab, from_=1, to=100, textvariable=self.settings_vars["num_runs"]) |
| | spin_num_runs.grid(row=13, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.trainer_tab, 13, 2, "Train multiple runs with different seeds and aggregate metrics.") |
| |
|
| | run_button = ttk.Button(self.trainer_tab, text="Run trainer", command=self._run_trainer) |
| | run_button.grid(row=99, column=0, columnspan=2, pady=(10, 10)) |
| |
|
| | def build_wandb_tab(self): |
| | ttk.Label(self.wandb_tab, text="Use W&B Hyperopt:").grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["use_wandb_hyperopt"] = tk.BooleanVar(value=False) |
| | check_use_wandb_hyperopt = ttk.Checkbutton(self.wandb_tab, variable=self.settings_vars["use_wandb_hyperopt"]) |
| | check_use_wandb_hyperopt.grid(row=0, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.wandb_tab, 0, 2, "Enable Weights & Biases hyperparameter sweeps.") |
| |
|
| | ttk.Label(self.wandb_tab, text="W&B Project:").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["wandb_project"] = tk.StringVar(value="Protify") |
| | entry_wandb_project = ttk.Entry(self.wandb_tab, textvariable=self.settings_vars["wandb_project"], width=30) |
| | entry_wandb_project.grid(row=1, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.wandb_tab, text="W&B Entity (optional):").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["wandb_entity"] = tk.StringVar(value="") |
| | entry_wandb_entity = ttk.Entry(self.wandb_tab, textvariable=self.settings_vars["wandb_entity"], width=30) |
| | entry_wandb_entity.grid(row=2, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.wandb_tab, text="Sweep Config Path:").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["sweep_config_path"] = tk.StringVar(value="yamls/sweep.yaml") |
| | entry_sweep_config_path = ttk.Entry(self.wandb_tab, textvariable=self.settings_vars["sweep_config_path"], width=30) |
| | entry_sweep_config_path.grid(row=3, column=1, padx=10, pady=5, sticky="w") |
| | browse_sweep_path_button = ttk.Button(self.wandb_tab, text="Browse", command=self._browse_sweep_config) |
| | browse_sweep_path_button.grid(row=3, column=2, padx=5, pady=5) |
| |
|
| | ttk.Label(self.wandb_tab, text="Sweep Count:").grid(row=4, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["sweep_count"] = tk.IntVar(value=10) |
| | spin_sweep_count = ttk.Spinbox(self.wandb_tab, from_=1, to=10000, textvariable=self.settings_vars["sweep_count"]) |
| | spin_sweep_count.grid(row=4, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.wandb_tab, text="Sweep Method:").grid(row=5, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["sweep_method"] = tk.StringVar(value="bayes") |
| | combo_sweep_method = ttk.Combobox( |
| | self.wandb_tab, |
| | textvariable=self.settings_vars["sweep_method"], |
| | values=["bayes", "grid", "random"], |
| | state="readonly", |
| | ) |
| | combo_sweep_method.grid(row=5, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.wandb_tab, text="Sweep Metric (Classification):").grid(row=6, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["sweep_metric_cls"] = tk.StringVar(value="eval_loss") |
| | entry_sweep_metric_cls = ttk.Entry(self.wandb_tab, textvariable=self.settings_vars["sweep_metric_cls"], width=30) |
| | entry_sweep_metric_cls.grid(row=6, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.wandb_tab, text="Sweep Metric (Regression):").grid(row=7, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["sweep_metric_reg"] = tk.StringVar(value="eval_loss") |
| | entry_sweep_metric_reg = ttk.Entry(self.wandb_tab, textvariable=self.settings_vars["sweep_metric_reg"], width=30) |
| | entry_sweep_metric_reg.grid(row=7, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.wandb_tab, text="Sweep Goal:").grid(row=8, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["sweep_goal"] = tk.StringVar(value="minimize") |
| | combo_sweep_goal = ttk.Combobox( |
| | self.wandb_tab, |
| | textvariable=self.settings_vars["sweep_goal"], |
| | values=["maximize", "minimize"], |
| | state="readonly", |
| | ) |
| | combo_sweep_goal.grid(row=8, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | run_button = ttk.Button(self.wandb_tab, text="Save W&B Settings", command=self._save_wandb_settings) |
| | run_button.grid(row=99, column=0, columnspan=2, pady=(10, 10)) |
| |
|
| | def build_modal_tab(self): |
| | ttk.Label(self.modal_tab, text="Modal App Name:").grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_app_name"] = tk.StringVar(value="protify-backend") |
| | entry_modal_app_name = ttk.Entry(self.modal_tab, textvariable=self.settings_vars["modal_app_name"], width=30) |
| | entry_modal_app_name.grid(row=0, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="Modal Environment (optional):").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_environment"] = tk.StringVar(value="") |
| | entry_modal_environment = ttk.Entry(self.modal_tab, textvariable=self.settings_vars["modal_environment"], width=30) |
| | entry_modal_environment.grid(row=1, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="Modal Deploy Tag (optional):").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_tag"] = tk.StringVar(value="") |
| | entry_modal_tag = ttk.Entry(self.modal_tab, textvariable=self.settings_vars["modal_tag"], width=30) |
| | entry_modal_tag.grid(row=2, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="Backend Module Path:").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_backend_path"] = tk.StringVar(value="src/protify/modal_backend.py") |
| | entry_modal_backend_path = ttk.Entry(self.modal_tab, textvariable=self.settings_vars["modal_backend_path"], width=30) |
| | entry_modal_backend_path.grid(row=3, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="GPU Type:").grid(row=4, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_gpu_type"] = tk.StringVar(value="A10") |
| | combo_modal_gpu_type = ttk.Combobox( |
| | self.modal_tab, |
| | textvariable=self.settings_vars["modal_gpu_type"], |
| | values=["H200", "H100", "A100-80GB", "A100", "L40S", "A10", "L4", "T4"], |
| | state="readonly", |
| | ) |
| | combo_modal_gpu_type.grid(row=4, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="Runtime Timeout (seconds):").grid(row=5, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_timeout_seconds"] = tk.IntVar(value=86400) |
| | spin_modal_timeout = ttk.Spinbox(self.modal_tab, from_=60, to=604800, textvariable=self.settings_vars["modal_timeout_seconds"]) |
| | spin_modal_timeout.grid(row=5, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="Poll Interval (seconds):").grid(row=6, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_poll_interval_seconds"] = tk.IntVar(value=5) |
| | spin_modal_poll_interval = ttk.Spinbox(self.modal_tab, from_=1, to=600, textvariable=self.settings_vars["modal_poll_interval_seconds"]) |
| | spin_modal_poll_interval.grid(row=6, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="Log Tail Length (chars):").grid(row=7, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_log_tail_chars"] = tk.IntVar(value=5000) |
| | spin_modal_log_tail_chars = ttk.Spinbox(self.modal_tab, from_=500, to=100000, textvariable=self.settings_vars["modal_log_tail_chars"]) |
| | spin_modal_log_tail_chars.grid(row=7, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="Current Job ID:").grid(row=8, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_job_id"] = tk.StringVar(value="") |
| | entry_modal_job_id = ttk.Entry(self.modal_tab, textvariable=self.settings_vars["modal_job_id"], width=30) |
| | entry_modal_job_id.grid(row=8, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="Current Call ID:").grid(row=9, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_call_id"] = tk.StringVar(value="") |
| | entry_modal_call_id = ttk.Entry(self.modal_tab, textvariable=self.settings_vars["modal_call_id"], width=30) |
| | entry_modal_call_id.grid(row=9, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="Artifact Output Directory:").grid(row=10, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_artifacts_dir"] = tk.StringVar(value="modal_artifacts") |
| | entry_modal_artifacts_dir = ttk.Entry(self.modal_tab, textvariable=self.settings_vars["modal_artifacts_dir"], width=30) |
| | entry_modal_artifacts_dir.grid(row=10, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | ttk.Label(self.modal_tab, text="Auto Poll Health:").grid(row=11, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["modal_auto_poll"] = tk.BooleanVar(value=True) |
| | check_modal_auto_poll = ttk.Checkbutton(self.modal_tab, variable=self.settings_vars["modal_auto_poll"]) |
| | check_modal_auto_poll.grid(row=11, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | deploy_button = ttk.Button(self.modal_tab, text="Deploy Modal Backend", command=self._modal_deploy_backend) |
| | deploy_button.grid(row=12, column=0, padx=10, pady=10, sticky="w") |
| |
|
| | submit_button = ttk.Button(self.modal_tab, text="Submit Remote Run", command=self._modal_submit_run) |
| | submit_button.grid(row=12, column=1, padx=10, pady=10, sticky="w") |
| |
|
| | poll_button = ttk.Button(self.modal_tab, text="Poll Status", command=self._modal_poll_status) |
| | poll_button.grid(row=13, column=0, padx=10, pady=5, sticky="w") |
| |
|
| | cancel_button = ttk.Button(self.modal_tab, text="Cancel Run", command=self._modal_cancel_run) |
| | cancel_button.grid(row=13, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | start_auto_poll_button = ttk.Button(self.modal_tab, text="Start Auto Poll", command=self._modal_start_auto_poll) |
| | start_auto_poll_button.grid(row=14, column=0, padx=10, pady=5, sticky="w") |
| |
|
| | stop_auto_poll_button = ttk.Button(self.modal_tab, text="Stop Auto Poll", command=self._modal_stop_auto_poll) |
| | stop_auto_poll_button.grid(row=14, column=1, padx=10, pady=5, sticky="w") |
| |
|
| | fetch_button = ttk.Button(self.modal_tab, text="Fetch Logs/Results/Plots", command=self._modal_fetch_artifacts) |
| | fetch_button.grid(row=15, column=0, columnspan=2, padx=10, pady=10, sticky="w") |
| |
|
| | def build_proteingym_tab(self): |
| | |
| | ttk.Label(self.proteingym_tab, text="Run ProteinGym:").grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["proteingym"] = tk.BooleanVar(value=False) |
| | check_proteingym = ttk.Checkbutton(self.proteingym_tab, variable=self.settings_vars["proteingym"]) |
| | check_proteingym.grid(row=0, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.proteingym_tab, 0, 2, "Enable ProteinGym zero-shot evaluation.") |
| |
|
| | |
| | ttk.Label(self.proteingym_tab, text="DMS IDs (space-separated or 'all'):").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["dms_ids"] = tk.StringVar(value="all") |
| | entry_dms_ids = ttk.Entry(self.proteingym_tab, textvariable=self.settings_vars["dms_ids"], width=30) |
| | entry_dms_ids.grid(row=1, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.proteingym_tab, 1, 2, "List of DMS IDs to evaluate, or 'all'.") |
| |
|
| | |
| | ttk.Label(self.proteingym_tab, text="Mode:").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["mode"] = tk.StringVar(value="benchmark") |
| | combo_mode = ttk.Combobox( |
| | self.proteingym_tab, |
| | textvariable=self.settings_vars["mode"], |
| | values=["benchmark", "indels", "multiples", "singles"] |
| | ) |
| | combo_mode.grid(row=2, column=1, padx=10, pady=5) |
| | self.add_help_button(self.proteingym_tab, 2, 2, "ProteinGym zero-shot mode.") |
| |
|
| | |
| | ttk.Label(self.proteingym_tab, text="Scoring Method:").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["scoring_method"] = tk.StringVar(value="masked_marginal") |
| | combo_scoring_method = ttk.Combobox( |
| | self.proteingym_tab, |
| | textvariable=self.settings_vars["scoring_method"], |
| | values=["masked_marginal", "mutant_marginal", "wildtype_marginal", "pll", "global_log_prob"] |
| | ) |
| | combo_scoring_method.grid(row=3, column=1, padx=10, pady=5) |
| | self.add_help_button(self.proteingym_tab, 3, 2, "Scoring method for zero-shot evaluation.") |
| |
|
| | |
| | ttk.Label(self.proteingym_tab, text="Scoring Window:").grid(row=4, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["scoring_window"] = tk.StringVar(value="optimal") |
| | combo_scoring_window = ttk.Combobox( |
| | self.proteingym_tab, |
| | textvariable=self.settings_vars["scoring_window"], |
| | values=["optimal", "sliding"] |
| | ) |
| | combo_scoring_window.grid(row=4, column=1, padx=10, pady=5) |
| | self.add_help_button(self.proteingym_tab, 4, 2, "Windowing strategy for scoring.") |
| |
|
| | |
| | ttk.Label(self.proteingym_tab, text="Batch Size:").grid(row=5, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["pg_batch_size"] = tk.IntVar(value=32) |
| | spin_pg_batch_size = ttk.Spinbox(self.proteingym_tab, from_=1, to=1024, textvariable=self.settings_vars["pg_batch_size"]) |
| | spin_pg_batch_size.grid(row=5, column=1, padx=10, pady=5) |
| | self.add_help_button(self.proteingym_tab, 5, 2, "Batch size for ProteinGym scoring.") |
| |
|
| | |
| | ttk.Label(self.proteingym_tab, text="Compare Scoring Methods:").grid(row=6, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["compare_scoring_methods"] = tk.BooleanVar(value=False) |
| | check_compare = ttk.Checkbutton(self.proteingym_tab, variable=self.settings_vars["compare_scoring_methods"]) |
| | check_compare.grid(row=6, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.proteingym_tab, 6, 2, "Compare different scoring methods across models and DMS assays.") |
| |
|
| | ttk.Label(self.proteingym_tab, text="Score Only:").grid(row=7, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["score_only"] = tk.BooleanVar(value=False) |
| | check_score_only = ttk.Checkbutton(self.proteingym_tab, variable=self.settings_vars["score_only"]) |
| | check_score_only.grid(row=7, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(self.proteingym_tab, 7, 2, "Skip scoring and run benchmark report generation on existing results.") |
| |
|
| | run_button = ttk.Button(self.proteingym_tab, text="Run ProteinGym", command=self._run_proteingym) |
| | run_button.grid(row=99, column=0, columnspan=2, pady=(10, 10)) |
| |
|
| | def build_scikit_tab(self): |
| | |
| | scikit_frame = ttk.LabelFrame(self.scikit_tab, text="Scikit-Learn Settings") |
| | scikit_frame.pack(fill="x", padx=10, pady=5) |
| | |
| | |
| | ttk.Label(scikit_frame, text="Use Scikit:").grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["use_scikit"] = tk.BooleanVar(value=False) |
| | check_scikit = ttk.Checkbutton(scikit_frame, variable=self.settings_vars["use_scikit"]) |
| | check_scikit.grid(row=0, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(scikit_frame, 0, 2, "Whether to use scikit-learn models instead of neural networks.") |
| |
|
| | |
| | ttk.Label(scikit_frame, text="Scikit Iterations:").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["scikit_n_iter"] = tk.IntVar(value=10) |
| | spin_scikit_n_iter = ttk.Spinbox(scikit_frame, from_=1, to=1000, textvariable=self.settings_vars["scikit_n_iter"]) |
| | spin_scikit_n_iter.grid(row=1, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(scikit_frame, 1, 2, "Number of iterations for iterative scikit-learn models.") |
| |
|
| | |
| | ttk.Label(scikit_frame, text="Scikit CV Folds:").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["scikit_cv"] = tk.IntVar(value=3) |
| | spin_scikit_cv = ttk.Spinbox(scikit_frame, from_=1, to=10, textvariable=self.settings_vars["scikit_cv"]) |
| | spin_scikit_cv.grid(row=2, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(scikit_frame, 2, 2, "Number of cross-validation folds for model evaluation.") |
| |
|
| | |
| | ttk.Label(scikit_frame, text="Scikit Random State:").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["scikit_random_state"] = tk.IntVar(value=42) |
| | spin_scikit_rand = ttk.Spinbox(scikit_frame, from_=0, to=10000, textvariable=self.settings_vars["scikit_random_state"]) |
| | spin_scikit_rand.grid(row=3, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(scikit_frame, 3, 2, "Random seed for scikit-learn models to ensure reproducibility.") |
| |
|
| | |
| | ttk.Label(scikit_frame, text="Scikit Model Name (optional):").grid(row=4, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["scikit_model_name"] = tk.StringVar(value="") |
| | entry_scikit_name = ttk.Entry(scikit_frame, textvariable=self.settings_vars["scikit_model_name"], width=30) |
| | entry_scikit_name.grid(row=4, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(scikit_frame, 4, 2, "Optional name for the scikit-learn model. Leave blank to use default.") |
| | |
| | |
| | ttk.Label(scikit_frame, text="Number of Jobs:").grid(row=5, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["n_jobs"] = tk.IntVar(value=1) |
| | spin_n_jobs = ttk.Spinbox(scikit_frame, from_=1, to=32, textvariable=self.settings_vars["n_jobs"]) |
| | spin_n_jobs.grid(row=5, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(scikit_frame, 5, 2, "Number of CPU cores to use for parallel processing. Use -1 for all cores.") |
| |
|
| | run_button = ttk.Button(self.scikit_tab, text="Run Scikit Models", command=self._run_scikit) |
| | run_button.pack(pady=(20, 10)) |
| |
|
| | def build_replay_tab(self): |
| | |
| | replay_frame = ttk.LabelFrame(self.replay_tab, text="Log Replay Settings") |
| | replay_frame.pack(fill="x", padx=10, pady=5) |
| |
|
| | |
| | ttk.Label(replay_frame, text="Replay Log Path:").grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["replay_path"] = tk.StringVar(value="") |
| | entry_replay = ttk.Entry(replay_frame, textvariable=self.settings_vars["replay_path"], width=40) |
| | entry_replay.grid(row=0, column=1, padx=10, pady=5) |
| | self.add_help_button(replay_frame, 0, 2, "Path to the log file to replay. Use Browse button to select a file.") |
| |
|
| | |
| | browse_button = ttk.Button(replay_frame, text="Browse", command=self._browse_replay_log) |
| | browse_button.grid(row=0, column=2, padx=5, pady=5) |
| |
|
| | |
| | replay_button = ttk.Button(replay_frame, text="Start Replay", command=self._start_replay) |
| | replay_button.grid(row=1, column=0, columnspan=3, pady=20) |
| |
|
| | def build_viz_tab(self): |
| | |
| | viz_frame = ttk.LabelFrame(self.viz_tab, text="Visualization Settings") |
| | viz_frame.pack(fill="x", padx=10, pady=5) |
| |
|
| | |
| | ttk.Label(viz_frame, text="Result ID:").grid(row=0, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["result_id"] = tk.StringVar(value="") |
| | entry_result_id = ttk.Entry(viz_frame, textvariable=self.settings_vars["result_id"], width=30) |
| | entry_result_id.grid(row=0, column=1, padx=10, pady=5) |
| | self.add_help_button(viz_frame, 0, 2, "ID of the result to visualize. Will look for results/{result_id}.tsv") |
| |
|
| | |
| | ttk.Label(viz_frame, text="Results File:").grid(row=1, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["results_file"] = tk.StringVar(value="") |
| | entry_results_file = ttk.Entry(viz_frame, textvariable=self.settings_vars["results_file"], width=30) |
| | entry_results_file.grid(row=1, column=1, padx=10, pady=5) |
| | |
| | |
| | browse_button = ttk.Button(viz_frame, text="Browse", command=self._browse_results_file) |
| | browse_button.grid(row=1, column=2, padx=5, pady=5) |
| | |
| | |
| | ttk.Label(viz_frame, text="Use Current Run:").grid(row=2, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["use_current_run"] = tk.BooleanVar(value=True) |
| | check_current_run = ttk.Checkbutton(viz_frame, variable=self.settings_vars["use_current_run"]) |
| | check_current_run.grid(row=2, column=1, padx=10, pady=5, sticky="w") |
| | self.add_help_button(viz_frame, 2, 2, "Use results from the current run.") |
| |
|
| | |
| | ttk.Label(viz_frame, text="Output Directory:").grid(row=3, column=0, padx=10, pady=5, sticky="w") |
| | self.settings_vars["viz_output_dir"] = tk.StringVar(value="plots") |
| | entry_output_dir = ttk.Entry(viz_frame, textvariable=self.settings_vars["viz_output_dir"], width=30) |
| | entry_output_dir.grid(row=3, column=1, padx=10, pady=5) |
| | self.add_help_button(viz_frame, 3, 2, "Directory where plots will be saved.") |
| |
|
| |
|
| | |
| | generate_button = ttk.Button(viz_frame, text="Generate Plots", command=self._generate_plots) |
| | generate_button.grid(row=99, column=0, columnspan=3, pady=20) |
| |
|
| | def add_help_button(self, parent, row, column, help_text): |
| | """Add a small help button that displays information when clicked""" |
| | help_button = ttk.Button(parent, text="?", width=2, |
| | command=lambda: messagebox.showinfo("Help", help_text)) |
| | help_button.grid(row=row, column=column, padx=(0,5), pady=5) |
| | return help_button |
| |
|
| | def _selected_model_dtype(self): |
| | dtype_name = self.settings_vars["model_dtype"].get() |
| | assert dtype_name in self.dtype_map, f"Unsupported model dtype: {dtype_name}" |
| | return self.dtype_map[dtype_name] |
| |
|
| | def _selected_embed_dtype(self): |
| | dtype_name = self.settings_vars["embed_dtype"].get() |
| | assert dtype_name in self.dtype_map, f"Unsupported embedding dtype: {dtype_name}" |
| | return self.dtype_map[dtype_name] |
| |
|
| | def _browse_data_dir(self): |
| | data_dir = filedialog.askdirectory(title="Select Data Directory") |
| | if not data_dir: |
| | return |
| | existing = self.settings_vars["data_dirs"].get().strip() |
| | if not existing: |
| | self.settings_vars["data_dirs"].set(data_dir) |
| | return |
| | existing_parts = [path.strip() for path in existing.split(",") if path.strip()] |
| | if data_dir not in existing_parts: |
| | existing_parts.append(data_dir) |
| | self.settings_vars["data_dirs"].set(", ".join(existing_parts)) |
| |
|
| | def _browse_sweep_config(self): |
| | filename = filedialog.askopenfilename( |
| | title="Select W&B Sweep Config", |
| | filetypes=(("YAML files", "*.yaml *.yml"), ("All files", "*.*")), |
| | ) |
| | if filename: |
| | self.settings_vars["sweep_config_path"].set(filename) |
| |
|
| | def _save_wandb_settings(self): |
| | print_message("Saving W&B sweep settings...") |
| | self.full_args.use_wandb_hyperopt = self.settings_vars["use_wandb_hyperopt"].get() |
| | self.full_args.wandb_project = self.settings_vars["wandb_project"].get().strip() or "Protify" |
| | wandb_entity = self.settings_vars["wandb_entity"].get().strip() |
| | self.full_args.wandb_entity = wandb_entity if wandb_entity else None |
| | self.full_args.sweep_config_path = self.settings_vars["sweep_config_path"].get().strip() or "yamls/sweep.yaml" |
| | self.full_args.sweep_count = self.settings_vars["sweep_count"].get() |
| | self.full_args.sweep_method = self.settings_vars["sweep_method"].get() |
| | self.full_args.sweep_metric_cls = self.settings_vars["sweep_metric_cls"].get().strip() or "eval_loss" |
| | self.full_args.sweep_metric_reg = self.settings_vars["sweep_metric_reg"].get().strip() or "eval_loss" |
| | self.full_args.sweep_goal = self.settings_vars["sweep_goal"].get() |
| |
|
| | args_dict = {k: v for k, v in self.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()} |
| | self.logger_args = SimpleNamespace(**args_dict) |
| | if "log_file" in self.__dict__: |
| | self._write_args() |
| | print_message("W&B sweep settings saved") |
| | print_done() |
| |
|
| | def _resolve_repo_root(self): |
| | return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) |
| |
|
| | def _resolve_modal_backend_path(self): |
| | configured_path = self.settings_vars["modal_backend_path"].get().strip() |
| | if not configured_path: |
| | configured_path = "src/protify/modal_backend.py" |
| | if os.path.isabs(configured_path): |
| | backend_path = configured_path |
| | else: |
| | home_dir = self.settings_vars["home_dir"].get().strip() |
| | candidate_home = os.path.abspath(os.path.join(home_dir, configured_path)) |
| | candidate_repo = os.path.abspath(os.path.join(self._resolve_repo_root(), configured_path)) |
| | if os.path.exists(candidate_home): |
| | backend_path = candidate_home |
| | else: |
| | backend_path = candidate_repo |
| | assert os.path.exists(backend_path), f"Modal backend path not found: {backend_path}" |
| | return backend_path |
| |
|
| | def _resolve_modal_credentials(self): |
| | modal_api_key = self.settings_vars["modal_api_key"].get().strip() |
| | modal_token_id = self.settings_vars["modal_token_id"].get().strip() |
| | modal_token_secret = self.settings_vars["modal_token_secret"].get().strip() |
| | if modal_api_key and ((not modal_token_id) or (not modal_token_secret)): |
| | modal_token_id, modal_token_secret = parse_modal_api_key(modal_api_key) |
| | self.settings_vars["modal_token_id"].set(modal_token_id) |
| | self.settings_vars["modal_token_secret"].set(modal_token_secret) |
| | if modal_token_id == "": |
| | modal_token_id = None |
| | if modal_token_secret == "": |
| | modal_token_secret = None |
| | return modal_token_id, modal_token_secret |
| |
|
| | def _build_modal_env(self): |
| | env = os.environ.copy() |
| | |
| | env["PYTHONIOENCODING"] = "utf-8" |
| | env["PYTHONUTF8"] = "1" |
| | modal_token_id, modal_token_secret = self._resolve_modal_credentials() |
| | if modal_token_id is not None: |
| | env["MODAL_TOKEN_ID"] = modal_token_id |
| | os.environ["MODAL_TOKEN_ID"] = modal_token_id |
| | if modal_token_secret is not None: |
| | env["MODAL_TOKEN_SECRET"] = modal_token_secret |
| | os.environ["MODAL_TOKEN_SECRET"] = modal_token_secret |
| | modal_environment = self.settings_vars["modal_environment"].get().strip() |
| | if modal_environment: |
| | env["MODAL_ENVIRONMENT"] = modal_environment |
| | os.environ["MODAL_ENVIRONMENT"] = modal_environment |
| | return env |
| |
|
| | def _get_modal_sdk(self): |
| | try: |
| | import modal |
| | except Exception as error: |
| | raise RuntimeError("Modal SDK is not installed. Install it with: py -m pip install modal") from error |
| | return modal |
| |
|
| | def _get_modal_function(self, function_name): |
| | modal = self._get_modal_sdk() |
| | app_name = self.settings_vars["modal_app_name"].get().strip() |
| | if app_name == "": |
| | app_name = "protify-backend" |
| | return modal.Function.from_name(app_name, function_name) |
| |
|
| | def _collect_modal_run_config(self): |
| | selected_model_indices = self.model_listbox.curselection() |
| | selected_models = [self.model_listbox.get(i) for i in selected_model_indices] |
| | if len(selected_models) == 0: |
| | selected_models = standard_models |
| |
|
| | selected_dataset_indices = self.data_listbox.curselection() |
| | selected_datasets = [self.data_listbox.get(i) for i in selected_dataset_indices] |
| | data_dirs_str = self.settings_vars["data_dirs"].get().strip() |
| | data_dirs = [path.strip() for path in data_dirs_str.split(",") if path.strip()] |
| |
|
| | run_proteingym = self.settings_vars["proteingym"].get() |
| | if (len(selected_datasets) == 0) and (len(data_dirs) == 0) and (not run_proteingym): |
| | selected_datasets = standard_data_benchmark |
| |
|
| | col_names = [name.strip() for name in self.settings_vars["col_names"].get().split(",") if name.strip()] |
| | multi_column_raw = self.settings_vars["multi_column"].get().strip() |
| | if multi_column_raw: |
| | multi_column = multi_column_raw.split() |
| | else: |
| | multi_column = None |
| |
|
| | embedding_pooling = [item.strip() for item in self.settings_vars["embedding_pooling_types"].get().split(",") if item.strip()] |
| | probe_pooling = [item.strip() for item in self.settings_vars["probe_pooling_types"].get().split(",") if item.strip()] |
| |
|
| | dms_ids_raw = self.settings_vars["dms_ids"].get().strip() |
| | if dms_ids_raw.lower() == "all": |
| | dms_ids = ["all"] |
| | else: |
| | dms_ids = [item.strip() for item in dms_ids_raw.split() if item.strip()] |
| |
|
| | wandb_entity = self.settings_vars["wandb_entity"].get().strip() |
| | if wandb_entity == "": |
| | wandb_entity = None |
| |
|
| | scikit_model_name = self.settings_vars["scikit_model_name"].get().strip() |
| | if scikit_model_name == "": |
| | scikit_model_name = None |
| |
|
| | hf_home = self.settings_vars["hf_home"].get().strip() |
| | if hf_home == "": |
| | hf_home = None |
| |
|
| | config = { |
| | "hf_username": self.settings_vars["huggingface_username"].get().strip() or "Synthyra", |
| | "hf_token": self.settings_vars["huggingface_token"].get().strip() or None, |
| | "wandb_api_key": self.settings_vars["wandb_api_key"].get().strip() or None, |
| | "synthyra_api_key": self.settings_vars["synthyra_api_key"].get().strip() or None, |
| | "hf_home": hf_home, |
| | "log_dir": self.settings_vars["log_dir"].get().strip() or "logs", |
| | "results_dir": self.settings_vars["results_dir"].get().strip() or "results", |
| | "model_save_dir": self.settings_vars["model_save_dir"].get().strip() or "weights", |
| | "embedding_save_dir": self.settings_vars["embedding_save_dir"].get().strip() or "embeddings", |
| | "download_dir": self.settings_vars["download_dir"].get().strip() or "Synthyra/vector_embeddings", |
| | "plots_dir": self.settings_vars["plots_dir"].get().strip() or "plots", |
| | "replay_path": None, |
| | "pretrained_probe_path": None, |
| | "data_names": selected_datasets, |
| | "data_dirs": data_dirs, |
| | "delimiter": self.settings_vars["delimiter"].get(), |
| | "col_names": col_names, |
| | "max_length": self.settings_vars["max_length"].get(), |
| | "trim": self.settings_vars["trim"].get(), |
| | "multi_column": multi_column, |
| | "aa_to_dna": self.settings_vars["aa_to_dna"].get(), |
| | "aa_to_rna": self.settings_vars["aa_to_rna"].get(), |
| | "dna_to_aa": self.settings_vars["dna_to_aa"].get(), |
| | "rna_to_aa": self.settings_vars["rna_to_aa"].get(), |
| | "codon_to_aa": self.settings_vars["codon_to_aa"].get(), |
| | "aa_to_codon": self.settings_vars["aa_to_codon"].get(), |
| | "random_pair_flipping": self.settings_vars["random_pair_flipping"].get(), |
| | "model_names": selected_models, |
| | "model_paths": None, |
| | "model_types": None, |
| | "model_dtype": self.settings_vars["model_dtype"].get(), |
| | "use_xformers": self.settings_vars["use_xformers"].get(), |
| | "embedding_batch_size": self.settings_vars["batch_size"].get(), |
| | "embedding_num_workers": self.settings_vars["num_workers"].get(), |
| | "num_workers": self.settings_vars["num_workers"].get(), |
| | "download_embeddings": self.settings_vars["download_embeddings"].get(), |
| | "matrix_embed": self.settings_vars["matrix_embed"].get(), |
| | "embedding_pooling_types": embedding_pooling, |
| | "save_embeddings": True, |
| | "embed_dtype": self.settings_vars["embed_dtype"].get(), |
| | "sql": self.settings_vars["sql"].get(), |
| | "probe_type": self.settings_vars["probe_type"].get(), |
| | "tokenwise": self.settings_vars["tokenwise"].get(), |
| | "hidden_size": self.settings_vars["hidden_size"].get(), |
| | "transformer_hidden_size": self.settings_vars["transformer_hidden_size"].get(), |
| | "dropout": self.settings_vars["dropout"].get(), |
| | "n_layers": self.settings_vars["n_layers"].get(), |
| | "pre_ln": self.settings_vars["pre_ln"].get(), |
| | "classifier_size": self.settings_vars["classifier_size"].get(), |
| | "transformer_dropout": self.settings_vars["transformer_dropout"].get(), |
| | "classifier_dropout": self.settings_vars["classifier_dropout"].get(), |
| | "n_heads": self.settings_vars["n_heads"].get(), |
| | "rotary": self.settings_vars["rotary"].get(), |
| | "probe_pooling_types": probe_pooling, |
| | "use_bias": self.settings_vars["use_bias"].get(), |
| | "save_model": self.settings_vars["save_model"].get(), |
| | "production_model": self.settings_vars["production_model"].get(), |
| | "lora": self.settings_vars["lora"].get(), |
| | "lora_r": self.settings_vars["lora_r"].get(), |
| | "lora_alpha": self.settings_vars["lora_alpha"].get(), |
| | "lora_dropout": self.settings_vars["lora_dropout"].get(), |
| | "sim_type": self.settings_vars["sim_type"].get(), |
| | "token_attention": self.settings_vars["token_attention"].get(), |
| | "add_token_ids": self.settings_vars["add_token_ids"].get(), |
| | "num_epochs": self.settings_vars["num_epochs"].get(), |
| | "probe_batch_size": self.settings_vars["probe_batch_size"].get(), |
| | "base_batch_size": self.settings_vars["base_batch_size"].get(), |
| | "probe_grad_accum": self.settings_vars["probe_grad_accum"].get(), |
| | "base_grad_accum": self.settings_vars["base_grad_accum"].get(), |
| | "lr": self.settings_vars["lr"].get(), |
| | "weight_decay": self.settings_vars["weight_decay"].get(), |
| | "patience": self.settings_vars["patience"].get(), |
| | "seed": self.settings_vars["seed"].get(), |
| | "deterministic": self.settings_vars["deterministic"].get(), |
| | "full_finetuning": self.settings_vars["full_finetuning"].get(), |
| | "hybrid_probe": self.settings_vars["hybrid_probe"].get(), |
| | "num_runs": self.settings_vars["num_runs"].get(), |
| | "read_scaler": self.settings_vars["read_scaler"].get(), |
| | "dms_ids": dms_ids, |
| | "proteingym": run_proteingym, |
| | "mode": self.settings_vars["mode"].get(), |
| | "scoring_method": self.settings_vars["scoring_method"].get(), |
| | "scoring_window": self.settings_vars["scoring_window"].get(), |
| | "pg_batch_size": self.settings_vars["pg_batch_size"].get(), |
| | "compare_scoring_methods": self.settings_vars["compare_scoring_methods"].get(), |
| | "score_only": self.settings_vars["score_only"].get(), |
| | "use_wandb_hyperopt": self.settings_vars["use_wandb_hyperopt"].get(), |
| | "wandb_project": self.settings_vars["wandb_project"].get().strip() or "Protify", |
| | "wandb_entity": wandb_entity, |
| | "sweep_config_path": self.settings_vars["sweep_config_path"].get().strip() or "yamls/sweep.yaml", |
| | "sweep_count": self.settings_vars["sweep_count"].get(), |
| | "sweep_method": self.settings_vars["sweep_method"].get(), |
| | "sweep_metric_cls": self.settings_vars["sweep_metric_cls"].get().strip() or "eval_loss", |
| | "sweep_metric_reg": self.settings_vars["sweep_metric_reg"].get().strip() or "eval_loss", |
| | "sweep_goal": self.settings_vars["sweep_goal"].get(), |
| | "use_scikit": self.settings_vars["use_scikit"].get(), |
| | "scikit_n_iter": self.settings_vars["scikit_n_iter"].get(), |
| | "scikit_cv": self.settings_vars["scikit_cv"].get(), |
| | "scikit_random_state": self.settings_vars["scikit_random_state"].get(), |
| | "scikit_model_name": scikit_model_name, |
| | "n_jobs": self.settings_vars["n_jobs"].get(), |
| | } |
| | return config |
| |
|
| | def _modal_deploy_backend(self): |
| | print_message("Deploying Modal backend...") |
| |
|
| | def background_deploy(): |
| | backend_path = self._resolve_modal_backend_path() |
| | repo_root = self._resolve_repo_root() |
| | env = self._build_modal_env() |
| |
|
| | app_name = self.settings_vars["modal_app_name"].get().strip() or "protify-backend" |
| | modal_environment = self.settings_vars["modal_environment"].get().strip() |
| | modal_tag = self.settings_vars["modal_tag"].get().strip() |
| |
|
| | command = [sys.executable, "-m", "modal", "deploy", backend_path, "--name", app_name] |
| | if modal_environment: |
| | command.extend(["--env", modal_environment]) |
| | if modal_tag: |
| | command.extend(["--tag", modal_tag]) |
| |
|
| | try: |
| | process = subprocess.run(command, cwd=repo_root, env=env, capture_output=True, text=True) |
| | except FileNotFoundError: |
| | fallback_command = ["modal", "deploy", backend_path, "--name", app_name] |
| | if modal_environment: |
| | fallback_command.extend(["--env", modal_environment]) |
| | if modal_tag: |
| | fallback_command.extend(["--tag", modal_tag]) |
| | process = subprocess.run(fallback_command, cwd=repo_root, env=env, capture_output=True, text=True) |
| |
|
| | if process.returncode != 0: |
| | if "No module named modal" in process.stderr: |
| | raise RuntimeError("Modal is not installed in this Python environment. Install it with: py -m pip install modal") |
| | raise RuntimeError(f"Modal deploy failed:\n{process.stderr}") |
| |
|
| | stdout_tail = process.stdout[-4000:] if process.stdout else "Deployment completed." |
| | print_message(stdout_tail) |
| | print_done() |
| |
|
| | self.run_in_background(background_deploy) |
| |
|
| | def _modal_submit_run(self): |
| | print_message("Submitting remote Modal run...") |
| |
|
| | def background_submit(): |
| | self._build_modal_env() |
| | submit_fn = self._get_modal_function("submit_protify_job") |
| | config = self._collect_modal_run_config() |
| |
|
| | gpu_type = self.settings_vars["modal_gpu_type"].get() |
| | timeout_seconds = self.settings_vars["modal_timeout_seconds"].get() |
| | hf_token = self.settings_vars["huggingface_token"].get().strip() or None |
| | wandb_api_key = self.settings_vars["wandb_api_key"].get().strip() or None |
| | synthyra_api_key = self.settings_vars["synthyra_api_key"].get().strip() or None |
| |
|
| | result = submit_fn.remote( |
| | config=config, |
| | gpu_type=gpu_type, |
| | hf_token=hf_token, |
| | wandb_api_key=wandb_api_key, |
| | synthyra_api_key=synthyra_api_key, |
| | timeout_seconds=timeout_seconds, |
| | ) |
| | assert isinstance(result, dict), "submit_protify_job returned a non-dict response." |
| | assert "job_id" in result, "submit_protify_job response missing job_id." |
| | assert "function_call_id" in result, "submit_protify_job response missing function_call_id." |
| |
|
| | job_id = result["job_id"] |
| | function_call_id = result["function_call_id"] |
| | self.settings_vars["modal_job_id"].set(job_id) |
| | self.settings_vars["modal_call_id"].set(function_call_id) |
| | self.full_args.modal_job_id = job_id |
| | self.full_args.modal_call_id = function_call_id |
| |
|
| | print_message(f"Modal job submitted.\nJob ID: {job_id}\nCall ID: {function_call_id}") |
| | if self.settings_vars["modal_auto_poll"].get(): |
| | self.modal_polling_active = True |
| | self.master.after(0, self._modal_auto_poll_loop) |
| | print_done() |
| |
|
| | self.run_in_background(background_submit) |
| |
|
| | def _modal_start_auto_poll(self): |
| | if self.modal_polling_active: |
| | print_message("Auto polling is already active.") |
| | return |
| | self.modal_polling_active = True |
| | print_message("Started Modal auto polling.") |
| | self._modal_auto_poll_loop() |
| |
|
| | def _modal_stop_auto_poll(self): |
| | self.modal_polling_active = False |
| | print_message("Stopped Modal auto polling.") |
| |
|
| | def _modal_auto_poll_loop(self): |
| | if not self.modal_polling_active: |
| | return |
| | if not self.settings_vars["modal_auto_poll"].get(): |
| | self.modal_polling_active = False |
| | return |
| |
|
| | job_id = self.settings_vars["modal_job_id"].get().strip() |
| | if not job_id: |
| | self.modal_polling_active = False |
| | return |
| |
|
| | self._modal_poll_status() |
| | poll_interval_seconds = self.settings_vars["modal_poll_interval_seconds"].get() |
| | self.master.after(max(1, poll_interval_seconds) * 1000, self._modal_auto_poll_loop) |
| |
|
| | def _modal_poll_status(self): |
| | job_id = self.settings_vars["modal_job_id"].get().strip() |
| | if not job_id: |
| | print_message("No Modal job ID set. Submit a remote run first.") |
| | return |
| | print_message(f"Polling Modal status for job {job_id}...") |
| |
|
| | def background_poll(): |
| | self._build_modal_env() |
| | status_fn = self._get_modal_function("get_job_status") |
| | log_tail_fn = self._get_modal_function("get_job_log_tail") |
| |
|
| | status_payload = status_fn.remote(job_id=job_id) |
| | max_chars = self.settings_vars["modal_log_tail_chars"].get() |
| | log_payload = log_tail_fn.remote(job_id=job_id, max_chars=max_chars) |
| |
|
| | assert isinstance(status_payload, dict), "get_job_status returned a non-dict response." |
| | if "function_call_id" in status_payload and status_payload["function_call_id"]: |
| | self.settings_vars["modal_call_id"].set(status_payload["function_call_id"]) |
| |
|
| | self.full_args.modal_last_status = status_payload |
| | status_value = status_payload["status"] if "status" in status_payload else "UNKNOWN" |
| | phase_value = status_payload["phase"] if "phase" in status_payload else "N/A" |
| | heartbeat_value = status_payload["last_heartbeat_utc"] if "last_heartbeat_utc" in status_payload else "N/A" |
| | heartbeat_age = status_payload["heartbeat_age_seconds"] if "heartbeat_age_seconds" in status_payload else None |
| | error_value = status_payload["error"] if "error" in status_payload else None |
| | heartbeat_age_text = "N/A" if heartbeat_age is None else f"{heartbeat_age:.1f}s" |
| | print_message( |
| | f"Modal Status: {status_value}\n" |
| | f"Phase: {phase_value}\n" |
| | f"Last Heartbeat: {heartbeat_value}\n" |
| | f"Heartbeat Age: {heartbeat_age_text}" |
| | ) |
| | if error_value: |
| | print_message(f"Failure Reason: {error_value}") |
| |
|
| | if isinstance(log_payload, dict) and "log_tail" in log_payload and log_payload["log_tail"]: |
| | print_message(f"Latest Logs (tail):\n{log_payload['log_tail']}") |
| |
|
| | if status_value in ["SUCCESS", "FAILED", "TERMINATED", "TIMEOUT"]: |
| | self.modal_polling_active = False |
| | print_done() |
| |
|
| | self.run_in_background(background_poll) |
| |
|
| | def _modal_cancel_run(self): |
| | function_call_id = self.settings_vars["modal_call_id"].get().strip() |
| | if not function_call_id: |
| | print_message("No Modal call ID set. Poll status or submit a run first.") |
| | return |
| | job_id = self.settings_vars["modal_job_id"].get().strip() |
| | print_message(f"Cancelling Modal run {function_call_id}...") |
| | self.modal_polling_active = False |
| |
|
| | def background_cancel(): |
| | self._build_modal_env() |
| | cancel_fn = self._get_modal_function("cancel_protify_job") |
| | if job_id: |
| | result = cancel_fn.remote(function_call_id=function_call_id, job_id=job_id) |
| | else: |
| | result = cancel_fn.remote(function_call_id=function_call_id, job_id=None) |
| | print_message(f"Cancel result: {result}") |
| | print_done() |
| |
|
| | self.run_in_background(background_cancel) |
| |
|
| | def _modal_fetch_artifacts(self): |
| | job_id = self.settings_vars["modal_job_id"].get().strip() |
| | if not job_id: |
| | print_message("No Modal job ID set. Submit a run first.") |
| | return |
| | print_message(f"Fetching Modal artifacts for job {job_id}...") |
| |
|
| | def background_fetch(): |
| | self._build_modal_env() |
| | results_fn = self._get_modal_function("get_results") |
| | result_payload = results_fn.remote(job_id=job_id) |
| | assert isinstance(result_payload, dict), "get_results returned a non-dict response." |
| | assert "success" in result_payload, "get_results response missing success field." |
| | assert result_payload["success"], f"Modal get_results failed: {result_payload}" |
| |
|
| | output_dir_raw = self.settings_vars["modal_artifacts_dir"].get().strip() or "modal_artifacts" |
| | home_dir = self.settings_vars["home_dir"].get().strip() or os.getcwd() |
| | if os.path.isabs(output_dir_raw): |
| | output_dir = output_dir_raw |
| | else: |
| | output_dir = os.path.abspath(os.path.join(home_dir, output_dir_raw)) |
| | job_dir = os.path.join(output_dir, job_id) |
| | os.makedirs(job_dir, exist_ok=True) |
| |
|
| | text_file_count = 0 |
| | image_file_count = 0 |
| |
|
| | files_payload = result_payload["files"] if "files" in result_payload else {} |
| | for rel_path in files_payload: |
| | local_path = os.path.join(job_dir, rel_path.replace("/", os.sep)) |
| | local_parent = os.path.dirname(local_path) |
| | os.makedirs(local_parent, exist_ok=True) |
| | with open(local_path, "w", encoding="utf-8") as file: |
| | file.write(files_payload[rel_path]) |
| | text_file_count += 1 |
| |
|
| | images_payload = result_payload["images"] if "images" in result_payload else {} |
| | for rel_path in images_payload: |
| | image_info = images_payload[rel_path] |
| | if "data" not in image_info: |
| | continue |
| | local_path = os.path.join(job_dir, rel_path.replace("/", os.sep)) |
| | local_parent = os.path.dirname(local_path) |
| | os.makedirs(local_parent, exist_ok=True) |
| | image_bytes = base64.b64decode(image_info["data"]) |
| | with open(local_path, "wb") as file: |
| | file.write(image_bytes) |
| | image_file_count += 1 |
| |
|
| | metadata_path = os.path.join(job_dir, "modal_fetch_summary.json") |
| | with open(metadata_path, "w", encoding="utf-8") as file: |
| | json.dump(result_payload, file, indent=2) |
| |
|
| | print_message( |
| | f"Saved Modal artifacts to {job_dir}\n" |
| | f"Text files: {text_file_count}\n" |
| | f"Images: {image_file_count}" |
| | ) |
| | print_done() |
| |
|
| | self.run_in_background(background_fetch) |
| |
|
| | def _session_start(self): |
| | print_message("Starting Protify session...") |
| | |
| | hf_token = self.settings_vars["huggingface_token"].get() |
| | synthyra_api_key = self.settings_vars["synthyra_api_key"].get() |
| | wandb_api_key = self.settings_vars["wandb_api_key"].get() |
| | modal_api_key = self.settings_vars["modal_api_key"].get().strip() |
| | modal_token_id = self.settings_vars["modal_token_id"].get().strip() |
| | modal_token_secret = self.settings_vars["modal_token_secret"].get().strip() |
| |
|
| | def background_login(): |
| | local_modal_token_id = modal_token_id |
| | local_modal_token_secret = modal_token_secret |
| | if modal_api_key and ((not local_modal_token_id) or (not local_modal_token_secret)): |
| | local_modal_token_id, local_modal_token_secret = parse_modal_api_key(modal_api_key) |
| |
|
| | if hf_token: |
| | from huggingface_hub import login |
| | login(hf_token) |
| | print_message('Logged in to Hugging Face') |
| | if wandb_api_key: |
| | try: |
| | import wandb |
| | wandb.login(key=wandb_api_key) |
| | print_message('Logged in to Weights & Biases') |
| | except Exception as error: |
| | print_message(f'W&B login failed: {error}') |
| | if synthyra_api_key: |
| | print_message('Synthyra API not integrated yet') |
| | |
| | self.full_args.hf_username = self.settings_vars["huggingface_username"].get() |
| | self.full_args.hf_token = hf_token |
| | self.full_args.synthyra_api_key = synthyra_api_key |
| | self.full_args.wandb_api_key = wandb_api_key |
| | self.full_args.modal_api_key = modal_api_key if modal_api_key else None |
| | self.full_args.modal_token_id = local_modal_token_id if local_modal_token_id else None |
| | self.full_args.modal_token_secret = local_modal_token_secret if local_modal_token_secret else None |
| | self.full_args.home_dir = self.settings_vars["home_dir"].get() |
| | self.full_args.model_dtype = self._selected_model_dtype() |
| | self.full_args.use_xformers = self.settings_vars["use_xformers"].get() |
| | self.full_args.num_runs = self.settings_vars["num_runs"].get() |
| | self.full_args.use_wandb_hyperopt = self.settings_vars["use_wandb_hyperopt"].get() |
| | self.full_args.wandb_project = self.settings_vars["wandb_project"].get().strip() or "Protify" |
| | wandb_entity = self.settings_vars["wandb_entity"].get().strip() |
| | self.full_args.wandb_entity = wandb_entity if wandb_entity else None |
| | self.full_args.sweep_config_path = self.settings_vars["sweep_config_path"].get().strip() or "yamls/sweep.yaml" |
| | self.full_args.sweep_count = self.settings_vars["sweep_count"].get() |
| | self.full_args.sweep_method = self.settings_vars["sweep_method"].get() |
| | self.full_args.sweep_metric_cls = self.settings_vars["sweep_metric_cls"].get().strip() or "eval_loss" |
| | self.full_args.sweep_metric_reg = self.settings_vars["sweep_metric_reg"].get().strip() or "eval_loss" |
| | self.full_args.sweep_goal = self.settings_vars["sweep_goal"].get() |
| | self.full_args.score_only = self.settings_vars["score_only"].get() |
| | self.full_args.aa_to_dna = self.settings_vars["aa_to_dna"].get() |
| | self.full_args.aa_to_rna = self.settings_vars["aa_to_rna"].get() |
| | self.full_args.dna_to_aa = self.settings_vars["dna_to_aa"].get() |
| | self.full_args.rna_to_aa = self.settings_vars["rna_to_aa"].get() |
| | self.full_args.codon_to_aa = self.settings_vars["codon_to_aa"].get() |
| | self.full_args.aa_to_codon = self.settings_vars["aa_to_codon"].get() |
| | self.full_args.random_pair_flipping = self.settings_vars["random_pair_flipping"].get() |
| | self.full_args.data_dirs = [] |
| |
|
| | if self.full_args.modal_token_id: |
| | os.environ["MODAL_TOKEN_ID"] = self.full_args.modal_token_id |
| | if self.full_args.modal_token_secret: |
| | os.environ["MODAL_TOKEN_SECRET"] = self.full_args.modal_token_secret |
| |
|
| | if self.full_args.use_xformers: |
| | os.environ["_USE_XFORMERS"] = "1" |
| | elif "_USE_XFORMERS" in os.environ: |
| | del os.environ["_USE_XFORMERS"] |
| | |
| | |
| | hf_home_value = self.settings_vars["hf_home"].get().strip() |
| | self.full_args.hf_home = hf_home_value if hf_home_value else None |
| |
|
| | def _make_true_dir(path): |
| | true_path = os.path.join(self.full_args.home_dir, path) |
| | os.makedirs(true_path, exist_ok=True) |
| | return true_path |
| |
|
| | self.full_args.log_dir = _make_true_dir(self.settings_vars["log_dir"].get()) |
| | self.full_args.results_dir = _make_true_dir(self.settings_vars["results_dir"].get()) |
| | self.full_args.model_save_dir = _make_true_dir(self.settings_vars["model_save_dir"].get()) |
| | self.full_args.plots_dir = _make_true_dir(self.settings_vars["plots_dir"].get()) |
| | self.full_args.embedding_save_dir = _make_true_dir(self.settings_vars["embedding_save_dir"].get()) |
| | self.full_args.download_dir = _make_true_dir(self.settings_vars["download_dir"].get()) |
| |
|
| | self.full_args.replay_path = None |
| | self.logger_args = SimpleNamespace(**self.full_args.__dict__) |
| | self.start_log_gui() |
| |
|
| | print_message(f"Session and logging started for id {self.random_id}") |
| | print_done() |
| | |
| | self.run_in_background(background_login) |
| |
|
| | def _create_probe_args(self): |
| | print_message("Configuring probe...") |
| | |
| | |
| | self.full_args.probe_type = self.settings_vars["probe_type"].get() |
| | self.full_args.tokenwise = self.settings_vars["tokenwise"].get() |
| | self.full_args.pre_ln = self.settings_vars["pre_ln"].get() |
| | self.full_args.n_layers = self.settings_vars["n_layers"].get() |
| | self.full_args.hidden_size = self.settings_vars["hidden_size"].get() |
| | self.full_args.dropout = self.settings_vars["dropout"].get() |
| | |
| | self.full_args.transformer_hidden_size = self.settings_vars["transformer_hidden_size"].get() |
| | self.full_args.classifier_size = self.settings_vars["classifier_size"].get() |
| | self.full_args.classifier_dropout = self.settings_vars["classifier_dropout"].get() |
| | self.full_args.n_heads = self.settings_vars["n_heads"].get() |
| | self.full_args.rotary = self.settings_vars["rotary"].get() |
| | |
| | pooling_str = self.settings_vars["probe_pooling_types"].get().strip() |
| | self.full_args.probe_pooling_types = [p.strip() for p in pooling_str.split(",") if p.strip()] |
| | |
| | self.full_args.transformer_dropout = self.settings_vars["transformer_dropout"].get() |
| | self.full_args.token_attention = self.settings_vars["token_attention"].get() |
| | self.full_args.use_bias = self.settings_vars["use_bias"].get() |
| | self.full_args.add_token_ids = self.settings_vars["add_token_ids"].get() |
| | |
| | self.full_args.sim_type = self.settings_vars["sim_type"].get() |
| | self.full_args.save_model = self.settings_vars["save_model"].get() |
| | self.full_args.production_model = self.settings_vars["production_model"].get() |
| | |
| | self.full_args.lora = self.settings_vars["lora"].get() |
| | self.full_args.lora_r = self.settings_vars["lora_r"].get() |
| | self.full_args.lora_alpha = self.settings_vars["lora_alpha"].get() |
| | self.full_args.lora_dropout = self.settings_vars["lora_dropout"].get() |
| | |
| | |
| | self.probe_args = ProbeArguments(**self.full_args.__dict__) |
| | |
| | |
| | args_dict = {k: v for k, v in self.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()} |
| | self.logger_args = SimpleNamespace(**args_dict) |
| | self._write_args() |
| | |
| | print_message("Probe configuration saved") |
| | print_done() |
| |
|
| | def _run_trainer(self): |
| | print_message("Starting training...") |
| | |
| | |
| | self.full_args.hybrid_probe = self.settings_vars["hybrid_probe"].get() |
| | self.full_args.full_finetuning = self.settings_vars["full_finetuning"].get() |
| | self.full_args.num_epochs = self.settings_vars["num_epochs"].get() |
| | self.full_args.probe_batch_size = self.settings_vars["probe_batch_size"].get() |
| | self.full_args.base_batch_size = self.settings_vars["base_batch_size"].get() |
| | self.full_args.probe_grad_accum = self.settings_vars["probe_grad_accum"].get() |
| | self.full_args.base_grad_accum = self.settings_vars["base_grad_accum"].get() |
| | self.full_args.lr = self.settings_vars["lr"].get() |
| | self.full_args.weight_decay = self.settings_vars["weight_decay"].get() |
| | self.full_args.patience = self.settings_vars["patience"].get() |
| | self.full_args.seed = self.settings_vars["seed"].get() |
| | self.full_args.read_scaler = self.settings_vars["read_scaler"].get() |
| | self.full_args.deterministic = self.settings_vars["deterministic"].get() |
| | self.full_args.num_runs = self.settings_vars["num_runs"].get() |
| | self.full_args.use_wandb_hyperopt = self.settings_vars["use_wandb_hyperopt"].get() |
| | self.full_args.wandb_project = self.settings_vars["wandb_project"].get().strip() or "Protify" |
| | wandb_entity = self.settings_vars["wandb_entity"].get().strip() |
| | self.full_args.wandb_entity = wandb_entity if wandb_entity else None |
| | self.full_args.sweep_config_path = self.settings_vars["sweep_config_path"].get().strip() or "yamls/sweep.yaml" |
| | self.full_args.sweep_count = self.settings_vars["sweep_count"].get() |
| | self.full_args.sweep_method = self.settings_vars["sweep_method"].get() |
| | self.full_args.sweep_metric_cls = self.settings_vars["sweep_metric_cls"].get().strip() or "eval_loss" |
| | self.full_args.sweep_metric_reg = self.settings_vars["sweep_metric_reg"].get().strip() or "eval_loss" |
| | self.full_args.sweep_goal = self.settings_vars["sweep_goal"].get() |
| | self.full_args.use_xformers = self.settings_vars["use_xformers"].get() |
| | if self.full_args.use_xformers: |
| | os.environ["_USE_XFORMERS"] = "1" |
| | elif "_USE_XFORMERS" in os.environ: |
| | del os.environ["_USE_XFORMERS"] |
| | |
| | |
| | self.trainer_args = TrainerArguments(**self.full_args.__dict__) |
| | |
| | |
| | args_dict = {k: v for k, v in self.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()} |
| | self.logger_args = SimpleNamespace(**args_dict) |
| | self._write_args() |
| | |
| | def background_train(): |
| | if self.full_args.use_wandb_hyperopt: |
| | if not self.full_args.full_finetuning: |
| | self.save_embeddings_to_disk() |
| | HyperoptModule.run_wandb_hyperopt(self) |
| | elif self.full_args.full_finetuning: |
| | self.run_full_finetuning() |
| | elif self.full_args.hybrid_probe: |
| | self.run_hybrid_probes() |
| | else: |
| | self.run_nn_probes() |
| | print_done() |
| | |
| | self.run_in_background(background_train) |
| |
|
| | def _run_proteingym(self): |
| | print_message("Starting ProteinGym...") |
| | |
| | |
| | self.full_args.proteingym = self.settings_vars["proteingym"].get() |
| | dms_ids_str = self.settings_vars["dms_ids"].get().strip() |
| | if dms_ids_str == "all": |
| | self.full_args.dms_ids = ["all"] |
| | else: |
| | self.full_args.dms_ids = dms_ids_str.split() |
| | |
| | self.full_args.mode = self.settings_vars["mode"].get() |
| | self.full_args.scoring_method = self.settings_vars["scoring_method"].get() |
| | self.full_args.scoring_window = self.settings_vars["scoring_window"].get() |
| | self.full_args.pg_batch_size = self.settings_vars["pg_batch_size"].get() |
| | self.full_args.compare_scoring_methods = self.settings_vars["compare_scoring_methods"].get() |
| | self.full_args.score_only = self.settings_vars["score_only"].get() |
| | |
| | |
| | args_dict = {k: v for k, v in self.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()} |
| | self.logger_args = SimpleNamespace(**args_dict) |
| | self._write_args() |
| | |
| | def background_proteingym(): |
| | if self.full_args.compare_scoring_methods and self.full_args.proteingym: |
| | print_message("Running scoring method comparison...") |
| | dms_ids = expand_dms_ids_all(self.full_args.dms_ids, mode=self.full_args.mode) |
| | model_names = self.full_args.model_names |
| | |
| | if len(model_names) == 0: |
| | print_message("Error: No models selected for comparison") |
| | return |
| |
|
| | output_csv = os.path.join(self.full_args.results_dir, 'scoring_methods_comparison.csv') |
| | |
| | compare_scoring_methods( |
| | model_names=model_names, |
| | device=None, |
| | methods=None, |
| | dms_ids=dms_ids, |
| | progress=True, |
| | output_csv=output_csv |
| | ) |
| | print_message(f"Scoring method comparison complete. Results saved to {output_csv}") |
| | |
| | elif self.full_args.proteingym: |
| | self.run_proteingym_zero_shot() |
| | |
| | print_done() |
| | |
| | self.run_in_background(background_proteingym) |
| |
|
| | def _run_scikit(self): |
| | print_message("Starting Scikit-learn models...") |
| | assert "datasets" in self.__dict__, "Datasets are not loaded. Run the Data tab first." |
| | assert len(self.datasets) > 0, "No datasets are loaded. Run the Data tab first." |
| | assert "all_seqs" in self.__dict__, "Sequences are not loaded. Run the Data tab first." |
| | assert len(self.all_seqs) > 0, "No sequences are loaded. Run the Data tab first." |
| | |
| | |
| | selected_indices = self.model_listbox.curselection() |
| | selected_models = [self.model_listbox.get(i) for i in selected_indices] |
| | if not selected_models: |
| | selected_models = standard_models |
| | self.full_args.model_names = selected_models |
| | self.full_args.model_paths = None |
| | self.full_args.model_types = None |
| | self.full_args.model_dtype = self._selected_model_dtype() |
| | self.full_args.use_xformers = self.settings_vars["use_xformers"].get() |
| | self.model_args = BaseModelArguments(**self.full_args.__dict__) |
| |
|
| | |
| | pooling_str = self.settings_vars["embedding_pooling_types"].get().strip() |
| | pooling_list = [p.strip() for p in pooling_str.split(",") if p.strip()] |
| | dtype_val = self._selected_embed_dtype() |
| |
|
| | self.full_args.embedding_batch_size = self.settings_vars["batch_size"].get() |
| | self.full_args.embedding_num_workers = self.settings_vars["num_workers"].get() |
| | self.full_args.download_embeddings = self.settings_vars["download_embeddings"].get() |
| | self.full_args.matrix_embed = self.settings_vars["matrix_embed"].get() |
| | self.full_args.embedding_pooling_types = pooling_list |
| | self.full_args.save_embeddings = True |
| | self.full_args.embed_dtype = dtype_val |
| | self.full_args.sql = self.settings_vars["sql"].get() |
| | self._sql = self.full_args.sql |
| | self._full = self.full_args.matrix_embed |
| | self.embedding_args = EmbeddingArguments(**self.full_args.__dict__) |
| |
|
| | |
| | self.full_args.use_scikit = self.settings_vars["use_scikit"].get() |
| | self.full_args.scikit_n_iter = self.settings_vars["scikit_n_iter"].get() |
| | self.full_args.scikit_cv = self.settings_vars["scikit_cv"].get() |
| | self.full_args.scikit_random_state = self.settings_vars["scikit_random_state"].get() |
| | scikit_model_name = self.settings_vars["scikit_model_name"].get().strip() |
| | if scikit_model_name: |
| | self.full_args.scikit_model_name = scikit_model_name |
| | else: |
| | self.full_args.scikit_model_name = None |
| | self.full_args.n_jobs = self.settings_vars["n_jobs"].get() |
| | self.full_args.n_iter = self.full_args.scikit_n_iter |
| | self.full_args.cv = self.full_args.scikit_cv |
| | self.full_args.random_state = self.full_args.scikit_random_state |
| | self.full_args.model_name = self.full_args.scikit_model_name |
| | self.scikit_args = self._build_scikit_args() |
| | |
| | |
| | args_dict = {k: v for k, v in self.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()} |
| | self.logger_args = SimpleNamespace(**args_dict) |
| | self._write_args() |
| | |
| | def background_scikit(): |
| | self.save_embeddings_to_disk() |
| | self.run_scikit_scheme() |
| | print_done() |
| | |
| | self.run_in_background(background_scikit) |
| |
|
| | def _select_models(self): |
| | print_message("Selecting models...") |
| | |
| | selected_indices = self.model_listbox.curselection() |
| | selected_models = [self.model_listbox.get(i) for i in selected_indices] |
| |
|
| | |
| | if not selected_models: |
| | selected_models = standard_models |
| |
|
| | |
| | self.full_args.model_names = selected_models |
| | self.full_args.model_paths = None |
| | self.full_args.model_types = None |
| | self.full_args.model_dtype = self._selected_model_dtype() |
| | self.full_args.use_xformers = self.settings_vars["use_xformers"].get() |
| | if self.full_args.use_xformers: |
| | os.environ["_USE_XFORMERS"] = "1" |
| | elif "_USE_XFORMERS" in os.environ: |
| | del os.environ["_USE_XFORMERS"] |
| | print_message(self.full_args.model_names) |
| | |
| | self.model_args = BaseModelArguments(**self.full_args.__dict__) |
| |
|
| | print("Model Args:") |
| | for k, v in self.model_args.__dict__.items(): |
| | if k != 'model_names': |
| | print(f"{k}:\n{v}") |
| | print("=========================\n") |
| | args_dict = {k: v for k, v in self.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()} |
| | self.logger_args = SimpleNamespace(**args_dict) |
| | self._write_args() |
| | print_done() |
| |
|
| | def _get_data(self): |
| | print_message("=== Getting Data ===") |
| | print_message("Loading and preparing datasets...") |
| | |
| | |
| | selected_indices = self.data_listbox.curselection() |
| | selected_datasets = [self.data_listbox.get(i) for i in selected_indices] |
| | data_dirs_str = self.settings_vars["data_dirs"].get().strip() |
| | data_dirs = [path.strip() for path in data_dirs_str.split(",") if path.strip()] |
| | |
| | if (not selected_datasets) and (len(data_dirs) == 0): |
| | selected_datasets = standard_data_benchmark |
| | |
| | def background_get_data(): |
| | |
| | self.full_args.data_names = selected_datasets |
| | self.full_args.data_dirs = data_dirs |
| | self.full_args.max_length = self.settings_vars["max_length"].get() |
| | self.full_args.trim = self.settings_vars["trim"].get() |
| | self.full_args.delimiter = self.settings_vars["delimiter"].get() |
| | self.full_args.col_names = [name.strip() for name in self.settings_vars["col_names"].get().split(",") if name.strip()] |
| | self.full_args.aa_to_dna = self.settings_vars["aa_to_dna"].get() |
| | self.full_args.aa_to_rna = self.settings_vars["aa_to_rna"].get() |
| | self.full_args.dna_to_aa = self.settings_vars["dna_to_aa"].get() |
| | self.full_args.rna_to_aa = self.settings_vars["rna_to_aa"].get() |
| | self.full_args.codon_to_aa = self.settings_vars["codon_to_aa"].get() |
| | self.full_args.aa_to_codon = self.settings_vars["aa_to_codon"].get() |
| | self.full_args.random_pair_flipping = self.settings_vars["random_pair_flipping"].get() |
| | |
| | |
| | multi_column_str = self.settings_vars["multi_column"].get().strip() |
| | if multi_column_str: |
| | self.full_args.multi_column = multi_column_str.split() |
| | else: |
| | self.full_args.multi_column = None |
| |
|
| | |
| | self._max_length = self.full_args.max_length |
| | self._trim = self.full_args.trim |
| | self._delimiter = self.full_args.delimiter |
| | self._col_names = self.full_args.col_names |
| | self._multi_column = self.full_args.multi_column |
| | self._aa_to_dna = self.full_args.aa_to_dna |
| | self._aa_to_rna = self.full_args.aa_to_rna |
| | self._dna_to_aa = self.full_args.dna_to_aa |
| | self._rna_to_aa = self.full_args.rna_to_aa |
| | self._codon_to_aa = self.full_args.codon_to_aa |
| | self._aa_to_codon = self.full_args.aa_to_codon |
| |
|
| | |
| | self.data_args = DataArguments(**self.full_args.__dict__) |
| | args_dict = {k: v for k, v in self.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()} |
| | self.logger_args = SimpleNamespace(**args_dict) |
| |
|
| | self._write_args() |
| | self.get_datasets() |
| | print_message("Data downloaded and stored") |
| | print_done() |
| | |
| | self.run_in_background(background_get_data) |
| |
|
| | def _get_embeddings(self): |
| | if not self.all_seqs: |
| | print_message('Sequences are not loaded yet. Please run the data tab first.') |
| | return |
| | |
| | |
| | print_message("Computing embeddings...") |
| | pooling_str = self.settings_vars["embedding_pooling_types"].get().strip() |
| | pooling_list = [p.strip() for p in pooling_str.split(",") if p.strip()] |
| | dtype_val = self._selected_embed_dtype() |
| | |
| | def background_get_embeddings(): |
| | |
| | self.full_args.all_seqs = self.all_seqs |
| | self.full_args.model_dtype = self._selected_model_dtype() |
| | self.full_args.embedding_batch_size = self.settings_vars["batch_size"].get() |
| | self.full_args.embedding_num_workers = self.settings_vars["num_workers"].get() |
| | self.full_args.download_embeddings = self.settings_vars["download_embeddings"].get() |
| | self.full_args.matrix_embed = self.settings_vars["matrix_embed"].get() |
| | self.full_args.embedding_pooling_types = pooling_list |
| | self.full_args.save_embeddings = True |
| | self.full_args.embed_dtype = dtype_val |
| | self.full_args.sql = self.settings_vars["sql"].get() |
| | self._sql = self.full_args.sql |
| | self._full = self.full_args.matrix_embed |
| | |
| | self.embedding_args = EmbeddingArguments(**self.full_args.__dict__) |
| | args_dict = {k: v for k, v in self.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()} |
| | self.logger_args = SimpleNamespace(**args_dict) |
| | self._write_args() |
| | |
| | print_message("Saving embeddings to disk") |
| | self.save_embeddings_to_disk() |
| | print_message("Embeddings saved to disk") |
| | print_done() |
| | |
| | self.run_in_background(background_get_embeddings) |
| |
|
| | def _browse_replay_log(self): |
| | filename = filedialog.askopenfilename( |
| | title="Select Replay Log", |
| | filetypes=(("Txt files", "*.txt"), ("All files", "*.*")) |
| | ) |
| | if filename: |
| | self.settings_vars["replay_path"].set(filename) |
| |
|
| | def _start_replay(self): |
| | replay_path = self.settings_vars["replay_path"].get() |
| | if not replay_path: |
| | print_message("Please select a replay log file first") |
| | return |
| | |
| | print_message("Starting replay from log file...") |
| | |
| | def background_replay(): |
| | from logger import LogReplayer |
| | replayer = LogReplayer(replay_path) |
| | replay_args = replayer.parse_log() |
| | replay_args.replay_path = replay_path |
| | |
| | |
| | main = MainProcess(replay_args, GUI=False) |
| | for k, v in main.full_args.__dict__.items(): |
| | print(f"{k}:\t{v}") |
| | |
| | |
| | replayer.run_replay(main) |
| | print_done() |
| | |
| | self.run_in_background(background_replay) |
| | |
| | def _browse_results_file(self): |
| | filename = filedialog.askopenfilename( |
| | title="Select Results File", |
| | filetypes=(("TSV files", "*.tsv"), ("All files", "*.*")) |
| | ) |
| | if filename: |
| | self.settings_vars["results_file"].set(filename) |
| | |
| | self.settings_vars["use_current_run"].set(False) |
| | |
| | def _generate_plots(self): |
| | print_message("Generating visualization plots...") |
| | |
| | |
| | results_file = None |
| | |
| | if self.settings_vars["use_current_run"].get() and hasattr(self, 'random_id'): |
| | |
| | results_file = os.path.join(self.settings_vars["results_dir"].get(), f"{self.random_id}.tsv") |
| | print_message(f"Using current run results: {results_file}") |
| | elif self.settings_vars["results_file"].get(): |
| | |
| | results_file = self.settings_vars["results_file"].get() |
| | print_message(f"Using selected results file: {results_file}") |
| | elif self.settings_vars["result_id"].get(): |
| | |
| | result_id = self.settings_vars["result_id"].get() |
| | results_file = os.path.join(self.settings_vars["results_dir"].get(), f"{result_id}.tsv") |
| | print_message(f"Using results file for ID {result_id}: {results_file}") |
| | else: |
| | print_message("No results file specified. Please enter a Result ID, browse for a file, or complete a run first.") |
| | return |
| | |
| | |
| | if not os.path.exists(results_file): |
| | print_message(f"Results file not found: {results_file}") |
| | return |
| | |
| | |
| | output_dir = self.settings_vars["viz_output_dir"].get() |
| | def background_generate_plots(): |
| | |
| | print_message(f"Generating plots in {output_dir}...") |
| | create_plots(results_file, output_dir) |
| | print_message("Plots generated successfully!") |
| | print_done() |
| | |
| | self.run_in_background(background_generate_plots) |
| |
|
| |
|
| | def main(): |
| | root = tk.Tk() |
| | app = GUI(root) |
| | print_title("Protify") |
| | root.mainloop() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|