Spaces:
Build error
Build error
# Ultralytics YOLO π, GPL-3.0 license | |
import os | |
import shutil | |
import psutil | |
import requests | |
from IPython import display # to display images and clear console output | |
from ultralytics.hub.auth import Auth | |
from ultralytics.hub.session import HubTrainingSession | |
from ultralytics.hub.utils import PREFIX, split_key | |
from ultralytics.yolo.utils import LOGGER, emojis, is_colab | |
from ultralytics.yolo.utils.torch_utils import select_device | |
from ultralytics.yolo.v8.detect import DetectionTrainer | |
def checks(verbose=True): | |
if is_colab(): | |
shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory | |
if verbose: | |
# System info | |
gib = 1 << 30 # bytes per GiB | |
ram = psutil.virtual_memory().total | |
total, used, free = shutil.disk_usage("/") | |
display.clear_output() | |
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)' | |
else: | |
s = '' | |
select_device(newline=False) | |
LOGGER.info(f'Setup complete β {s}') | |
def start(key=''): | |
# Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY') | |
def request_api_key(attempts=0): | |
"""Prompt the user to input their API key""" | |
import getpass | |
max_attempts = 3 | |
tries = f"Attempt {str(attempts + 1)} of {max_attempts}" if attempts > 0 else "" | |
LOGGER.info(f"{PREFIX}Login. {tries}") | |
input_key = getpass.getpass("Enter your Ultralytics HUB API key:\n") | |
auth.api_key, model_id = split_key(input_key) | |
if not auth.authenticate(): | |
attempts += 1 | |
LOGGER.warning(f"{PREFIX}Invalid API key β οΈ\n") | |
if attempts < max_attempts: | |
return request_api_key(attempts) | |
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate β")) | |
else: | |
return model_id | |
try: | |
api_key, model_id = split_key(key) | |
auth = Auth(api_key) # attempts cookie login if no api key is present | |
attempts = 1 if len(key) else 0 | |
if not auth.get_state(): | |
if len(key): | |
LOGGER.warning(f"{PREFIX}Invalid API key β οΈ\n") | |
model_id = request_api_key(attempts) | |
LOGGER.info(f"{PREFIX}Authenticated β ") | |
if not model_id: | |
raise ConnectionError(emojis('Connecting with global API key is not currently supported. β')) | |
session = HubTrainingSession(model_id=model_id, auth=auth) | |
session.check_disk_space() | |
# TODO: refactor, hardcoded for v8 | |
args = session.model.copy() | |
args.pop("id") | |
args.pop("status") | |
args.pop("weights") | |
args["data"] = "coco128.yaml" | |
args["model"] = "yolov8n.yaml" | |
args["batch_size"] = 16 | |
args["imgsz"] = 64 | |
trainer = DetectionTrainer(overrides=args) | |
session.register_callbacks(trainer) | |
setattr(trainer, 'hub_session', session) | |
trainer.train() | |
except Exception as e: | |
LOGGER.warning(f"{PREFIX}{e}") | |
def reset_model(key=''): | |
# Reset a trained model to an untrained state | |
api_key, model_id = split_key(key) | |
r = requests.post('https://api.ultralytics.com/model-reset', json={"apiKey": api_key, "modelId": model_id}) | |
if r.status_code == 200: | |
LOGGER.info(f"{PREFIX}model reset successfully") | |
return | |
LOGGER.warning(f"{PREFIX}model reset failure {r.status_code} {r.reason}") | |
def export_model(key='', format='torchscript'): | |
# Export a model to all formats | |
api_key, model_id = split_key(key) | |
formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs', | |
'ultralytics_tflite', 'ultralytics_coreml') | |
assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}" | |
r = requests.post('https://api.ultralytics.com/export', | |
json={ | |
"apiKey": api_key, | |
"modelId": model_id, | |
"format": format}) | |
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}" | |
LOGGER.info(f"{PREFIX}{format} export started β ") | |
def get_export(key='', format='torchscript'): | |
# Get an exported model dictionary with download URL | |
api_key, model_id = split_key(key) | |
formats = ('torchscript', 'onnx', 'openvino', 'engine', 'coreml', 'saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs', | |
'ultralytics_tflite', 'ultralytics_coreml') | |
assert format in formats, f"ERROR: Unsupported export format '{format}' passed, valid formats are {formats}" | |
r = requests.post('https://api.ultralytics.com/get-export', | |
json={ | |
"apiKey": api_key, | |
"modelId": model_id, | |
"format": format}) | |
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}" | |
return r.json() | |
# temp. For checking | |
if __name__ == "__main__": | |
start(key="b3fba421be84a20dbe68644e14436d1cce1b0a0aaa_HeMfHgvHsseMPhdq7Ylz") | |