Spaces:
Runtime error
Runtime error
File size: 3,414 Bytes
9afcee2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from functools import partial
import torch
from .transformer import INTR
from .sam_transformer import SamTransformer
from .sam import ImageEncoderViT, MaskDecoder, PromptEncoder, TwoWayTransformer
def build_demo_model():
# model = INTR(
# backbone_name='resnet50',
# image_size=[768, 1024],
# num_queries=15,
# freeze_backbone=False,
# transformer_hidden_dim=256,
# transformer_dropout=0,
# transformer_nhead=8,
# transformer_dim_feedforward=2048,
# transformer_num_encoder_layers=6,
# transformer_num_decoder_layers=6,
# transformer_normalize_before=False,
# transformer_return_intermediate_dec=True,
# layers_movable=1,
# layers_rigid=1,
# layers_kinematic=1,
# layers_action=1,
# layers_axis=3,
# layers_affordance=3,
# depth_on=True,
# )
# sam_vit_b
encoder_embed_dim=768
encoder_depth=12
encoder_num_heads=12
encoder_global_attn_indexes=[2, 5, 8, 11]
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
model = SamTransformer(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
properties_on=True,
),
affordance_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
properties_on=False,
),
depth_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
properties_on=False,
),
transformer_hidden_dim=prompt_embed_dim,
backbone_name='vit_b',
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
return model
|