Spaces:
Running
Running
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license | |
""" | |
YOLO-NAS model interface. | |
Example: | |
```python | |
from ultralytics import NAS | |
model = NAS("yolo_nas_s") | |
results = model.predict("ultralytics/assets/bus.jpg") | |
``` | |
""" | |
from pathlib import Path | |
import torch | |
from ultralytics.engine.model import Model | |
from ultralytics.utils.downloads import attempt_download_asset | |
from ultralytics.utils.torch_utils import model_info | |
from .predict import NASPredictor | |
from .val import NASValidator | |
class NAS(Model): | |
""" | |
YOLO NAS model for object detection. | |
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. | |
It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models. | |
Example: | |
```python | |
from ultralytics import NAS | |
model = NAS("yolo_nas_s") | |
results = model.predict("ultralytics/assets/bus.jpg") | |
``` | |
Attributes: | |
model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'. | |
Note: | |
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files. | |
""" | |
def __init__(self, model="yolo_nas_s.pt") -> None: | |
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model.""" | |
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models." | |
super().__init__(model, task="detect") | |
def _load(self, weights: str, task=None) -> None: | |
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided.""" | |
import super_gradients | |
suffix = Path(weights).suffix | |
if suffix == ".pt": | |
self.model = torch.load(attempt_download_asset(weights)) | |
elif suffix == "": | |
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") | |
# Override the forward method to ignore additional arguments | |
def new_forward(x, *args, **kwargs): | |
"""Ignore additional __call__ arguments.""" | |
return self.model._original_forward(x) | |
self.model._original_forward = self.model.forward | |
self.model.forward = new_forward | |
# Standardize model | |
self.model.fuse = lambda verbose=True: self.model | |
self.model.stride = torch.tensor([32]) | |
self.model.names = dict(enumerate(self.model._class_names)) | |
self.model.is_fused = lambda: False # for info() | |
self.model.yaml = {} # for info() | |
self.model.pt_path = weights # for export() | |
self.model.task = "detect" # for export() | |
def info(self, detailed=False, verbose=True): | |
""" | |
Logs model info. | |
Args: | |
detailed (bool): Show detailed information about model. | |
verbose (bool): Controls verbosity. | |
""" | |
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) | |
def task_map(self): | |
"""Returns a dictionary mapping tasks to respective predictor and validator classes.""" | |
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}} | |