import torch import os import logging import urllib.request from diffusers import FluxFillPipeline, FluxTransformer2DModel from segment_anything import sam_model_registry, SamPredictor from segment_anything import build_sam import spaces from huggingface_hub import hf_hub_download logger = logging.getLogger(__name__) sam_ckpt = hf_hub_download( repo_id="SnapwearAI/sam_model", filename="sam_vit_h_4b8939.pth", ) @spaces.GPU def get_sam_predictor(): sam = build_sam(checkpoint=sam_ckpt) sam.to(device=device) predictor = SamPredictor(sam) return predictor @spaces.GPU def get_flux_pipeline(): transformer = FluxTransformer2DModel.from_pretrained( "SnapwearAI/bg-transformer", subfolder="transformer", # <-- tell HF to look inside transformer/ torch_dtype=torch.bfloat16 ) pipe_flux = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16 ).to("cuda") return pipe_flux OUTPUT_DIR = "outputs" device = 'cuda' if torch.cuda.is_available() else 'cpu' sam_predictor = None def download_models(): """Download models from official sources.""" # Create models directory os.makedirs("models", exist_ok=True) # ═══════════════════════════════════════════════════════════════ # SAM MODEL - Download from Facebook's official release # ═══════════════════════════════════════════════════════════════ sam_path = "models/sam_vit_h_4b8939.pth" if not os.path.exists(sam_path): logger.info("📥 Downloading SAM model from Facebook (2.6GB)...") try: urllib.request.urlretrieve( "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", sam_path ) logger.info(f"✅ SAM model downloaded to: {sam_path}") except Exception as e: logger.error(f"❌ Failed to download SAM model: {e}") raise else: logger.info(f"✅ SAM model already exists: {sam_path}") # ═══════════════════════════════════════════════════════════════ # GROUNDING DINO MODEL - Download from GitHub releases # ═══════════════════════════════════════════════════════════════ grounding_path = "models/groundingdino_swint_ogc.pth" if not os.path.exists(grounding_path): logger.info("📥 Downloading GroundingDINO model from GitHub (694MB)...") try: urllib.request.urlretrieve( "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", grounding_path ) logger.info(f"✅ GroundingDINO model downloaded to: {grounding_path}") except Exception as e: logger.error(f"❌ Failed to download GroundingDINO model: {e}") raise else: logger.info(f"✅ GroundingDINO model already exists: {grounding_path}") # ═══════════════════════════════════════════════════════════════ # CONFIG FILE - Create if doesn't exist # ═══════════════════════════════════════════════════════════════ config_path = "models/GroundingDINO_SwinT_OGC.py" if not os.path.exists(config_path): logger.info("📝 Creating GroundingDINO config file...") # Minimal config that works with GroundingDINO config_content = '''import os.path as osp import sys # Add current directory to path for imports sys.path.insert(0, osp.dirname(__file__)) # Model configuration model = dict( type='GroundingDINO', num_queries=900, with_box_refine=True, as_two_stage=True, data_preprocessor=dict( type='DetDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], bgr_to_rgb=True, pad_mask=False, ), language_model=dict( type='BertModel', name='bert-base-uncased', max_tokens=256, pad_to_max=False, use_sub_sentence_represent=True, special_tokens_list=["[CLS]", "[SEP]", ".", "?"], add_pooling_layer=False, ), backbone=dict( type='SwinTransformer', pretrain_img_size=384, embed_dims=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=12, mlp_ratio=4, qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, patch_norm=True, out_indices=(1, 2, 3), with_cp=True, convert_weights=True, ), neck=dict( type='ChannelMapper', in_channels=[192, 384, 768], kernel_size=1, out_channels=256, act_cfg=None, norm_cfg=dict(type='GN', num_groups=32), num_outs=4), encoder=dict( type='DetrTransformerEncoder', num_layers=6, transformerlayers=dict( type='BaseTransformerLayer', attn_cfgs=dict( type='MultiScaleDeformableAttention', embed_dims=256, num_heads=8, num_levels=4, num_points=4, im2col_step=64, dropout=0.0, batch_first=False, norm_cfg=None, init_cfg=None), feedforward_channels=2048, ffn_dropout=0.0, operation_order=('self_attn', 'norm', 'ffn', 'norm'))), decoder=dict( type='GroundingDINOTransformerDecoder', num_layers=6, return_intermediate=True, transformerlayers=dict( type='GroundingDINOTransformerDecoderLayer', attn_cfgs=[ dict( type='MultiheadAttention', embed_dims=256, num_heads=8, dropout=0.0, batch_first=False), dict( type='MultiScaleDeformableAttention', embed_dims=256, num_heads=8, num_levels=4, num_points=4, im2col_step=64, dropout=0.0, batch_first=False, norm_cfg=None, init_cfg=None) ], feedforward_channels=2048, ffn_dropout=0.0, operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'))), positional_encoding=dict( type='SinePositionalEncoding', num_feats=128, normalize=True, offset=-0.5), bbox_head=dict( type='GroundingDINOHead', num_queries=900, num_classes=256, in_channels=2048, sync_cls_avg_factor=True, as_two_stage=True, with_box_refine=True, dn_cfg=dict( type='CdnQueryGenerator', noise_scale=dict(label=0.5, box=1.0), group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)), transformer=dict( type='GroundingDINOTransformer', embed_dims=256, num_feature_levels=4, encoder=dict( type='DetrTransformerEncoder', num_layers=6, transformerlayers=dict( type='BaseTransformerLayer', attn_cfgs=dict( type='MultiScaleDeformableAttention', embed_dims=256, num_heads=8, num_levels=4, num_points=4, im2col_step=64, dropout=0.0, batch_first=False, norm_cfg=None, init_cfg=None), feedforward_channels=2048, ffn_dropout=0.0, operation_order=('self_attn', 'norm', 'ffn', 'norm'))), decoder=dict( type='GroundingDINOTransformerDecoder', num_layers=6, return_intermediate=True, transformerlayers=dict( type='GroundingDINOTransformerDecoderLayer', attn_cfgs=[ dict( type='MultiheadAttention', embed_dims=256, num_heads=8, dropout=0.0, batch_first=False), dict( type='MultiScaleDeformableAttention', embed_dims=256, num_heads=8, num_levels=4, num_points=4, im2col_step=64, dropout=0.0, batch_first=False, norm_cfg=None, init_cfg=None) ], feedforward_channels=2048, ffn_dropout=0.0, operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'))), positional_encoding=dict( type='SinePositionalEncoding', num_feats=128, normalize=True, offset=-0.5)), loss_cls=dict( type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), loss_bbox=dict(type='L1Loss', loss_weight=5.0), loss_iou=dict(type='GIoULoss', loss_weight=2.0)), dn_cfg=dict( label_noise_scale=0.5, box_noise_scale=1.0, group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)), # training and testing settings train_cfg=dict( assigner=dict( type='HungarianAssigner', cls_cost=dict(type='FocalLossCost', weight=2.0), reg_cost=dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))), test_cfg=dict(max_per_img=300)) # Dataset settings dataset_type = 'CocoDataset' data_root = 'data/coco/' ''' with open(config_path, 'w') as f: f.write(config_content) logger.info(f"✅ Config file created at: {config_path}") else: logger.info(f"✅ Config already exists: {config_path}") return sam_path, grounding_path, config_path def initialize_pipeline(): """Initialize GroundingDINO and SAM models.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"🚀 Using device: {device}") try: # Download models sam_path, grounding_path, config_path = download_models() # ═══════════════════════════════════════════════════════════════ # Initialize GroundingDINO # ═══════════════════════════════════════════════════════════════ logger.info("🔧 Loading GroundingDINO model...") # Import here to avoid issues if not installed from groundingdino.util.inference import Model grounding_dino_model = Model( model_config_path=config_path, model_checkpoint_path=grounding_path, device=device ) logger.info("✅ GroundingDINO loaded successfully!") # ═══════════════════════════════════════════════════════════════ # Initialize SAM # ═══════════════════════════════════════════════════════════════ logger.info("🔧 Loading SAM model...") sam = sam_model_registry["vit_h"](checkpoint=sam_path) sam.to(device=device) sam_predictor = SamPredictor(sam) logger.info("✅ SAM loaded successfully!") logger.info("🎉 All models initialized successfully!") return { "grounding_dino": grounding_dino_model, "sam_predictor": sam_predictor, "device": device } except Exception as e: logger.error(f"❌ Failed to initialize models: {e}") import traceback logger.error(f"Full traceback: {traceback.format_exc()}") raise