Spaces:
Sleeping
Sleeping
๐๏ธ [Update] the path for saving models
Browse files- yolo/config/general.yaml +1 -1
- yolo/model/yolo.py +1 -1
- yolo/tools/dataset_preparation.py +8 -4
yolo/config/general.yaml
CHANGED
|
@@ -11,4 +11,4 @@ lucky_number: 10
|
|
| 11 |
use_wandb: False
|
| 12 |
use_TensorBoard: False
|
| 13 |
|
| 14 |
-
weight: v9-c.pt
|
|
|
|
| 11 |
use_wandb: False
|
| 12 |
use_TensorBoard: False
|
| 13 |
|
| 14 |
+
weight: weights/v9-c.pt
|
yolo/model/yolo.py
CHANGED
|
@@ -134,7 +134,7 @@ def create_model(cfg: Config) -> YOLO:
|
|
| 134 |
logger.info("โ
Success load model weight")
|
| 135 |
else:
|
| 136 |
logger.info(f"๐ Weight {cfg.weight} not found, try downloading")
|
| 137 |
-
prepare_weight(
|
| 138 |
|
| 139 |
log_model_structure(model.model)
|
| 140 |
draw_model(model=model)
|
|
|
|
| 134 |
logger.info("โ
Success load model weight")
|
| 135 |
else:
|
| 136 |
logger.info(f"๐ Weight {cfg.weight} not found, try downloading")
|
| 137 |
+
prepare_weight(weight_path=cfg.weight)
|
| 138 |
|
| 139 |
log_model_structure(model.model)
|
| 140 |
draw_model(model=model)
|
yolo/tools/dataset_preparation.py
CHANGED
|
@@ -81,15 +81,19 @@ def prepare_dataset(cfg: DatasetConfig):
|
|
| 81 |
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
| 82 |
|
| 83 |
|
| 84 |
-
def prepare_weight(downlaod_link: Optional[str] = None,
|
|
|
|
| 85 |
if downlaod_link is None:
|
| 86 |
downlaod_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
|
| 87 |
weight_link = f"{downlaod_link}{weight_name}"
|
| 88 |
|
| 89 |
-
if os.path.
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
try:
|
| 92 |
-
download_file(weight_link,
|
| 93 |
except requests.exceptions.RequestException as e:
|
| 94 |
logger.warning(f"Failed to download the weight file: {e}")
|
| 95 |
|
|
|
|
| 81 |
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
| 82 |
|
| 83 |
|
| 84 |
+
def prepare_weight(downlaod_link: Optional[str] = None, weight_path: str = "v9-c.pt"):
|
| 85 |
+
weight_name = os.path.basename(weight_path)
|
| 86 |
if downlaod_link is None:
|
| 87 |
downlaod_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
|
| 88 |
weight_link = f"{downlaod_link}{weight_name}"
|
| 89 |
|
| 90 |
+
if not os.path.isdir(os.path.dirname(weight_path)):
|
| 91 |
+
os.makedirs(os.path.dirname(weight_path))
|
| 92 |
+
|
| 93 |
+
if os.path.exists(weight_path):
|
| 94 |
+
logger.info(f"Weight file '{weight_path}' already exists.")
|
| 95 |
try:
|
| 96 |
+
download_file(weight_link, weight_path)
|
| 97 |
except requests.exceptions.RequestException as e:
|
| 98 |
logger.warning(f"Failed to download the weight file: {e}")
|
| 99 |
|