phase_2b / script.py
yusufbardolia's picture
Update script.py
c4ac108 verified
import os
import torch
import pandas as pd
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from tqdm import tqdm
def run_inference(image_path, model, save_path, prompt, box_threshold, text_threshold, device):
# 1. Get list of images
try:
test_images = sorted(os.listdir(image_path))
except FileNotFoundError:
print(f"⚠️ Warning: Path {image_path} not found. Creating dummy submission.")
test_images = []
bboxes = []
category_ids = []
test_images_names = []
print(f"πŸš€ Running inference on {len(test_images)} images...")
print(f"πŸ“ Prompt: {prompt}")
# 2. Loop through all test images
for image_name in tqdm(test_images):
test_images_names.append(image_name)
bbox = []
category_id = []
try:
full_img_path = os.path.join(image_path, image_name)
# Load image and ensure RGB
img = Image.open(full_img_path).convert("RGB")
except Exception as e:
print(f"Error loading {image_name}: {e}")
bboxes.append([])
category_ids.append([])
continue
inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[img.size[::-1]]
)
# 3. Process Results (SAFE MODE: Map all to Class ID 0)
for result in results:
boxes = result["boxes"]
for box in boxes:
xmin, ymin, xmax, ymax = box.tolist()
width = xmax - xmin
height = ymax - ymin
bbox.append([xmin, ymin, width, height])
category_id.append(0)
bboxes.append(bbox)
category_ids.append(category_id)
# 4. Create Submission DataFrame
df_predictions = pd.DataFrame(columns=["file_name", "bbox", "category_id"])
for i in range(len(test_images_names)):
new_row = pd.DataFrame({
"file_name": test_images_names[i],
"bbox": str(bboxes[i]),
"category_id": str(category_ids[i]),
}, index=[0])
df_predictions = pd.concat([df_predictions, new_row], ignore_index=True)
df_predictions.to_csv(save_path, index=False)
print("βœ… Submission file generated.")
if __name__ == "__main__":
# --- HUGGING FACE SERVER CONFIGURATION ---
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
current_directory = os.path.dirname(os.path.abspath(__file__))
TEST_IMAGE_PATH = "/tmp/data/test_images"
SUBMISSION_SAVE_PATH = os.path.join(current_directory, "submission.csv")
# --- MODEL LOADING ---
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(os.path.join(current_directory, "processor"))
model = AutoModelForZeroShotObjectDetection.from_pretrained(os.path.join(current_directory, "model"))
model.to(device)
# ==========================================
# πŸ† REVERTED WINNING CONFIGURATION
# ==========================================
# 1. Prompt Strategy: "Medical Names + Synonyms"
# We are bringing back the specific names because the model recognizes them better
# than generic "silver metal".
PROMPT = (
"Monopolar Curved Scissors . surgical scissors . "
"Prograsp Forceps . grasper jaws . "
"Large Needle Driver . needle holder ."
)
# 2. Threshold Strategy: "The Sweet Spot"
# 0.40 was too high (low recall). 0.25 was too low (high noise).
# 0.30 balances finding the tool vs ignoring the background.
BOX_THRESHOLD = 0.30
TEXT_THRESHOLD = 0.25
# ==========================================
run_inference(TEST_IMAGE_PATH, model, SUBMISSION_SAVE_PATH, PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, device)