heaven1-base / fix_nms_error.py
Tomas
Add initial project setup with model configuration, requirements, and upload script
58af2e6 unverified
import os
import sys
import importlib.util
from pathlib import Path
def locate_trl_module():
"""Find the location of the TRL module in the Python path."""
try:
spec = importlib.util.find_spec('trl')
if spec is None:
print("TRL module not found in the Python path")
return None
trl_path = Path(spec.origin).parent
print(f"Found TRL module at: {trl_path}")
return trl_path
except Exception as e:
print(f"Error locating TRL module: {e}")
return None
def patch_sft_trainer():
"""Patch the SFTTrainer to avoid using torchvision's NMS operator."""
trl_path = locate_trl_module()
if trl_path is None:
return False
# Path to the trainer.py file which likely contains the NMS reference
trainer_path = trl_path / "trainer" / "sft_trainer.py"
if not trainer_path.exists():
print(f"Could not find the SFT trainer file at: {trainer_path}")
return False
print(f"Found SFT trainer file at: {trainer_path}")
# Read the file content
with open(trainer_path, "r") as f:
content = f.read()
# Check if 'torchvision' is in the file
if "torchvision" not in content:
print("No torchvision imports found in the SFT trainer file.")
return False
# Create backup
backup_path = trainer_path.with_suffix(".py.bak")
print(f"Creating backup at: {backup_path}")
with open(backup_path, "w") as f:
f.write(content)
# Replace imports - common patterns
patched_content = content
# Pattern 1: Direct import of nms
patched_content = patched_content.replace(
"from torchvision.ops import nms",
"# from torchvision.ops import nms # Commented out to fix NMS error"
)
# Pattern 2: Import torchvision
patched_content = patched_content.replace(
"import torchvision",
"# import torchvision # Commented out to fix NMS error"
)
# Pattern 3: Import from torchvision.ops
patched_content = patched_content.replace(
"from torchvision.ops",
"# from torchvision.ops # Commented out to fix NMS error"
)
# Add our custom NMS implementation
custom_nms = """
# Custom NMS implementation to avoid torchvision dependency
def nms(boxes, scores, iou_threshold):
"""
Performs non-maximum suppression (NMS) on the boxes according to their
intersection-over-union (IoU).
Args:
boxes (Tensor[N, 4]): boxes to perform NMS on
scores (Tensor[N]): scores for each one of the boxes
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
Returns:
Tensor: int64 tensor with the indices of the elements that have been kept
"""
import torch
# Sort boxes by scores
_, order = scores.sort(0, descending=True)
keep = []
while order.numel() > 0:
if order.numel() == 1:
keep.append(order.item())
break
i = order[0].item()
keep.append(i)
# Compute IoU of the remaining boxes with the largest box
xx1 = torch.max(boxes[i, 0], boxes[order[1:], 0])
yy1 = torch.max(boxes[i, 1], boxes[order[1:], 1])
xx2 = torch.min(boxes[i, 2], boxes[order[1:], 2])
yy2 = torch.min(boxes[i, 3], boxes[order[1:], 3])
w = torch.clamp(xx2 - xx1, min=0.0)
h = torch.clamp(yy2 - yy1, min=0.0)
inter = w * h
# IoU = intersection / (area1 + area2 - intersection)
box_area = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1])
other_area = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1])
iou = inter / (box_area + other_area - inter)
# Keep boxes with IoU less than threshold
inds = torch.where(iou <= iou_threshold)[0]
order = order[inds + 1]
return torch.tensor(keep, dtype=torch.int64)
"""
# Add our custom implementation somewhere near the imports
import_end = patched_content.find("\n\n", patched_content.find("import "))
if import_end == -1:
import_end = patched_content.find("\n", patched_content.find("import "))
patched_content = patched_content[:import_end] + custom_nms + patched_content[import_end:]
# Write the patched file
with open(trainer_path, "w") as f:
f.write(patched_content)
print(f"Successfully patched {trainer_path}")
print("The SFTTrainer should now work without requiring torchvision's NMS operator")
return True
if __name__ == "__main__":
success = patch_sft_trainer()
if success:
print("\nPatch applied successfully. You can now run the fine-tuning script.")
else:
print("\nFailed to apply the patch. Please check the error messages above.")