#!/usr/bin/env python """ DeepSurg Technologies Ltd. (c) 2025 Surgical VLLM - v1 """ import os import torch import torch.nn.functional as F from PIL import Image from transformers import BertTokenizer # Import the VisualBertClassification model (ensure the module is in your PYTHONPATH) from models.VisualBertClassification_ssgqa import VisualBertClassification # For SurgVLP encoder from mmengine.config import Config from utils.SurgVLP import surgvlp import random # For Gradio UI import gradio as gr image_files = None selectedID = 0 question_dropdown = None #NO GPU is available os.environ["CUDA_VISIBLE_DEVICES"] = "-1" def seed_everything(seed=27): torch.manual_seed(seed) #torch.cuda.manual_seed_all(seed) os.environ["PYTHONHASHSEED"] = str(seed) #torch.backends.cudnn.deterministic = True #torch.backends.cudnn.benchmark = False def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_heads=8, dropout=0.1, emb_dim=300): """ Initialize the VisualBertClassification model and load the checkpoint. """ model = VisualBertClassification( vocab_size=len(tokenizer), layers=encoder_layers, n_heads=n_heads, num_class=num_class, ) checkpoint = torch.load("./checkpoint.tar", map_location=device) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() return model def load_surgvlp_encoder(device): """ Load the SurgVLP encoder and its preprocessing function. """ config_path = './utils/config_surgvlp.py' configs = Config.fromfile(config_path)['config'] encoder_model, encoder_preprocess = surgvlp.load(configs.model_config, device=device, pretrain='./SurgVLP2.pth') encoder_model.eval() return encoder_model, encoder_preprocess # Label conversion list (mapping model output indices to text labels) LABEL_LIST = [ "0", "1", "10", "2", "3", "4", "5", "6", "7", "8", "9", "False", "True", "abdominal_wall_cavity", "adhesion", "anatomy", "aspirate", "bipolar", "blood_vessel", "blue", "brown", "clip", "clipper", "coagulate", "cut", "cystic_artery", "cystic_duct", "cystic_pedicle", "cystic_plate", "dissect", "fluid", "gallbladder", "grasp", "grasper", "gut", "hook", "instrument", "irrigate", "irrigator", "liver", "omentum", "pack", "peritoneum", "red", "retract", "scissors", "silver", "specimen_bag", "specimenbag", "white", "yellow" ] def main(): seed_everything() device = "cpu" tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") visualbert_model = load_visualbert_model(tokenizer, device) encoder_model, encoder_preprocess = load_surgvlp_encoder(device) print("Models loaded successfully.") # Define the directories containing images and corresponding label files. global image_files images_dir = "./test_data/images/VID" labels_dir = "./test_data/labels/VID/" image_files = [os.path.join(images_dir, f) for f in sorted(os.listdir(images_dir)) if f.lower().endswith('.png')] random.shuffle(image_files) print(f"Found {len(image_files)} images.") # Get first 20 images. image_files = image_files[:20] # Build a predefined questions array (by reading the label files for each image). questions = [] for image_path in image_files: image_id = int(os.path.basename(image_path).replace('.png', '')) label_path = os.path.join(labels_dir, f"{image_id}.txt") try: with open(label_path, 'r') as f: lines = f.readlines() for line in lines: # Split each line at '|' and take the first part as the question. questions.append(line.split("|")[0]) except Exception as e: # If a file is missing, skip it. continue # Remove duplicates (optional) and sort. def predict_image(selected_images, question): """ Processes the selected image (by file path) along with the surgical question. Returns a text summary that includes the image file name and top-3 predictions. """ if not selected_images: return "Please select an image from the list." if question.strip() == "": return "Please select a question from the dropdown." # Use the global selectedID to pick the image. image_path = image_files[selectedID] try: pil_image = Image.open(image_path).convert("RGB") except Exception as e: return f"Could not open image: {str(e)}" image_processed = encoder_preprocess(pil_image).unsqueeze(0).to(device) with torch.no_grad(): visual_features = encoder_model(image_processed, None, mode='video')['img_emb'] visual_features /= visual_features.norm(dim=-1, keepdim=True) visual_features = visual_features.unsqueeze(1) inputs = tokenizer( [question], return_tensors="pt", padding="max_length", truncation=True, max_length=77, ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = visualbert_model(inputs, visual_features) probabilities = F.softmax(outputs, dim=1) topk = torch.topk(probabilities, k=3, dim=1) topk_scores = topk.values.cpu().numpy().flatten() topk_indices = topk.indices.cpu().numpy().flatten() top_predictions = [(LABEL_LIST[i], float(score)) for i, score in zip(topk_indices, topk_scores)] image_name = os.path.basename(image_path) output_str = f"Frame: {image_name}\n\nTop 3 Predictions:\n" for rank, (lbl, score) in enumerate(top_predictions, start=1): output_str += f"Rank {rank}: {lbl} ({score:.4f})\t\t\t" print(f"Selected image: {image_name}") return output_str # Callback to update the global selectedID when the user selects an image from the SelectData. def update_selected(selection: gr.SelectData): global selectedID global question_dropdown selectedID = selection.index question_dropdown = gr.Dropdown( choices=questions[selectedID], label="Select a Question" ) with gr.Blocks() as demo: gr.Markdown("# DeepSurg Surgical VQA Demo (V1)") gr.Markdown("## Cholecystectomy Surgery VLLM") gr.Markdown("### Current version supports label-based answers only.") #add a logo here # Use gr.SelectData to let the user choose one image. image_gallery = gr.Gallery( value=image_files, label="Select an Image", interactive=True, allow_preview = True, preview = True, columns=[20], ) image_gallery.select(fn=update_selected, inputs=None) # Dropdown for selecting a predefined question. global question_dropdown question_dropdown = gr.Dropdown( choices=questions, label="Select a Question" ) generate_btn = gr.Button("Generate") predictions_output = gr.Textbox(label="Predictions", lines=10) generate_btn.click( fn=predict_image, inputs=[image_gallery, question_dropdown], outputs=predictions_output ) print("Launching the Gradio UI...") demo.launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": main()