Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import huggingface_hub | |
| import onnxruntime as rt | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import csv | |
| import datetime | |
| import time | |
| # --- Constants --- | |
| LOG_FILE = "processing_log.csv" | |
| LOG_HEADER = [ | |
| "Timestamp", "Repository", "Model Filename", "Model Size (MB)", | |
| "Image Resolution (WxH)", "Execution Provider", "Processing Time (s)" | |
| ] | |
| # Global variables for model and providers | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| model_repo_default = "skytnt/anime-seg" | |
| # --- Logging Functions --- | |
| def initialize_log_file(): | |
| """Creates the log file and writes the header if it doesn't exist.""" | |
| if not os.path.exists(LOG_FILE): | |
| try: | |
| with open(LOG_FILE, 'w', newline='', encoding='utf-8') as f: | |
| writer = csv.writer(f) | |
| writer.writerow(LOG_HEADER) | |
| print(f"Log file initialized: {LOG_FILE}") | |
| except IOError as e: | |
| print(f"Error initializing log file {LOG_FILE}: {e}") | |
| def log_processing_event(timestamp, repo, model_filename, model_size_mb, | |
| resolution, provider, processing_time): | |
| """Appends a processing event to the CSV log file.""" | |
| try: | |
| with open(LOG_FILE, 'a', newline='', encoding='utf-8') as f: | |
| writer = csv.writer(f) | |
| writer.writerow([ | |
| timestamp, repo, model_filename, f"{model_size_mb:.2f}", | |
| resolution, provider, f"{processing_time:.4f}" | |
| ]) | |
| except IOError as e: | |
| print(f"Error writing to log file {LOG_FILE}: {e}") | |
| except Exception as e: | |
| print(f"An unexpected error occurred during logging: {e}") | |
| def read_log_file(): | |
| """Reads the entire log file content.""" | |
| try: | |
| if not os.path.exists(LOG_FILE): | |
| return "Log file not found." | |
| with open(LOG_FILE, 'r', encoding='utf-8') as f: | |
| # Read all lines and join them for display | |
| return "".join(f.readlines()) | |
| # Alternatively, for cleaner display of CSV in a textbox: | |
| # reader = csv.reader(f) | |
| # rows = list(reader) | |
| # # Format header and rows nicely | |
| # header = rows[0] | |
| # data_rows = rows[1:] | |
| # formatted_rows = [", ".join(header)] # Join header elements | |
| # for row in data_rows: | |
| # formatted_rows.append(", ".join(row)) # Join data elements | |
| # return "\n".join(formatted_rows) | |
| except IOError as e: | |
| print(f"Error reading log file {LOG_FILE}: {e}") | |
| return f"Error reading log file: {e}" | |
| except Exception as e: | |
| print(f"An unexpected error occurred reading log file: {e}") | |
| return f"Error reading log file: {e}" | |
| # --- Helper Functions --- | |
| def get_model_details_from_choice(choice_string: str) -> tuple[str, float | None]: | |
| """ | |
| Extracts filename and size (MB) from the dropdown choice string. | |
| Returns (filename, size_mb) or (filename, None) if size is not parseable. | |
| """ | |
| if not choice_string: | |
| return "", None | |
| parts = choice_string.split(" (") | |
| filename = parts[0] | |
| size_mb = None | |
| if len(parts) > 1 and parts[1].endswith(" MB)"): | |
| try: | |
| size_str = parts[1].replace(" MB)", "") | |
| size_mb = float(size_str) | |
| except ValueError: | |
| pass # Size couldn't be parsed | |
| return filename, size_mb | |
| # --- Model Loading and UI Functions (Mostly unchanged, modifications marked) --- | |
| def update_onnx_files(repo: str): | |
| """ | |
| Lists .onnx files in the Hugging Face repository and updates the Dropdown with file sizes. | |
| """ | |
| onnx_files_with_size = [] | |
| try: | |
| api = huggingface_hub.HfApi() | |
| repo_info = api.model_info(repo_id=repo, files_metadata=True) | |
| for file_info in repo_info.siblings: | |
| if file_info.rfilename.endswith('.onnx'): | |
| try: | |
| # Use file_info.size which is in bytes | |
| size_mb = file_info.size / (1024 * 1024) if file_info.size else 0 | |
| onnx_files_with_size.append(f"{file_info.rfilename} ({size_mb:.2f} MB)") | |
| except Exception: | |
| onnx_files_with_size.append(f"{file_info.rfilename} (Size N/A)") | |
| if onnx_files_with_size: | |
| onnx_files_with_size.sort() | |
| return gr.update(choices=onnx_files_with_size, value=onnx_files_with_size[0]) | |
| else: | |
| return gr.update(choices=[], value="", warning=f"No .onnx files found in repository '{repo}'") | |
| except huggingface_hub.utils.RepositoryNotFoundError: | |
| return gr.update(choices=[], value="", error=f"Repository '{repo}' not found or access denied.") | |
| except Exception as e: | |
| print(f"Error fetching repo files for {repo}: {e}") | |
| return gr.update(choices=[], value="", error=f"Error fetching files: {str(e)}") | |
| # Get default choices and filename | |
| default_onnx_files_with_size = [] | |
| default_model_filename = "" | |
| try: | |
| initial_update = update_onnx_files(model_repo_default) | |
| if isinstance(initial_update, gr.update) and initial_update.choices: | |
| default_onnx_files_with_size = initial_update.choices | |
| default_model_filename, _ = get_model_details_from_choice(default_onnx_files_with_size[0]) # Use helper | |
| else: | |
| default_onnx_files_with_size = ["isnetis.onnx (Size N/A)"] | |
| default_model_filename = "isnetis.onnx" | |
| print(f"Warning: Could not fetch initial ONNX files from {model_repo_default}. Using fallback '{default_model_filename}'.") | |
| except Exception as e: | |
| default_onnx_files_with_size = ["isnetis.onnx (Size N/A)"] | |
| default_model_filename = "isnetis.onnx" | |
| print(f"Error during initial model fetch: {e}. Using fallback '{default_model_filename}'.") | |
| # Global variables for current model state | |
| current_model_repo = model_repo_default | |
| current_model_filename = default_model_filename | |
| # Initial download and model load | |
| model_path = None | |
| rmbg_model = None | |
| try: | |
| print(f"Attempting initial download: {current_model_repo}/{current_model_filename}") | |
| if current_model_filename: # Only download if we have a filename | |
| model_path = huggingface_hub.hf_hub_download(current_model_repo, current_model_filename) | |
| rmbg_model = rt.InferenceSession(model_path, providers=providers) | |
| print(f"Initial model loaded successfully: {model_path}") | |
| print(f"Available Execution Providers: {rt.get_available_providers()}") | |
| print(f"Using Provider(s): {rmbg_model.get_providers()}") | |
| else: | |
| print("FATAL: No default model filename determined. Cannot load initial model.") | |
| except Exception as e: | |
| print(f"FATAL: Could not download or load initial model '{current_model_repo}/{current_model_filename}'. Error: {e}") | |
| # --- Inference Functions (Unchanged get_mask, rmbg_fn) --- | |
| def get_mask(img, s=1024): | |
| if rmbg_model is None: | |
| raise gr.Error("Model is not loaded. Please check model selection and update status.") | |
| img_normalized = (img / 255.0).astype(np.float32) | |
| h0, w0 = img.shape[:2] | |
| if h0 >= w0: h, w = (s, int(s * w0 / h0)) | |
| else: h, w = (int(s * h0 / w0), s) | |
| ph, pw = s - h, s - w | |
| img_input = np.zeros([s, s, 3], dtype=np.float32) | |
| resized_img = cv2.resize(img_normalized, (w, h), interpolation=cv2.INTER_AREA) | |
| img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = resized_img | |
| img_input = np.transpose(img_input, (2, 0, 1))[np.newaxis, :] | |
| input_name = rmbg_model.get_inputs()[0].name | |
| mask_output = rmbg_model.run(None, {input_name: img_input})[0][0] | |
| mask_processed = np.transpose(mask_output, (1, 2, 0)) | |
| mask_processed = mask_processed[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] | |
| mask_resized = cv2.resize(mask_processed, (w0, h0), interpolation=cv2.INTER_LINEAR) | |
| if mask_resized.ndim == 2: mask_resized = mask_resized[:, :, np.newaxis] | |
| mask_final = np.clip(mask_resized, 0, 1) | |
| return mask_final | |
| def rmbg_fn(img): | |
| if img is None: raise gr.Error("Please provide an input image.") | |
| mask = get_mask(img) | |
| if img.dtype != np.uint8: img = (img * 255).clip(0, 255).astype(np.uint8) if img.max() <= 1.0 else img.clip(0, 255).astype(np.uint8) | |
| alpha_channel = (mask * 255).astype(np.uint8) | |
| if img.shape[2] == 3: img_out_rgba = np.concatenate([img, alpha_channel], axis=2) | |
| else: img_out_rgba = img.copy(); img_out_rgba[:, :, 3] = alpha_channel[:,:,0] | |
| mask_img_display = (mask * 255).astype(np.uint8).repeat(3, axis=2) | |
| return mask_img_display, img_out_rgba | |
| # --- Model Update Function --- | |
| def update_model(model_repo, model_filename_with_size): | |
| global rmbg_model, current_model_repo, current_model_filename | |
| model_filename, _ = get_model_details_from_choice(model_filename_with_size) # Use helper | |
| if not model_filename: return "Error: No model filename selected or extracted." | |
| if model_repo == current_model_repo and model_filename == current_model_filename: | |
| # Even if it's the same, report the provider being used | |
| current_provider = rmbg_model.get_providers()[0] if rmbg_model else "N/A" | |
| return f"Model already loaded: {current_model_repo}/{current_model_filename}\nUsing Provider: {current_provider}" | |
| try: | |
| print(f"Updating model to: {model_repo}/{model_filename}") | |
| model_path = huggingface_hub.hf_hub_download(model_repo, model_filename) | |
| new_rmbg_model = rt.InferenceSession(model_path, providers=providers) | |
| rmbg_model = new_rmbg_model | |
| current_model_repo = model_repo | |
| current_model_filename = model_filename | |
| active_provider = rmbg_model.get_providers()[0] # Get the provider actually used | |
| print(f"Model updated successfully: {model_path}") | |
| print(f"Using Provider: {active_provider}") | |
| return f"Model updated: {current_model_repo}/{current_model_filename}\nUsing Provider: {active_provider}" | |
| except huggingface_hub.utils.HfHubHTTPError as e: | |
| print(f"Error downloading model: {e}") | |
| return f"Error downloading model: {model_repo}/{model_filename}. ({e.response.status_code})" | |
| except rt.ONNXRuntimeException as e: | |
| print(f"Error loading ONNX model: {e}") | |
| # Attempt to provide more specific feedback if it's a provider issue | |
| if "CUDAExecutionProvider" in str(e): | |
| return f"Error loading ONNX model '{model_filename}'. CUDA unavailable or setup issue? Falling back might require restart or different build. Error: {e}" | |
| return f"Error loading ONNX model '{model_filename}'. Incompatible or corrupted? Error: {e}" | |
| except Exception as e: | |
| print(f"Error updating model: {e}") | |
| return f"Error updating model: {str(e)}" | |
| # --- Main Processing Function (MODIFIED FOR LOGGING) --- | |
| def process_and_update(img, model_repo, model_filename_with_size, history): | |
| global current_model_repo, current_model_filename, rmbg_model | |
| # --- Pre-checks --- | |
| if img is None: | |
| return None, [], history, "generated", "Please upload an image first.", read_log_file() # Return current log | |
| if rmbg_model is None: | |
| return None, [], history, "generated", "ERROR: Model not loaded. Update model first.", read_log_file() # Return current log | |
| selected_model_filename, selected_model_size_mb = get_model_details_from_choice(model_filename_with_size) # Use helper | |
| status_message = "" | |
| # --- Model Update Check --- | |
| if model_repo != current_model_repo or selected_model_filename != current_model_filename: | |
| status_message = update_model(model_repo, model_filename_with_size) | |
| if "Error" in status_message: | |
| return None, [], history, "generated", f"Model Update Failed:\n{status_message}", read_log_file() # Return current log | |
| if rmbg_model is None: | |
| return None, [], history, "generated", "ERROR: Model failed to load after update.", read_log_file() # Return current log | |
| # --- Processing & Logging --- | |
| try: | |
| start_time = time.time() # Start timer | |
| mask_img, generated_img_rgba = rmbg_fn(img) # Run inference | |
| end_time = time.time() # End timer | |
| processing_time = end_time - start_time # Calculate duration | |
| # --- Gather Log Information --- | |
| timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| h, w = img.shape[:2] | |
| resolution = f"{w}x{h}" | |
| # Get the *actually used* provider from the loaded session | |
| active_provider = rmbg_model.get_providers()[0] | |
| # Log the event | |
| log_processing_event( | |
| timestamp=timestamp, | |
| repo=current_model_repo, # Use the confirmed current repo | |
| model_filename=current_model_filename, # Use the confirmed current filename | |
| model_size_mb=selected_model_size_mb if selected_model_size_mb is not None else 0.0, # Use extracted size | |
| resolution=resolution, | |
| provider=active_provider, | |
| processing_time=processing_time | |
| ) | |
| # --- Prepare Outputs --- | |
| new_history = history + [generated_img_rgba] | |
| output_pair = [mask_img, generated_img_rgba] | |
| current_log_content = read_log_file() # Read updated log | |
| status_message = f"{status_message}\nProcessing complete ({processing_time:.2f}s)".strip() | |
| return generated_img_rgba, output_pair, new_history, "generated", status_message, current_log_content | |
| except Exception as e: | |
| print(f"Error during processing: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Still return the log content even if processing fails | |
| return None, [], history, "generated", f"Error during processing: {str(e)}", read_log_file() | |
| # --- UI Interaction Functions (Unchanged toggle_view, clear_all needs slight modification) --- | |
| def toggle_view(view_state, output_pair): | |
| if not output_pair or len(output_pair) != 2: | |
| return None, view_state, "View Mask" if view_state == "generated" else "View Generated" | |
| if view_state == "generated": | |
| return output_pair[0], "mask", "View Generated" | |
| else: | |
| return output_pair[1], "generated", "View Mask" | |
| def clear_all(): | |
| """ Resets inputs, outputs, states, status, but keeps log view """ | |
| # Keeps the log viewer content, as history shouldn't be wiped by clearing inputs | |
| initial_log_content = read_log_file() # Read log to display upon clearing | |
| return None, None, [], [], "generated", "Interface cleared.", "View Mask", [], initial_log_content | |
| # --- Gradio UI Definition --- | |
| if __name__ == "__main__": | |
| initialize_log_file() # Ensure log file exists before launching app | |
| app = gr.Blocks(css=".gradio-container { max-width: 95% !important; }") # Wider | |
| with app: | |
| gr.Markdown("# Image Background Removal (Segmentation) with Logging") | |
| gr.Markdown("Test ONNX models, view performance logs.") | |
| with gr.Row(): | |
| # Left Column: Controls and Input | |
| with gr.Column(scale=2): | |
| with gr.Group(): | |
| gr.Markdown("### Model Selection") | |
| model_repo_input = gr.Textbox(value=model_repo_default, label="Hugging Face Repository") | |
| model_filename_dropdown = gr.Dropdown( | |
| choices=default_onnx_files_with_size, | |
| value=default_onnx_files_with_size[0] if default_onnx_files_with_size else "", | |
| label="ONNX Model File (.onnx)" | |
| ) | |
| update_btn = gr.Button("π Update/Load Model") | |
| model_status_textbox = gr.Textbox(label="Status", value="Initial model loaded." if rmbg_model else "ERROR: Initial model failed to load.", interactive=False, lines=2) | |
| gr.Markdown("#### Source Image") | |
| input_img = gr.Image(label="Upload Image", type="numpy") | |
| with gr.Row(): | |
| run_btn = gr.Button("βΆοΈ Run Background Removal", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Inputs/Outputs") | |
| # Right Column: Output and Logs | |
| with gr.Column(scale=3): | |
| gr.Markdown("#### Output Image") | |
| output_img = gr.Image(label="Output", image_mode="RGBA", format="png", type="numpy") | |
| toggle_btn = gr.Button("View Mask") | |
| gr.Markdown("---") | |
| gr.Markdown("### Processing History") | |
| history_gallery = gr.Gallery(label="Generated Image History", show_label=False, columns=8, object_fit="contain", height="auto") | |
| gr.Markdown("---") | |
| gr.Markdown("### Processing Log (`processing_log.csv`)") | |
| # Use gr.Code for better viewing of CSV/text data | |
| log_display = gr.Code( | |
| value=read_log_file(), # Initial content | |
| label="Log Viewer", | |
| lines=10, | |
| interactive=False | |
| ) | |
| # Optional: Add a manual refresh button if auto-update isn't sufficient | |
| # refresh_log_btn = gr.Button("π Refresh Log View") | |
| # Hidden states | |
| output_pair_state = gr.State([]) | |
| view_state = gr.State("generated") | |
| history_state = gr.State([]) | |
| # --- Event Listeners --- | |
| model_repo_input.submit(fn=update_onnx_files, inputs=model_repo_input, outputs=model_filename_dropdown) | |
| model_repo_input.blur(fn=update_onnx_files, inputs=model_repo_input, outputs=model_filename_dropdown) | |
| update_btn.click(fn=update_model, inputs=[model_repo_input, model_filename_dropdown], outputs=model_status_textbox) | |
| # Run includes updating the log display | |
| run_btn.click( | |
| fn=process_and_update, | |
| inputs=[input_img, model_repo_input, model_filename_dropdown, history_state], | |
| outputs=[output_img, output_pair_state, history_state, view_state, model_status_textbox, log_display] # ADD log_display here | |
| ) | |
| toggle_btn.click(fn=toggle_view, inputs=[view_state, output_pair_state], outputs=[output_img, view_state, toggle_btn]) | |
| # Clear resets inputs/outputs/status, but re-reads log for display | |
| clear_btn.click( | |
| fn=clear_all, | |
| outputs=[input_img, output_img, output_pair_state, history_state, view_state, model_status_textbox, toggle_btn, history_gallery, log_display] # ADD log_display here | |
| ) | |
| # Manual log refresh button (optional, as run/clear update it) | |
| # refresh_log_btn.click(fn=read_log_file, inputs=None, outputs=log_display) | |
| history_state.change(fn=lambda history: history, inputs=history_state, outputs=history_gallery) | |
| app.launch(debug=True) |