| | import torch |
| | from transformers import AutoProcessor, AutoModelForCausalLM |
| | from ultralytics import YOLO |
| | import gdown |
| | import os |
| | from safetensors.torch import load_file |
| |
|
| | |
| | gdown_cache_dir = os.path.join(os.getcwd(), "cache") |
| | os.makedirs(gdown_cache_dir, exist_ok=True) |
| | os.environ["GDOWN_CACHE"] = gdown_cache_dir |
| |
|
| | def download_model_from_drive(file_id, destination_path): |
| | """Download the model from Google Drive using gdown.""" |
| | |
| | url = f"https://drive.google.com/uc?id={file_id}" |
| | |
| | directory = os.path.dirname(destination_path) |
| | if directory: |
| | os.makedirs(directory, exist_ok=True) |
| | |
| | gdown.download(url, destination_path, quiet=False) |
| |
|
| | def load_models(device='cpu'): |
| | """Load YOLO model and the caption generation model.""" |
| | |
| | model_file_path = "model.safetensors" |
| | |
| | |
| | if not os.path.exists(model_file_path): |
| | file_id = "1hUCqZ3X8mcM-KcwWFjcsFg7PA0hUvE3k" |
| | print(f"Downloading model to {model_file_path}...") |
| | download_model_from_drive(file_id, model_file_path) |
| |
|
| | |
| | print("Loading YOLO model...") |
| | yolo_model = YOLO("best.pt").to(device) |
| |
|
| | |
| | print("Loading processor for the caption model...") |
| | processor = AutoProcessor.from_pretrained( |
| | "microsoft/Florence-2-base", |
| | trust_remote_code=True |
| | ) |
| |
|
| | |
| | print("Loading caption generation model...") |
| | model_state_dict = load_file(model_file_path) |
| | caption_model = AutoModelForCausalLM.from_pretrained( |
| | "microsoft/Florence-2-base", |
| | trust_remote_code=True |
| | ) |
| | caption_model.load_state_dict(model_state_dict) |
| | caption_model.to(device) |
| |
|
| | print("Models loaded successfully!") |
| | return { |
| | 'yolo_model': yolo_model, |
| | 'processor': processor, |
| | 'caption_model': caption_model |
| | } |
| |
|
| | |
| | if __name__ == "__main__": |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | models = load_models(device=device) |
| | print("All models are ready to use!") |
| |
|