| import os |
| import torch |
| import pandas as pd |
| from PIL import Image |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection |
| from tqdm import tqdm |
|
|
| def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold, device): |
| |
| |
| try: |
| test_images = sorted(os.listdir(image_path)) |
| except FileNotFoundError: |
| print(f"β οΈ Warning: Path {image_path} not found. Creating dummy submission.") |
| test_images = [] |
|
|
| bboxes = [] |
| category_ids = [] |
| test_images_names = [] |
| |
| print(f"π Running inference on {len(test_images)} images...") |
| print(f"π Prompt: {prompt}") |
| |
| |
| for image_name in tqdm(test_images): |
| test_images_names.append(image_name) |
| bbox = [] |
| category_id = [] |
| |
| try: |
| full_img_path = os.path.join(image_path, image_name) |
| |
| img = Image.open(full_img_path).convert("RGB") |
| except Exception as e: |
| print(f"Error loading {image_name}: {e}") |
| bboxes.append([]) |
| category_ids.append([]) |
| continue |
| |
| inputs = processor(images=img, text=prompt, return_tensors="pt").to(device) |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| results = processor.post_process_grounded_object_detection( |
| outputs, |
| inputs.input_ids, |
| threshold=box_threshold, |
| text_threshold=text_threshold, |
| target_sizes=[img.size[::-1]] |
| ) |
| |
| |
| for result in results: |
| boxes = result["boxes"] |
| |
| for box in boxes: |
| xmin, ymin, xmax, ymax = box.tolist() |
| width = xmax - xmin |
| height = ymax - ymin |
| bbox.append([xmin, ymin, width, height]) |
| category_id.append(0) |
| |
| bboxes.append(bbox) |
| category_ids.append(category_id) |
| |
| |
| df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"]) |
| |
| for i in range(len(test_images_names)): |
| new_row = pd.DataFrame({ |
| "file_name": test_images_names[i], |
| "bbox": str(bboxes[i]), |
| "category_id": str(category_ids[i]), |
| }, index=[0]) |
| df_predictions = pd.concat([df_predictions, new_row], ignore_index=True) |
| |
| df_predictions.to_csv(save_path, index=False) |
| print("β
Submission file generated.") |
|
|
|
|
| if __name__ == "__main__": |
|
|
| |
| os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" |
| os.environ["HF_HUB_OFFLINE"] = "1" |
| os.environ["HF_DATASETS_OFFLINE"] = "1" |
| |
| current_directory = os.path.dirname(os.path.abspath(__file__)) |
| TEST_IMAGE_PATH = "/tmp/data/test_images" |
| SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv") |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| processor = AutoProcessor.from_pretrained(os.path.join(current_directory, "processor")) |
| model = AutoModelForZeroShotObjectDetection.from_pretrained(os.path.join(current_directory, "model")) |
| model.to(device) |
| |
| |
| |
| |
| |
| |
| |
| |
| PROMPT = ( |
| "Monopolar Curved Scissors . surgical scissors . " |
| "Prograsp Forceps . grasper jaws . " |
| "Large Needle Driver . needle holder ." |
| ) |
| |
| |
| |
| |
| BOX_THRESHOLD = 0.30 |
| TEXT_THRESHOLD = 0.25 |
| |
| |
| |
| run_inference(TEST_IMAGE_PATH, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, device) |