import os # Set environment variables for Spaces compatibility os.environ['OMP_NUM_THREADS'] = '1' os.environ['MKL_NUM_THREADS'] = '1' import cv2 import yaml import torch import random import gradio as gr import numpy as np import kagglehub from PIL import Image from glob import glob import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.pyplot as plt from matplotlib import patches from torchvision import transforms as T from ultralytics import YOLO import shutil import tempfile from pathlib import Path import json from io import BytesIO # Try to import spaces for Hugging Face Spaces GPU support try: import spaces ON_SPACES = True except ImportError: ON_SPACES = False # Create a dummy decorator if not on Spaces class spaces: @staticmethod def GPU(duration=60): def decorator(func): return func return decorator # Set Kaggle API credentials from environment variable if os.getenv("KDATA_API"): kaggle_key = os.getenv("KDATA_API") # Parse the key if it's in JSON format if "{" in kaggle_key: key_data = json.loads(kaggle_key) os.environ["KAGGLE_USERNAME"] = key_data.get("username", "") os.environ["KAGGLE_KEY"] = key_data.get("key", "") # Global variables model = None dataset_path = None training_in_progress = False class Visualization: def __init__(self, root, data_types, n_ims, rows, cmap=None): self.n_ims, self.rows = n_ims, rows self.cmap, self.data_types = cmap, data_types self.colors = ["firebrick", "darkorange", "blueviolet"] self.root = root self.get_cls_names() self.get_bboxes() def get_cls_names(self): yaml_path = f"{self.root}/data.yaml" if not os.path.exists(yaml_path): print(f"Warning: {yaml_path} not found") self.class_dict = {} return with open(yaml_path, 'r') as file: data = yaml.safe_load(file) class_names = data.get('names', []) self.class_dict = {index: name for index, name in enumerate(class_names)} # Print class names for debugging if self.class_dict: print(f"Dataset classes: {', '.join(class_names)}") def get_bboxes(self): self.vis_datas, self.analysis_datas, self.im_paths = {}, {}, {} for data_type in self.data_types: all_bboxes, all_analysis_datas = [], {} im_paths = glob(f"{self.root}/{data_type}/images/*") for idx, im_path in enumerate(im_paths): bboxes = [] im_ext = os.path.splitext(im_path)[-1] lbl_path = im_path.replace(im_ext, ".txt") lbl_path = lbl_path.replace(f"{data_type}/images", f"{data_type}/labels") if not os.path.isfile(lbl_path): continue meta_data = open(lbl_path).readlines() for data in meta_data: parts = data.strip().split()[:5] cls_name = self.class_dict[int(parts[0])] bboxes.append([cls_name] + [float(x) for x in parts[1:]]) if cls_name not in all_analysis_datas: all_analysis_datas[cls_name] = 1 else: all_analysis_datas[cls_name] += 1 all_bboxes.append(bboxes) self.vis_datas[data_type] = all_bboxes self.analysis_datas[data_type] = all_analysis_datas self.im_paths[data_type] = im_paths def plot_single(self, im_path, bboxes): or_im = np.array(Image.open(im_path).convert("RGB")) height, width, _ = or_im.shape for bbox in bboxes: class_id, x_center, y_center, w, h = bbox x_min = int((x_center - w / 2) * width) y_min = int((y_center - h / 2) * height) x_max = int((x_center + w / 2) * width) y_max = int((y_center + h / 2) * height) color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) cv2.rectangle(img=or_im, pt1=(x_min, y_min), pt2=(x_max, y_max), color=color, thickness=3) # Add text overlay cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) # OpenCV uses BGR, but PIL expects RGB, and we already loaded as RGB # So no conversion needed return Image.fromarray(or_im) def vis_samples(self, data_type, n_samples=4): if data_type not in self.vis_datas: return None indices = [random.randint(0, len(self.vis_datas[data_type]) - 1) for _ in range(min(n_samples, len(self.vis_datas[data_type])))] figs = [] for idx in indices: im_path = self.im_paths[data_type][idx] bboxes = self.vis_datas[data_type][idx] fig = self.plot_single(im_path, bboxes) figs.append(fig) return figs def data_analysis(self, data_type): if data_type not in self.analysis_datas: return None plt.style.use('default') fig, ax = plt.subplots(figsize=(12, 6)) cls_names = list(self.analysis_datas[data_type].keys()) counts = list(self.analysis_datas[data_type].values()) color_map = {"train": "firebrick", "valid": "darkorange", "test": "blueviolet"} color = color_map.get(data_type, "steelblue") indices = np.arange(len(counts)) bars = ax.bar(indices, counts, 0.7, color=color) ax.set_xlabel("Class Names", fontsize=12) ax.set_xticks(indices) ax.set_xticklabels(cls_names, rotation=45, ha='right') ax.set_ylabel("Data Counts", fontsize=12) ax.set_title(f"{data_type.upper()} Dataset Class Distribution", fontsize=14) for i, (bar, v) in enumerate(zip(bars, counts)): ax.text(bar.get_x() + bar.get_width()/2, v + 1, str(v), ha='center', va='bottom', fontsize=10, color='navy') plt.tight_layout() # Save to BytesIO and convert to PIL Image buf = BytesIO() fig.savefig(buf, format='png', dpi=100, bbox_inches='tight') buf.seek(0) img = Image.open(buf) plt.close(fig) return img def download_dataset(): """Download the dataset using kagglehub""" global dataset_path try: # Create a local directory to store the dataset local_dir = "./xray_dataset" # Download dataset dataset_path = kagglehub.dataset_download("orvile/x-ray-baggage-anomaly-detection") # If the dataset is downloaded to a temporary location, copy it to our local directory if dataset_path != local_dir and os.path.exists(dataset_path): if os.path.exists(local_dir): shutil.rmtree(local_dir) shutil.copytree(dataset_path, local_dir) dataset_path = local_dir return f"Dataset downloaded successfully to: {dataset_path}" except Exception as e: return f"Error downloading dataset: {str(e)}\n\nPlease ensure KDATA_API environment variable is set correctly." def visualize_data(data_type, num_samples): """Visualize sample images from the dataset""" if dataset_path is None: return [], "Please download the dataset first!" try: vis = Visualization(root=dataset_path, data_types=[data_type], n_ims=num_samples, rows=2, cmap="rgb") figs = vis.vis_samples(data_type, num_samples) if figs is None: return [], f"No data found for {data_type} dataset" return figs, f"Showing {len(figs)} samples from {data_type} dataset" except Exception as e: return [], f"Error visualizing data: {str(e)}" def analyze_class_distribution(data_type): """Analyze class distribution in the dataset""" if dataset_path is None: return None, "Please download the dataset first!" try: vis = Visualization(root=dataset_path, data_types=[data_type], n_ims=20, rows=5, cmap="rgb") fig = vis.data_analysis(data_type) if fig is None: return None, f"No data found for {data_type} dataset" return fig, f"Class distribution for {data_type} dataset" except Exception as e: return None, f"Error analyzing data: {str(e)}" @spaces.GPU(duration=300) # Request GPU for 5 minutes for training def train_model(epochs, batch_size, img_size, device_selection): """Train YOLOv11 model""" global model, training_in_progress if dataset_path is None: return [], "Please download the dataset first!" if training_in_progress: return [], "Training already in progress!" training_in_progress = True try: # Determine device - on Spaces, always use GPU if available if ON_SPACES and torch.cuda.is_available(): device = 0 elif device_selection == "Auto": device = 0 if torch.cuda.is_available() else "cpu" elif device_selection == "CPU": device = "cpu" else: device = 0 if torch.cuda.is_available() else "cpu" # Read dataset info yaml_path = f"{dataset_path}/data.yaml" with open(yaml_path, 'r') as file: data_config = yaml.safe_load(file) class_names = data_config.get('names', []) print(f"Training on {len(class_names)} classes: {class_names}") # Initialize model - use yolov8n if yolo11n not available try: model = YOLO("yolo11n.pt") except Exception as e: print(f"YOLOv11 not available: {e}, falling back to YOLOv8") model = YOLO("yolov8n.pt") # Fallback to YOLOv8 # Create project directory project_dir = "./xray_detection" os.makedirs(project_dir, exist_ok=True) # Train model with optimized settings for X-ray detection results = model.train( data=yaml_path, epochs=epochs, imgsz=img_size, batch=batch_size, device=device, project=project_dir, name="train", exist_ok=True, verbose=True, patience=5, # Reduce patience for faster training on Spaces save_period=5, # Save checkpoints every 5 epochs workers=0, # Important: Set to 0 to avoid multiprocessing issues single_cls=False, rect=False, cache=False, # Disable caching to avoid memory issues amp=True, # Use automatic mixed precision for faster training # Optimization settings optimizer='AdamW', lr0=0.001, # Initial learning rate lrf=0.01, # Final learning rate factor momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, # Loss weights box=7.5, cls=0.5, dfl=1.5, # Augmentation settings for X-ray images hsv_h=0.0, # No hue augmentation for X-ray hsv_s=0.0, # No saturation augmentation hsv_v=0.1, # Slight value augmentation degrees=0.0, # No rotation translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, # No vertical flip for X-ray fliplr=0.5, # Horizontal flip is okay mosaic=1.0, mixup=0.0, copy_paste=0.0 ) # Collect training result plots results_path = os.path.join(project_dir, "train") plots = [] plot_files = ["results.png", "confusion_matrix.png", "val_batch0_pred.jpg", "train_batch0.jpg", "val_batch0_labels.jpg"] for plot_file in plot_files: plot_path = os.path.join(results_path, plot_file) if os.path.exists(plot_path): plots.append(Image.open(plot_path)) # Save the model path model_path = os.path.join(results_path, "weights", "best.pt") # Load the trained model to ensure it's ready for inference model_loaded = False class_info = "" if os.path.exists(model_path): try: model = YOLO(model_path) model_loaded = True class_info = f"\n✅ Trained on {len(model.names)} classes: {', '.join(list(model.names.values()))}" # Run a test inference to ensure model works test_img = np.zeros((640, 640, 3), dtype=np.uint8) test_results = model(test_img, verbose=False) class_info += "\n✅ Model test passed - ready for inference!" except Exception as e: class_info = f"\n⚠️ Model loaded but test failed: {str(e)}" else: class_info = "\n❌ Model file not found!" training_in_progress = False # Provide instructions for saving the model save_instructions = """ ✅ **Training Complete!** 📥 **Next Steps:** 1. Click "📥 Download Model (.pt)" button below to save your model 2. Keep the downloaded file safe - you'll need it after Space restarts 3. To reuse: Upload the model file in the "Upload & Load Model" section ⚠️ **Important**: This model will be lost when the Space restarts! """ return plots, f"Model saved to {model_path}{class_info}{save_instructions}" except Exception as e: training_in_progress = False return [], f"Error during training: {str(e)}" # 🔍 Inference (Modified to highlight bomb, pistol, spring, grenade, eod_gear, battery) @spaces.GPU(duration=60) # Request GPU for 1 minute for inference def run_inference(input_image, conf_threshold): """Run inference on a single image and print detected item names.""" global model # Try to load the trained model if not already loaded if model is None: trained_model_path = "./xray_detection/train/weights/best.pt" if os.path.exists(trained_model_path): try: model = YOLO(trained_model_path) print(f"Loaded trained model from {trained_model_path}") except Exception: pass # If still no model, try default if model is None: for fallback in ("yolo11n.pt", "yolov8n.pt"): try: model = YOLO(fallback) print(f"Loaded fallback model: {fallback}") break except Exception: continue if model is None: return None, "Please train the model first or load a pre-trained model!" if input_image is None: return None, "Please upload an image!" try: # Save the input image temporarily with proper format temp_path = "temp_inference.jpg" if input_image.mode != 'RGB': input_image = input_image.convert('RGB') input_image.save(temp_path, format='JPEG', quality=95) # Run inference imgsz = 640 results = model( temp_path, conf=conf_threshold, verbose=False, device=0 if torch.cuda.is_available() else 'cpu', imgsz=imgsz, augment=False, agnostic_nms=False, max_det=300 ) # Draw annotated image annotated_image = results[0].plot( conf=True, labels=True, boxes=True, masks=False, probs=False ) # Prepare detection information detections = [] detection_count = 0 danger_set = {'bomb', 'pistol', 'spring', 'grenade', 'eod_gear', 'battery'} if results[0].boxes is not None: detection_count = len(results[0].boxes) for idx, box in enumerate(results[0].boxes): cls = int(box.cls) conf_val = float(box.conf) xyxy = list(map(int, box.xyxy[0].tolist())) cls_name = model.names.get(cls, f"Class {cls}") # Highlight dangerous items prefix = "‼️ " if cls_name in danger_set else "" detections.append( f"{idx + 1}. {prefix}{cls_name}: {conf_val:.3f} " f"| Box: [{xyxy[0]}, {xyxy[1]}, {xyxy[2]}, {xyxy[3]}]" ) # Clean up temp file if os.path.exists(temp_path): os.remove(temp_path) # Assemble detection text det_text_header = ( f"Model classes ({len(model.names)}): {', '.join(list(model.names.values())[:10])}...\n" f"Confidence threshold: {conf_threshold}\n\n" ) if detections: detection_text = ( det_text_header + f"✅ Found {detection_count} object(s):\n\n" + "\n".join(detections) ) else: detection_text = det_text_header + "❌ No objects detected." return Image.fromarray(annotated_image), detection_text except Exception as e: import traceback traceback.print_exc() return None, f"Error during inference: {str(e)}" @spaces.GPU(duration=60) # Request GPU for batch inference def batch_inference(data_type, num_images): """Run inference on multiple images from test set""" global model # Try to load the trained model if not already loaded if model is None: trained_model_path = "./xray_detection/train/weights/best.pt" if os.path.exists(trained_model_path): try: model = YOLO(trained_model_path) print(f"Loaded trained model for batch inference") except: try: model = YOLO("yolo11n.pt") print("Loaded default model for batch inference") except: try: model = YOLO("yolov8n.pt") print("Loaded YOLOv8 model as fallback for batch inference") except: return [], "Please train the model first!" else: return [], "No trained model found. Please train the model first!" if dataset_path is None: return [], "Please download the dataset first!" try: image_dir = f"{dataset_path}/{data_type}/images" if not os.path.exists(image_dir): return [], f"Directory {image_dir} not found!" image_files = glob(f"{image_dir}/*")[:num_images] if not image_files: return [], f"No images found in {image_dir}" results_images = [] detection_counts = [] for img_path in image_files: results = model(img_path, verbose=False, conf=0.25, imgsz=640) annotated = results[0].plot() results_images.append(Image.fromarray(annotated)) # Count detections if results[0].boxes is not None: detection_counts.append(len(results[0].boxes)) else: detection_counts.append(0) # Check model type model_type = "X-ray detection model" if len(model.names) != 80 else "General COCO model" avg_detections = sum(detection_counts) / len(detection_counts) if detection_counts else 0 return results_images, f"Processed {len(results_images)} images using {model_type}\nAverage detections per image: {avg_detections:.1f}" except Exception as e: return [], f"Error during batch inference: {str(e)}" def get_dataset_info(): """Get information about the X-ray dataset classes""" if dataset_path is None: return "Dataset not downloaded yet." try: yaml_path = f"{dataset_path}/data.yaml" if not os.path.exists(yaml_path): return "Dataset configuration file not found." with open(yaml_path, 'r') as file: data = yaml.safe_load(file) class_names = data.get('names', []) num_classes = len(class_names) # Count images train_images = len(glob(f"{dataset_path}/train/images/*")) if os.path.exists(f"{dataset_path}/train/images") else 0 valid_images = len(glob(f"{dataset_path}/valid/images/*")) if os.path.exists(f"{dataset_path}/valid/images") else 0 test_images = len(glob(f"{dataset_path}/test/images/*")) if os.path.exists(f"{dataset_path}/test/images") else 0 info = f"### 📊 X-ray Baggage Dataset Info\n\n" info += f"**Classes ({num_classes}):** {', '.join(class_names)}\n\n" info += f"**Dataset Split:**\n" info += f"- Training: {train_images} images\n" info += f"- Validation: {valid_images} images\n" info += f"- Test: {test_images} images\n" info += f"- Total: {train_images + valid_images + test_images} images\n\n" info += f"**What to expect:** The model will learn to detect these prohibited items in X-ray scans." return info except Exception as e: return f"Error reading dataset info: {str(e)}" """Load a pre-trained model""" global model try: # Check if it's a HuggingFace model path if model_path.startswith("hf://") or "/" in model_path and not os.path.exists(model_path): # Load from HuggingFace Hub model = YOLO(model_path) return f"Model loaded successfully from HuggingFace: {model_path}" if not os.path.exists(model_path): # Try default paths default_paths = [ "./xray_detection/train/weights/best.pt", "./xray_detection/train/weights/last.pt", "yolo11n.pt", "yolov8n.pt" ] for path in default_paths: if os.path.exists(path): model_path = path break if os.path.exists(model_path): model = YOLO(model_path) # Check if it's a trained model by looking at class names try: if hasattr(model, 'names') and len(model.names) > 0: class_names = ", ".join([f"{i}: {name}" for i, name in model.names.items()][:5]) if len(model.names) > 5: class_names += f"... (총 {len(model.names)} 클래스)" return f"Model loaded successfully from {model_path}\n클래스: {class_names}" except: pass return f"Model loaded successfully from {model_path}" else: return "Model file not found. Please train a model first or provide a valid path." except Exception as e: return f"Error loading model: {str(e)}" def load_pretrained_model(model_file): """Load a pre-trained model from uploaded file""" global model if model_file is None: return "Please upload a model file (.pt)" try: # model_file is already a filepath string when type="filepath" temp_path = model_file # Load the model model = YOLO(temp_path) # Check model info try: if hasattr(model, 'names') and len(model.names) > 0: num_classes = len(model.names) class_names = ", ".join([f"{name}" for name in list(model.names.values())[:5]]) if len(model.names) > 5: class_names += f"... (총 {num_classes} 클래스)" if num_classes == 80: return f"⚠️ Loaded COCO model with {num_classes} classes. This is not trained for X-ray detection.\nClasses: {class_names}" else: return f"✅ Model loaded successfully!\nClasses ({num_classes}): {class_names}" else: return "✅ Model loaded successfully!" except: return "✅ Model loaded successfully!" except Exception as e: return f"Error loading model: {str(e)}" def check_model_status(): """Check current model status""" global model if model is None: # Try to load trained model trained_path = "./xray_detection/train/weights/best.pt" if os.path.exists(trained_path): try: model = YOLO(trained_path) num_classes = len(model.names) class_names = ', '.join(list(model.names.values())) return f"✅ Trained model loaded: {num_classes} classes\n📋 Classes: {class_names}" except: return "❌ No model loaded. Please train or load a model first." return "❌ No model loaded. Please train or load a model first." else: try: num_classes = len(model.names) class_names = ', '.join(list(model.names.values())) if num_classes == 80: return f"⚠️ Default COCO model loaded ({num_classes} classes). For X-ray detection, please train on the X-ray dataset." else: return f"✅ Model loaded: {num_classes} classes\n📋 Classes: {class_names}" except: return "✅ Model loaded" # Create Gradio interface with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎯 X-ray Baggage Anomaly Detection with YOLO This application allows you to: 1. Download and visualize the X-ray baggage dataset 2. Analyze class distributions 3. Train a YOLO model for object detection 4. Run inference on new images **Note:** GPU will be automatically allocated when needed for training and inference. """) # Check if there's a pre-existing model initial_model_status = "🔍 Checking for existing models..." if os.path.exists("./xray_detection/train/weights/best.pt"): try: model = YOLO("./xray_detection/train/weights/best.pt") initial_model_status = "✅ Found previously trained model! Ready to use." except: initial_model_status = "❌ No model loaded. Please train or upload a model." else: initial_model_status = "❌ No model loaded. Please train or upload a model." gr.Markdown(f"**Model Status:** {initial_model_status}") # Add instructions for Kaggle API setup with gr.Accordion("📝 Setup Instructions", open=False): gr.Markdown(""" ### Kaggle API Setup 1. Get your Kaggle API credentials from https://www.kaggle.com/settings 2. Set the KDATA_API environment variable in Hugging Face Spaces settings: ``` KDATA_API={"username":"your_username","key":"your_api_key"} ``` ### Model Persistence on Hugging Face Spaces - Models trained on Spaces are **temporary** and will be lost when the Space restarts - After training, download your model using the "📥 Download Model" button - Upload the downloaded model file to reuse it after Space restarts - No need for HuggingFace Hub or complex setups! """) with gr.Tab("📊 Dataset"): with gr.Row(): download_btn = gr.Button("Download Dataset", variant="primary", scale=1) download_status = gr.Textbox(label="Status", interactive=False, scale=3) download_btn.click(download_dataset, outputs=download_status) # Dataset info section with gr.Row(): dataset_info = gr.Markdown(value="Dataset not downloaded yet.") info_btn = gr.Button("🔄 Refresh Dataset Info", scale=0) def update_dataset_info(): return get_dataset_info() info_btn.click(update_dataset_info, outputs=dataset_info) gr.Markdown("### Visualize Dataset Samples") with gr.Row(): data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type") num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples") viz_btn = gr.Button("Visualize Samples") viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto") viz_status = gr.Textbox(label="Status", interactive=False) viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples], outputs=[viz_gallery, viz_status]) gr.Markdown("### Analyze Class Distribution") with gr.Row(): data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type") analyze_btn = gr.Button("Analyze Distribution") distribution_plot = gr.Image(label="Class Distribution", type="pil") analysis_status = gr.Textbox(label="Status", interactive=False) analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis, outputs=[distribution_plot, analysis_status]) gr.Markdown("### Visualize Dataset Samples") with gr.Row(): data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type") num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples") viz_btn = gr.Button("Visualize Samples") viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto") viz_status = gr.Textbox(label="Status", interactive=False) viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples], outputs=[viz_gallery, viz_status]) gr.Markdown("### Analyze Class Distribution") with gr.Row(): data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type") analyze_btn = gr.Button("Analyze Distribution") distribution_plot = gr.Image(label="Class Distribution", type="pil") analysis_status = gr.Textbox(label="Status", interactive=False) analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis, outputs=[distribution_plot, analysis_status]) with gr.Tab("🚀 Training"): gr.Markdown("### Train YOLO Model") gr.Markdown(""" **Note:** Training will automatically use GPU if available. This may take several minutes. **Recommended Settings for X-ray Detection:** - **Epochs:** 20-30 for good results - **Batch Size:** 2-4 for better convergence - **Image Size:** 640 for best quality - **Expected time:** ~2-5 minutes for 20 epochs ⚠️ **Important**: Models are temporary on Spaces! Download your model after training. """) with gr.Row(): epochs_input = gr.Slider(1, 50, 20, step=1, label="Epochs (20+ recommended)") batch_size_input = gr.Slider(2, 16, 4, step=2, label="Batch Size (lower for better results)") img_size_input = gr.Slider(320, 640, 640, step=32, label="Image Size (640 recommended)") device_input = gr.Radio(["Auto", "GPU", "CPU"], value="Auto", label="Device") train_btn = gr.Button("Start Training", variant="primary") training_gallery = gr.Gallery(label="Training Results", columns=3, height="auto") training_status = gr.Textbox(label="Training Status", interactive=False) train_btn.click(train_model, inputs=[epochs_input, batch_size_input, img_size_input, device_input], outputs=[training_gallery, training_status]) gr.Markdown("### 📥 Model Management") with gr.Row(): with gr.Column(): gr.Markdown("#### 1️⃣ Download Trained Model") gr.Markdown("After training, download your model to save it permanently.") # Function to prepare model for download def prepare_model_download(): model_path = "./xray_detection/train/weights/best.pt" if os.path.exists(model_path): return gr.update(value=model_path, visible=True), "✅ Model ready for download!" else: return gr.update(value=None, visible=False), "❌ No trained model found. Please train a model first." download_btn = gr.Button("📥 Download Model (.pt)", variant="secondary") download_file = gr.File(label="Download Model File", visible=False) download_status = gr.Textbox(label="Download Status", interactive=False) download_btn.click(prepare_model_download, outputs=[download_file, download_status]) with gr.Column(): gr.Markdown("#### 2️⃣ Upload & Load Model") gr.Markdown("Upload a previously trained model file to continue using it.") model_upload = gr.File( label="Upload Model File (.pt)", file_types=[".pt"], type="filepath" ) load_btn = gr.Button("📤 Load Uploaded Model", variant="secondary") load_status = gr.Textbox(label="Load Status", interactive=False) load_btn.click(load_pretrained_model, inputs=model_upload, outputs=load_status) # Auto-load when file is uploaded model_upload.change(load_pretrained_model, inputs=model_upload, outputs=load_status) with gr.Tab("🔍 Inference"): # Model status check with gr.Row(): model_status = gr.Textbox(label="Model Status", value=check_model_status(), interactive=False) refresh_status_btn = gr.Button("🔄 Refresh Status", scale=0) refresh_status_btn.click(check_model_status, outputs=model_status) gr.Markdown(""" ## 🎯 모델이 객체를 감지하지 못하나요? **권장 학습 설정:** - **Epochs: 30** (최소 20 이상) - **Batch Size: 2 또는 4** - **Image Size: 640** **체크리스트:** 1. ✅ X-ray 이미지인가? (일반 사진은 작동 안 함) 2. ✅ 충분히 학습했나? (20+ epochs) 3. ✅ Confidence threshold를 0.01로 낮춰봤나? 4. ✅ 모델이 제대로 로드되었나? (상태 확인) **성공적인 학습 후 예상 결과:** - Firearm (총기류) 감지 - Knife (칼) 감지 - Pliers (펜치) 감지 - Scissors (가위) 감지 - Wrench (렌치) 감지 """) gr.Markdown("### Single Image Inference") gr.Markdown("Upload an X-ray baggage image to detect prohibited items.") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Upload X-ray Image") conf_threshold = gr.Slider(0.01, 0.9, 0.25, step=0.01, label="Confidence Threshold (낮을수록 더 많이 감지)") # Debug options with gr.Row(): inference_btn = gr.Button("Run Detection", variant="primary") test_btn = gr.Button("Test with 0.01 threshold", variant="secondary", scale=0) # Add example images if dataset is available example_images = [] if dataset_path and os.path.exists(f"{dataset_path}/test/images"): test_images = glob(f"{dataset_path}/test/images/*")[:5] example_images.extend(test_images) if example_images: gr.Examples( examples=[[img] for img in example_images], inputs=input_image, label="Example X-ray Images (Click to load)" ) with gr.Column(): output_image = gr.Image(type="pil", label="Detection Result") detection_info = gr.Textbox(label="Detection Info", lines=8) inference_btn.click(run_inference, inputs=[input_image, conf_threshold], outputs=[output_image, detection_info]) # Test with very low threshold test_btn.click( lambda img: run_inference(img, 0.01), inputs=[input_image], outputs=[output_image, detection_info] ) # Auto-refresh model status after inference inference_btn.click(check_model_status, outputs=model_status) gr.Markdown("### Batch Inference") gr.Markdown("Run detection on multiple images from the test dataset.") with gr.Row(): batch_data_type = gr.Dropdown(["test", "valid"], value="test", label="Dataset Type") batch_num_images = gr.Slider(1, 10, 5, step=1, label="Number of Images") batch_btn = gr.Button("Run Batch Inference") batch_gallery = gr.Gallery(label="Batch Results", columns=3, height="auto") batch_status = gr.Textbox(label="Status", interactive=False) batch_btn.click(batch_inference, inputs=[batch_data_type, batch_num_images], outputs=[batch_gallery, batch_status]) # Footer gr.Markdown("---") gr.Markdown("""