File size: 7,635 Bytes
a14e3ff fca2557 a14e3ff fca2557 a14e3ff fca2557 a14e3ff fca2557 a14e3ff fca2557 a14e3ff fca2557 a14e3ff fca2557 a14e3ff fc1f662 a14e3ff fca2557 a14e3ff bf64a7f a14e3ff fca2557 85d2c3e a14e3ff fca2557 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
#!/usr/bin/env python
"""
DeepSurg Technologies Ltd. (c) 2025
Surgical VLLM - v1
"""
import os
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import BertTokenizer
# Import the VisualBertClassification model (ensure the module is in your PYTHONPATH)
from models.VisualBertClassification_ssgqa import VisualBertClassification
# For SurgVLP encoder
from mmengine.config import Config
from utils.SurgVLP import surgvlp
import random
# For Gradio UI
import gradio as gr
image_files = None
selectedID = 0
question_dropdown = None
#NO GPU is available
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
def seed_everything(seed=27):
torch.manual_seed(seed)
#torch.cuda.manual_seed_all(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False
def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_heads=8, dropout=0.1, emb_dim=300):
"""
Initialize the VisualBertClassification model and load the checkpoint.
"""
model = VisualBertClassification(
vocab_size=len(tokenizer),
layers=encoder_layers,
n_heads=n_heads,
num_class=num_class,
)
checkpoint = torch.load("./checkpoint.tar", map_location=device)
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
return model
def load_surgvlp_encoder(device):
"""
Load the SurgVLP encoder and its preprocessing function.
"""
config_path = './utils/config_surgvlp.py'
configs = Config.fromfile(config_path)['config']
encoder_model, encoder_preprocess = surgvlp.load(configs.model_config, device=device, pretrain='./SurgVLP2.pth')
encoder_model.eval()
return encoder_model, encoder_preprocess
# Label conversion list (mapping model output indices to text labels)
LABEL_LIST = [
"0", "1", "10", "2", "3", "4", "5", "6", "7", "8", "9",
"False", "True", "abdominal_wall_cavity", "adhesion", "anatomy",
"aspirate", "bipolar", "blood_vessel", "blue", "brown", "clip",
"clipper", "coagulate", "cut", "cystic_artery", "cystic_duct",
"cystic_pedicle", "cystic_plate", "dissect", "fluid", "gallbladder",
"grasp", "grasper", "gut", "hook", "instrument", "irrigate", "irrigator",
"liver", "omentum", "pack", "peritoneum", "red", "retract", "scissors",
"silver", "specimen_bag", "specimenbag", "white", "yellow"
]
def main():
seed_everything()
device = "cpu"
tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
visualbert_model = load_visualbert_model(tokenizer, device)
encoder_model, encoder_preprocess = load_surgvlp_encoder(device)
print("Models loaded successfully.")
# Define the directories containing images and corresponding label files.
global image_files
images_dir = "./test_data/images/VID"
labels_dir = "./test_data/labels/VID/"
image_files = [os.path.join(images_dir, f) for f in sorted(os.listdir(images_dir)) if f.lower().endswith('.png')]
random.shuffle(image_files)
print(f"Found {len(image_files)} images.")
# Get first 20 images.
image_files = image_files[:20]
# Build a predefined questions array (by reading the label files for each image).
questions = []
for image_path in image_files:
image_id = int(os.path.basename(image_path).replace('.png', ''))
label_path = os.path.join(labels_dir, f"{image_id}.txt")
try:
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
# Split each line at '|' and take the first part as the question.
questions.append(line.split("|")[0])
except Exception as e:
# If a file is missing, skip it.
continue
# Remove duplicates (optional) and sort.
def predict_image(selected_images, question):
"""
Processes the selected image (by file path) along with the surgical question.
Returns a text summary that includes the image file name and top-3 predictions.
"""
if not selected_images:
return "Please select an image from the list."
if question.strip() == "":
return "Please select a question from the dropdown."
# Use the global selectedID to pick the image.
image_path = image_files[selectedID]
try:
pil_image = Image.open(image_path).convert("RGB")
except Exception as e:
return f"Could not open image: {str(e)}"
image_processed = encoder_preprocess(pil_image).unsqueeze(0).to(device)
with torch.no_grad():
visual_features = encoder_model(image_processed, None, mode='video')['img_emb']
visual_features /= visual_features.norm(dim=-1, keepdim=True)
visual_features = visual_features.unsqueeze(1)
inputs = tokenizer(
[question],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=77,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = visualbert_model(inputs, visual_features)
probabilities = F.softmax(outputs, dim=1)
topk = torch.topk(probabilities, k=3, dim=1)
topk_scores = topk.values.cpu().numpy().flatten()
topk_indices = topk.indices.cpu().numpy().flatten()
top_predictions = [(LABEL_LIST[i], float(score)) for i, score in zip(topk_indices, topk_scores)]
image_name = os.path.basename(image_path)
output_str = f"Frame: {image_name}\n\nTop 3 Predictions:\n"
for rank, (lbl, score) in enumerate(top_predictions, start=1):
output_str += f"Rank {rank}: {lbl} ({score:.4f})\t\t\t"
print(f"Selected image: {image_name}")
return output_str
# Callback to update the global selectedID when the user selects an image from the SelectData.
def update_selected(selection: gr.SelectData):
global selectedID
global question_dropdown
selectedID = selection.index
question_dropdown = gr.Dropdown(
choices=questions[selectedID],
label="Select a Question"
)
with gr.Blocks() as demo:
gr.Markdown("# DeepSurg Surgical VQA Demo (V1)")
gr.Markdown("## Cholecystectomy Surgery VLLM")
gr.Markdown("### Current version supports label-based answers only.")
#add a logo here
# Use gr.SelectData to let the user choose one image.
image_gallery = gr.Gallery(
value=image_files,
label="Select an Image",
interactive=True,
allow_preview = True,
preview = True,
columns=[20],
)
image_gallery.select(fn=update_selected, inputs=None)
# Dropdown for selecting a predefined question.
global question_dropdown
question_dropdown = gr.Dropdown(
choices=questions,
label="Select a Question"
)
generate_btn = gr.Button("Generate")
predictions_output = gr.Textbox(label="Predictions", lines=10)
generate_btn.click(
fn=predict_image,
inputs=[image_gallery, question_dropdown],
outputs=predictions_output
)
print("Launching the Gradio UI...")
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main()
|