Tomas
Add initial project setup with model configuration, requirements, and upload script
58af2e6
unverified
import os | |
import sys | |
import importlib.util | |
from pathlib import Path | |
def locate_trl_module(): | |
"""Find the location of the TRL module in the Python path.""" | |
try: | |
spec = importlib.util.find_spec('trl') | |
if spec is None: | |
print("TRL module not found in the Python path") | |
return None | |
trl_path = Path(spec.origin).parent | |
print(f"Found TRL module at: {trl_path}") | |
return trl_path | |
except Exception as e: | |
print(f"Error locating TRL module: {e}") | |
return None | |
def patch_sft_trainer(): | |
"""Patch the SFTTrainer to avoid using torchvision's NMS operator.""" | |
trl_path = locate_trl_module() | |
if trl_path is None: | |
return False | |
# Path to the trainer.py file which likely contains the NMS reference | |
trainer_path = trl_path / "trainer" / "sft_trainer.py" | |
if not trainer_path.exists(): | |
print(f"Could not find the SFT trainer file at: {trainer_path}") | |
return False | |
print(f"Found SFT trainer file at: {trainer_path}") | |
# Read the file content | |
with open(trainer_path, "r") as f: | |
content = f.read() | |
# Check if 'torchvision' is in the file | |
if "torchvision" not in content: | |
print("No torchvision imports found in the SFT trainer file.") | |
return False | |
# Create backup | |
backup_path = trainer_path.with_suffix(".py.bak") | |
print(f"Creating backup at: {backup_path}") | |
with open(backup_path, "w") as f: | |
f.write(content) | |
# Replace imports - common patterns | |
patched_content = content | |
# Pattern 1: Direct import of nms | |
patched_content = patched_content.replace( | |
"from torchvision.ops import nms", | |
"# from torchvision.ops import nms # Commented out to fix NMS error" | |
) | |
# Pattern 2: Import torchvision | |
patched_content = patched_content.replace( | |
"import torchvision", | |
"# import torchvision # Commented out to fix NMS error" | |
) | |
# Pattern 3: Import from torchvision.ops | |
patched_content = patched_content.replace( | |
"from torchvision.ops", | |
"# from torchvision.ops # Commented out to fix NMS error" | |
) | |
# Add our custom NMS implementation | |
custom_nms = """ | |
# Custom NMS implementation to avoid torchvision dependency | |
def nms(boxes, scores, iou_threshold): | |
""" | |
Performs non-maximum suppression (NMS) on the boxes according to their | |
intersection-over-union (IoU). | |
Args: | |
boxes (Tensor[N, 4]): boxes to perform NMS on | |
scores (Tensor[N]): scores for each one of the boxes | |
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold | |
Returns: | |
Tensor: int64 tensor with the indices of the elements that have been kept | |
""" | |
import torch | |
# Sort boxes by scores | |
_, order = scores.sort(0, descending=True) | |
keep = [] | |
while order.numel() > 0: | |
if order.numel() == 1: | |
keep.append(order.item()) | |
break | |
i = order[0].item() | |
keep.append(i) | |
# Compute IoU of the remaining boxes with the largest box | |
xx1 = torch.max(boxes[i, 0], boxes[order[1:], 0]) | |
yy1 = torch.max(boxes[i, 1], boxes[order[1:], 1]) | |
xx2 = torch.min(boxes[i, 2], boxes[order[1:], 2]) | |
yy2 = torch.min(boxes[i, 3], boxes[order[1:], 3]) | |
w = torch.clamp(xx2 - xx1, min=0.0) | |
h = torch.clamp(yy2 - yy1, min=0.0) | |
inter = w * h | |
# IoU = intersection / (area1 + area2 - intersection) | |
box_area = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1]) | |
other_area = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1]) | |
iou = inter / (box_area + other_area - inter) | |
# Keep boxes with IoU less than threshold | |
inds = torch.where(iou <= iou_threshold)[0] | |
order = order[inds + 1] | |
return torch.tensor(keep, dtype=torch.int64) | |
""" | |
# Add our custom implementation somewhere near the imports | |
import_end = patched_content.find("\n\n", patched_content.find("import ")) | |
if import_end == -1: | |
import_end = patched_content.find("\n", patched_content.find("import ")) | |
patched_content = patched_content[:import_end] + custom_nms + patched_content[import_end:] | |
# Write the patched file | |
with open(trainer_path, "w") as f: | |
f.write(patched_content) | |
print(f"Successfully patched {trainer_path}") | |
print("The SFTTrainer should now work without requiring torchvision's NMS operator") | |
return True | |
if __name__ == "__main__": | |
success = patch_sft_trainer() | |
if success: | |
print("\nPatch applied successfully. You can now run the fine-tuning script.") | |
else: | |
print("\nFailed to apply the patch. Please check the error messages above.") |