diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c957848293c654a36ed7309c83d1f5d1a02b9997 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f4b62dbc2efc502109f0c41c9c4da30eff26b6ec --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__ +*.ckpt +assets/ckpts +__pycache__/ +*.sh \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..d441a3264c05e72a7a7a04be17881485d121224a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,57 @@ +FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y --no-install-recommends \ + git \ + git-lfs \ + wget \ + curl \ + # ffmpeg \ + ffmpeg \ + x264 \ + # python build dependencies \ + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libreadline-dev \ + libsqlite3-dev \ + libncursesw5-dev \ + xz-utils \ + tk-dev \ + libxml2-dev \ + libxmlsec1-dev \ + libffi-dev \ + liblzma-dev && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN useradd -m -u 1000 user +USER user +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:${PATH} +WORKDIR ${HOME}/app + +RUN curl https://pyenv.run | bash +ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH} +ENV PYTHON_VERSION=3.8.16 +RUN pyenv install ${PYTHON_VERSION} && \ + pyenv global ${PYTHON_VERSION} && \ + pyenv rehash && \ + pip install --no-cache-dir -U pip setuptools wheel + +RUN pip install --no-cache-dir -U torch==1.12.1 torchvision==0.13.1 +COPY --chown=1000 requirements.txt /tmp/requirements.txt +RUN pip install --no-cache-dir -U -r /tmp/requirements.txt + +COPY --chown=1000 . ${HOME}/app +# RUN cd Tune-A-Video && patch -p1 < ../patch +ENV PYTHONPATH=${HOME}/app \ + PYTHONUNBUFFERED=1 \ + GRADIO_ALLOW_FLAGGING=never \ + GRADIO_NUM_PORTS=1 \ + GRADIO_SERVER_NAME=0.0.0.0 \ + GRADIO_THEME=huggingface \ + SYSTEM=spaces +CMD ["python", "app.py"] \ No newline at end of file diff --git a/README copy.md b/README copy.md new file mode 100644 index 0000000000000000000000000000000000000000..f2dade29e6dfb1211d36d8172f671187ec3f3a6c --- /dev/null +++ b/README copy.md @@ -0,0 +1,13 @@ +--- +title: StyleDrop Pytorch +emoji: 📊 +colorFrom: purple +colorTo: pink +sdk: gradio +sdk_version: 3.35.2 +app_file: app.py +pinned: false +license: mit +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..470b4ae708e58f29bd5fd5fbcbfd8dcd67ee6c69 --- /dev/null +++ b/app.py @@ -0,0 +1,264 @@ +import os +import gradio as gr +import open_clip +import torch +import taming.models.vqgan +import ml_collections +import einops +import random +import pathlib +import subprocess +import shlex +import wget +# Model +from libs.muse import MUSE +import utils +import numpy as np +from PIL import Image +print("cuda available:",torch.cuda.is_available()) +print("cuda device count:",torch.cuda.device_count()) +print("cuda device name:",torch.cuda.get_device_name(0)) +print(os.system("nvidia-smi")) +print(os.system("nvcc --version")) + +empty_context = np.load("assets/contexts/empty_context.npy") + +print("downloading cc3m-285000.ckpt") +os.makedirs("assets/ckpts/cc3m-285000.ckpt",exist_ok=True) +os.system("ls") +wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth","assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth") +wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/optimizer.pth","assets/ckpts/cc3m-285000.ckpt/optimizer.pth") +wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet.pth","assets/ckpts/cc3m-285000.ckpt/nnet.pth") +wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth","assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth") +wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/step.pth","assets/ckpts/cc3m-285000.ckpt/step.pth") +wget.download("https://huggingface.co/zideliu/vqgan/resolve/main/vqgan_jax_strongaug.ckpt","assets/vqgan_jax_strongaug.ckpt") + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def d(**kwargs): + """Helper of creating a config dict.""" + return ml_collections.ConfigDict(initial_dictionary=kwargs) + +def get_config(): + config = ml_collections.ConfigDict() + config.seed = 1234 + config.z_shape = (8, 16, 16) + + config.autoencoder = d( + config_file='vq-f16-jax.yaml', + ) + config.resume_root="assets/ckpts/cc3m-285000.ckpt" + config.adapter_path=None + config.optimizer = d( + name='adamw', + lr=0.0002, + weight_decay=0.03, + betas=(0.99, 0.99), + ) + config.lr_scheduler = d( + name='customized', + warmup_steps=5000 + ) + config.nnet = d( + name='uvit_t2i_vq', + img_size=16, + codebook_size=1024, + in_chans=4, + embed_dim=1152, + depth=28, + num_heads=16, + mlp_ratio=4, + qkv_bias=False, + clip_dim=1280, + num_clip_token=77, + use_checkpoint=True, + skip=True, + d_prj=32, + is_shared=False + ) + config.muse = d( + ignore_ind=-1, + smoothing=0.1, + gen_temp=4.5 + ) + config.sample = d( + sample_steps=36, + n_samples=50, + mini_batch_size=8, + cfg=True, + linear_inc_scale=True, + scale=10., + path='', + lambdaA=2.0, # Stage I: 2.0; Stage II: TODO + lambdaB=5.0, # Stage I: 5.0; Stage II: TODO + ) + return config + +def cfg_nnet(x, context, scale=None,lambdaA=None,lambdaB=None): + _cond = nnet_ema(x, context=context) + _cond_w_adapter = nnet_ema(x,context=context,use_adapter=True) + _empty_context = torch.tensor(empty_context, device=device) + _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0)) + _uncond = nnet_ema(x, context=_empty_context) + res = _cond + scale * (_cond - _uncond) + if lambdaA is not None: + res = _cond_w_adapter + lambdaA*(_cond_w_adapter - _cond) + lambdaB*(_cond - _uncond) + return res + +def unprocess(x): + x.clamp_(0., 1.) + return x + +config = get_config() +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + +# Load open_clip and vq model +prompt_model,_,_ = open_clip.create_model_and_transforms('ViT-bigG-14', 'laion2b_s39b_b160k') +prompt_model = prompt_model.to(device) +prompt_model.eval() +tokenizer = open_clip.get_tokenizer('ViT-bigG-14') + +vq_model = taming.models.vqgan.get_model('vq-f16-jax.yaml') +vq_model.eval() +vq_model.requires_grad_(False) +vq_model.to(device) + +## config + +muse = MUSE(codebook_size=vq_model.n_embed, device=device, **config.muse) + +train_state = utils.initialize_train_state(config, device) +train_state.resume(ckpt_root=config.resume_root) +nnet_ema = train_state.nnet_ema +nnet_ema.eval() +nnet_ema.requires_grad_(False) +nnet_ema.to(device) +style_ref = { + "None":None, + "0102":"style_adapter/0102.pth", + "0103":"style_adapter/0103.pth", + "0106":"style_adapter/0106.pth", + "0108":"style_adapter/0108.pth", + "0301":"style_adapter/0301.pth", + "0305":"style_adapter/0305.pth", + } +style_postfix ={ + "None":"", + "0102":" in watercolor painting style", + "0103":" in watercolor painting style", + "0106":" in line drawing style", + "0108":" in oil painting style", + "0301":" in 3d rendering style", + "0305":" in kid crayon drawing style", +} + +def decode(_batch): + return vq_model.decode_code(_batch) + +def process(prompt,num_samples,lambdaA,lambdaB,style,seed,sample_steps,image=None): + config.sample.lambdaA = lambdaA + config.sample.lambdaB = lambdaB + config.sample.sample_steps = sample_steps + print(style) + adapter_path = style_ref[style] + adapter_postfix = style_postfix[style] + print(f"load adapter path: {adapter_path}") + if adapter_path is not None: + nnet_ema.adapter.load_state_dict(torch.load(adapter_path)) + else: + config.sample.lambdaA=None + config.sample.lambdaB=None + print("load adapter Done!") + # Encode prompt + prompt = prompt+adapter_postfix + text_tokens = tokenizer(prompt).to(device) + text_embedding = prompt_model.encode_text(text_tokens) + text_embedding = text_embedding.repeat(num_samples, 1, 1) # B 77 1280 + print(text_embedding.shape) + + print(f"lambdaA: {lambdaA}, lambdaB: {lambdaB}, sample_steps: {sample_steps}") + if seed==-1: + seed = random.randint(0,65535) + config.seed = seed + print(f"seed: {seed}") + set_seed(config.seed) + res = muse.generate(config,num_samples,cfg_nnet,decode,is_eval=True,context=text_embedding) + print(res.shape) + res = (res*255+0.5).clamp_(0,255).permute(0,2,3,1).to('cpu',torch.uint8).numpy() + im = [res[i] for i in range(num_samples)] + return im + +block = gr.Blocks() +with block: + with gr.Row(): + gr.Markdown("## StyleDrop based on Muse (Inference Only) ") + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt") + run_button = gr.Button(label="Run") + num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=1234) + style = gr.Radio(choices=["0102","0103","0106","0108","0305","None"],type="value",value="None",label="Style") + + with gr.Accordion("Advanced options",open=False): + lambdaA = gr.Slider(label="lambdaA", minimum=0.0, maximum=5.0, value=2.0, step=0.01) + lambdaB = gr.Slider(label="lambdaB", minimum=0.0, maximum=10.0, value=5.0, step=0.01) + sample_steps = gr.Slider(label="Sample steps", minimum=1, maximum=50, value=36, step=1) + image=gr.Image(value=None) + with gr.Column(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(columns=2, height='auto') + + with gr.Row(): + examples = [ + [ + "A banana on the table", + 1,2.0,5.0,"0103",1234,36, + "data/image_01_03.jpg", + ], + [ + + "A cow", + 1,2.0,5.0,"0102",1234,36, + "data/image_01_02.jpg", + ], + [ + + "A portrait of tabby cat", + 1,2.0,5.0,"0106",1234,36, + "data/image_01_06.jpg", + ], + [ + + "A church in the field", + 1,2.0,5.0,"0108",1234,36, + "data/image_01_08.jpg", + ], + [ + + "A Christmas tree", + 1,2.0,5.0,"0305",1234,36, + "data/image_03_05.jpg", + ] + + ] + gr.Examples(examples=examples, + fn=process, + inputs=[ + prompt, + num_samples,lambdaA,lambdaB,style,seed,sample_steps,image, + ], + outputs=result_gallery, + cache_examples=os.getenv('SYSTEM') == 'spaces' + ) + ips = [prompt,num_samples,lambdaA,lambdaB,style,seed,sample_steps,image] + run_button.click( + fn=process, + inputs=ips, + outputs=[result_gallery] + ) +block.queue().launch(share=False) + diff --git a/assets/contexts/empty_context.npy b/assets/contexts/empty_context.npy new file mode 100755 index 0000000000000000000000000000000000000000..b92ad3aeea065997afb64cf4883a441e31f31743 --- /dev/null +++ b/assets/contexts/empty_context.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf06c46310efa57d47e34e5221ffa757dc6c60e91c8758fcb1d19040ee61e9fc +size 394368 diff --git a/assets/fid_stats/fid_stats_cc3m_val.npz b/assets/fid_stats/fid_stats_cc3m_val.npz new file mode 100755 index 0000000000000000000000000000000000000000..4b05efa8e5ab48ded3bbba7119c6f9db87407607 --- /dev/null +++ b/assets/fid_stats/fid_stats_cc3m_val.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84605eaad681c8fdb13c5f96f9bcc7a7d8648e4e03023f2498aec7deb3ea3179 +size 33571316 diff --git a/assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz b/assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz new file mode 100755 index 0000000000000000000000000000000000000000..39ed305fa2b2b54438a9b25f5ce24c0124cc511f --- /dev/null +++ b/assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:374aa549982adbfd595eaecc8a014eea6566156f8b227fc2d9052c0482bb4a2f +size 33571316 diff --git a/assets/pipeline.png b/assets/pipeline.png new file mode 100644 index 0000000000000000000000000000000000000000..89913e26455d748af53f9172b07e66a8607af251 Binary files /dev/null and b/assets/pipeline.png differ diff --git a/configs/cc3m_xl_vqf16_jax_2048bs_featset_CLIP_G.py b/configs/cc3m_xl_vqf16_jax_2048bs_featset_CLIP_G.py new file mode 100644 index 0000000000000000000000000000000000000000..24f1b97a0942cea3ff062b12fba6510a194154fe --- /dev/null +++ b/configs/cc3m_xl_vqf16_jax_2048bs_featset_CLIP_G.py @@ -0,0 +1,92 @@ +import ml_collections + + +def d(**kwargs): + """Helper of creating a config dict.""" + return ml_collections.ConfigDict(initial_dictionary=kwargs) + + +def get_config(): + config = ml_collections.ConfigDict() + + config.seed = 1234 + config.z_shape = (8, 16, 16) + + config.autoencoder = d( + config_file='vq-f16-jax.yaml', + ) + + config.train = d( + n_steps=999999999, + batch_size=2048, + log_interval=10, + eval_interval=5000, + save_interval=5000, + fid_interval=50000, + num_workers=8, + resampled=False, + ) + + config.eval = d( + n_samples=10000, + sample_steps=18, + ) + + config.optimizer = d( + name='adamw', + lr=0.0002, + weight_decay=0.03, + betas=(0.99, 0.99), + ) + + config.lr_scheduler = d( + name='customized', + warmup_steps=5000 + ) + + config.nnet = d( + name='uvit_t2i_vq', + img_size=16, + codebook_size=1024, + in_chans=4, + embed_dim=1152, + depth=28, + num_heads=16, + mlp_ratio=4, + qkv_bias=False, + clip_dim=1280, + num_clip_token=77, + use_checkpoint=True, + skip=True, + ) + + config.muse = d( + ignore_ind=-1, + smoothing=0.1, + gen_temp=4.5 + ) + + config.dataset = d( + name='cc3m_web', + cfg=True, + p_uncond=0.15, + ) + + config.wds = d( + train_data='assets/datasets/cc3m/vq_f16_jax_clipG_cc3m_train_emb/{00000..03044}.tar', + val_data='assets/datasets/cc3m/vq_f16_jax_clipG_cc3m_val_emb/{00000..00012}.tar', + ctx_path='assets/contexts', + dist_eval=True, + ) + + config.sample = d( + sample_steps=18, + n_samples=30000, + mini_batch_size=2, + cfg=True, + linear_inc_scale=True, + scale=10., + path='', + ) + + return config diff --git a/configs/custom.py b/configs/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4a1f8f41f089718ec95b0e87c7cdafede84e49 --- /dev/null +++ b/configs/custom.py @@ -0,0 +1,83 @@ +import ml_collections + + +def d(**kwargs): + """Helper of creating a config dict.""" + return ml_collections.ConfigDict(initial_dictionary=kwargs) + + +def get_config(): + config = ml_collections.ConfigDict() + + + config.seed = 1234 + config.z_shape = (8, 16, 16) + + config.autoencoder = d( + config_file='vq-f16-jax.yaml', + ) + config.data_path="data/one_style.json" + config.resume_root="assets/ckpts/cc3m-285000.ckpt" + config.adapter_path=None + config.sample_interval=True + config.train = d( + n_steps=1000, + batch_size=8, + log_interval=20, + eval_interval=100, + save_interval=100, + fid_interval=20000, + num_workers=8, + resampled=False, + ) + + config.optimizer = d( + name='adamw', + lr=0.0003, + weight_decay=0.03, + betas=(0.99, 0.99), + ) + + config.lr_scheduler = d( + name='customized', + warmup_steps=-1, # 5000 + ) + + config.nnet = d( + name='uvit_t2i_vq', + img_size=16, + codebook_size=1024, + in_chans=4, + embed_dim=1152, + depth=28, + num_heads=16, + mlp_ratio=4, + qkv_bias=False, + clip_dim=1280, + num_clip_token=77, + use_checkpoint=False, + skip=True, + d_prj=32,# Stage I: 32; Stage II: TODO + is_shared=False, # Stage I: False; Stage II: False + ) + + config.muse = d( + ignore_ind=-1, + smoothing=0.1, + gen_temp=4.5 + ) + + + config.sample = d( + sample_steps=36, + n_samples=50, + mini_batch_size=8, + cfg=True, + linear_inc_scale=True, + scale=10., + path='', + lambdaA=2.0, # Stage I: 2.0; Stage II: TODO + lambdaB=5.0, # Stage I: 5.0; Stage II: TODO + ) + + return config diff --git a/configs/imagenet256_base_vq_jax.py b/configs/imagenet256_base_vq_jax.py new file mode 100644 index 0000000000000000000000000000000000000000..81c5def95cf6341b178fa9ee2ffb9682a9bef72b --- /dev/null +++ b/configs/imagenet256_base_vq_jax.py @@ -0,0 +1,84 @@ +import ml_collections + + +def d(**kwargs): + """Helper of creating a config dict.""" + return ml_collections.ConfigDict(initial_dictionary=kwargs) + + +def get_config(): + config = ml_collections.ConfigDict() + + config.seed = 1234 + config.z_shape = (8, 16, 16) + + config.autoencoder = d( + config_file='vq-f16-jax.yaml', + ) + + config.train = d( + n_steps=99999999, + batch_size=2048, + log_interval=10, + eval_interval=5000, + save_interval=5000, + fid_interval=50000, + ) + + config.eval = d( + n_samples=10000, + sample_steps=12, + ) + + config.optimizer = d( + name='adamw', + lr=0.0004, + weight_decay=0.03, + betas=(0.99, 0.99), + ) + + config.lr_scheduler = d( + name='customized', + warmup_steps=5000 + ) + + config.nnet = d( + name='uvit_vq', + img_size=16, + codebook_size=1024, + in_chans=256, + patch_size=1, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=False, + num_classes=1001, + use_checkpoint=False, + skip=True, + ) + + config.muse = d( + ignore_ind=-1, + smoothing=0.1, + gen_temp=4.5 + ) + + config.dataset = d( + name='imagenet256_features', + path='assets/datasets/imagenet256_vq_features/vq-f16-jax', + cfg=True, + p_uncond=0.15, + ) + + config.sample = d( + sample_steps=12, + n_samples=50000, + mini_batch_size=50, + cfg=True, + linear_inc_scale=True, + scale=3., + path='' + ) + + return config diff --git a/configs/vae_configs/vq-f16-jax.yaml b/configs/vae_configs/vq-f16-jax.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ceafd596794746d06263b83d7a00438a8bf74ed --- /dev/null +++ b/configs/vae_configs/vq-f16-jax.yaml @@ -0,0 +1,42 @@ +model: + base_learning_rate: 4.5e-6 + target: taming.models.vqgan.VQModel + params: + embed_dim: 256 + n_embed: 1024 + ddconfig: + double_z: False + z_channels: 256 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [16] + dropout: 0.0 + + lossconfig: + target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator + params: + disc_conditional: False + disc_in_channels: 3 + disc_start: 250001 + disc_weight: 0.8 + codebook_weight: 1.0 + +data: + target: main.DataModuleFromConfig + params: + batch_size: 8 + num_workers: 24 + train: + target: taming.data.imagenet.ImageNetTrain + params: + config: + size: 256 + validation: + target: taming.data.imagenet.ImageNetValidation + params: + config: + size: 256 diff --git a/custom/custom_dataset.py b/custom/custom_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cd02ac6f77b542dcd5267c9041f4d04bafbf7040 --- /dev/null +++ b/custom/custom_dataset.py @@ -0,0 +1,233 @@ + +from torch.utils.data import Dataset + +import os +import numpy as np +import taming.models.vqgan +import open_clip +import random +from PIL import Image +import torch +import math +import json +import torchvision.transforms as transforms +torch.manual_seed(0) +np.random.seed(0) + +class test_custom_dataset(Dataset): + + def __init__(self, style: str = None): + self.empty_context = np.load("assets/contexts/empty_context.npy") + self.object=[ + "A chihuahua ", + "A tabby cat ", + "A portrait of chihuahua ", + "An apple on the table ", + "A banana on the table ", + "A church on the street ", + "A church in the mountain ", + "A church in the field ", + "A church on the beach ", + "A chihuahua walking on the street ", + "A tabby cat walking on the street", + "A portrait of tabby cat ", + "An apple on the dish ", + "A banana on the dish ", + "A human walking on the street ", + "A temple on the street ", + "A temple in the mountain ", + "A temple in the field ", + "A temple on the beach ", + "A chihuahua walking in the forest ", + "A tabby cat walking in the forest ", + "A portrait of human face ", + "An apple on the ground ", + "A banana on the ground ", + "A human walking in the forest ", + "A cabin on the street ", + "A cabin in the mountain ", + "A cabin in the field ", + "A cabin on the beach ", + ] + self.style = [ + "in 3d rendering style", + ] + if style is not None: + self.style = [style] + + def __getitem__(self, index): + prompt = self.object[index]+self.style[0] + + return prompt, prompt + + def __len__(self): + return len(self.object) + + def unpreprocess(self, v): # to B C H W and [0, 1] + v.clamp_(0., 1.) + return v + + @property + def fid_stat(self): + return f'assets/fid_stats/fid_stats_cc3m_val.npz' + + +class train_custom_dataset(Dataset): + + def __init__(self, train_file: str=None, ): + + self.train_img = json.load(open(train_file, 'r')) + self.path_preffix = "/".join(train_file.split("/")[:-1]) + self.prompt = [] + self.image = [] + self.style = [] + for im in self.train_img.keys(): + im_path = os.path.join(self.path_preffix, im) + self.object = self.train_img[im][0] + self.style = self.train_img[im][1] + im_prompt = self.object +" "+self.style + self.image.append(im_path) + self.prompt.append(im_prompt) + self.empty_context = np.load("assets/contexts/empty_context.npy") + + self.transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomHorizontalFlip(), + # transforms.RandomVerticalFlip(), + transforms.ToTensor(), + ]) + print("-----------------"*3) + print("train dataset length: ", len(self.prompt)) + print("train dataset length: ", len(self.image)) + print(self.prompt[0]) + print(self.image[0]) + print("-----------------"*3) + def __getitem__(self, index): + prompt = self.prompt[0] + image = Image.open(self.image[0]).convert("RGB") + image = self.transform(image) + + return image,prompt + # return dict(img=image_embedding, text=text_embedding) + + def __len__(self): + return 24 + + def unpreprocess(self, v): # to B C H W and [0, 1] + v.clamp_(0., 1.) + return v + + @property + def fid_stat(self): + return f'assets/fid_stats/fid_stats_cc3m_val.npz' + + + + + +class Discriptor(Dataset): + def __init__(self,style: str=None): + self.object =[ + # "A parrot ", + # "A bird ", + # "A chihuahua in the snow", + # "A towel ", + # "A number '1' ", + # "A number '2' ", + # "A number '3' ", + # "A number '6' ", + # "A letter 'L' ", + # "A letter 'Z' ", + # "A letter 'D' ", + # "A rabbit ", + # "A train ", + # "A table ", + # "A dish ", + # "A large boat ", + # "A puppy ", + # "A cup ", + # "A watermelon ", + # "An apple ", + # "A banana ", + # "A chair ", + # "A Welsh Corgi ", + # "A cat ", + # "A house ", + # "A flower ", + # "A sunflower ", + # "A car ", + # "A jeep car ", + # "A truck ", + # "A Posche car ", + # "A vase ", + # "A chihuahua ", + # "A tabby cat ", + "A portrait of chihuahua ", + "An apple on the table ", + "A banana on the table ", + "A human ", + "A church on the street ", + "A church in the mountain ", + "A church in the field ", + "A church on the beach ", + "A chihuahua walking on the street ", + "A tabby cat walking on the street", + "A portrait of tabby cat ", + "An apple on the dish ", + "A banana on the dish ", + "A human walking on the street ", + "A temple on the street ", + "A temple in the mountain ", + "A temple in the field ", + "A temple on the beach ", + "A chihuahua walking in the forest ", + "A tabby cat walking in the forest ", + "A portrait of human face ", + "An apple on the ground ", + "A banana on the ground ", + "A human walking in the forest ", + "A cabin on the street ", + "A cabin in the mountain ", + "A cabin in the field ", + "A cabin on the beach ", + "A letter 'A' ", + "A letter 'B' ", + "A letter 'C' ", + "A letter 'D' ", + "A letter 'E' ", + "A letter 'F' ", + "A letter 'G' ", + "A butterfly ", + " A baby penguin ", + "A bench ", + "A boat ", + "A cow ", + "A hat ", + "A piano ", + "A robot ", + "A christmas tree ", + "A dog ", + "A moose ", + ] + + self.style =[ + "in 3d rendering style", + ] + if style is not None: + self.style = [style] + + def __getitem__(self, index): + prompt = self.object[index]+self.style[0] + return prompt + + def __len__(self): + return len(self.object) + + def unpreprocess(self, v): # to B C H W and [0, 1] + v.clamp_(0., 1.) + return v + + @property + def fid_stat(self): + return f'assets/fid_stats/fid_stats_cc3m_val.npz' + \ No newline at end of file diff --git a/data/data.json b/data/data.json new file mode 100644 index 0000000000000000000000000000000000000000..bd72eb46115dc926d759482d55690c7ad7423704 --- /dev/null +++ b/data/data.json @@ -0,0 +1,22 @@ +{ + "image_01_01.jpg":["A bay","in watercolor painting style"], + "image_01_02.jpg":["A house", "in watercolor painting style"], + "image_01_03.jpg":["A cat", "in watercolor painting style"], + "image_01_04.jpg":["Flowers", "in watercolor painting style"], + "image_01_05.jpg":["A village", "in oil painting style"], + "image_01_06.jpg":["A village", "in line drawing style"], + "image_01_07.jpg":["A portrait of a person", "in oil painting style"], + "image_01_08.jpg":["A portrait of a person wearing a hat", "in oil painting style"], + "image_02_01.jpg":["A person drwoning into th phone", "in cartoon line drawing style"], + "image_02_02.jpg":["A woman walking a dog", "in flat cartoon illustration style"], + "image_02_03.jpg":["A woman working on a laptop", "in flat cartoon illustration style"], + "image_02_04.jpg":["A Christmas tree", "in sticker style"], + "image_02_05.jpg":["A wave", "in abstract rainbow colored flowing smoke wave design"], + "image_02_06.jpg":["A mushroom", "in glowing style"], + "image_03_01.jpg":["Slice of watermelon and clouds in the background", "in 3d rendering style"], + "image_03_03.jpg":["A thumbs up", "in glowing 3d rendering style"], + "image_03_04.jpg":["A woman", "in 3d rendering style"], + "image_03_05.jpg":["A bear", "in kid crayon drawing style"], + "image_03_07.jpg":["A flower", "in melting golden 3d rendering style"], + "image_03_08.jpg":["A Viking face with beard", "in wooden sculpture"] +} \ No newline at end of file diff --git a/data/image_01_01.jpg b/data/image_01_01.jpg new file mode 100644 index 0000000000000000000000000000000000000000..68a6e5bb3b870bc236666c8da0ed3e59ba8fd27c --- /dev/null +++ b/data/image_01_01.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b467d766af07216c77d933abfbd8fbf97efc69604f6d98f57da207609f5322b +size 119071 diff --git a/data/image_01_02.jpg b/data/image_01_02.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0595736c4877da6ca3a0d573ddd8dd13f4a9033d --- /dev/null +++ b/data/image_01_02.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:426033b83f52843be0552d4b94453ad07141b29c7f21e0555ec9e3304d73e8ad +size 177371 diff --git a/data/image_01_03.jpg b/data/image_01_03.jpg new file mode 100644 index 0000000000000000000000000000000000000000..47822ebd7fbd49c62fba850f9109b9569483ce01 --- /dev/null +++ b/data/image_01_03.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2335c5df1ee92c60229fb5198ba0ceb02dc157fb4c3aaa3e191466577cc80eae +size 662604 diff --git a/data/image_01_04.jpg b/data/image_01_04.jpg new file mode 100644 index 0000000000000000000000000000000000000000..59c379694d788c088ccbc97107f4db2b1390d037 --- /dev/null +++ b/data/image_01_04.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:92a4544523e35cbe5a23b67820f2e6257c5703d8edced66a584b002ec1865c02 +size 35117 diff --git a/data/image_01_05.jpg b/data/image_01_05.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0f6427762ab20049ed94146f41d92ae93ac7c858 --- /dev/null +++ b/data/image_01_05.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d06b8a46a2878a25573c618f912929beffc0441f5a8d3f2e9ac3ae3217df94f +size 250662 diff --git a/data/image_01_06.jpg b/data/image_01_06.jpg new file mode 100644 index 0000000000000000000000000000000000000000..209e66aee568a6c5199255fba29d48986fecbeb8 --- /dev/null +++ b/data/image_01_06.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d02c652a5836154ceab17aec342dea76d06c4f6a23c964c45244426bf87fd0af +size 157661 diff --git a/data/image_01_07.jpg b/data/image_01_07.jpg new file mode 100644 index 0000000000000000000000000000000000000000..58587049efd6245cacef4f0898b14f0b11eebc49 --- /dev/null +++ b/data/image_01_07.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:688a5e48e1208de644f2163a2b44d46a54b1ce3627407bebcf1f389c58a34c46 +size 1482274 diff --git a/data/image_01_08.jpg b/data/image_01_08.jpg new file mode 100644 index 0000000000000000000000000000000000000000..adc2ddd23f2e39464bc8073fddbb60598046abc6 --- /dev/null +++ b/data/image_01_08.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47632bf3a07a6c7630d032d64371ae58fb08469900eff849ae52b256948b6930 +size 626273 diff --git a/data/image_02_01.jpg b/data/image_02_01.jpg new file mode 100644 index 0000000000000000000000000000000000000000..77ee4299976e6d730065288fe113d0a581a51ca3 --- /dev/null +++ b/data/image_02_01.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e3550da99d36ec1568f313c45401a72a17c42ac32801a2c507ff7d85d874716 +size 71890 diff --git a/data/image_02_02.jpg b/data/image_02_02.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ff3cb81cddcf3b9858b14fbf87014a1fb11faea8 --- /dev/null +++ b/data/image_02_02.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9768bcda5ec0953f20a954542232d0a0d630e681ffe96c92d05d49d2f8a22183 +size 464529 diff --git a/data/image_02_03.jpg b/data/image_02_03.jpg new file mode 100644 index 0000000000000000000000000000000000000000..adabb77ce4cca97629df99c19a5359dd3ef572ca --- /dev/null +++ b/data/image_02_03.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f07fe073d140d6dc2d4af9609ba73ba4750f46aa2304d2ffc171989d8c4fba78 +size 1096476 diff --git a/data/image_02_04.jpg b/data/image_02_04.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b927e3a81967662dbb17b2eced1754dc5f9e67b8 --- /dev/null +++ b/data/image_02_04.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57e5dcf39366c4da8727fff4c48214151b4d427033402f28e91e1a5e5384eeb8 +size 481009 diff --git a/data/image_02_05.jpg b/data/image_02_05.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2b490d95b09439d0dd2db20e7f3d99c13966ed89 --- /dev/null +++ b/data/image_02_05.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42c439155b17df9bab951a56f9e88e46c7c0109d345fc07553f62d7ccefbbc05 +size 65755 diff --git a/data/image_02_06.jpg b/data/image_02_06.jpg new file mode 100644 index 0000000000000000000000000000000000000000..67963951bec8d79100a4d3f02d07a12d644a8b0d --- /dev/null +++ b/data/image_02_06.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:efb5a021a7fb5fdcb6e6ed7f8aa282e6a9ae50177a9d8199f82bba748f54d172 +size 175720 diff --git a/data/image_03_01.jpg b/data/image_03_01.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5cc16e7dac344a2d011acaba2d3f97aba9ce44ca --- /dev/null +++ b/data/image_03_01.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b490adc5a556bd5d2f68ef3a28d0ca85fbc8b0d04212df2f19d8a10001eb09a8 +size 140079 diff --git a/data/image_03_03.jpg b/data/image_03_03.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8ed55dfccff1937a06e2f6095b60519da38a7ae9 --- /dev/null +++ b/data/image_03_03.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1cdc7fa8d2c8ac873140c4b9c06d0df911063a9a8535d429ad0ddd50e8e7175 +size 123571 diff --git a/data/image_03_04.jpg b/data/image_03_04.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e2c3d0fbeca21a0353c3cfcccfa18e9e9aeb11e9 --- /dev/null +++ b/data/image_03_04.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1bce51718f7a09b4e647df9a0e95f19ec2a18678c6d1f057a798828365a4c64 +size 212809 diff --git a/data/image_03_05.jpg b/data/image_03_05.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5280206942f91bfbf6273b0fb057ae3525db3fb2 --- /dev/null +++ b/data/image_03_05.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53e41c6832e722d45170958160ffc4a632da969dc84a98d9fd608620e183825b +size 531555 diff --git a/data/image_03_07.jpg b/data/image_03_07.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f1d1469bd5bb72f34bd460377cec14f1a56375bc --- /dev/null +++ b/data/image_03_07.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d41a949ddb0d7683c27dfd9be52b0dce62f7492a443bc6bdfa4a0e038af949a4 +size 79970 diff --git a/data/image_03_08.jpg b/data/image_03_08.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f61a06056490b13d69de9267b02842803fe4b97c --- /dev/null +++ b/data/image_03_08.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17c9388900a405ffbd387114965c61b008b235c900393a99feecae4bb02675b5 +size 418919 diff --git a/data/one_style.json b/data/one_style.json new file mode 100644 index 0000000000000000000000000000000000000000..c6c48a02d0d9c05f7fc1fdd8f9f8e94858c15381 --- /dev/null +++ b/data/one_style.json @@ -0,0 +1,3 @@ +{ + "image_01_02.jpg":["A house", "in watercolor painting style"] +} \ No newline at end of file diff --git a/libs/__init__.py b/libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d539774e00df75d2b2075d53ac242d974f597835 --- /dev/null +++ b/libs/__init__.py @@ -0,0 +1 @@ +# codes from third party diff --git a/libs/muse.py b/libs/muse.py new file mode 100644 index 0000000000000000000000000000000000000000..3556442e0cebc0cadb402bbf2c7afb7e4412f498 --- /dev/null +++ b/libs/muse.py @@ -0,0 +1,107 @@ +import numpy as np +import torch +import math +from einops import rearrange +from torch.nn import functional as F + + +def add_gumbel_noise(t, temperature, device): + return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device)) + + +class MUSE(object): + def __init__(self, codebook_size, device, ignore_ind=-1, smoothing=0., gen_temp=4.5): + self.mask_ind = codebook_size # for input masking + self.ignore_ind = ignore_ind # for ce loss, excluding visible + self.device = device + self.smoothing = smoothing + self.gen_temp = gen_temp + + @staticmethod + def cosine_schedule(t): + return torch.cos(t * math.pi * 0.5) + + def sample(self, x0): + N, L, device = *x0.shape, self.device + timesteps = torch.zeros((N,), device=device).float().uniform_(0, 1) + rand_mask_probs = self.cosine_schedule(timesteps) # cosine schedule + num_token_masked = (L * rand_mask_probs).round().clamp(min=1) + batch_randperm = torch.rand(N, L, device=device).argsort(dim=-1) + mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1') + masked_ids = torch.where(mask, self.mask_ind, x0) + labels = torch.where(mask, x0, self.ignore_ind) + return labels, masked_ids + + def loss(self, pred, label): + return F.cross_entropy(pred.transpose(1, 2), label.long(), + ignore_index=self.ignore_ind, label_smoothing=self.smoothing) + + @torch.no_grad() + def generate(self, config, _n_samples, nnet, decode_fn, is_eval=False, **kwargs): + fmap_size, _sample_steps, device = config.z_shape[-1], config.sample.sample_steps, self.device + + seq_len = fmap_size ** 2 + ids = torch.full((_n_samples, seq_len), self.mask_ind, dtype=torch.long, device=device) + cfg_scale = 0. + for step in range(_sample_steps): + ratio = 1. * (step + 1) / _sample_steps + annealed_temp = self.gen_temp * (1 - ratio) + is_mask = (ids == self.mask_ind) + logits = nnet(ids, **kwargs, scale=cfg_scale) + # sampling & scoring + sampled_ids = add_gumbel_noise(logits, annealed_temp, device).argmax(dim=-1) + sampled_logits = torch.squeeze( + torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) + sampled_ids = torch.where(is_mask, sampled_ids, ids) + sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() + # masking + mask_ratio = np.cos(ratio * math.pi * 0.5) + mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(device) + mask_len = torch.maximum(torch.Tensor([1]).to(device), + torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, + mask_len))[0].squeeze() + confidence = add_gumbel_noise(sampled_logits, annealed_temp, device) + sorted_confidence, _ = torch.sort(confidence, axis=-1) + cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] + masking = (confidence <= cut_off) + ids = torch.where(masking, self.mask_ind, sampled_ids) + cfg_scale = ratio * config.sample.scale + + _z1 = rearrange(sampled_ids, 'b (i j) -> b i j', i=fmap_size, j=fmap_size) + + # with adapter + ids = torch.full((_n_samples, seq_len), self.mask_ind, dtype=torch.long, device=device) + cfg_scale = 0. + lambdaA=0. + lambdaB=0. + for step in range(_sample_steps): + ratio = 1. * (step + 1) / _sample_steps + annealed_temp = self.gen_temp * (1 - ratio) + is_mask = (ids == self.mask_ind) + # 尝试使用 *ratio + logits = nnet(ids, **kwargs, scale=cfg_scale,lambdaA=lambdaA,lambdaB=lambdaB) + # sampling & scoring + sampled_ids = add_gumbel_noise(logits, annealed_temp, device).argmax(dim=-1) + sampled_logits = torch.squeeze( + torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) + sampled_ids = torch.where(is_mask, sampled_ids, ids) + sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() + # masking + mask_ratio = np.cos(ratio * math.pi * 0.5) + mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(device) + mask_len = torch.maximum(torch.Tensor([1]).to(device), + torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, + mask_len))[0].squeeze() + confidence = add_gumbel_noise(sampled_logits, annealed_temp, device) + sorted_confidence, _ = torch.sort(confidence, axis=-1) + cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] + masking = (confidence <= cut_off) + ids = torch.where(masking, self.mask_ind, sampled_ids) + cfg_scale = ratio * config.sample.scale + lambdaA = config.sample.lambdaA + lambdaB = config.sample.lambdaB + + _z2 = rearrange(sampled_ids, 'b (i j) -> b i j', i=fmap_size, j=fmap_size) + _z = _z2 if is_eval else torch.cat([_z1,_z2],dim=0) + out = decode_fn(_z) + return out diff --git a/libs/uvit_t2i_vq.py b/libs/uvit_t2i_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..cf380e1b93ed224c93b8467f1c0b18489fa849ec --- /dev/null +++ b/libs/uvit_t2i_vq.py @@ -0,0 +1,282 @@ +import torch +import torch.nn as nn +import math + +from loguru import logger + +import timm +from timm.models.layers import trunc_normal_ +from timm.models.vision_transformer import PatchEmbed, Mlp + +assert timm.__version__ == "0.3.2" # version check +import einops +import torch.utils.checkpoint +import torch.nn.functional as F + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILBLE = True + print("xformers available, will use xformers attention") +except: + XFORMERS_IS_AVAILBLE = False + print("xformers not available, will use pytorch attention instead") + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1): + super().__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6) + self.dropout = nn.Dropout(dropout) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) + + torch.nn.init.normal_(self.word_embeddings.weight, std=.02) + torch.nn.init.normal_(self.position_embeddings.weight, std=.02) + + def forward( + self, input_ids + ): + input_shape = input_ids.size() + + seq_length = input_shape[1] + + position_ids = self.position_ids[:, :seq_length] + + inputs_embeds = self.word_embeddings(input_ids) + + position_embeddings = self.position_embeddings(position_ids) + embeddings = inputs_embeds + position_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class MlmLayer(nn.Module): + + def __init__(self, feat_emb_dim, word_emb_dim, vocab_size): + super().__init__() + self.fc = nn.Linear(feat_emb_dim, word_emb_dim) + self.gelu = nn.GELU() + self.ln = nn.LayerNorm(word_emb_dim) + self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size)) + + def forward(self, x, word_embeddings): + mlm_hidden = self.fc(x) + mlm_hidden = self.gelu(mlm_hidden) + mlm_hidden = self.ln(mlm_hidden) + word_embeddings = word_embeddings.transpose(0, 1) + logits = torch.matmul(mlm_hidden, word_embeddings) + logits = logits + self.bias + return logits + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + if XFORMERS_IS_AVAILBLE: + qkv = self.qkv(x) + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B L H D + x = xformers.ops.memory_efficient_attention(q, k, v) + x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Adapter(nn.Module): + def __init__(self, d_emb:int, d_prj:int,n_layer: int, is_shared: bool): + super().__init__() + self.D = d_emb + self.H = d_prj + self.L = n_layer + self.is_shared = is_shared + if self.is_shared: + self.DD = nn.Embedding(self.L,self.H) + self.DU = nn.Embedding(self.L,self.D) + self.WD = nn.Embedding(1,self.D*self.H) + self.WU = nn.Embedding(1,self.H*self.D) + else: + self.WD = nn.Embedding(self.L,self.D*self.H) + self.WU = nn.Embedding(self.L,self.H*self.D) + self.activate = nn.GELU() + + self._init_weights() + def _init_weights(self): + for p in self.WU.parameters(): + p.detach().zero_() + nn.init.trunc_normal_(self.WD.weight,mean=0,std=0.02) + + if self.is_shared: + nn.init.trunc_normal_(self.DD.weight,mean=0,std=0.02) + for p in self.DU.parameters(): + p.detach().zero_() + + def forward(self, emb, layer): + idx = torch.arange(self.L).to(emb.device) + layer = torch.tensor(layer).to(emb.device) + if self.is_shared: + idx0 = torch.zeros_like(idx).to(emb.device) + dd = self.DD(idx).reshape(self.L, 1,self.H) + du = self.DU(idx).reshape(self.L, 1,self.D) + wd = self.WD(idx0).reshape(self.L, self.D,self.H) + dd + wu = self.WU(idx0).reshape(self.L, self.H,self.D) + du + else: + wd = self.WD(idx).reshape(self.L, self.D,self.H) + wu = self.WU(idx).reshape(self.L, self.H,self.D) + + prj = torch.einsum('...d,dh->...h',emb,wd[layer]) + prj = self.activate(prj) + prj = torch.einsum('...h,hd->...d',prj,wu[layer]) + return emb + prj +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + self.skip_linear = nn.Linear(2 * dim, dim) if skip else None + self.use_checkpoint = use_checkpoint + + def forward(self, x, skip=None, adapter=None, layer=None): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, skip, adapter, layer) + else: + return self._forward(x, skip, adapter, layer) + + def _forward(self, x, skip=None,adapter=None, layer=None): + if self.skip_linear is not None: + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + + attn = self.attn(self.norm1(x)) + if adapter is not None: + attn = adapter(attn, layer) + + x = x + attn + x = x + self.mlp(self.norm2(x)) + return x + + +class UViT(nn.Module): + def __init__(self, img_size=16, in_chans=8, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., + qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False, + clip_dim=768, num_clip_token=77, skip=True, codebook_size=1024,d_prj=4,is_shared=True): + super().__init__() + logger.debug(f'codebook size in nnet: {codebook_size}') + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.in_chans = in_chans + self.skip = skip + + self.codebook_size = codebook_size + vocab_size = codebook_size + 1 + self.time_embed = None + self.extras = num_clip_token + self.num_vis_tokens = int((img_size) ** 2) + self.token_emb = BertEmbeddings(vocab_size=vocab_size, + hidden_size=embed_dim, + max_position_embeddings=self.num_vis_tokens, + dropout=0.1) + print(f'num vis tokens: {self.num_vis_tokens}') + + self.context_embed = nn.Linear(clip_dim, embed_dim) + + self.in_blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.mid_block = Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, use_checkpoint=use_checkpoint) + + self.out_blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.norm = norm_layer(embed_dim) + self.mlm_layer = MlmLayer(feat_emb_dim=embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size) + self.adapter = Adapter(d_emb=embed_dim, d_prj=d_prj, n_layer=depth, is_shared=is_shared) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore # type: ignore + def no_weight_decay(self): + return {'pos_embed'} + + def forward(self, masked_ids, context,use_adapter=False): + assert len(masked_ids.shape) == 2 + x = self.token_emb(masked_ids) + context_token = self.context_embed(context.type_as(x)) + x = torch.cat((context_token, x), dim=1) + + layer=0 + + if self.skip: + skips = [] + for blk in self.in_blocks: + # 将adapter放在attention之后 + x = blk(x,adapter=self.adapter if use_adapter else None,layer=layer) + if self.skip: + skips.append(x)# type: ignore + layer+=1 + + x = self.mid_block(x) + + for blk in self.out_blocks: + if self.skip: + x = blk(x, skips.pop(),adapter = self.adapter if use_adapter else None,layer=layer)# type: ignore + else: + x = blk(x,adapter = self.adapter if use_adapter else None,layer=layer) + + x = self.norm(x) + + word_embeddings = self.token_emb.word_embeddings.weight.data.detach() + x = self.mlm_layer(x, word_embeddings) + x = x[:, self.extras:, :self.codebook_size] + return x diff --git a/libs/uvit_vq.py b/libs/uvit_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..62cbe7fd72c07cc0dcaafed4228ead7153286cea --- /dev/null +++ b/libs/uvit_vq.py @@ -0,0 +1,264 @@ +import os + +import torch +import torch.nn as nn +import math + +from loguru import logger + +import timm +from timm.models.layers import trunc_normal_ +from timm.models.vision_transformer import PatchEmbed, Mlp + +assert timm.__version__ == "0.3.2" # version check +import einops +import torch.utils.checkpoint +import torch.nn.functional as F + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1): + super().__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6) + self.dropout = nn.Dropout(dropout) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) + + torch.nn.init.normal_(self.word_embeddings.weight, std=.02) + torch.nn.init.normal_(self.position_embeddings.weight, std=.02) + + def forward( + self, input_ids + ): + input_shape = input_ids.size() + + seq_length = input_shape[1] + + position_ids = self.position_ids[:, :seq_length] + + inputs_embeds = self.word_embeddings(input_ids) + + position_embeddings = self.position_embeddings(position_ids) + embeddings = inputs_embeds + position_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class MlmLayer(nn.Module): + + def __init__(self, feat_emb_dim, word_emb_dim, vocab_size): + super().__init__() + self.fc = nn.Linear(feat_emb_dim, word_emb_dim) + self.gelu = nn.GELU() + self.ln = nn.LayerNorm(word_emb_dim) + self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size)) + + def forward(self, x, word_embeddings): + mlm_hidden = self.fc(x) + mlm_hidden = self.gelu(mlm_hidden) + mlm_hidden = self.ln(mlm_hidden) + word_embeddings = word_embeddings.transpose(0, 1) + logits = torch.matmul(mlm_hidden, word_embeddings) + logits = logits + self.bias + return logits + + +def patchify(imgs, patch_size): + x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) + return x + + +def unpatchify(x, channels=3, flatten=False): + patch_size = int((x.shape[2] // channels) ** 0.5) + h = w = int(x.shape[1] ** .5) + assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] + if flatten: + x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1 w p2) C', h=h, p1=patch_size, p2=patch_size) + else: + x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + if XFORMERS_IS_AVAILBLE: + qkv = self.qkv(x) + qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B L H D + x = xformers.ops.memory_efficient_attention(q, k, v) + x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + self.skip_linear = nn.Linear(2 * dim, dim) if skip else None + self.use_checkpoint = use_checkpoint + + def forward(self, x, skip=None): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, skip) + else: + return self._forward(x, skip) + + def _forward(self, x, skip=None): + if self.skip_linear is not None: + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class UViT(nn.Module): + def __init__(self, img_size=16, patch_size=1, in_chans=8, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., + qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, num_classes=-1, + use_checkpoint=False, skip=True, codebook_size=1024): + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_classes = num_classes + self.in_chans = in_chans + self.skip = skip + + logger.debug(f'codebook size in nnet: {codebook_size}') + self.codebook_size = codebook_size + if num_classes > 0: + self.extras = 1 + vocab_size = codebook_size + num_classes + 1 + else: + self.extras = 0 + vocab_size = codebook_size + 1 + + self.token_emb = BertEmbeddings(vocab_size=vocab_size, + hidden_size=embed_dim, + max_position_embeddings=int(img_size ** 2) + self.extras, + dropout=0.1) + logger.debug(f'token emb weight shape: {self.token_emb.word_embeddings.weight.shape}') + + if patch_size != 1: # downsamp + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, input_shape='bhwc') + logger.debug(f'patch emb weight shape: {self.patch_embed.proj.weight.shape}') + self.decoder_pred = nn.Linear(embed_dim, patch_size ** 2 * embed_dim, bias=True) + else: + self.patch_embed = None + self.decoder_pred = None + + self.in_blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.mid_block = Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, use_checkpoint=use_checkpoint) + + self.out_blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.norm = norm_layer(embed_dim) + self.mlm_layer = MlmLayer(feat_emb_dim=embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed'} + + def forward(self, x, context=None): + assert len(x.shape) == 2 + if context is not None: + context = context + self.codebook_size + 1 # shift, mask token is self.codebook_size + x = torch.cat((context, x), dim=1) + x = self.token_emb(x.long()) + if self.patch_embed is not None: + featmap_downsampled = self.patch_embed( + x[:, self.extras:].reshape(-1, *self.patch_embed.img_size, self.embed_dim)).reshape(x.shape[0], -1, self.embed_dim) + x = torch.cat((x[:, :self.extras], featmap_downsampled), dim=1) + + if self.skip: + skips = [] + for blk in self.in_blocks: + x = blk(x) + if self.skip: + skips.append(x) + + x = self.mid_block(x) + + for blk in self.out_blocks: + if self.skip: + x = blk(x, skips.pop()) + else: + x = blk(x) + + x = self.norm(x) + if self.decoder_pred is not None: + featmap_upsampled = unpatchify(self.decoder_pred(x[:, self.extras:]), self.embed_dim, flatten=True) + x = torch.cat((x[:, :self.extras], featmap_upsampled), dim=1) + word_embeddings = self.token_emb.word_embeddings.weight.data.detach() + x = self.mlm_layer(x, word_embeddings) + x = x[:, self.extras:, :self.codebook_size] + return x diff --git a/open_clip/__init__.py b/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..088c86441ec71a241320de79b7b66a6afeb3a049 --- /dev/null +++ b/open_clip/__init__.py @@ -0,0 +1,13 @@ +from .coca_model import CoCa +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub +from .tokenizer import SimpleTokenizer, tokenize, decode +from .transform import image_transform, AugmentationCfg diff --git a/open_clip/bpe_simple_vocab_16e6.txt.gz b/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/open_clip/coca_model.py b/open_clip/coca_model.py new file mode 100644 index 0000000000000000000000000000000000000000..039453af70d1c865dd7cc6016f732aff2f7dc3d2 --- /dev/null +++ b/open_clip/coca_model.py @@ -0,0 +1,458 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = pad_id + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize=True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize=True, embed_cls=True): + text = text[:, :-1] if embed_cls else text # make space for CLS token + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize=True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True, embed_cls=True): + text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) + return text_latent + + def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + # TODO: add assertion to avoid bugs? + labels = text[:, -token_embs.shape[1]:] + + logits = self.text_decoder(image_embs, token_embs) + return { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "labels": labels, + "logit_scale": self.logit_scale.exp() + } + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + + with torch.no_grad(): + sot_token_id = 49406 if sot_token_id is None else sot_token_id + eos_token_id = 49407 if eos_token_id is None else eos_token_id + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + + stopping_criteria = StoppingCriteriaList( + stopping_criteria + ) + + device = image.device + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs = image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + return torch.cat( + (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + cur_len = text.shape[1] + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if stopping_criteria(out, None): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + embed_cls=False, + image_latent=image_latent, + image_embs=image_embs + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or stopping_criteria(input_ids, None): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } diff --git a/open_clip/constants.py b/open_clip/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/open_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/open_clip/factory.py b/open_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..14011f9340fc6e54876c3c5bcb9e23a8cd57849d --- /dev/null +++ b/open_clip/factory.py @@ -0,0 +1,366 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf +from .transform import image_transform, AugmentationCfg +from .tokenizer import HFTokenizer, tokenize + + +HF_HUB_PREFIX = 'hf-hub:' +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, 'r') as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + if model_name.startswith(HF_HUB_PREFIX): + tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) + else: + config = get_model_config(model_name) + tokenizer = HFTokenizer( + config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + return tokenizer + + +def load_state_dict(checkpoint_path: str, map_location='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + + +def load_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, +): + has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) + if has_hf_hub_prefix: + model_id = model_name[len(HF_HUB_PREFIX):] + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + pretrained_cfg = config['preprocess_cfg'] + model_cfg = config['model_cfg'] + else: + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + pretrained_cfg = {} + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logging.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + jit=jit, + cache_dir=cache_dir, + ) + + # to always output dict even if it is clip + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + else: + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') + else: + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + if pretrained_image: + if 'timm_model_name' in model_cfg.get('vision_cfg', {}): + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + else: + assert False, 'pretrained image towers currently only supported for timm models' + + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + if custom_text: + if is_hf_model: + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + + model.to(device=device) + if precision in ("fp16", "bf16"): + convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + + # to always output dict even if it is clip + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + return model + + +def create_loss(args): + if args.distill: + return DistillClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + elif "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_patch_dropout=force_patch_dropout, + force_image_size=force_image_size, + pretrained_image=pretrained_image, + pretrained_hf=pretrained_hf, + cache_dir=cache_dir, + output_dict=output_dict, + ) + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess_train, preprocess_val + + +def create_model_from_pretrained( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_image_size=force_image_size, + cache_dir=cache_dir, + require_pretrained=True, + ) + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess diff --git a/open_clip/generation_utils.py b/open_clip/generation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/open_clip/hf_configs.py b/open_clip/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..e236222bafce0358445ea16953ca0b2d5a84758a --- /dev/null +++ b/open_clip/hf_configs.py @@ -0,0 +1,45 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, +} diff --git a/open_clip/hf_model.py b/open_clip/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fbccc812757bf10b122ff14096980e0e38d1d221 --- /dev/null +++ b/open_clip/hf_model.py @@ -0,0 +1,176 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" + +import re + +import torch +import torch.nn as nn +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +from .hf_configs import arch_dict + + +# utils +def _camel2snake(s): + return re.sub(r'(? torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss + + +class DistillClipLoss(ClipLoss): + + def dist_loss(self, teacher_logits, student_logits): + return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, + ): + logits_per_image, logits_per_text = \ + self.get_logits(image_features, text_features, logit_scale) + + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + + labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) + + contrastive_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + distill_loss = ( + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) + ) / 2 + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + + return contrastive_loss, distill_loss diff --git a/open_clip/model.py b/open_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..230f36e41143346e9fc2c1c80d4c9740fffe7a81 --- /dev/null +++ b/open_clip/model.py @@ -0,0 +1,445 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +from dataclasses import dataclass +import logging +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from .hf_model import HFTextEncoder +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .utils import to_2tuple + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + output_tokens: bool = False + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 + output_tokens: bool = False + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + drop=vision_cfg.timm_drop, + drop_path=vision_cfg.timm_drop_path, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + input_patchnorm=vision_cfg.input_patchnorm, + global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + proj=text_cfg.proj, + pooler_type=text_cfg.pooler_type, + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens, + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, + pad_id=text_cfg.pad_id, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed diff --git a/open_clip/model_configs/RN101-quickgelu.json b/open_clip/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/RN101.json b/open_clip/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/RN50-quickgelu.json b/open_clip/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/open_clip/model_configs/RN50.json b/open_clip/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/RN50x16.json b/open_clip/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/RN50x4.json b/open_clip/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/RN50x64.json b/open_clip/model_configs/RN50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..f5aaa2ee3de21ddb03cbd12766a3419bf34898c7 --- /dev/null +++ b/open_clip/model_configs/RN50x64.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-16-plus-240.json b/open_clip/model_configs/ViT-B-16-plus-240.json new file mode 100644 index 0000000000000000000000000000000000000000..5bbd12bcd01f64d6d0a0aa8316b129327a0d169a --- /dev/null +++ b/open_clip/model_configs/ViT-B-16-plus-240.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 240, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-16-plus.json b/open_clip/model_configs/ViT-B-16-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..5dc1e09baccef2b15055c1bffeb9903e760101c6 --- /dev/null +++ b/open_clip/model_configs/ViT-B-16-plus.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-16.json b/open_clip/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-32-plus-256.json b/open_clip/model_configs/ViT-B-32-plus-256.json new file mode 100644 index 0000000000000000000000000000000000000000..2f09c857de9a4c01ae51297a7e2451984879f9de --- /dev/null +++ b/open_clip/model_configs/ViT-B-32-plus-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 896, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-32-quickgelu.json b/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-B-32.json b/open_clip/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-H-14.json b/open_clip/model_configs/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74 --- /dev/null +++ b/open_clip/model_configs/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-H-16.json b/open_clip/model_configs/ViT-H-16.json new file mode 100644 index 0000000000000000000000000000000000000000..588485455fdf8193ec16474450b94e31c91ea93c --- /dev/null +++ b/open_clip/model_configs/ViT-H-16.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-L-14-280.json b/open_clip/model_configs/ViT-L-14-280.json new file mode 100644 index 0000000000000000000000000000000000000000..2262deaefa82792d35d73c0d7c8e620525092581 --- /dev/null +++ b/open_clip/model_configs/ViT-L-14-280.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 280, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-L-14-336.json b/open_clip/model_configs/ViT-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97 --- /dev/null +++ b/open_clip/model_configs/ViT-L-14-336.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-L-14.json b/open_clip/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-L-16-320.json b/open_clip/model_configs/ViT-L-16-320.json new file mode 100644 index 0000000000000000000000000000000000000000..fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba --- /dev/null +++ b/open_clip/model_configs/ViT-L-16-320.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 320, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-L-16.json b/open_clip/model_configs/ViT-L-16.json new file mode 100644 index 0000000000000000000000000000000000000000..82a1cedfa290adacbbdc02bc5d589734c22d41d3 --- /dev/null +++ b/open_clip/model_configs/ViT-L-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-M-16-alt.json b/open_clip/model_configs/ViT-M-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8 --- /dev/null +++ b/open_clip/model_configs/ViT-M-16-alt.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16, + "ls_init_value": 1e-4 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-M-16.json b/open_clip/model_configs/ViT-M-16.json new file mode 100644 index 0000000000000000000000000000000000000000..f2f3225a46e09237730a151d161f70c86b985172 --- /dev/null +++ b/open_clip/model_configs/ViT-M-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-M-32-alt.json b/open_clip/model_configs/ViT-M-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0 --- /dev/null +++ b/open_clip/model_configs/ViT-M-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-M-32.json b/open_clip/model_configs/ViT-M-32.json new file mode 100644 index 0000000000000000000000000000000000000000..4f718642821035d9776d1e006817d65ede074366 --- /dev/null +++ b/open_clip/model_configs/ViT-M-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-S-16-alt.json b/open_clip/model_configs/ViT-S-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..a8c056555e4da3ba0d1475a61fc316362ecce76f --- /dev/null +++ b/open_clip/model_configs/ViT-S-16-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-S-16.json b/open_clip/model_configs/ViT-S-16.json new file mode 100644 index 0000000000000000000000000000000000000000..1d8504e59658803f3093e5b05de45f30a09b8185 --- /dev/null +++ b/open_clip/model_configs/ViT-S-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-S-32-alt.json b/open_clip/model_configs/ViT-S-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..e1dfdec9824df09a2010e991ccfa1d9ee2f45807 --- /dev/null +++ b/open_clip/model_configs/ViT-S-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-S-32.json b/open_clip/model_configs/ViT-S-32.json new file mode 100644 index 0000000000000000000000000000000000000000..9b8b4191b268de267268cfcb90fc01c6b9df07d8 --- /dev/null +++ b/open_clip/model_configs/ViT-S-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-bigG-14.json b/open_clip/model_configs/ViT-bigG-14.json new file mode 100644 index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7 --- /dev/null +++ b/open_clip/model_configs/ViT-bigG-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-e-14.json b/open_clip/model_configs/ViT-e-14.json new file mode 100644 index 0000000000000000000000000000000000000000..91a0fe14d25a107fb8ec48dd7faae313fd26ed7b --- /dev/null +++ b/open_clip/model_configs/ViT-e-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 56, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.5715, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 36 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/ViT-g-14.json b/open_clip/model_configs/ViT-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091 --- /dev/null +++ b/open_clip/model_configs/ViT-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/coca_ViT-B-32.json b/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..7e7eb520a6a0096e5602d509ecd6186e278f4725 --- /dev/null +++ b/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "attn_pooler_heads": 8 + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/coca_ViT-L-14.json b/open_clip/model_configs/coca_ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3d5ca4ca2338540f06852df5ff35ea6277e64555 --- /dev/null +++ b/open_clip/model_configs/coca_ViT-L-14.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "attn_pooler_heads": 12 + }, + "custom_text": true +} diff --git a/open_clip/model_configs/coca_base.json b/open_clip/model_configs/coca_base.json new file mode 100644 index 0000000000000000000000000000000000000000..cf8c6cecb78a49d7e7140145a0307cbd561077c2 --- /dev/null +++ b/open_clip/model_configs/coca_base.json @@ -0,0 +1,31 @@ +{ + "embed_dim": 512, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "vocab_size": 64000, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256, + "attn_pooler_heads": 8 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768, + "embed_cls": true, + "output_tokens": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/open_clip/model_configs/coca_roberta-ViT-B-32.json b/open_clip/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..fb46354b95a17a46d7fcfd9d504e917ee6c1608c --- /dev/null +++ b/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "linear", + "width": 768, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} diff --git a/open_clip/model_configs/convnext_base.json b/open_clip/model_configs/convnext_base.json new file mode 100644 index 0000000000000000000000000000000000000000..bb6dba181d950ea5081155c90d47e72c94816b80 --- /dev/null +++ b/open_clip/model_configs/convnext_base.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_base_w.json b/open_clip/model_configs/convnext_base_w.json new file mode 100644 index 0000000000000000000000000000000000000000..82ea7ae3659e5514f37ff982f0ab1141dff4bd18 --- /dev/null +++ b/open_clip/model_configs/convnext_base_w.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_base_w_320.json b/open_clip/model_configs/convnext_base_w_320.json new file mode 100644 index 0000000000000000000000000000000000000000..0a07c4e16abaa4015ecc5f82ec845de16e1f9d88 --- /dev/null +++ b/open_clip/model_configs/convnext_base_w_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_large.json b/open_clip/model_configs/convnext_large.json new file mode 100644 index 0000000000000000000000000000000000000000..c4a1fea73dbead71c218a0e74b9b15f9b252e3ef --- /dev/null +++ b/open_clip/model_configs/convnext_large.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_large_d.json b/open_clip/model_configs/convnext_large_d.json new file mode 100644 index 0000000000000000000000000000000000000000..ae8fed21b58e1a6a411daf8b792ee50f0ab42346 --- /dev/null +++ b/open_clip/model_configs/convnext_large_d.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_large_d_320.json b/open_clip/model_configs/convnext_large_d_320.json new file mode 100644 index 0000000000000000000000000000000000000000..54c3df36a6f56ace0b12ada24c13058de96feed8 --- /dev/null +++ b/open_clip/model_configs/convnext_large_d_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_small.json b/open_clip/model_configs/convnext_small.json new file mode 100644 index 0000000000000000000000000000000000000000..3592c2a5cd21aae8d2544931773cf7603f67ea28 --- /dev/null +++ b/open_clip/model_configs/convnext_small.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_small", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_tiny.json b/open_clip/model_configs/convnext_tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..ad11470f5ec40ffec771096971ce58d3d5b9249b --- /dev/null +++ b/open_clip/model_configs/convnext_tiny.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_tiny", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_xlarge.json b/open_clip/model_configs/convnext_xlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..2a909965932eef994177c829fefc2bdc1c219b3f --- /dev/null +++ b/open_clip/model_configs/convnext_xlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 20 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_xxlarge.json b/open_clip/model_configs/convnext_xxlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..23a55a681c346d1a315d8a163c1cb6ad495e6a91 --- /dev/null +++ b/open_clip/model_configs/convnext_xxlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/convnext_xxlarge_320.json b/open_clip/model_configs/convnext_xxlarge_320.json new file mode 100644 index 0000000000000000000000000000000000000000..ac5134ca12cbaa97772cde059270d345386a74c7 --- /dev/null +++ b/open_clip/model_configs/convnext_xxlarge_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/mt5-base-ViT-B-32.json b/open_clip/model_configs/mt5-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..58cad89cf0f446bbe15e4e25b1ac43424a828017 --- /dev/null +++ b/open_clip/model_configs/mt5-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "google/mt5-base", + "hf_tokenizer_name": "google/mt5-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/open_clip/model_configs/mt5-xl-ViT-H-14.json b/open_clip/model_configs/mt5-xl-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b432810777ba7269dbb0e89edfe65cdd27e7d255 --- /dev/null +++ b/open_clip/model_configs/mt5-xl-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "google/mt5-xl", + "hf_tokenizer_name": "google/mt5-xl", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/open_clip/model_configs/roberta-ViT-B-32.json b/open_clip/model_configs/roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..ed687d472a73bb2ac96025f355f80437ab14c260 --- /dev/null +++ b/open_clip/model_configs/roberta-ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/open_clip/model_configs/swin_base_patch4_window7_224.json b/open_clip/model_configs/swin_base_patch4_window7_224.json new file mode 100644 index 0000000000000000000000000000000000000000..bd6820f0cf2aa655e0a2723287f4b78895a58e6a --- /dev/null +++ b/open_clip/model_configs/swin_base_patch4_window7_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "swin_base_patch4_window7_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/vit_medium_patch16_gap_256.json b/open_clip/model_configs/vit_medium_patch16_gap_256.json new file mode 100644 index 0000000000000000000000000000000000000000..8843eaf08cad16c3e7b5f496fd650715c9573f65 --- /dev/null +++ b/open_clip/model_configs/vit_medium_patch16_gap_256.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_medium_patch16_gap_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json b/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json new file mode 100644 index 0000000000000000000000000000000000000000..ed217b202d5e6071c5307f4547c97ff4cfe2abd1 --- /dev/null +++ b/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_relpos_medium_patch16_cls_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json b/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..751bccc2c6fc41bc4ff20182de88d86739d518d9 --- /dev/null +++ b/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-base", + "hf_tokenizer_name": "xlm-roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json b/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..31f271faa9bbb7a9da53900b483a4c00a16f3c4a --- /dev/null +++ b/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-large", + "hf_tokenizer_name": "xlm-roberta-large", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/open_clip/modified_resnet.py b/open_clip/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c0b033a80e7d08a20a367050c5b1bc5d5292e7 --- /dev/null +++ b/open_clip/modified_resnet.py @@ -0,0 +1,181 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from open_clip.utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/open_clip/openai.py b/open_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4e13e876d6a7a3463b457e62c517cb063b1356 --- /dev/null +++ b/open_clip/openai.py @@ -0,0 +1,144 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + jit: bool = True, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + if precision.startswith('amp') or precision == 'fp32': + model.float() + elif precision == 'bf16': + convert_weights_to_lp(model, dtype=torch.bfloat16) + + return model + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 (typically for CPU) + if precision == 'fp32': + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + model.float() + + # ensure image_size attr available at consistent location for both jit and non-jit + model.visual.image_size = model.input_resolution.item() + return model diff --git a/open_clip/pretrained.py b/open_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..84f779c257627cb9075cb242a5baa51023a3a96a --- /dev/null +++ b/open_clip/pretrained.py @@ -0,0 +1,375 @@ +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Union + +from tqdm import tqdm + +from .version import __version__ + +try: + from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + + +_RN50 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN50_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN101 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN101_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN50x4 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), +) + +_RN50x16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), +) + +_RN50x64 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), +) + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + # laion400m_32k=_pcfg( + # url="", + # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + # laion400m_64k=_pcfg( + # url="", + # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + +_robertaViTB32 = dict( + laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), +) + +_xlmRobertaBaseViTB32 = dict( + laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), +) + +_xlmRobertaLargeFrozenViTH14 = dict( + frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), +) + +_convnext_base = dict( + laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), +) + +_convnext_base_w = dict( + laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), +) + +_convnext_base_w_320 = dict( + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), +) + +_convnext_large_d = dict( + laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), +) + +_convnext_large_d_320 = dict( + laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), +) + +_convnext_xxlarge = dict( + laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), +) + +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + +_coca_VITL14 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') +) + + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "RN50x64": _RN50x64, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-B-16-plus-240": _VITB16_PLUS_240, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, + "roberta-ViT-B-32": _robertaViTB32, + "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, + "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + "convnext_base": _convnext_base, + "convnext_base_w": _convnext_base_w, + "convnext_base_w_320": _convnext_base_w_320, + "convnext_large_d": _convnext_large_d, + "convnext_large_d_320": _convnext_large_d_320, + "convnext_xxlarge": _convnext_xxlarge, + "coca_ViT-B-32": _coca_VITB32, + "coca_ViT-L-14": _coca_VITL14, +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/open_clip/push_to_hf_hub.py b/open_clip/push_to_hf_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..23c0631c81dcb43829b7374fac09406ecefcb436 --- /dev/null +++ b/open_clip/push_to_hf_hub.py @@ -0,0 +1,243 @@ +import argparse +import json +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple + +import torch + +try: + from huggingface_hub import ( + create_repo, + get_hf_file_metadata, + hf_hub_download, + hf_hub_url, + repo_type_and_id_from_hf_id, + upload_folder, + ) + from huggingface_hub.utils import EntryNotFoundError + _has_hf_hub = True +except ImportError: + _has_hf_hub = False + +from .factory import create_model_from_pretrained, get_model_config, get_tokenizer +from .tokenizer import HFTokenizer + + +def save_config_for_hf( + model, + config_path: str, + model_config: Optional[dict] +): + preprocess_cfg = { + 'mean': model.visual.image_mean, + 'std': model.visual.image_std, + } + hf_config = { + 'model_cfg': model_config, + 'preprocess_cfg': preprocess_cfg, + } + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def save_for_hf( + model, + tokenizer: HFTokenizer, + model_config: dict, + save_directory: str, + weights_filename='open_clip_pytorch_model.bin', + config_filename='open_clip_config.json', +): + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + weights_path = save_directory / weights_filename + torch.save(model.state_dict(), weights_path) + + tokenizer.save_pretrained(save_directory) + + config_path = save_directory / config_filename + save_config_for_hf(model, config_path, model_config=model_config) + + +def push_to_hf_hub( + model, + tokenizer, + model_config: Optional[dict], + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, +): + if not isinstance(tokenizer, HFTokenizer): + # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 + tokenizer = HFTokenizer('openai/clip-vit-large-patch14') + + # Create repo if it doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if README file already exist in repo + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf( + model, + tokenizer=tokenizer, + model_config=model_config, + save_directory=tmpdir, + ) + + # Add readme if it does not exist + if not has_readme: + model_card = model_card or {} + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = generate_readme(model_card, model_name) + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) + + +def push_pretrained_to_hf_hub( + model_name, + pretrained: str, + repo_id: str, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, +): + model, preprocess_eval = create_model_from_pretrained( + model_name, + pretrained=pretrained, + image_mean=image_mean, + image_std=image_std, + ) + + model_config = get_model_config(model_name) + assert model_config + + tokenizer = get_tokenizer(model_name) + + push_to_hf_hub( + model=model, + tokenizer=tokenizer, + model_config=model_config, + repo_id=repo_id, + commit_message=commit_message, + token=token, + revision=revision, + private=private, + create_pr=create_pr, + model_card=model_card, + ) + + +def generate_readme(model_card: dict, model_name: str): + readme_text = "---\n" + readme_text += "tags:\n- zero-shot-image-classification\n- clip\n" + readme_text += "library_tag: open_clip\n" + readme_text += f"license: {model_card.get('license', 'mit')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + + return readme_text + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") + parser.add_argument( + "--model", type=str, help="Name of the model to use.", + ) + parser.add_argument( + "--pretrained", type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--repo-id", type=str, + help="Destination HF Hub repo-id ie 'organization/model_id'.", + ) + parser.add_argument( + '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override default image mean value of dataset') + parser.add_argument( + '--image-std', type=float, nargs='+', default=None, metavar='STD', + help='Override default image std deviation of of dataset') + args = parser.parse_args() + + print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') + + # FIXME add support to pass model_card json / template from file via cmd line + + push_pretrained_to_hf_hub( + args.model, + args.pretrained, + args.repo_id, + image_mean=args.image_mean, # override image mean/std if trained w/ non defaults + image_std=args.image_std, + ) + + print(f'{args.model} saved.') diff --git a/open_clip/timm_model.py b/open_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dc71a693f9a42ec01fd88d307661bc382b4d05bc --- /dev/null +++ b/open_clip/timm_model.py @@ -0,0 +1,127 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + drop_path=None, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + timm_kwargs = {} + if drop_path is not None: + timm_kwargs['drop_path_rate'] = drop_path + self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if pool in ('abs_attn', 'rot_attn'): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, 'projection layer needed if non-attention pooling is used.' + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/open_clip/tokenizer.py b/open_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..23fcfcbcb4ca051ba5bba7520918693001999282 --- /dev/null +++ b/open_clip/tokenizer.py @@ -0,0 +1,214 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/open_clip/transform.py b/open_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..748884a3c7cb7ece1ca521ca1dbf40bb74855007 --- /dev/null +++ b/open_clip/transform.py @@ -0,0 +1,133 @@ +import warnings +from dataclasses import dataclass, asdict +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None + interpolation: Optional[str] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + normalize = Normalize(mean=mean, std=std) + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time + aug_cfg_dict.setdefault('interpolation', 'random') + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + **aug_cfg_dict, + ) + else: + train_transform = Compose([ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + if aug_cfg_dict: + warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) diff --git a/open_clip/transformer.py b/open_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..94620a4205a8b0b802616cbe675f0895c80691ec --- /dev/null +++ b/open_clip/transformer.py @@ -0,0 +1,742 @@ +import os +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence, Tuple + +import torch +from loguru import logger +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple +try: + import xformers.ops as xops +except ImportError: + xops = None + print("Please 'pip install xformers'") + + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0., + xattn=False, + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + self.xattn_drop = attn_drop + self.xattn = xattn + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + if self.xattn: + q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + + # logger.debug(f'using memory efficient attention') + x = xops.memory_efficient_attention( + q, k, v, + p=self.xattn_drop, + scale=self.scale if self.logit_scale is None else None, + attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None, + # op=xops.MemoryEfficientAttentionFlashAttentionOp + ) + else: + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + if xattn: + self.attn = Attention(d_model, n_head, xattn=True) + else: + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + self.xattn = xattn + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + if self.xattn: + return self.attn(x, attn_mask=attn_mask) + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + # t = time.time() + x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn=False, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + logger.debug(f'xattn in transformer of CLIP is {xattn}') + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, + xattn=xattn) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + input_patchnorm: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False + ): + super().__init__() + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 + self.input_patchnorm = input_patchnorm + + if input_patchnorm: + patch_input_dim = patch_height * patch_width * 3 + self.patchnorm_pre_ln = LayerNorm(patch_input_dim) + self.conv1 = nn.Linear(patch_input_dim, width) + else: + self.patchnorm_pre_ln = nn.Identity() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.global_average_pool = global_average_pool + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.attn_pool = None + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + + def forward(self, x: torch.Tensor): + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) + x = self.patchnorm_pre_ln(x) + x = self.conv1(x) + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if self.attn_pool is not None: + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + else: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, + ): + super().__init__() + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + + xattn = (os.getenv('FLASH_TXT', 'f') == 't') + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + xattn=xattn + ) + + self.xattn = xattn + self.ln_final = norm_layer(width) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N: int): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + if self.cls_emb is not None: + pooled, tokens = x[:, -1], x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/open_clip/utils.py b/open_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51e80c5e296b24cae130ab0459baf268e0db7673 --- /dev/null +++ b/open_clip/utils.py @@ -0,0 +1,60 @@ +from itertools import repeat +import collections.abc + +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) diff --git a/open_clip/version.py b/open_clip/version.py new file mode 100644 index 0000000000000000000000000000000000000000..754dd42c2a768c881ebf544251fb6374a32f9b6a --- /dev/null +++ b/open_clip/version.py @@ -0,0 +1 @@ +__version__ = '2.15.0' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0821c16a1ff5372932f0aebd915a4af86c21dabe --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +torch==1.12.1 +torchvision==0.13.1 +accelerate==0.12.0 +absl-py +ml_collections +einops +ftfy==6.1.1 +transformers==4.23.1 +loguru +gradio==3.34.0 +omegaconf +wget +xformers=0.0.16 diff --git a/style_adapter/0102.pth b/style_adapter/0102.pth new file mode 100644 index 0000000000000000000000000000000000000000..88fa20fde3b05012731dee97627f066502f21824 --- /dev/null +++ b/style_adapter/0102.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c317c550da3782a64b38f065a2be7b3be8f003ea673139a0e0b2d0a997e1b752 +size 8258599 diff --git a/style_adapter/0103.pth b/style_adapter/0103.pth new file mode 100644 index 0000000000000000000000000000000000000000..8691176486c09e0e9ff9978703df6edf8b7a21d4 --- /dev/null +++ b/style_adapter/0103.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac9541eaa4f1c663b550468366a80867b58c00f692df86e9781680056110d775 +size 8258599 diff --git a/style_adapter/0106.pth b/style_adapter/0106.pth new file mode 100644 index 0000000000000000000000000000000000000000..2d99865fb019522ecf5f320384cd35ac1b9fa27e --- /dev/null +++ b/style_adapter/0106.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7a32ccc74cab655a1b80add036e9fc025a913a078efa44c5e9588cb647a8388 +size 8258523 diff --git a/style_adapter/0108.pth b/style_adapter/0108.pth new file mode 100644 index 0000000000000000000000000000000000000000..300afe7eb3b74fd990e3acef566a39c87598de3a --- /dev/null +++ b/style_adapter/0108.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec6b8f78315aaf918d55a70bf8225e0b93e7f8bc816fbb757eb62f8db7681237 +size 8258523 diff --git a/style_adapter/0301.pth b/style_adapter/0301.pth new file mode 100644 index 0000000000000000000000000000000000000000..c3062fbc5d4b1a9247ace18663f37cefdabbc6ca --- /dev/null +++ b/style_adapter/0301.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12221f3631ac031ff7339fdc4258bb8c1f9462e818561c90147e9bba8436b14b +size 8258523 diff --git a/style_adapter/0305.pth b/style_adapter/0305.pth new file mode 100644 index 0000000000000000000000000000000000000000..4a8603e750040183999197d6e05896776fe6f20f --- /dev/null +++ b/style_adapter/0305.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a04f123c0b64fbc363662ba9fdfa6bfef479c5d7f4badb4a70efb1ce5868b571 +size 8258523 diff --git a/taming/models/vqgan.py b/taming/models/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..58864b3d5c947952a0934e4c6c1b7cc4eba829c0 --- /dev/null +++ b/taming/models/vqgan.py @@ -0,0 +1,114 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from taming.modules.diffusionmodules.model import Encoder, Decoder +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + + +class VQModel(nn.Module): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__() + self.n_embed = n_embed + self.embed_dim = embed_dim + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, sane_index_shape=sane_index_shape) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.image_key = image_key + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.eval() + self.requires_grad_(False) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu") + if "state_dict" in sd.keys(): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + print("Strict load") + self.load_state_dict(sd, strict=True) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.get_codebook_entry(code_b, [*code_b.shape, self.embed_dim]) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, info = self.encode(input) + return quant, diff, info + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + return x.float() + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +def get_model(config_file='vq-f16-jax.yaml'): + from omegaconf import OmegaConf + config = OmegaConf.load(f'configs/vae_configs/{config_file}').model + return VQModel(ddconfig=config.params.ddconfig, + lossconfig=config.params.lossconfig, + n_embed=config.params.n_embed, + embed_dim=config.params.embed_dim, + ckpt_path='assets/vqgan_jax_strongaug.ckpt') diff --git a/taming/modules/diffusionmodules/model.py b/taming/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6278f8cd75520cfbcbc178e75959f58c906b25 --- /dev/null +++ b/taming/modules/diffusionmodules/model.py @@ -0,0 +1,324 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False) + else: + self.nin_shortcut = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(h) + else: + x = self.nin_shortcut(h) + + return x+h + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=False, in_channels, + resolution, z_channels, double_z=True, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1, + bias=False) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + down = nn.Module() + down.block = block + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + up = nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks): + h = self.up[i_level].block[i_block](h, temb) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h diff --git a/taming/modules/util.py b/taming/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee16385d8b1342a2d60a5f1aa5cadcfbe934bd8 --- /dev/null +++ b/taming/modules/util.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn + + +def count_params(model): + total_params = sum(p.numel() for p in model.parameters()) + return total_params + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class Labelator(AbstractEncoder): + """Net2Net Interface for Class-Conditional Model""" + def __init__(self, n_classes, quantize_interface=True): + super().__init__() + self.n_classes = n_classes + self.quantize_interface = quantize_interface + + def encode(self, c): + c = c[:,None] + if self.quantize_interface: + return c, None, [None, None, c.long()] + return c + + +class SOSProvider(AbstractEncoder): + # for unconditional training + def __init__(self, sos_token, quantize_interface=True): + super().__init__() + self.sos_token = sos_token + self.quantize_interface = quantize_interface + + def encode(self, x): + # get batch size from data and replicate sos_token + c = torch.ones(x.shape[0], 1)*self.sos_token + c = c.long().to(x.device) + if self.quantize_interface: + return c, None, [None, None, c] + return c diff --git a/taming/modules/vqvae/quantize.py b/taming/modules/vqvae/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..d75544e41fa01bce49dd822b1037963d62f79b51 --- /dev/null +++ b/taming/modules/vqvae/quantize.py @@ -0,0 +1,445 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + #.........\end + + # with: + # .........\start + #min_encoding_indices = torch.argmin(d, dim=1) + #z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:,None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, + kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, + remap=None, unknown_index="random"): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:,self.used,...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:,self.used,...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b*h*w == indices.shape[0] + indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" + assert rescale_logits==False, "Only for interface compatible with Gumbel" + assert return_logits==False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0],-1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad = False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + #normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + remap=None, unknown_index="random"): + super().__init__() + self.codebook_dim = codebook_dim + self.num_tokens = num_tokens + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + #z, 'b c h w -> b h w c' + z = rearrange(z, 'b c h w -> b h w c') + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + #EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + #EMA embedding average + embed_sum = encodings.transpose(0,1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + #normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + #z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, (perplexity, encodings, encoding_indices) diff --git a/taming/util.py b/taming/util.py new file mode 100644 index 0000000000000000000000000000000000000000..06053e5defb87977f9ab07e69bf4da12201de9b7 --- /dev/null +++ b/taming/util.py @@ -0,0 +1,157 @@ +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + + +if __name__ == "__main__": + config = {"keya": "a", + "keyb": "b", + "keyc": + {"cc1": 1, + "cc2": 2, + } + } + from omegaconf import OmegaConf + config = OmegaConf.create(config) + print(config) + retrieve(config, "keya") + diff --git a/timm/__init__.py b/timm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db3d3f22f4defb0f2f6ee7ef53a7e88fe3a7d380 --- /dev/null +++ b/timm/__init__.py @@ -0,0 +1,3 @@ +from .version import __version__ +from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ + is_scriptable, is_exportable, set_scriptable, set_exportable diff --git a/timm/data/__init__.py b/timm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15617859ef06d070adb96d884e486dbf54f56099 --- /dev/null +++ b/timm/data/__init__.py @@ -0,0 +1,10 @@ +from .constants import * +from .config import resolve_data_config +from .dataset import Dataset, DatasetTar, AugMixDataset +from .transforms import * +from .loader import create_loader +from .transforms_factory import create_transform +from .mixup import Mixup, FastCollateMixup +from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ + rand_augment_transform, auto_augment_transform +from .real_labels import RealLabelsImagenet diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf5464dc12b4ac7058ea2219ea3f6665ec6c012 --- /dev/null +++ b/timm/data/auto_augment.py @@ -0,0 +1,817 @@ +""" AutoAugment, RandAugment, and AugMix for PyTorch + +This code implements the searched ImageNet policies with various tweaks and improvements and +does not include any of the search code. + +AA and RA Implementation adapted from: + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py + +AugMix adapted from: + https://github.com/google-research/augmix + +Papers: + AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 + Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import random +import math +import re +from PIL import Image, ImageOps, ImageEnhance, ImageChops +import PIL +import numpy as np + + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + +_HPARAMS_DEFAULT = dict( + translate_const=250, + img_mean=_FILL, +) + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return level, + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return (level / _MAX_LEVEL) * 1.8 + 0.1, + + +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * .9 + level = 1.0 + _randomly_negate(level) + return level, + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return level, + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams['translate_const'] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return level, + + +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get('translate_pct', 0.45) + level = (level / _MAX_LEVEL) * translate_pct + level = _randomly_negate(level) + return level, + + +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 4), + + +def _posterize_increasing_level_to_arg(level, hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return 4 - _posterize_level_to_arg(level, hparams)[0], + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 4) + 4, + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 256), + + +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return 256 - _solarize_level_to_arg(level, _hparams)[0], + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return int((level / _MAX_LEVEL) * 110), + + +LEVEL_TO_ARG = { + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'Rotate': _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + 'Posterize': _posterize_level_to_arg, + 'PosterizeIncreasing': _posterize_increasing_level_to_arg, + 'PosterizeOriginal': _posterize_original_level_to_arg, + 'Solarize': _solarize_level_to_arg, + 'SolarizeIncreasing': _solarize_increasing_level_to_arg, + 'SolarizeAdd': _solarize_add_level_to_arg, + 'Color': _enhance_level_to_arg, + 'ColorIncreasing': _enhance_increasing_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'ContrastIncreasing': _enhance_increasing_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'BrightnessIncreasing': _enhance_increasing_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'SharpnessIncreasing': _enhance_increasing_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': _translate_abs_level_to_arg, + 'TranslateY': _translate_abs_level_to_arg, + 'TranslateXRel': _translate_rel_level_to_arg, + 'TranslateYRel': _translate_rel_level_to_arg, +} + + +NAME_TO_OP = { + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'PosterizeIncreasing': posterize, + 'PosterizeOriginal': posterize, + 'Solarize': solarize, + 'SolarizeIncreasing': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'ColorIncreasing': color, + 'Contrast': contrast, + 'ContrastIncreasing': contrast, + 'Brightness': brightness, + 'BrightnessIncreasing': brightness, + 'Sharpness': sharpness, + 'SharpnessIncreasing': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, +} + + +class AugmentOp: + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = dict( + fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + ) + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get('magnitude_std', 0) + + def __call__(self, img): + if self.prob < 1.0 and random.random() > self.prob: + return img + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() + return self.aug_fn(img, *level_args, **self.kwargs) + + +def auto_augment_policy_v0(hparams): + # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference. + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy_v0r(hparams): + # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used + # in Google research implementation (number of bits discarded increases with magnitude) + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy_original(hparams): + # ImageNet policy from https://arxiv.org/abs/1805.09501 + policy = [ + [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], + [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], + [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)], + [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], + [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)], + [('Rotate', 0.8, 8), ('Color', 0.4, 0)], + [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], + [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Rotate', 0.8, 8), ('Color', 1.0, 2)], + [('Color', 0.8, 8), ('Solarize', 0.8, 7)], + [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], + [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], + [('Color', 0.4, 0), ('Equalize', 0.6, 3)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy_originalr(hparams): + # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation + policy = [ + [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], + [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], + [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)], + [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], + [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)], + [('Rotate', 0.8, 8), ('Color', 0.4, 0)], + [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], + [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Rotate', 0.8, 8), ('Color', 1.0, 2)], + [('Color', 0.8, 8), ('Solarize', 0.8, 7)], + [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], + [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], + [('Color', 0.4, 0), ('Equalize', 0.6, 3)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy(name='v0', hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + if name == 'original': + return auto_augment_policy_original(hparams) + elif name == 'originalr': + return auto_augment_policy_originalr(hparams) + elif name == 'v0': + return auto_augment_policy_v0(hparams) + elif name == 'v0r': + return auto_augment_policy_v0r(hparams) + else: + assert False, 'Unknown AA policy (%s)' % name + + +class AutoAugment: + + def __init__(self, policy): + self.policy = policy + + def __call__(self, img): + sub_policy = random.choice(self.policy) + for op in sub_policy: + img = op(img) + return img + + +def auto_augment_transform(config_str, hparams): + """ + Create a AutoAugment transform + + :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). + The remaining sections, not order sepecific determine + 'mstd' - float std deviation of magnitude noise applied + Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 + + :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme + + :return: A PyTorch compatible Transform + """ + config = config_str.split('-') + policy_name = config[0] + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + else: + assert False, 'Unknown AutoAugment config section' + aa_policy = auto_augment_policy(policy_name, hparams=hparams) + return AutoAugment(aa_policy) + + +_RAND_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'Posterize', + 'Solarize', + 'SolarizeAdd', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + #'Cutout' # NOTE I've implement this as random erasing separately +] + + +_RAND_INCREASING_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeIncreasing', + 'SolarizeIncreasing', + 'SolarizeAdd', + 'ColorIncreasing', + 'ContrastIncreasing', + 'BrightnessIncreasing', + 'SharpnessIncreasing', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + #'Cutout' # NOTE I've implement this as random erasing separately +] + + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + 'Rotate': 0.3, + 'ShearX': 0.2, + 'ShearY': 0.2, + 'TranslateXRel': 0.1, + 'TranslateYRel': 0.1, + 'Color': .025, + 'Sharpness': 0.025, + 'AutoContrast': 0.025, + 'Solarize': .005, + 'SolarizeAdd': .005, + 'Contrast': .005, + 'Brightness': .005, + 'Equalize': .005, + 'Posterize': 0, + 'Invert': 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [AugmentOp( + name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams): + """ + Create a RandAugment transform + + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS + config = config_str.split('-') + assert config[0] == 'rand' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'inc': + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'w': + weight_idx = int(val) + else: + assert False, 'Unknown RandAugment config section' + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) + + +_AUGMIX_TRANSFORMS = [ + 'AutoContrast', + 'ColorIncreasing', # not in paper + 'ContrastIncreasing', # not in paper + 'BrightnessIncreasing', # not in paper + 'SharpnessIncreasing', # not in paper + 'Equalize', + 'Rotate', + 'PosterizeIncreasing', + 'SolarizeIncreasing', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', +] + + +def augmix_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _AUGMIX_TRANSFORMS + return [AugmentOp( + name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] + + +class AugMixAugment: + """ AugMix Transform + Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py + From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - + https://arxiv.org/abs/1912.02781 + """ + def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): + self.ops = ops + self.alpha = alpha + self.width = width + self.depth = depth + self.blended = blended # blended mode is faster but not well tested + + def _calc_blended_weights(self, ws, m): + ws = ws * m + cump = 1. + rws = [] + for w in ws[::-1]: + alpha = w / cump + cump *= (1 - alpha) + rws.append(alpha) + return np.array(rws[::-1], dtype=np.float32) + + def _apply_blended(self, img, mixing_weights, m): + # This is my first crack and implementing a slightly faster mixed augmentation. Instead + # of accumulating the mix for each chain in a Numpy array and then blending with original, + # it recomputes the blending coefficients and applies one PIL image blend per chain. + # TODO the results appear in the right ballpark but they differ by more than rounding. + img_orig = img.copy() + ws = self._calc_blended_weights(mixing_weights, m) + for w in ws: + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img_orig # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + img = Image.blend(img, img_aug, w) + return img + + def _apply_basic(self, img, mixing_weights, m): + # This is a literal adaptation of the paper/official implementation without normalizations and + # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the + # typical augmentation transforms, could use a GPU / Kornia implementation. + img_shape = img.size[0], img.size[1], len(img.getbands()) + mixed = np.zeros(img_shape, dtype=np.float32) + for mw in mixing_weights: + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + mixed += mw * np.asarray(img_aug, dtype=np.float32) + np.clip(mixed, 0, 255., out=mixed) + mixed = Image.fromarray(mixed.astype(np.uint8)) + return Image.blend(img, mixed, m) + + def __call__(self, img): + mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) + m = np.float32(np.random.beta(self.alpha, self.alpha)) + if self.blended: + mixed = self._apply_blended(img, mixing_weights, m) + else: + mixed = self._apply_basic(img, mixing_weights, m) + return mixed + + +def augment_and_mix_transform(config_str, hparams): + """ Create AugMix PyTorch transform + + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude (severity) of augmentation mix (default: 3) + 'w' - integer width of augmentation chain (default: 3) + 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) + 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) + 'mstd' - float std deviation of magnitude noise applied (default: 0) + Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 + + :param hparams: Other hparams (kwargs) for the Augmentation transforms + + :return: A PyTorch compatible Transform + """ + magnitude = 3 + width = 3 + depth = -1 + alpha = 1. + blended = False + config = config_str.split('-') + assert config[0] == 'augmix' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'w': + width = int(val) + elif key == 'd': + depth = int(val) + elif key == 'a': + alpha = float(val) + elif key == 'b': + blended = bool(val) + else: + assert False, 'Unknown AugMix config section' + ops = augmix_ops(magnitude=magnitude, hparams=hparams) + return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended) diff --git a/timm/data/config.py b/timm/data/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb4bda84ae573d7a2a18ae791ca514f36b4dd9d --- /dev/null +++ b/timm/data/config.py @@ -0,0 +1,75 @@ +import logging +from .constants import * + + +_logger = logging.getLogger(__name__) + + +def resolve_data_config(args, default_cfg={}, model=None, verbose=True): + new_config = {} + default_cfg = default_cfg + if not default_cfg and model is not None and hasattr(model, 'default_cfg'): + default_cfg = model.default_cfg + + # Resolve input/image size + in_chans = 3 + if 'chans' in args and args['chans'] is not None: + in_chans = args['chans'] + + input_size = (in_chans, 224, 224) + if 'input_size' in args and args['input_size'] is not None: + assert isinstance(args['input_size'], (tuple, list)) + assert len(args['input_size']) == 3 + input_size = tuple(args['input_size']) + in_chans = input_size[0] # input_size overrides in_chans + elif 'img_size' in args and args['img_size'] is not None: + assert isinstance(args['img_size'], int) + input_size = (in_chans, args['img_size'], args['img_size']) + elif 'input_size' in default_cfg: + input_size = default_cfg['input_size'] + new_config['input_size'] = input_size + + # resolve interpolation method + new_config['interpolation'] = 'bicubic' + if 'interpolation' in args and args['interpolation']: + new_config['interpolation'] = args['interpolation'] + elif 'interpolation' in default_cfg: + new_config['interpolation'] = default_cfg['interpolation'] + + # resolve dataset + model mean for normalization + new_config['mean'] = IMAGENET_DEFAULT_MEAN + if 'mean' in args and args['mean'] is not None: + mean = tuple(args['mean']) + if len(mean) == 1: + mean = tuple(list(mean) * in_chans) + else: + assert len(mean) == in_chans + new_config['mean'] = mean + elif 'mean' in default_cfg: + new_config['mean'] = default_cfg['mean'] + + # resolve dataset + model std deviation for normalization + new_config['std'] = IMAGENET_DEFAULT_STD + if 'std' in args and args['std'] is not None: + std = tuple(args['std']) + if len(std) == 1: + std = tuple(list(std) * in_chans) + else: + assert len(std) == in_chans + new_config['std'] = std + elif 'std' in default_cfg: + new_config['std'] = default_cfg['std'] + + # resolve default crop percentage + new_config['crop_pct'] = DEFAULT_CROP_PCT + if 'crop_pct' in args and args['crop_pct'] is not None: + new_config['crop_pct'] = args['crop_pct'] + elif 'crop_pct' in default_cfg: + new_config['crop_pct'] = default_cfg['crop_pct'] + + if verbose: + _logger.info('Data processing configuration for current model + dataset:') + for n, v in new_config.items(): + _logger.info('\t%s: %s' % (n, str(v))) + + return new_config diff --git a/timm/data/constants.py b/timm/data/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d4a01b0316989a3f5142167f1e384b098bc930 --- /dev/null +++ b/timm/data/constants.py @@ -0,0 +1,7 @@ +DEFAULT_CROP_PCT = 0.875 +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) +IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) +IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) diff --git a/timm/data/dataset.py b/timm/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..99d99917b282c60be78d6b2f6faa313eaadd9225 --- /dev/null +++ b/timm/data/dataset.py @@ -0,0 +1,215 @@ +""" Quick n Simple Image Folder, Tarfile based DataSet + +Hacked together by / Copyright 2020 Ross Wightman +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch.utils.data as data + +import os +import re +import torch +import tarfile +from PIL import Image + + +IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg'] + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): + labels = [] + filenames = [] + for root, subdirs, files in os.walk(folder, topdown=False): + rel_path = os.path.relpath(root, folder) if (root != folder) else '' + label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') + for f in files: + base, ext = os.path.splitext(f) + if ext.lower() in types: + filenames.append(os.path.join(root, f)) + labels.append(label) + if class_to_idx is None: + # building class index + unique_labels = set(labels) + sorted_labels = list(sorted(unique_labels, key=natural_key)) + class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] + if sort: + images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) + return images_and_targets, class_to_idx + + +def load_class_map(filename, root=''): + class_map_path = filename + if not os.path.exists(class_map_path): + class_map_path = os.path.join(root, filename) + assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename + class_map_ext = os.path.splitext(filename)[-1].lower() + if class_map_ext == '.txt': + with open(class_map_path) as f: + class_to_idx = {v.strip(): k for k, v in enumerate(f)} + else: + assert False, 'Unsupported class map extension' + return class_to_idx + + +class Dataset(data.Dataset): + + def __init__( + self, + root, + load_bytes=False, + transform=None, + class_map=''): + + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) + images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) + if len(images) == 0: + raise RuntimeError(f'Found 0 images in subfolders of {root}. ' + f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}') + self.root = root + self.samples = images + self.imgs = self.samples # torchvision ImageFolder compat + self.class_to_idx = class_to_idx + self.load_bytes = load_bytes + self.transform = transform + + def __getitem__(self, index): + path, target = self.samples[index] + img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') + if self.transform is not None: + img = self.transform(img) + if target is None: + target = torch.zeros(1).long() + return img, target + + def __len__(self): + return len(self.samples) + + def filename(self, index, basename=False, absolute=False): + filename = self.samples[index][0] + if basename: + filename = os.path.basename(filename) + elif not absolute: + filename = os.path.relpath(filename, self.root) + return filename + + def filenames(self, basename=False, absolute=False): + fn = lambda x: x + if basename: + fn = os.path.basename + elif not absolute: + fn = lambda x: os.path.relpath(x, self.root) + return [fn(x[0]) for x in self.samples] + + +def _extract_tar_info(tarfile, class_to_idx=None, sort=True): + files = [] + labels = [] + for ti in tarfile.getmembers(): + if not ti.isfile(): + continue + dirname, basename = os.path.split(ti.path) + label = os.path.basename(dirname) + ext = os.path.splitext(basename)[1] + if ext.lower() in IMG_EXTENSIONS: + files.append(ti) + labels.append(label) + if class_to_idx is None: + unique_labels = set(labels) + sorted_labels = list(sorted(unique_labels, key=natural_key)) + class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] + if sort: + tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) + return tarinfo_and_targets, class_to_idx + + +class DatasetTar(data.Dataset): + + def __init__(self, root, load_bytes=False, transform=None, class_map=''): + + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) + assert os.path.isfile(root) + self.root = root + with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later + self.samples, self.class_to_idx = _extract_tar_info(tf, class_to_idx) + self.imgs = self.samples + self.tarfile = None # lazy init in __getitem__ + self.load_bytes = load_bytes + self.transform = transform + + def __getitem__(self, index): + if self.tarfile is None: + self.tarfile = tarfile.open(self.root) + tarinfo, target = self.samples[index] + iob = self.tarfile.extractfile(tarinfo) + img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB') + if self.transform is not None: + img = self.transform(img) + if target is None: + target = torch.zeros(1).long() + return img, target + + def __len__(self): + return len(self.samples) + + def filename(self, index, basename=False): + filename = self.samples[index][0].name + if basename: + filename = os.path.basename(filename) + return filename + + def filenames(self, basename=False): + fn = os.path.basename if basename else lambda x: x + return [fn(x[0].name) for x in self.samples] + + +class AugMixDataset(torch.utils.data.Dataset): + """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" + + def __init__(self, dataset, num_splits=2): + self.augmentation = None + self.normalize = None + self.dataset = dataset + if self.dataset.transform is not None: + self._set_transforms(self.dataset.transform) + self.num_splits = num_splits + + def _set_transforms(self, x): + assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' + self.dataset.transform = x[0] + self.augmentation = x[1] + self.normalize = x[2] + + @property + def transform(self): + return self.dataset.transform + + @transform.setter + def transform(self, x): + self._set_transforms(x) + + def _normalize(self, x): + return x if self.normalize is None else self.normalize(x) + + def __getitem__(self, i): + x, y = self.dataset[i] # all splits share the same dataset base transform + x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split) + # run the full augmentation on the remaining splits + for _ in range(self.num_splits - 1): + x_list.append(self._normalize(self.augmentation(x))) + return tuple(x_list), y + + def __len__(self): + return len(self.dataset) diff --git a/timm/data/distributed_sampler.py b/timm/data/distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9506a8805dc3cec25498cd32d7c7476b1b372f8a --- /dev/null +++ b/timm/data/distributed_sampler.py @@ -0,0 +1,51 @@ +import math +import torch +from torch.utils.data import Sampler +import torch.distributed as dist + + +class OrderedDistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples diff --git a/timm/data/loader.py b/timm/data/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..317f77df8a9f18d47058a1beca471c9a0d886dab --- /dev/null +++ b/timm/data/loader.py @@ -0,0 +1,257 @@ +""" Loader Factory, Fast Collate, CUDA Prefetcher + +Prefetcher and Fast Collate inspired by NVIDIA APEX example at +https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch.utils.data +import numpy as np + +from .transforms_factory import create_transform +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .distributed_sampler import OrderedDistributedSampler +from .random_erasing import RandomErasing +from .mixup import FastCollateMixup + + +def fast_collate(batch): + """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" + assert isinstance(batch[0], tuple) + batch_size = len(batch) + if isinstance(batch[0][0], tuple): + # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position + # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position + inner_tuple_size = len(batch[0][0]) + flattened_batch_size = batch_size * inner_tuple_size + targets = torch.zeros(flattened_batch_size, dtype=torch.int64) + tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length + for j in range(inner_tuple_size): + targets[i + j * batch_size] = batch[i][1] + tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) + return tensor, targets + elif isinstance(batch[0][0], np.ndarray): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i] += torch.from_numpy(batch[i][0]) + return tensor, targets + elif isinstance(batch[0][0], torch.Tensor): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i].copy_(batch[i][0]) + return tensor, targets + else: + assert False + + +class PrefetchLoader: + + def __init__(self, + loader, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + fp16=False, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0): + self.loader = loader + self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) + self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) + self.fp16 = fp16 + if fp16: + self.mean = self.mean.half() + self.std = self.std.half() + if re_prob > 0.: + self.random_erasing = RandomErasing( + probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) + else: + self.random_erasing = None + + def __iter__(self): + stream = torch.cuda.Stream() + first = True + + for next_input, next_target in self.loader: + with torch.cuda.stream(stream): + next_input = next_input.cuda(non_blocking=True) + next_target = next_target.cuda(non_blocking=True) + if self.fp16: + next_input = next_input.half().sub_(self.mean).div_(self.std) + else: + next_input = next_input.float().sub_(self.mean).div_(self.std) + if self.random_erasing is not None: + next_input = self.random_erasing(next_input) + + if not first: + yield input, target + else: + first = False + + torch.cuda.current_stream().wait_stream(stream) + input = next_input + target = next_target + + yield input, target + + def __len__(self): + return len(self.loader) + + @property + def sampler(self): + return self.loader.sampler + + @property + def dataset(self): + return self.loader.dataset + + @property + def mixup_enabled(self): + if isinstance(self.loader.collate_fn, FastCollateMixup): + return self.loader.collate_fn.mixup_enabled + else: + return False + + @mixup_enabled.setter + def mixup_enabled(self, x): + if isinstance(self.loader.collate_fn, FastCollateMixup): + self.loader.collate_fn.mixup_enabled = x + + +def create_loader( + dataset, + input_size, + batch_size, + is_training=False, + use_prefetcher=True, + no_aug=False, + re_prob=0., + re_mode='const', + re_count=1, + re_split=False, + scale=None, + ratio=None, + hflip=0.5, + vflip=0., + color_jitter=0.4, + auto_augment=None, + num_aug_splits=0, + interpolation='bilinear', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_workers=1, + distributed=False, + crop_pct=None, + collate_fn=None, + pin_memory=False, + fp16=False, + tf_preprocessing=False, + use_multi_epochs_loader=False +): + re_num_splits = 0 + if re_split: + # apply RE to second half of batch if no aug split otherwise line up with aug split + re_num_splits = num_aug_splits or 2 + dataset.transform = create_transform( + input_size, + is_training=is_training, + use_prefetcher=use_prefetcher, + no_aug=no_aug, + scale=scale, + ratio=ratio, + hflip=hflip, + vflip=vflip, + color_jitter=color_jitter, + auto_augment=auto_augment, + interpolation=interpolation, + mean=mean, + std=std, + crop_pct=crop_pct, + tf_preprocessing=tf_preprocessing, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits, + separate=num_aug_splits > 0, + ) + + sampler = None + if distributed: + if is_training: + sampler = torch.utils.data.distributed.DistributedSampler(dataset) + else: + # This will add extra duplicate entries to result in equal num + # of samples per-process, will slightly alter validation results + sampler = OrderedDistributedSampler(dataset) + + if collate_fn is None: + collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate + + loader_class = torch.utils.data.DataLoader + + if use_multi_epochs_loader: + loader_class = MultiEpochsDataLoader + + loader = loader_class( + dataset, + batch_size=batch_size, + shuffle=sampler is None and is_training, + num_workers=num_workers, + sampler=sampler, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=is_training, + ) + if use_prefetcher: + prefetch_re_prob = re_prob if is_training and not no_aug else 0. + loader = PrefetchLoader( + loader, + mean=mean, + std=std, + fp16=fp16, + re_prob=prefetch_re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits + ) + + return loader + + +class MultiEpochsDataLoader(torch.utils.data.DataLoader): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._DataLoader__initialized = False + self.batch_sampler = _RepeatSampler(self.batch_sampler) + self._DataLoader__initialized = True + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler(object): + """ Sampler that repeats forever. + + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) diff --git a/timm/data/mixup.py b/timm/data/mixup.py new file mode 100644 index 0000000000000000000000000000000000000000..38477548a070a1a338ed18ddc74cdaf5050f84be --- /dev/null +++ b/timm/data/mixup.py @@ -0,0 +1,316 @@ +""" Mixup and Cutmix + +Papers: +mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) + +CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) + +Code Reference: +CutMix: https://github.com/clovaai/CutMix-PyTorch + +Hacked together by / Copyright 2020 Ross Wightman +""" +import numpy as np +import torch + + +def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): + x = x.long().view(-1, 1) + return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) + + +def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): + off_value = smoothing / num_classes + on_value = 1. - smoothing + off_value + y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) + y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) + return y1 * lam + y2 * (1. - lam) + + +def rand_bbox(img_shape, lam, margin=0., count=None): + """ Standard CutMix bounding-box + Generates a random square bbox based on lambda value. This impl includes + support for enforcing a border margin as percent of bbox dimensions. + + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) + count (int): Number of bbox to generate + """ + ratio = np.sqrt(1 - lam) + img_h, img_w = img_shape[-2:] + cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) + margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) + cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) + cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) + yl = np.clip(cy - cut_h // 2, 0, img_h) + yh = np.clip(cy + cut_h // 2, 0, img_h) + xl = np.clip(cx - cut_w // 2, 0, img_w) + xh = np.clip(cx + cut_w // 2, 0, img_w) + return yl, yh, xl, xh + + +def rand_bbox_minmax(img_shape, minmax, count=None): + """ Min-Max CutMix bounding-box + Inspired by Darknet cutmix impl, generates a random rectangular bbox + based on min/max percent values applied to each dimension of the input image. + + Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. + + Args: + img_shape (tuple): Image shape as tuple + minmax (tuple or list): Min and max bbox ratios (as percent of image size) + count (int): Number of bbox to generate + """ + assert len(minmax) == 2 + img_h, img_w = img_shape[-2:] + cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) + cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) + yl = np.random.randint(0, img_h - cut_h, size=count) + xl = np.random.randint(0, img_w - cut_w, size=count) + yu = yl + cut_h + xu = xl + cut_w + return yl, yu, xl, xu + + +def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): + """ Generate bbox and apply lambda correction. + """ + if ratio_minmax is not None: + yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) + else: + yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) + if correct_lam or ratio_minmax is not None: + bbox_area = (yu - yl) * (xu - xl) + lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) + return (yl, yu, xl, xu), lam + + +class Mixup: + """ Mixup/Cutmix that applies different params to each element or whole batch + + Args: + mixup_alpha (float): mixup alpha value, mixup is active if > 0. + cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. + cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. + prob (float): probability of applying mixup or cutmix per batch or element + switch_prob (float): probability of switching to cutmix instead of mixup when both are active + mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) + correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders + label_smoothing (float): apply label smoothing to the mixed target tensor + num_classes (int): number of classes for target + """ + def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, + mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): + self.mixup_alpha = mixup_alpha + self.cutmix_alpha = cutmix_alpha + self.cutmix_minmax = cutmix_minmax + if self.cutmix_minmax is not None: + assert len(self.cutmix_minmax) == 2 + # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe + self.cutmix_alpha = 1.0 + self.mix_prob = prob + self.switch_prob = switch_prob + self.label_smoothing = label_smoothing + self.num_classes = num_classes + self.mode = mode + self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix + self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) + + def _params_per_elem(self, batch_size): + lam = np.ones(batch_size, dtype=np.float32) + use_cutmix = np.zeros(batch_size, dtype=np.bool) + if self.mixup_enabled: + if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: + use_cutmix = np.random.rand(batch_size) < self.switch_prob + lam_mix = np.where( + use_cutmix, + np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), + np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) + elif self.mixup_alpha > 0.: + lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) + elif self.cutmix_alpha > 0.: + use_cutmix = np.ones(batch_size, dtype=np.bool) + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) + else: + assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." + lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) + return lam, use_cutmix + + def _params_per_batch(self): + lam = 1. + use_cutmix = False + if self.mixup_enabled and np.random.rand() < self.mix_prob: + if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: + use_cutmix = np.random.rand() < self.switch_prob + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ + np.random.beta(self.mixup_alpha, self.mixup_alpha) + elif self.mixup_alpha > 0.: + lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) + elif self.cutmix_alpha > 0.: + use_cutmix = True + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) + else: + assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." + lam = float(lam_mix) + return lam, use_cutmix + + def _mix_elem(self, x): + batch_size = len(x) + lam_batch, use_cutmix = self._params_per_elem(batch_size) + x_orig = x.clone() # need to keep an unmodified original for mixing source + for i in range(batch_size): + j = batch_size - i - 1 + lam = lam_batch[i] + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + x[i] = x[i] * lam + x_orig[j] * (1 - lam) + return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + + def _mix_pair(self, x): + batch_size = len(x) + lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + x_orig = x.clone() # need to keep an unmodified original for mixing source + for i in range(batch_size // 2): + j = batch_size - i - 1 + lam = lam_batch[i] + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] + x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + x[i] = x[i] * lam + x_orig[j] * (1 - lam) + x[j] = x[j] * lam + x_orig[i] * (1 - lam) + lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) + return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + + def _mix_batch(self, x): + lam, use_cutmix = self._params_per_batch() + if lam == 1.: + return 1. + if use_cutmix: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] + else: + x_flipped = x.flip(0).mul_(1. - lam) + x.mul_(lam).add_(x_flipped) + return lam + + def __call__(self, x, target): + assert len(x) % 2 == 0, 'Batch size should be even when using this' + if self.mode == 'elem': + lam = self._mix_elem(x) + elif self.mode == 'pair': + lam = self._mix_pair(x) + else: + lam = self._mix_batch(x) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) + return x, target + + +class FastCollateMixup(Mixup): + """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch + + A Mixup impl that's performed while collating the batches. + """ + + def _mix_elem_collate(self, output, batch, half=False): + batch_size = len(batch) + num_elem = batch_size // 2 if half else batch_size + assert len(output) == num_elem + lam_batch, use_cutmix = self._params_per_elem(num_elem) + for i in range(num_elem): + j = batch_size - i - 1 + lam = lam_batch[i] + mixed = batch[i][0] + if lam != 1.: + if use_cutmix[i]: + if not half: + mixed = mixed.copy() + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.rint(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) + if half: + lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) + return torch.tensor(lam_batch).unsqueeze(1) + + def _mix_pair_collate(self, output, batch): + batch_size = len(batch) + lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + for i in range(batch_size // 2): + j = batch_size - i - 1 + lam = lam_batch[i] + mixed_i = batch[i][0] + mixed_j = batch[j][0] + assert 0 <= lam <= 1.0 + if lam < 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + patch_i = mixed_i[:, yl:yh, xl:xh].copy() + mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] + mixed_j[:, yl:yh, xl:xh] = patch_i + lam_batch[i] = lam + else: + mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) + mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) + mixed_i = mixed_temp + np.rint(mixed_j, out=mixed_j) + np.rint(mixed_i, out=mixed_i) + output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) + output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) + lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) + return torch.tensor(lam_batch).unsqueeze(1) + + def _mix_batch_collate(self, output, batch): + batch_size = len(batch) + lam, use_cutmix = self._params_per_batch() + if use_cutmix: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + for i in range(batch_size): + j = batch_size - i - 1 + mixed = batch[i][0] + if lam != 1.: + if use_cutmix: + mixed = mixed.copy() # don't want to modify the original while iterating + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] + else: + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.rint(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) + return lam + + def __call__(self, batch, _=None): + batch_size = len(batch) + assert batch_size % 2 == 0, 'Batch size should be even when using this' + half = 'half' in self.mode + if half: + batch_size //= 2 + output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + if self.mode == 'elem' or self.mode == 'half': + lam = self._mix_elem_collate(output, batch, half=half) + elif self.mode == 'pair': + lam = self._mix_pair_collate(output, batch) + else: + lam = self._mix_batch_collate(output, batch) + target = torch.tensor([b[1] for b in batch], dtype=torch.int64) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') + target = target[:batch_size] + return output, target + diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py new file mode 100644 index 0000000000000000000000000000000000000000..78967d105dd77b56a3ccefb6ff1838a8058c0384 --- /dev/null +++ b/timm/data/random_erasing.py @@ -0,0 +1,97 @@ +""" Random Erasing (Cutout) + +Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 +Copyright Zhun Zhong & Liang Zheng + +Hacked together by / Copyright 2020 Ross Wightman +""" +import random +import math +import torch + + +def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): + # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() + # paths, flip the order so normal is run on CPU if this becomes a problem + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + if per_pixel: + return torch.empty(patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomErasing: + """ Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. + Args: + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. + min_aspect: Minimum aspect ratio of erased area. + mode: pixel color mode, one of 'const', 'rand', or 'pixel' + 'const' - erase block is constant color of 0 for all channels + 'rand' - erase block is same per-channel random (normal) color + 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. + """ + + def __init__( + self, + probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, + mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count + self.num_splits = num_splits + mode = mode.lower() + self.rand_color = False + self.per_pixel = False + if mode == 'rand': + self.rand_color = True # per block random normal + elif mode == 'pixel': + self.per_pixel = True # per pixel random normal + else: + assert not mode or mode == 'const' + self.device = device + + def _erase(self, img, chan, img_h, img_w, dtype): + if random.random() > self.probability: + return + area = img_h * img_w + count = self.min_count if self.min_count == self.max_count else \ + random.randint(self.min_count, self.max_count) + for _ in range(count): + for attempt in range(10): + target_area = random.uniform(self.min_area, self.max_area) * area / count + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, self.rand_color, (chan, h, w), + dtype=dtype, device=self.device) + break + + def __call__(self, input): + if len(input.size()) == 3: + self._erase(input, *input.size(), input.dtype) + else: + batch_size, chan, img_h, img_w = input.size() + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 + for i in range(batch_start, batch_size): + self._erase(input[i], chan, img_h, img_w, input.dtype) + return input diff --git a/timm/data/real_labels.py b/timm/data/real_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..939c34867e7915ce3e4cc7da04a5bc1653ec4f2c --- /dev/null +++ b/timm/data/real_labels.py @@ -0,0 +1,42 @@ +""" Real labels evaluator for ImageNet +Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 +Based on Numpy example at https://github.com/google-research/reassessed-imagenet + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os +import json +import numpy as np + + +class RealLabelsImagenet: + + def __init__(self, filenames, real_json='real.json', topk=(1, 5)): + with open(real_json) as real_labels: + real_labels = json.load(real_labels) + real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} + self.real_labels = real_labels + self.filenames = filenames + assert len(self.filenames) == len(self.real_labels) + self.topk = topk + self.is_correct = {k: [] for k in topk} + self.sample_idx = 0 + + def add_result(self, output): + maxk = max(self.topk) + _, pred_batch = output.topk(maxk, 1, True, True) + pred_batch = pred_batch.cpu().numpy() + for pred in pred_batch: + filename = self.filenames[self.sample_idx] + filename = os.path.basename(filename) + if self.real_labels[filename]: + for k in self.topk: + self.is_correct[k].append( + any([p in self.real_labels[filename] for p in pred[:k]])) + self.sample_idx += 1 + + def get_accuracy(self, k=None): + if k is None: + return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} + else: + return float(np.mean(self.is_correct[k])) * 100 diff --git a/timm/data/tf_preprocessing.py b/timm/data/tf_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..899cf36477ff31863d19659842728d45159837e9 --- /dev/null +++ b/timm/data/tf_preprocessing.py @@ -0,0 +1,236 @@ +""" Tensorflow Preprocessing Adapter + +Allows use of Tensorflow preprocessing pipeline in PyTorch Transform + +Copyright of original Tensorflow code below. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ImageNet preprocessing for MnasNet.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import numpy as np + +IMAGE_SIZE = 224 +CROP_PADDING = 32 + + +def distorted_bounding_box_crop(image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=100, + scope=None): + """Generates cropped_image using one of the bboxes randomly distorted. + + See `tf.image.sample_distorted_bounding_box` for more documentation. + + Args: + image_bytes: `Tensor` of binary image data. + bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` + where each coordinate is [0, 1) and the coordinates are arranged + as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole + image. + min_object_covered: An optional `float`. Defaults to `0.1`. The cropped + area of the image must contain at least this fraction of any bounding + box supplied. + aspect_ratio_range: An optional list of `float`s. The cropped area of the + image must have an aspect ratio = width / height within this range. + area_range: An optional list of `float`s. The cropped area of the image + must contain a fraction of the supplied image within in this range. + max_attempts: An optional `int`. Number of attempts at generating a cropped + region of the image of the specified constraints. After `max_attempts` + failures, return the entire image. + scope: Optional `str` for name scope. + Returns: + cropped image `Tensor` + """ + with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]): + shape = tf.image.extract_jpeg_shape(image_bytes) + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + shape, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + + return image + + +def _at_least_x_are_equal(a, b, x): + """At least `x` of `a` and `b` `Tensors` are equal.""" + match = tf.equal(a, b) + match = tf.cast(match, tf.int32) + return tf.greater_equal(tf.reduce_sum(match), x) + + +def _decode_and_random_crop(image_bytes, image_size, resize_method): + """Make a random crop of image_size.""" + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + image = distorted_bounding_box_crop( + image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(3. / 4, 4. / 3.), + area_range=(0.08, 1.0), + max_attempts=10, + scope=None) + original_shape = tf.image.extract_jpeg_shape(image_bytes) + bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) + + image = tf.cond( + bad, + lambda: _decode_and_center_crop(image_bytes, image_size), + lambda: tf.image.resize([image], [image_size, image_size], resize_method)[0]) + + return image + + +def _decode_and_center_crop(image_bytes, image_size, resize_method): + """Crops to center of image with padding then scales image_size.""" + shape = tf.image.extract_jpeg_shape(image_bytes) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ((image_size / (image_size + CROP_PADDING)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), + tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + crop_window = tf.stack([offset_height, offset_width, + padded_center_crop_size, padded_center_crop_size]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + image = tf.image.resize([image], [image_size, image_size], resize_method)[0] + + return image + + +def _flip(image): + """Random horizontal image flip.""" + image = tf.image.random_flip_left_right(image) + return image + + +def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + interpolation: image interpolation method + + Returns: + A preprocessed image `Tensor`. + """ + resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR + image = _decode_and_random_crop(image_bytes, image_size, resize_method) + image = _flip(image) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( + image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + return image + + +def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + interpolation: image interpolation method + + Returns: + A preprocessed image `Tensor`. + """ + resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR + image = _decode_and_center_crop(image_bytes, image_size, resize_method) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( + image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + return image + + +def preprocess_image(image_bytes, + is_training=False, + use_bfloat16=False, + image_size=IMAGE_SIZE, + interpolation='bicubic'): + """Preprocesses the given image. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + is_training: `bool` for whether the preprocessing is for training. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + interpolation: image interpolation method + + Returns: + A preprocessed image `Tensor` with value range of [0, 255]. + """ + if is_training: + return preprocess_for_train(image_bytes, use_bfloat16, image_size, interpolation) + else: + return preprocess_for_eval(image_bytes, use_bfloat16, image_size, interpolation) + + +class TfPreprocessTransform: + + def __init__(self, is_training=False, size=224, interpolation='bicubic'): + self.is_training = is_training + self.size = size[0] if isinstance(size, tuple) else size + self.interpolation = interpolation + self._image_bytes = None + self.process_image = self._build_tf_graph() + self.sess = None + + def _build_tf_graph(self): + with tf.device('/cpu:0'): + self._image_bytes = tf.placeholder( + shape=[], + dtype=tf.string, + ) + img = preprocess_image( + self._image_bytes, self.is_training, False, self.size, self.interpolation) + return img + + def __call__(self, image_bytes): + if self.sess is None: + self.sess = tf.Session() + img = self.sess.run(self.process_image, feed_dict={self._image_bytes: image_bytes}) + img = img.round().clip(0, 255).astype(np.uint8) + if img.ndim < 3: + img = np.expand_dims(img, axis=-1) + img = np.rollaxis(img, 2) # HWC to CHW + return img diff --git a/timm/data/transforms.py b/timm/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b08e309957c31026f77cab1c8c121b198ec5ac --- /dev/null +++ b/timm/data/transforms.py @@ -0,0 +1,158 @@ +import torch +import torchvision.transforms.functional as F +from PIL import Image +import warnings +import math +import random +import numpy as np + + +class ToNumpy: + + def __call__(self, pil_img): + np_img = np.array(pil_img, dtype=np.uint8) + if np_img.ndim < 3: + np_img = np.expand_dims(np_img, axis=-1) + np_img = np.rollaxis(np_img, 2) # HWC to CHW + return np_img + + +class ToTensor: + + def __init__(self, dtype=torch.float32): + self.dtype = dtype + + def __call__(self, pil_img): + np_img = np.array(pil_img, dtype=np.uint8) + if np_img.ndim < 3: + np_img = np.expand_dims(np_img, axis=-1) + np_img = np.rollaxis(np_img, 2) # HWC to CHW + return torch.from_numpy(np_img).to(dtype=self.dtype) + + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +def _pil_interp(method): + if method == 'bicubic': + return Image.BICUBIC + elif method == 'lanczos': + return Image.LANCZOS + elif method == 'hamming': + return Image.HAMMING + else: + # default bilinear, do we want to allow nearest? + return Image.BILINEAR + + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +class RandomResizedCropAndInterpolation: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), + interpolation='bilinear'): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..01c9fcf238ec7f856053683348cc9edde1640370 --- /dev/null +++ b/timm/data/transforms_factory.py @@ -0,0 +1,236 @@ +""" Transforms Factory +Factory methods for building image transforms for use with TIMM (PyTorch Image Models) + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math + +import torch +from torchvision import transforms + +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT +from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform +from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor +from timm.data.random_erasing import RandomErasing + + +def transforms_noaug_train( + img_size=224, + interpolation='bilinear', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, +): + if interpolation == 'random': + # random interpolation not supported with no-aug + interpolation = 'bilinear' + tfl = [ + transforms.Resize(img_size, _pil_interp(interpolation)), + transforms.CenterCrop(img_size) + ] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + return transforms.Compose(tfl) + + +def transforms_imagenet_train( + img_size=224, + scale=None, + ratio=None, + hflip=0.5, + vflip=0., + color_jitter=0.4, + auto_augment=None, + interpolation='random', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0, + separate=False, +): + """ + If separate==True, the transforms are returned as a tuple of 3 separate transforms + for use in a mixing dataset that passes + * all data through the first (primary) transform, called the 'clean' data + * a portion of the data through the secondary transform + * normalizes and converts the branches above with the third, final transform + """ + scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range + ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range + primary_tfl = [ + RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)] + if hflip > 0.: + primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] + if vflip > 0.: + primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] + + secondary_tfl = [] + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = dict( + translate_const=int(img_size_min * 0.45), + img_mean=tuple([min(255, round(255 * x)) for x in mean]), + ) + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + if auto_augment.startswith('rand'): + secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] + elif auto_augment.startswith('augmix'): + aa_params['translate_pct'] = 0.3 + secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] + else: + secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] + elif color_jitter is not None: + # color jitter is enabled when not using AA + if isinstance(color_jitter, (list, tuple)): + # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue + color_jitter = (float(color_jitter),) * 3 + secondary_tfl += [transforms.ColorJitter(*color_jitter)] + + final_tfl = [] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + final_tfl += [ToNumpy()] + else: + final_tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + if re_prob > 0.: + final_tfl.append( + RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) + + if separate: + return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) + else: + return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) + + +def transforms_imagenet_eval( + img_size=224, + crop_pct=None, + interpolation='bilinear', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD): + crop_pct = crop_pct or DEFAULT_CROP_PCT + + if isinstance(img_size, tuple): + assert len(img_size) == 2 + if img_size[-1] == img_size[-2]: + # fall-back to older behaviour so Resize scales to shortest edge if target is square + scale_size = int(math.floor(img_size[0] / crop_pct)) + else: + scale_size = tuple([int(x / crop_pct) for x in img_size]) + else: + scale_size = int(math.floor(img_size / crop_pct)) + + tfl = [ + transforms.Resize(scale_size, _pil_interp(interpolation)), + transforms.CenterCrop(img_size), + ] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + + return transforms.Compose(tfl) + + +def create_transform( + input_size, + is_training=False, + use_prefetcher=False, + no_aug=False, + scale=None, + ratio=None, + hflip=0.5, + vflip=0., + color_jitter=0.4, + auto_augment=None, + interpolation='bilinear', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0, + crop_pct=None, + tf_preprocessing=False, + separate=False): + + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if tf_preprocessing and use_prefetcher: + assert not separate, "Separate transforms not supported for TF preprocessing" + from timm.data.tf_preprocessing import TfPreprocessTransform + transform = TfPreprocessTransform( + is_training=is_training, size=img_size, interpolation=interpolation) + else: + if is_training and no_aug: + assert not separate, "Cannot perform split augmentation with no_aug" + transform = transforms_noaug_train( + img_size, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std) + elif is_training: + transform = transforms_imagenet_train( + img_size, + scale=scale, + ratio=ratio, + hflip=hflip, + vflip=vflip, + color_jitter=color_jitter, + auto_augment=auto_augment, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits, + separate=separate) + else: + assert not separate, "Separate transforms not supported for validation preprocessing" + transform = transforms_imagenet_eval( + img_size, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std, + crop_pct=crop_pct) + + return transform diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28a686ce896f4335dcf074717b52077b43c237d7 --- /dev/null +++ b/timm/loss/__init__.py @@ -0,0 +1,3 @@ +from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from .jsd import JsdCrossEntropy +from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel \ No newline at end of file diff --git a/timm/loss/asymmetric_loss.py b/timm/loss/asymmetric_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b10f9c797c2cb3b2652302717b592dada216f3 --- /dev/null +++ b/timm/loss/asymmetric_loss.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn + + +class AsymmetricLossMultiLabel(nn.Module): + def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): + super(AsymmetricLossMultiLabel, self).__init__() + + self.gamma_neg = gamma_neg + self.gamma_pos = gamma_pos + self.clip = clip + self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss + self.eps = eps + + def forward(self, x, y): + """" + Parameters + ---------- + x: input logits + y: targets (multi-label binarized vector) + """ + + # Calculating Probabilities + x_sigmoid = torch.sigmoid(x) + xs_pos = x_sigmoid + xs_neg = 1 - x_sigmoid + + # Asymmetric Clipping + if self.clip is not None and self.clip > 0: + xs_neg = (xs_neg + self.clip).clamp(max=1) + + # Basic CE calculation + los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) + los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) + loss = los_pos + los_neg + + # Asymmetric Focusing + if self.gamma_neg > 0 or self.gamma_pos > 0: + if self.disable_torch_grad_focal_loss: + torch._C.set_grad_enabled(False) + pt0 = xs_pos * y + pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p + pt = pt0 + pt1 + one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) + one_sided_w = torch.pow(1 - pt, one_sided_gamma) + if self.disable_torch_grad_focal_loss: + torch._C.set_grad_enabled(True) + loss *= one_sided_w + + return -loss.sum() + + +class AsymmetricLossSingleLabel(nn.Module): + def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): + super(AsymmetricLossSingleLabel, self).__init__() + + self.eps = eps + self.logsoftmax = nn.LogSoftmax(dim=-1) + self.targets_classes = [] # prevent gpu repeated memory allocation + self.gamma_pos = gamma_pos + self.gamma_neg = gamma_neg + self.reduction = reduction + + def forward(self, inputs, target, reduction=None): + """" + Parameters + ---------- + x: input logits + y: targets (1-hot vector) + """ + + num_classes = inputs.size()[-1] + log_preds = self.logsoftmax(inputs) + self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) + + # ASL weights + targets = self.targets_classes + anti_targets = 1 - targets + xs_pos = torch.exp(log_preds) + xs_neg = 1 - xs_pos + xs_pos = xs_pos * targets + xs_neg = xs_neg * anti_targets + asymmetric_w = torch.pow(1 - xs_pos - xs_neg, + self.gamma_pos * targets + self.gamma_neg * anti_targets) + log_preds = log_preds * asymmetric_w + + if self.eps > 0: # label smoothing + self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) + + # loss calculation + loss = - self.targets_classes.mul(log_preds) + + loss = loss.sum(dim=-1) + if self.reduction == 'mean': + loss = loss.mean() + + return loss diff --git a/timm/loss/cross_entropy.py b/timm/loss/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..60bef646cc6c31fd734f234346dbc4255def6622 --- /dev/null +++ b/timm/loss/cross_entropy.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LabelSmoothingCrossEntropy(nn.Module): + """ + NLL loss with label smoothing. + """ + def __init__(self, smoothing=0.1): + """ + Constructor for the LabelSmoothing module. + :param smoothing: label smoothing factor + """ + super(LabelSmoothingCrossEntropy, self).__init__() + assert smoothing < 1.0 + self.smoothing = smoothing + self.confidence = 1. - smoothing + + def forward(self, x, target): + logprobs = F.log_softmax(x, dim=-1) + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.mean() + + +class SoftTargetCrossEntropy(nn.Module): + + def __init__(self): + super(SoftTargetCrossEntropy, self).__init__() + + def forward(self, x, target): + loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) + return loss.mean() diff --git a/timm/loss/jsd.py b/timm/loss/jsd.py new file mode 100644 index 0000000000000000000000000000000000000000..dd64e156c23d27aa03817a587ae367e8175fc126 --- /dev/null +++ b/timm/loss/jsd.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .cross_entropy import LabelSmoothingCrossEntropy + + +class JsdCrossEntropy(nn.Module): + """ Jensen-Shannon Divergence + Cross-Entropy Loss + + Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py + From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - + https://arxiv.org/abs/1912.02781 + + Hacked together by / Copyright 2020 Ross Wightman + """ + def __init__(self, num_splits=3, alpha=12, smoothing=0.1): + super().__init__() + self.num_splits = num_splits + self.alpha = alpha + if smoothing is not None and smoothing > 0: + self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) + else: + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def __call__(self, output, target): + split_size = output.shape[0] // self.num_splits + assert split_size * self.num_splits == output.shape[0] + logits_split = torch.split(output, split_size) + + # Cross-entropy is only computed on clean images + loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) + probs = [F.softmax(logits, dim=1) for logits in logits_split] + + # Clamp mixture distribution to avoid exploding KL divergence + logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() + loss += self.alpha * sum([F.kl_div( + logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) + return loss diff --git a/timm/models/__init__.py b/timm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53765fc86c89c4ba2ad74a064528b8e68d956065 --- /dev/null +++ b/timm/models/__init__.py @@ -0,0 +1,34 @@ +from .cspnet import * +from .densenet import * +from .dla import * +from .dpn import * +from .efficientnet import * +from .gluon_resnet import * +from .gluon_xception import * +from .hrnet import * +from .inception_resnet_v2 import * +from .inception_v3 import * +from .inception_v4 import * +from .mobilenetv3 import * +from .nasnet import * +from .pnasnet import * +from .regnet import * +from .res2net import * +from .resnest import * +from .resnet import * +from .rexnet import * +from .selecsls import * +from .senet import * +from .sknet import * +from .tresnet import * +from .vision_transformer import * +from .vovnet import * +from .xception import * +from .xception_aligned import * + +from .factory import create_model +from .helpers import load_checkpoint, resume_checkpoint +from .layers import TestTimePoolHead, apply_test_time_pool +from .layers import convert_splitbn_model +from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit +from .registry import * diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ca9eaf160e21f2f9e69029787d99de2d41f21d54 --- /dev/null +++ b/timm/models/cspnet.py @@ -0,0 +1,453 @@ +"""PyTorch CspNet + +A PyTorch implementation of Cross Stage Partial Networks including: +* CSPResNet50 +* CSPResNeXt50 +* CSPDarkNet53 +* and DarkNet53 for good measure + +Based on paper `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929 + +Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStagePartialNetworks + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct, DropPath, create_attn, get_norm_act_layer +from .registry import register_model + + +__all__ = ['CspNet'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.887, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + 'cspresnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'), + 'cspresnet50d': _cfg(url=''), + 'cspresnet50w': _cfg(url=''), + 'cspresnext50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth', + input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.875 # FIXME I trained this at 224x224, not 256 like ref impl + ), + 'cspresnext50_iabn': _cfg(url=''), + 'cspdarknet53': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'), + 'cspdarknet53_iabn': _cfg(url=''), + 'darknet53': _cfg(url=''), +} + + +model_cfgs = dict( + cspresnet50=dict( + stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), + stage=dict( + out_chs=(128, 256, 512, 1024), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + exp_ratio=(2.,) * 4, + bottle_ratio=(0.5,) * 4, + block_ratio=(1.,) * 4, + cross_linear=True, + ) + ), + cspresnet50d=dict( + stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), + stage=dict( + out_chs=(128, 256, 512, 1024), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + exp_ratio=(2.,) * 4, + bottle_ratio=(0.5,) * 4, + block_ratio=(1.,) * 4, + cross_linear=True, + ) + ), + cspresnet50w=dict( + stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), + stage=dict( + out_chs=(256, 512, 1024, 2048), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + exp_ratio=(1.,) * 4, + bottle_ratio=(0.25,) * 4, + block_ratio=(0.5,) * 4, + cross_linear=True, + ) + ), + cspresnext50=dict( + stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), + stage=dict( + out_chs=(256, 512, 1024, 2048), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + groups=(32,) * 4, + exp_ratio=(1.,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + cross_linear=True, + ) + ), + cspdarknet53=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 2, 8, 8, 4), + stride=(2,) * 5, + exp_ratio=(2.,) + (1.,) * 4, + bottle_ratio=(0.5,) + (1.0,) * 4, + block_ratio=(1.,) + (0.5,) * 4, + down_growth=True, + ) + ), + darknet53=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 2, 8, 8, 4), + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + ) + ) +) + + +def create_stem( + in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='', + act_layer=None, norm_layer=None, aa_layer=None): + stem = nn.Sequential() + if not isinstance(out_chs, (tuple, list)): + out_chs = [out_chs] + assert len(out_chs) + in_c = in_chans + for i, out_c in enumerate(out_chs): + conv_name = f'conv{i + 1}' + stem.add_module(conv_name, ConvBnAct( + in_c, out_c, kernel_size, stride=stride if i == 0 else 1, + act_layer=act_layer, norm_layer=norm_layer)) + in_c = out_c + last_conv = conv_name + if pool: + if aa_layer is not None: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) + stem.add_module('aa', aa_layer(channels=in_c, stride=2)) + else: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv])) + + +class ResBottleneck(nn.Module): + """ ResNe(X)t Bottleneck Block + """ + + def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(ResBottleneck, self).__init__() + mid_chs = int(round(out_chs * bottle_ratio)) + ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + + self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs) + self.conv2 = ConvBnAct(mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs) + self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None + self.conv3 = ConvBnAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) + self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None + self.drop_path = drop_path + self.act3 = act_layer(inplace=True) + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.attn2 is not None: + x = self.attn2(x) + x = self.conv3(x) + if self.attn3 is not None: + x = self.attn3(x) + if self.drop_path is not None: + x = self.drop_path(x) + x = x + shortcut + # FIXME partial shortcut needed if first block handled as per original, not used for my current impl + #x[:, :shortcut.size(1)] += shortcut + x = self.act3(x) + return x + + +class DarkBlock(nn.Module): + """ DarkNet Block + """ + + def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, + drop_block=None, drop_path=None): + super(DarkBlock, self).__init__() + mid_chs = int(round(out_chs * bottle_ratio)) + ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs) + self.conv2 = ConvBnAct(mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs) + self.attn = create_attn(attn_layer, channels=out_chs) + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.attn is not None: + x = self.attn(x) + if self.drop_path is not None: + x = self.drop_path(x) + x = x + shortcut + return x + + +class CrossStage(nn.Module): + """Cross Stage.""" + def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1., + groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None, + block_fn=ResBottleneck, **block_kwargs): + super(CrossStage, self).__init__() + first_dilation = first_dilation or dilation + down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels + exp_chs = int(round(out_chs * exp_ratio)) + block_out_chs = int(round(out_chs * block_ratio)) + conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) + + if stride != 1 or first_dilation != dilation: + self.conv_down = ConvBnAct( + in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) + prev_chs = down_chs + else: + self.conv_down = None + prev_chs = in_chs + + # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also, + # there is also special case for the first stage for some of the model that results in uneven split + # across the two paths. I did it this way for simplicity for now. + self.conv_exp = ConvBnAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) + prev_chs = exp_chs // 2 # output of conv_exp is always split in two + + self.blocks = nn.Sequential() + for i in range(depth): + drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None + self.blocks.add_module(str(i), block_fn( + prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + prev_chs = block_out_chs + + # transition convs + self.conv_transition_b = ConvBnAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs) + self.conv_transition = ConvBnAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) + + def forward(self, x): + if self.conv_down is not None: + x = self.conv_down(x) + x = self.conv_exp(x) + xs, xb = x.chunk(2, dim=1) + xb = self.blocks(xb) + out = self.conv_transition(torch.cat([xs, self.conv_transition_b(xb)], dim=1)) + return out + + +class DarkStage(nn.Module): + """DarkNet stage.""" + + def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1, + first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs): + super(DarkStage, self).__init__() + first_dilation = first_dilation or dilation + + self.conv_down = ConvBnAct( + in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'), + aa_layer=block_kwargs.get('aa_layer', None)) + + prev_chs = out_chs + block_out_chs = int(round(out_chs * block_ratio)) + self.blocks = nn.Sequential() + for i in range(depth): + drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None + self.blocks.add_module(str(i), block_fn( + prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + prev_chs = block_out_chs + + def forward(self, x): + x = self.conv_down(x) + x = self.blocks(x) + return x + + +def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): + # get per stage args for stage and containing blocks, calculate strides to meet target output_stride + num_stages = len(cfg['depth']) + if 'groups' not in cfg: + cfg['groups'] = (1,) * num_stages + if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)): + cfg['down_growth'] = (cfg['down_growth'],) * num_stages + if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)): + cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages + cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \ + [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])] + stage_strides = [] + stage_dilations = [] + stage_first_dilations = [] + dilation = 1 + for cfg_stride in cfg['stride']: + stage_first_dilations.append(dilation) + if curr_stride >= output_stride: + dilation *= cfg_stride + stride = 1 + else: + stride = cfg_stride + curr_stride *= stride + stage_strides.append(stride) + stage_dilations.append(dilation) + cfg['stride'] = stage_strides + cfg['dilation'] = stage_dilations + cfg['first_dilation'] = stage_first_dilations + stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())] + return stage_args + + +class CspNet(nn.Module): + """Cross Stage Partial base model. + + Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929 + Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks + + NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the + darknet impl. I did it this way for simplicity and less special cases. + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., + act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0., + zero_init_last_bn=True, stage_fn=CrossStage, block_fn=ResBottleneck): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + + # Construct the stem + self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args) + self.feature_info = [stem_feat_info] + prev_chs = stem_feat_info['num_chs'] + curr_stride = stem_feat_info['reduction'] # reduction does not include pool + if cfg['stem']['pool']: + curr_stride *= 2 + + # Construct the stages + per_stage_args = _cfg_to_stage_args( + cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate) + self.stages = nn.Sequential() + for i, sa in enumerate(per_stage_args): + self.stages.add_module( + str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn)) + prev_chs = sa['out_chs'] + curr_stride *= sa['stride'] + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + + # Construct the head + self.num_features = prev_chs + self.head = ClassifierHead( + in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_cspnet(variant, pretrained=False, **kwargs): + cfg_variant = variant.split('_')[0] + return build_model_with_cfg( + CspNet, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], **kwargs) + + +@register_model +def cspresnet50(pretrained=False, **kwargs): + return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnet50d(pretrained=False, **kwargs): + return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnet50w(pretrained=False, **kwargs): + return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnext50(pretrained=False, **kwargs): + return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnext50_iabn(pretrained=False, **kwargs): + norm_layer = get_norm_act_layer('iabn') + return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs) + + +@register_model +def cspdarknet53(pretrained=False, **kwargs): + return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) + + +@register_model +def cspdarknet53_iabn(pretrained=False, **kwargs): + norm_layer = get_norm_act_layer('iabn') + return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) + + +@register_model +def darknet53(pretrained=False, **kwargs): + return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) diff --git a/timm/models/densenet.py b/timm/models/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e2056458959d0f1117ed27afac456253ce9061 --- /dev/null +++ b/timm/models/densenet.py @@ -0,0 +1,385 @@ +"""Pytorch Densenet implementation w/ tweaks +This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with +fixed kwargs passthrough and addition of dynamic global avg/max pool. +""" +import re +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch.jit.annotations import List + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import BatchNormAct2d, create_norm_act, BlurPool2d, create_classifier +from .registry import register_model + +__all__ = ['DenseNet'] + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'features.conv0', 'classifier': 'classifier', + } + + +default_cfgs = { + 'densenet121': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'), + 'densenet121d': _cfg(url=''), + 'densenetblur121d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'), + 'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'), + 'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'), + 'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'), + 'densenet264': _cfg(url=''), + 'densenet264d_iabn': _cfg(url=''), + 'tv_densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'), +} + + +class DenseLayer(nn.Module): + def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d, + drop_rate=0., memory_efficient=False): + super(DenseLayer, self).__init__() + self.add_module('norm1', norm_layer(num_input_features)), + self.add_module('conv1', nn.Conv2d( + num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), + self.add_module('norm2', norm_layer(bn_size * growth_rate)), + self.add_module('conv2', nn.Conv2d( + bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + + def bottleneck_fn(self, xs): + # type: (List[torch.Tensor]) -> torch.Tensor + concated_features = torch.cat(xs, 1) + bottleneck_output = self.conv1(self.norm1(concated_features)) # noqa: T484 + return bottleneck_output + + # todo: rewrite when torchscript supports any + def any_requires_grad(self, x): + # type: (List[torch.Tensor]) -> bool + for tensor in x: + if tensor.requires_grad: + return True + return False + + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, x): + # type: (List[torch.Tensor]) -> torch.Tensor + def closure(*xs): + return self.bottleneck_fn(xs) + + return cp.checkpoint(closure, *x) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (torch.Tensor) + pass + + # torchscript does not yet support *args, so we overload method + # allowing it to take either a List[Tensor] or single Tensor + def forward(self, x): # noqa: F811 + if isinstance(x, torch.Tensor): + prev_features = [x] + else: + prev_features = x + + if self.memory_efficient and self.any_requires_grad(prev_features): + if torch.jit.is_scripting(): + raise Exception("Memory Efficient not supported in JIT") + bottleneck_output = self.call_checkpoint_bottleneck(prev_features) + else: + bottleneck_output = self.bottleneck_fn(prev_features) + + new_features = self.conv2(self.norm2(bottleneck_output)) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + return new_features + + +class DenseBlock(nn.ModuleDict): + _version = 2 + + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, + drop_rate=0., memory_efficient=False): + super(DenseBlock, self).__init__() + for i in range(num_layers): + layer = DenseLayer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + norm_layer=norm_layer, + drop_rate=drop_rate, + memory_efficient=memory_efficient, + ) + self.add_module('denselayer%d' % (i + 1), layer) + + def forward(self, init_features): + features = [init_features] + for name, layer in self.items(): + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseTransition(nn.Sequential): + def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None): + super(DenseTransition, self).__init__() + self.add_module('norm', norm_layer(num_input_features)) + self.add_module('conv', nn.Conv2d( + num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) + if aa_layer is not None: + self.add_module('pool', aa_layer(num_output_features, stride=2)) + else: + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +class DenseNet(nn.Module): + r"""Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" `_ + + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 4 ints) - how many layers in each pooling block + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + num_classes (int) - number of classification classes + memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, + but slower. Default: *False*. See `"paper" `_ + """ + + def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='', + num_classes=1000, in_chans=3, global_pool='avg', + norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False, + aa_stem_only=True): + self.num_classes = num_classes + self.drop_rate = drop_rate + super(DenseNet, self).__init__() + + # Stem + deep_stem = 'deep' in stem_type # 3x3 deep stem + num_init_features = growth_rate * 2 + if aa_layer is None: + stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + stem_pool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=num_init_features, stride=2)]) + if deep_stem: + stem_chs_1 = stem_chs_2 = growth_rate + if 'tiered' in stem_type: + stem_chs_1 = 3 * (growth_rate // 4) + stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4) + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)), + ('norm0', norm_layer(stem_chs_1)), + ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)), + ('norm1', norm_layer(stem_chs_2)), + ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)), + ('norm2', norm_layer(num_init_features)), + ('pool0', stem_pool), + ])) + else: + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ('norm0', norm_layer(num_init_features)), + ('pool0', stem_pool), + ])) + self.feature_info = [ + dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')] + current_stride = 4 + + # DenseBlocks + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + norm_layer=norm_layer, + drop_rate=drop_rate, + memory_efficient=memory_efficient + ) + module_name = f'denseblock{(i + 1)}' + self.features.add_module(module_name, block) + num_features = num_features + num_layers * growth_rate + transition_aa_layer = None if aa_stem_only else aa_layer + if i != len(block_config) - 1: + self.feature_info += [ + dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)] + current_stride *= 2 + trans = DenseTransition( + num_input_features=num_features, num_output_features=num_features // 2, + norm_layer=norm_layer, aa_layer=transition_aa_layer) + self.features.add_module(f'transition{i + 1}', trans) + num_features = num_features // 2 + + # Final batch norm + self.features.add_module('norm5', norm_layer(num_features)) + + self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')] + self.num_features = num_features + + # Linear layer + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + # Official init from torch repo. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.bias, 0) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + return self.features(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + # both classifier and block drop? + # if self.drop_rate > 0.: + # x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x + + +def _filter_torchvision_pretrained(state_dict): + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + return state_dict + + +def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs): + kwargs['growth_rate'] = growth_rate + kwargs['block_config'] = block_config + return build_model_with_cfg( + DenseNet, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, **kwargs) + + +@register_model +def densenet121(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenetblur121d(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep', + aa_layer=BlurPool2d, **kwargs) + return model + + +@register_model +def densenet121d(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', + pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet169(pretrained=False, **kwargs): + r"""Densenet-169 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet201(pretrained=False, **kwargs): + r"""Densenet-201 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet161(pretrained=False, **kwargs): + r"""Densenet-161 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet264(pretrained=False, **kwargs): + r"""Densenet-264 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet264d_iabn(pretrained=False, **kwargs): + r"""Densenet-264 model with deep stem and Inplace-ABN + """ + def norm_act_fn(num_features, **kwargs): + return create_norm_act('iabn', num_features, **kwargs) + model = _create_densenet( + 'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', + norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tv_densenet121(pretrained=False, **kwargs): + r"""Densenet-121 model with original Torchvision weights, from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'tv_densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/dla.py b/timm/models/dla.py new file mode 100644 index 0000000000000000000000000000000000000000..a41ec3260b9aacf01e2f9fb953b31649c374fd55 --- /dev/null +++ b/timm/models/dla.py @@ -0,0 +1,438 @@ +""" Deep Layer Aggregation and DLA w/ Res2Net +DLA original adapted from Official Pytorch impl at: +DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484 + +Res2Net additions from: https://github.com/gasvn/Res2Net/ +Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169 +""" +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['DLA'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'base_layer.0', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'dla34': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla34-ba72cf86.pth'), + 'dla46_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla46_c-2bfd52c3.pth'), + 'dla46x_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla46x_c-d761bae7.pth'), + 'dla60x_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60x_c-b870c45c.pth'), + 'dla60': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60-24839fc4.pth'), + 'dla60x': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60x-d15cacda.pth'), + 'dla102': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102-d94d9790.pth'), + 'dla102x': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102x-ad62be81.pth'), + 'dla102x2': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102x2-262837b6.pth'), + 'dla169': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla169-0914e092.pth'), + 'dla60_res2net': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net_dla60_4s-d88db7f9.pth'), + 'dla60_res2next': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next_dla60_4s-d327927b.pth'), +} + + +class DlaBasic(nn.Module): + """DLA Basic""" + + def __init__(self, inplanes, planes, stride=1, dilation=1, **_): + super(DlaBasic, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out + + +class DlaBottleneck(nn.Module): + """DLA/DLA-X Bottleneck""" + expansion = 2 + + def __init__(self, inplanes, outplanes, stride=1, dilation=1, cardinality=1, base_width=64): + super(DlaBottleneck, self).__init__() + self.stride = stride + mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality) + mid_planes = mid_planes // self.expansion + + self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(mid_planes) + self.conv2 = nn.Conv2d( + mid_planes, mid_planes, kernel_size=3, stride=stride, padding=dilation, + bias=False, dilation=dilation, groups=cardinality) + self.bn2 = nn.BatchNorm2d(mid_planes) + self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(outplanes) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class DlaBottle2neck(nn.Module): + """ Res2Net/Res2NeXT DLA Bottleneck + Adapted from https://github.com/gasvn/Res2Net/blob/master/dla.py + """ + expansion = 2 + + def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4, cardinality=8, base_width=4): + super(DlaBottle2neck, self).__init__() + self.is_first = stride > 1 + self.scale = scale + mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality) + mid_planes = mid_planes // self.expansion + self.width = mid_planes + + self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(mid_planes * scale) + + num_scale_convs = max(1, scale - 1) + convs = [] + bns = [] + for _ in range(num_scale_convs): + convs.append(nn.Conv2d( + mid_planes, mid_planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, groups=cardinality, bias=False)) + bns.append(nn.BatchNorm2d(mid_planes)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + if self.is_first: + self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) + + self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(outplanes) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + spx = torch.split(out, self.width, 1) + spo = [] + for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): + sp = spx[i] if i == 0 or self.is_first else sp + spx[i] + sp = conv(sp) + sp = bn(sp) + sp = self.relu(sp) + spo.append(sp) + if self.scale > 1: + spo.append(self.pool(spx[-1]) if self.is_first else spx[-1]) + out = torch.cat(spo, 1) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class DlaRoot(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, residual): + super(DlaRoot, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.residual = residual + + def forward(self, *x): + children = x + x = self.conv(torch.cat(x, 1)) + x = self.bn(x) + if self.residual: + x += children[0] + x = self.relu(x) + + return x + + +class DlaTree(nn.Module): + def __init__(self, levels, block, in_channels, out_channels, stride=1, + dilation=1, cardinality=1, base_width=64, + level_root=False, root_dim=0, root_kernel_size=1, root_residual=False): + super(DlaTree, self).__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity() + self.project = nn.Identity() + cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width) + if levels == 1: + self.tree1 = block(in_channels, out_channels, stride, **cargs) + self.tree2 = block(out_channels, out_channels, 1, **cargs) + if in_channels != out_channels: + # NOTE the official impl/weights have project layers in levels > 1 case that are never + # used, I've moved the project layer here to avoid wasted params but old checkpoints will + # need strict=False while loading. + self.project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels)) + else: + cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual)) + self.tree1 = DlaTree( + levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs) + self.tree2 = DlaTree( + levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs) + if levels == 1: + self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual) + self.level_root = level_root + self.root_dim = root_dim + self.levels = levels + + def forward(self, x, residual=None, children=None): + children = [] if children is None else children + bottom = self.downsample(x) + residual = self.project(bottom) + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class DLA(nn.Module): + def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, + cardinality=1, base_width=64, block=DlaBottle2neck, residual_root=False, + drop_rate=0.0, global_pool='avg'): + super(DLA, self).__init__() + self.channels = channels + self.num_classes = num_classes + self.cardinality = cardinality + self.base_width = base_width + self.drop_rate = drop_rate + assert output_stride == 32 # FIXME support dilation + + self.base_layer = nn.Sequential( + nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(channels[0]), + nn.ReLU(inplace=True)) + self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) + self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) + cargs = dict(cardinality=cardinality, base_width=base_width, root_residual=residual_root) + self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs) + self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs) + self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs) + self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs) + self.feature_info = [ + dict(num_chs=channels[0], reduction=1, module='level0'), # rare to have a meaningful stride 1 level + dict(num_chs=channels[1], reduction=2, module='level1'), + dict(num_chs=channels[2], reduction=4, module='level2'), + dict(num_chs=channels[3], reduction=8, module='level3'), + dict(num_chs=channels[4], reduction=16, module='level4'), + dict(num_chs=channels[5], reduction=32, module='level5'), + ] + + self.num_features = channels[-1] + self.global_pool, self.fc = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + modules = [] + for i in range(convs): + modules.extend([ + nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1, + padding=dilation, bias=False, dilation=dilation), + nn.BatchNorm2d(planes), + nn.ReLU(inplace=True)]) + inplanes = planes + return nn.Sequential(*modules) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + + def forward_features(self, x): + x = self.base_layer(x) + x = self.level0(x) + x = self.level1(x) + x = self.level2(x) + x = self.level3(x) + x = self.level4(x) + x = self.level5(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + if not self.global_pool.is_identity(): + x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) + return x + + +def _create_dla(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + DLA, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_strict=False, feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), **kwargs) + + +@register_model +def dla60_res2net(pretrained=False, **kwargs): + model_kwargs = dict( + levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), + block=DlaBottle2neck, cardinality=1, base_width=28, **kwargs) + return _create_dla('dla60_res2net', pretrained, **model_kwargs) + + +@register_model +def dla60_res2next(pretrained=False,**kwargs): + model_kwargs = dict( + levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), + block=DlaBottle2neck, cardinality=8, base_width=4, **kwargs) + return _create_dla('dla60_res2next', pretrained, **model_kwargs) + + +@register_model +def dla34(pretrained=False, **kwargs): # DLA-34 + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], + block=DlaBasic, **kwargs) + return _create_dla('dla34', pretrained, **model_kwargs) + + +@register_model +def dla46_c(pretrained=False, **kwargs): # DLA-46-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, **kwargs) + return _create_dla('dla46_c', pretrained, **model_kwargs) + + +@register_model +def dla46x_c(pretrained=False, **kwargs): # DLA-X-46-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla46x_c', pretrained, **model_kwargs) + + +@register_model +def dla60x_c(pretrained=False, **kwargs): # DLA-X-60-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla60x_c', pretrained, **model_kwargs) + + +@register_model +def dla60(pretrained=False, **kwargs): # DLA-60 + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, **kwargs) + return _create_dla('dla60', pretrained, **model_kwargs) + + +@register_model +def dla60x(pretrained=False, **kwargs): # DLA-X-60 + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla60x', pretrained, **model_kwargs) + + +@register_model +def dla102(pretrained=False, **kwargs): # DLA-102 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, residual_root=True, **kwargs) + return _create_dla('dla102', pretrained, **model_kwargs) + + +@register_model +def dla102x(pretrained=False, **kwargs): # DLA-X-102 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, **kwargs) + return _create_dla('dla102x', pretrained, **model_kwargs) + + +@register_model +def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, **kwargs) + return _create_dla('dla102x2', pretrained, **model_kwargs) + + +@register_model +def dla169(pretrained=False, **kwargs): # DLA-169 + model_kwargs = dict( + levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, residual_root=True, **kwargs) + return _create_dla('dla169', pretrained, **model_kwargs) diff --git a/timm/models/dpn.py b/timm/models/dpn.py new file mode 100644 index 0000000000000000000000000000000000000000..61ce6a0e016184e0cc60e4586a6b38364017efe4 --- /dev/null +++ b/timm/models/dpn.py @@ -0,0 +1,316 @@ +""" PyTorch implementation of DualPathNetworks +Based on original MXNet implementation https://github.com/cypw/DPNs with +many ideas from another PyTorch implementation https://github.com/oyam/pytorch-DPNs. + +This implementation is compatible with the pretrained weights from cypw's MXNet implementation. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import BatchNormAct2d, ConvBnAct, create_conv2d, create_classifier +from .registry import register_model + +__all__ = ['DPN'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD, + 'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'dpn68': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'), + 'dpn68b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'dpn92': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'), + 'dpn98': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'), + 'dpn131': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'), + 'dpn107': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth') +} + + +class CatBnAct(nn.Module): + def __init__(self, in_chs, norm_layer=BatchNormAct2d): + super(CatBnAct, self).__init__() + self.bn = norm_layer(in_chs, eps=0.001) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (Tuple[torch.Tensor, torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (torch.Tensor) + pass + + def forward(self, x): + if isinstance(x, tuple): + x = torch.cat(x, dim=1) + return self.bn(x) + + +class BnActConv2d(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, stride, groups=1, norm_layer=BatchNormAct2d): + super(BnActConv2d, self).__init__() + self.bn = norm_layer(in_chs, eps=0.001) + self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, groups=groups) + + def forward(self, x): + return self.conv(self.bn(x)) + + +class DualPathBlock(nn.Module): + def __init__( + self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False): + super(DualPathBlock, self).__init__() + self.num_1x1_c = num_1x1_c + self.inc = inc + self.b = b + if block_type == 'proj': + self.key_stride = 1 + self.has_proj = True + elif block_type == 'down': + self.key_stride = 2 + self.has_proj = True + else: + assert block_type == 'normal' + self.key_stride = 1 + self.has_proj = False + + self.c1x1_w_s1 = None + self.c1x1_w_s2 = None + if self.has_proj: + # Using different member names here to allow easier parameter key matching for conversion + if self.key_stride == 2: + self.c1x1_w_s2 = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2) + else: + self.c1x1_w_s1 = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1) + + self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1) + self.c3x3_b = BnActConv2d( + in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, stride=self.key_stride, groups=groups) + if b: + self.c1x1_c = CatBnAct(in_chs=num_3x3_b) + self.c1x1_c1 = create_conv2d(num_3x3_b, num_1x1_c, kernel_size=1) + self.c1x1_c2 = create_conv2d(num_3x3_b, inc, kernel_size=1) + else: + self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1) + self.c1x1_c1 = None + self.c1x1_c2 = None + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] + pass + + def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(x, tuple): + x_in = torch.cat(x, dim=1) + else: + x_in = x + if self.c1x1_w_s1 is None and self.c1x1_w_s2 is None: + # self.has_proj == False, torchscript requires condition on module == None + x_s1 = x[0] + x_s2 = x[1] + else: + # self.has_proj == True + if self.c1x1_w_s1 is not None: + # self.key_stride = 1 + x_s = self.c1x1_w_s1(x_in) + else: + # self.key_stride = 2 + x_s = self.c1x1_w_s2(x_in) + x_s1 = x_s[:, :self.num_1x1_c, :, :] + x_s2 = x_s[:, self.num_1x1_c:, :, :] + x_in = self.c1x1_a(x_in) + x_in = self.c3x3_b(x_in) + x_in = self.c1x1_c(x_in) + if self.c1x1_c1 is not None: + # self.b == True, using None check for torchscript compat + out1 = self.c1x1_c1(x_in) + out2 = self.c1x1_c2(x_in) + else: + out1 = x_in[:, :self.num_1x1_c, :, :] + out2 = x_in[:, self.num_1x1_c:, :, :] + resid = x_s1 + out1 + dense = torch.cat([x_s2, out2], dim=1) + return resid, dense + + +class DPN(nn.Module): + def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, + b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32, + num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', fc_act=nn.ELU): + super(DPN, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.b = b + assert output_stride == 32 # FIXME look into dilation support + bw_factor = 1 if small else 4 + blocks = OrderedDict() + + # conv1 + blocks['conv1_1'] = ConvBnAct( + in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_kwargs=dict(eps=.001)) + blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')] + + # conv2 + bw = 64 * bw_factor + inc = inc_sec[0] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[0] + 1): + blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=4, module=f'features.conv2_{k_sec[0]}')] + + # conv3 + bw = 128 * bw_factor + inc = inc_sec[1] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[1] + 1): + blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=8, module=f'features.conv3_{k_sec[1]}')] + + # conv4 + bw = 256 * bw_factor + inc = inc_sec[2] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[2] + 1): + blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=16, module=f'features.conv4_{k_sec[2]}')] + + # conv5 + bw = 512 * bw_factor + inc = inc_sec[3] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[3] + 1): + blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')] + + def _fc_norm(f, eps): return BatchNormAct2d(f, eps=eps, act_layer=fc_act, inplace=False) + blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=_fc_norm) + + self.num_features = in_chs + self.features = nn.Sequential(blocks) + + # Using 1x1 conv for the FC layer to allow the extra pooling scheme + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + + def forward_features(self, x): + return self.features(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + if not self.global_pool.is_identity(): + x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) + return x + + +def _create_dpn(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + DPN, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_concat=True, flatten_sequential=True), **kwargs) + + +@register_model +def dpn68(pretrained=False, **kwargs): + model_kwargs = dict( + small=True, num_init_features=10, k_r=128, groups=32, + k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) + return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn68b(pretrained=False, **kwargs): + model_kwargs = dict( + small=True, num_init_features=10, k_r=128, groups=32, + b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) + return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn92(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=64, k_r=96, groups=32, + k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs) + return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn98(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=96, k_r=160, groups=40, + k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs) + return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn131(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=128, k_r=160, groups=40, + k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs) + return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn107(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=128, k_r=200, groups=50, + k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs) + return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4a89590b5bf78f20f2cf4469700eb493eab668b4 --- /dev/null +++ b/timm/models/efficientnet.py @@ -0,0 +1,1727 @@ +""" PyTorch EfficientNet Family + +An implementation of EfficienNet that covers variety of related models with efficient architectures: + +* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports) + - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 + - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 + - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 + - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 + +* MixNet (Small, Medium, and Large) + - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 + +* MNasNet B1, A1 (SE), Small + - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 + +* FBNet-C + - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 + +* Single-Path NAS Pixel1 + - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 + +* And likely more... + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import List + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights +from .features import FeatureInfo, FeatureHooks +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import create_conv2d, create_classifier +from .registry import register_model + +__all__ = ['EfficientNet'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'mnasnet_050': _cfg(url=''), + 'mnasnet_075': _cfg(url=''), + 'mnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'), + 'mnasnet_140': _cfg(url=''), + + 'semnasnet_050': _cfg(url=''), + 'semnasnet_075': _cfg(url=''), + 'semnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'), + 'semnasnet_140': _cfg(url=''), + 'mnasnet_small': _cfg(url=''), + + 'mobilenetv2_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth'), + 'mobilenetv2_110d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth'), + 'mobilenetv2_120d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth'), + 'mobilenetv2_140': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth'), + + 'fbnetc_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + interpolation='bilinear'), + 'spnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + interpolation='bilinear'), + + 'efficientnet_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'), + 'efficientnet_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + input_size=(3, 240, 240), pool_size=(8, 8)), + 'efficientnet_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + input_size=(3, 260, 260), pool_size=(9, 9)), + 'efficientnet_b2a': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0), + 'efficientnet_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_b3a': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0), + 'efficientnet_b4': _cfg( + url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'efficientnet_b5': _cfg( + url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'efficientnet_b6': _cfg( + url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'efficientnet_b7': _cfg( + url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'efficientnet_b8': _cfg( + url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'efficientnet_l2': _cfg( + url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), + + 'efficientnet_es': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'), + 'efficientnet_em': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_em_ra2-66250f76.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_el': _cfg( + url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'efficientnet_cc_b0_4e': _cfg(url=''), + 'efficientnet_cc_b0_8e': _cfg(url=''), + 'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'efficientnet_lite0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth'), + 'efficientnet_lite1': _cfg( + url='', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_lite2': _cfg( + url='', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'efficientnet_lite3': _cfg( + url='', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_lite4': _cfg( + url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + + 'efficientnet_b1_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'efficientnet_b2_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'efficientnet_b3_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + + 'tf_efficientnet_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b0_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)), + 'tf_efficientnet_b1_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b0_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_l2_ns_475': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', + input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936), + 'tf_efficientnet_l2_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', + input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96), + + 'tf_efficientnet_es': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 224, 224), ), + 'tf_efficientnet_em': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_el': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'tf_efficientnet_cc_b0_4e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b0_8e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b1_8e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'tf_efficientnet_lite0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'), + 'tf_efficientnet_lite4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'), + + 'mixnet_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), + 'mixnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'), + 'mixnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'), + 'mixnet_xl': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth'), + 'mixnet_xxl': _cfg(), + + 'tf_mixnet_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'), + 'tf_mixnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'), + 'tf_mixnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), +} + +_DEBUG = False + + +class EfficientNet(nn.Module): + """ (Generic) EfficientNet + + A flexible and performant PyTorch implementation of efficient network architectures, including: + * EfficientNet B0-B8, L2 + * EfficientNet-EdgeTPU + * EfficientNet-CondConv + * MixNet S, M, L, XL + * MnasNet A1, B1, and small + * FBNet C + * Single-Path NAS Pixel1 + + """ + + def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, + channel_multiplier=1.0, channel_divisor=8, channel_min=None, + output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): + super(EfficientNet, self).__init__() + norm_kwargs = norm_kwargs or {} + + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + + # Stem + if not fix_stem: + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = builder.features + head_chs = builder.in_chs + + # Head + Pooling + self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type) + self.bn2 = norm_layer(self.num_features, **norm_kwargs) + self.act2 = act_layer(inplace=True) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + efficientnet_init_weights(self) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool]) + layers.extend([nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +class EfficientNetFeatures(nn.Module): + """ EfficientNet Feature Extractor + + A work-in-progress feature extraction module for EfficientNet, to use as a backbone for segmentation + and object detection models. + """ + + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', + in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(EfficientNetFeatures, self).__init__() + norm_kwargs = norm_kwargs or {} + self.drop_rate = drop_rate + + # Stem + if not fix_stem: + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = FeatureInfo(builder.features, out_indices) + self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} + + efficientnet_init_weights(self) + + # Register feature extraction hooks with FeatureHooks helper + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + def forward(self, x) -> List[torch.Tensor]: + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + if self.feature_hooks is None: + features = [] + if 0 in self._stage_out_idx: + features.append(x) # add stem out + for i, b in enumerate(self.blocks): + x = b(x) + if i + 1 in self._stage_out_idx: + features.append(x) + return features + else: + self.blocks(x) + out = self.feature_hooks.get_output(x.device) + return list(out.values()) + + +def _create_effnet(model_kwargs, variant, pretrained=False): + features_only = False + model_cls = EfficientNet + if model_kwargs.pop('features_only', False): + features_only = True + model_kwargs.pop('num_classes', 0) + model_kwargs.pop('num_features', 0) + model_kwargs.pop('head_conv', None) + model_cls = EfficientNetFeatures + model = build_model_with_cfg( + model_cls, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=8, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_effnet(model_kwargs,variant, pretrained) + return model + + +def _gen_mobilenet_v2( + variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), + num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + fix_stem=fix_stem_head, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=resolve_act_layer(kwargs, 'relu6'), + **kwargs + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-EdgeTPU model + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu + """ + + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=resolve_act_layer(kwargs, 'relu'), + **kwargs, + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an EfficientNet-CondConv model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + # NOTE unlike official impl, this one uses `cc` option where x is the base number of experts for each stage and + # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=resolve_act_layer(kwargs, 'swish'), + **kwargs, + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu6'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1536, + stem_size=16, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=1536, + stem_size=24, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_effnet(model_kwargs, variant, pretrained) + return model + + +@register_model +def mnasnet_050(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_075(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_100(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_b1(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + return mnasnet_100(pretrained, **kwargs) + + +@register_model +def mnasnet_140(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.4 """ + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_050(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_075(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_100(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_a1(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + return semnasnet_100(pretrained, **kwargs) + + +@register_model +def semnasnet_140(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_small(pretrained=False, **kwargs): + """ MNASNet Small, depth multiplier of 1.0. """ + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_100(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.0 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_140(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.4 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_110d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers""" + model = _gen_mobilenet_v2( + 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_120d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """ + model = _gen_mobilenet_v2( + 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetc_100(pretrained=False, **kwargs): + """ FBNet-C """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def spnasnet_100(pretrained=False, **kwargs): + """ Single-Path NAS Pixel1""" + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2a(pretrained=False, **kwargs): + """ EfficientNet-B2 @ 288x288 w/ 1.0 test crop""" + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b2a', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3a(pretrained=False, **kwargs): + """ EfficientNet-B3 @ 320x320 w/ 1.0 test crop-pct """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b3a', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_l2(pretrained=False, **kwargs): + """ EfficientNet-L2.""" + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. """ + model = _gen_efficientnet_edge( + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. """ + model = _gen_efficientnet_edge( + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. """ + model = _gen_efficientnet_edge( + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b1_pruned(pretrained=False, **kwargs): + """ EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + variant = 'efficientnet_b1_pruned' + model = _gen_efficientnet( + variant, channel_multiplier=1.0, depth_multiplier=1.1, pruned=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2_pruned(pretrained=False, **kwargs): + """ EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pruned=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3_pruned(pretrained=False, **kwargs): + """ EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pruned=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0_ap(pretrained=False, **kwargs): + """ EfficientNet-B0 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1_ap(pretrained=False, **kwargs): + """ EfficientNet-B1 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2_ap(pretrained=False, **kwargs): + """ EfficientNet-B2 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3_ap(pretrained=False, **kwargs): + """ EfficientNet-B3 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4_ap(pretrained=False, **kwargs): + """ EfficientNet-B4 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5_ap(pretrained=False, **kwargs): + """ EfficientNet-B5 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6_ap(pretrained=False, **kwargs): + """ EfficientNet-B6 AdvProp. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7_ap(pretrained=False, **kwargs): + """ EfficientNet-B7 AdvProp. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b8_ap(pretrained=False, **kwargs): + """ EfficientNet-B8 AdvProp. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0_ns(pretrained=False, **kwargs): + """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1_ns(pretrained=False, **kwargs): + """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2_ns(pretrained=False, **kwargs): + """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3_ns(pretrained=False, **kwargs): + """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4_ns(pretrained=False, **kwargs): + """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5_ns(pretrained=False, **kwargs): + """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6_ns(pretrained=False, **kwargs): + """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7_ns(pretrained=False, **kwargs): + """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_l2_ns(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. + """ + model = _gen_mixnet_s( + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. + """ + model = _gen_mixnet_m( + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. + """ + model = _gen_mixnet_m( + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_xl(pretrained=False, **kwargs): + """Creates a MixNet Extra-Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + model = _gen_mixnet_m( + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_xxl(pretrained=False, **kwargs): + """Creates a MixNet Double Extra Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + model = _gen_mixnet_m( + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_s( + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..d7421ff4f07b230ef50268ea65f7f7aa2c0da8cf --- /dev/null +++ b/timm/models/efficientnet_blocks.py @@ -0,0 +1,397 @@ +""" EfficientNet, MobileNetV3, etc Blocks + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from .layers import create_conv2d, drop_path, get_act_layer +from .layers.activations import sigmoid + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +_SE_ARGS_DEFAULT = dict( + gate_fn=sigmoid, + act_layer=None, + reduce_mid=False, + divisor=1) + + +def resolve_se_args(kwargs, in_chs, act_layer=None): + se_kwargs = kwargs.copy() if kwargs is not None else {} + # fill in args that aren't specified with the defaults + for k, v in _SE_ARGS_DEFAULT.items(): + se_kwargs.setdefault(k, v) + # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch + if not se_kwargs.pop('reduce_mid'): + se_kwargs['reduced_base_chs'] = in_chs + # act_layer override, if it remains None, the containing block's act_layer will be used + if se_kwargs['act_layer'] is None: + assert act_layer is not None + se_kwargs['act_layer'] = act_layer + return se_kwargs + + +def resolve_act_layer(kwargs, default='relu'): + act_layer = kwargs.pop('act_layer', default) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + return act_layer + + +def make_divisible(v, divisor=8, min_value=None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + channels *= multiplier + return make_divisible(channels, divisor, channel_min) + + +class ChannelShuffle(nn.Module): + # FIXME haven't used yet + def __init__(self, groups): + super(ChannelShuffle, self).__init__() + self.groups = groups + + def forward(self, x): + """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" + N, C, H, W = x.size() + g = self.groups + assert C % g == 0, "Incompatible group size {} for input channel {}".format( + g, C + ) + return ( + x.view(N, g, int(C / g), H, W) + .permute(0, 2, 1, 3, 4) + .contiguous() + .view(N, C, H, W) + ) + + +class SqueezeExcite(nn.Module): + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, + act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_): + super(SqueezeExcite, self).__init__() + reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + self.gate_fn = gate_fn + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + return x * self.gate_fn(x_se) + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(ConvBnAct, self).__init__() + norm_kwargs = norm_kwargs or {} + self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn1 = norm_layer(out_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + def feature_info(self, location): + if location == 'expansion': # output of conv after act, same as block coutput + info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv.out_channels) + return info + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion + (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. + """ + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + norm_kwargs = norm_kwargs or {} + has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.has_pw_act = pw_act # activation after point-wise conv + self.drop_path_rate = drop_path_rate + + self.conv_dw = create_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None + + self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PW + info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) + return info + + def forward(self, x): + residual = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + if self.se is not None: + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += residual + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE and CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + conv_kwargs=None, drop_path_rate=0.): + super(InvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs = make_divisible(in_chs * exp_ratio) + has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_path_rate = drop_path_rate + + # Point-wise expansion + self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = create_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + if has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None + + # Point-wise linear projection + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PWL + info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) + return info + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += residual + + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + num_experts=0, drop_path_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, + act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, + drop_path_rate=drop_path_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + residual = x + + # CondConv routing + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + + # Point-wise expansion + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += residual + return x + + +class EdgeResidual(nn.Module): + """ Residual block with expansion convolution followed by pointwise-linear w/ stride""" + + def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + drop_path_rate=0.): + super(EdgeResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + if fake_in_chs > 0: + mid_chs = make_divisible(fake_in_chs * exp_ratio) + else: + mid_chs = make_divisible(in_chs * exp_ratio) + has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_path_rate = drop_path_rate + + # Expansion convolution + self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None + + # Point-wise linear projection + self.conv_pwl = create_conv2d( + mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + + def feature_info(self, location): + if location == 'expansion': # after SE, before PWL + info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) + return info + + def forward(self, x): + residual = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act1(x) + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += residual + + return x diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f670aa6cc08847bc24a4774ae039539682e1c7da --- /dev/null +++ b/timm/models/efficientnet_builder.py @@ -0,0 +1,414 @@ +""" EfficientNet, MobileNetV3, etc Builder + +Assembles EfficieNet and related network feature blocks from string definitions. +Handles stride, dilation calculations, and selects feature extraction points. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import logging +import math +import re +from copy import deepcopy + +import torch.nn as nn + +from .efficientnet_blocks import * +from .layers import CondConv2d, get_condconv_initializer + +__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"] + +_logger = logging.getLogger(__name__) + + +def _log_info_if(msg, condition): + if condition: + _logger.info(msg) + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + # string options being checked on individual basis, combine if they grow + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = get_act_layer('relu') + elif v == 'r6': + value = get_act_layer('relu6') + elif v == 'hs': + value = get_act_layer('hard_swish') + elif v == 'sw': + value = get_act_layer('swish') + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): + arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) + else: + arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + return arch_args + + +class EfficientNetBuilder: + """ Build Trunk Blocks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + output_stride=32, pad_type='', act_layer=None, se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='', + verbose=False): + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.output_stride = output_stride + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_path_rate = drop_path_rate + if feature_location == 'depthwise': + # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense + _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'") + feature_location = 'expansion' + self.feature_location = feature_location + assert feature_location in ('bottleneck', 'expansion', '') + self.verbose = verbose + + # state updated during build, consumed by model + self.in_chs = None + self.features = [] + + def _round_channels(self, chs): + return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) + + def _make_block(self, ba, block_idx, block_count): + drop_path_rate = self.drop_path_rate * block_idx / block_count + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input filters + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + if ba.get('num_experts', 0) > 0: + block = CondConvResidual(**ba) + else: + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = EdgeResidual(**ba) + elif bt == 'cn': + _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + + return block + + def __call__(self, in_chs, model_block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + model_block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose) + self.in_chs = in_chs + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + stages = [] + if model_block_args[0][0]['stride'] > 1: + # if the first block starts with a stride, we need to extract first level feat from stem + feature_info = dict( + module='act1', num_chs=in_chs, stage=0, reduction=current_stride, + hook_type='forward' if self.feature_location != 'bottleneck' else '') + self.features.append(feature_info) + + # outer list of block_args defines the stacks + for stack_idx, stack_args in enumerate(model_block_args): + last_stack = stack_idx + 1 == len(model_block_args) + _log_info_if('Stack: {}'.format(stack_idx), self.verbose) + assert isinstance(stack_args, list) + + blocks = [] + # each stack (stage of blocks) contains a list of block arguments + for block_idx, block_args in enumerate(stack_args): + last_block = block_idx + 1 == len(stack_args) + _log_info_if(' Block: {}'.format(block_idx), self.verbose) + + assert block_args['stride'] in (1, 2) + if block_idx >= 1: # only the first block in any stack can have a stride > 1 + block_args['stride'] = 1 + + extract_features = False + if last_block: + next_stack_idx = stack_idx + 1 + extract_features = next_stack_idx >= len(model_block_args) or \ + model_block_args[next_stack_idx][0]['stride'] > 1 + + next_dilation = current_dilation + if block_args['stride'] > 1: + next_output_stride = current_stride * block_args['stride'] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args['stride'] + block_args['stride'] = 1 + _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format( + self.output_stride), self.verbose) + else: + current_stride = next_output_stride + block_args['dilation'] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + + # create the block + block = self._make_block(block_args, total_block_idx, total_block_count) + blocks.append(block) + + # stash feature module name and channel info for model feature extraction + if extract_features: + feature_info = dict( + stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location)) + module_name = f'blocks.{stack_idx}.{block_idx}' + leaf_name = feature_info.get('module', '') + feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name + self.features.append(feature_info) + + total_block_idx += 1 # incr global block idx (across all stacks) + stages.append(nn.Sequential(*blocks)) + return stages + + +def _init_weight_goog(m, n='', fix_group_fanout=True): + """ Weight initialization as per Tensorflow official implementations. + + Args: + m (nn.Module): module to init + n (str): module name + fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs + + Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: + * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + """ + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + init_weight_fn = get_condconv_initializer( + lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def efficientnet_init_weights(model: nn.Module, init_fn=None): + init_fn = init_fn or _init_weight_goog + for n, m in model.named_modules(): + init_fn(m, n) + diff --git a/timm/models/factory.py b/timm/models/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..70209c960c2f4f9c4648ce4b721d2ea24c50cc0b --- /dev/null +++ b/timm/models/factory.py @@ -0,0 +1,64 @@ +from .registry import is_model, is_model_in_modules, model_entrypoint +from .helpers import load_checkpoint +from .layers import set_layer_config + + +def create_model( + model_name, + pretrained=False, + num_classes=1000, + in_chans=3, + checkpoint_path='', + scriptable=None, + exportable=None, + no_jit=None, + **kwargs): + """Create a model + + Args: + model_name (str): name of model to instantiate + pretrained (bool): load pretrained ImageNet-1k weights if true + num_classes (int): number of classes for final fully connected layer (default: 1000) + in_chans (int): number of input channels / colors (default: 3) + checkpoint_path (str): path of checkpoint to load after model is initialized + scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) + exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) + no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) + + Keyword Args: + drop_rate (float): dropout rate for training (default: 0.0) + global_pool (str): global pool type (default: 'avg') + **: other kwargs are model specific + """ + model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) + + # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args + is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) + if not is_efficientnet: + kwargs.pop('bn_tf', None) + kwargs.pop('bn_momentum', None) + kwargs.pop('bn_eps', None) + + # handle backwards compat with drop_connect -> drop_path change + drop_connect_rate = kwargs.pop('drop_connect_rate', None) + if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: + print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." + " Setting drop_path to %f." % drop_connect_rate) + kwargs['drop_path_rate'] = drop_connect_rate + + # Parameters that aren't supported by all models or are intended to only override model defaults if set + # should default to None in command line args/cfg. Remove them if they are present and not set so that + # non-supporting models don't break and default args remain in effect. + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): + if is_model(model_name): + create_fn = model_entrypoint(model_name) + model = create_fn(**model_args, **kwargs) + else: + raise RuntimeError('Unknown model (%s)' % model_name) + + if checkpoint_path: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/timm/models/features.py b/timm/models/features.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d6890f3ed07311c5484b4a397c3b1da555880a --- /dev/null +++ b/timm/models/features.py @@ -0,0 +1,284 @@ +""" PyTorch Feature Extraction Helpers + +A collection of classes, functions, modules to help extract features from models +and provide a common interface for describing them. + +The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter +https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict, defaultdict +from copy import deepcopy +from functools import partial +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + + +class FeatureInfo: + + def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + prev_reduction = 1 + for fi in feature_info: + # sanity check the mandatory fields, there may be additional fields depending on the model + assert 'num_chs' in fi and fi['num_chs'] > 0 + assert 'reduction' in fi and fi['reduction'] >= prev_reduction + prev_reduction = fi['reduction'] + assert 'module' in fi + self.out_indices = out_indices + self.info = feature_info + + def from_other(self, out_indices: Tuple[int]): + return FeatureInfo(deepcopy(self.info), out_indices) + + def get(self, key, idx=None): + """ Get value by key at specified index (indices) + if idx == None, returns value for key at each output index + if idx is an integer, return value for that feature module index (ignoring output indices) + if idx is a list/tupple, return value for each module index (ignoring output indices) + """ + if idx is None: + return [self.info[i][key] for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i][key] for i in idx] + else: + return self.info[idx][key] + + def get_dicts(self, keys=None, idx=None): + """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) + """ + if idx is None: + if keys is None: + return [self.info[i] for i in self.out_indices] + else: + return [{k: self.info[i][k] for k in keys} for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] + else: + return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} + + def channels(self, idx=None): + """ feature channels accessor + """ + return self.get('num_chs', idx) + + def reduction(self, idx=None): + """ feature reduction (output stride) accessor + """ + return self.get('reduction', idx) + + def module_name(self, idx=None): + """ feature module name accessor + """ + return self.get('module', idx) + + def __getitem__(self, item): + return self.info[item] + + def __len__(self): + return len(self.info) + + +class FeatureHooks: + """ Feature Hook Helper + + This module helps with the setup and extraction of hooks for extracting features from + internal nodes in a model by node name. This works quite well in eager Python but needs + redesign for torcscript. + """ + + def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): + # setup feature hooks + modules = {k: v for k, v in named_modules} + for i, h in enumerate(hooks): + hook_name = h['module'] + m = modules[hook_name] + hook_id = out_map[i] if out_map else hook_name + hook_fn = partial(self._collect_output_hook, hook_id) + hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type + if hook_type == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + elif hook_type == 'forward': + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + self._feature_outputs = defaultdict(OrderedDict) + + def _collect_output_hook(self, hook_id, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][hook_id] = x + + def get_output(self, device) -> Dict[str, torch.tensor]: + output = self._feature_outputs[device] + self._feature_outputs[device] = OrderedDict() # clear after reading + return output + + +def _module_list(module, flatten_sequential=False): + # a yield/iter would be better for this but wouldn't be compatible with torchscript + ml = [] + for name, module in module.named_children(): + if flatten_sequential and isinstance(module, nn.Sequential): + # first level of Sequential containers is flattened into containing model + for child_name, child_module in module.named_children(): + combined = [name, child_name] + ml.append(('_'.join(combined), '.'.join(combined), child_module)) + else: + ml.append((name, name, module)) + return ml + + +def _get_feature_info(net, out_indices): + feature_info = getattr(net, 'feature_info') + if isinstance(feature_info, FeatureInfo): + return feature_info.from_other(out_indices) + elif isinstance(feature_info, (list, tuple)): + return FeatureInfo(net.feature_info, out_indices) + else: + assert False, "Provided feature_info is not valid" + + +def _get_return_layers(feature_info, out_map): + module_names = feature_info.module_name() + return_layers = {} + for i, name in enumerate(module_names): + return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] + return return_layers + + +class FeatureDictNet(nn.ModuleDict): + """ Feature extractor with OrderedDict return + + Wrap a model and extract features as specified by the out indices, the network is + partially re-built from contained modules. + + There is a strong assumption that the modules have been registered into the model in the same + order as they are used. There should be no reuse of the same nn.Module more than once, including + trivial modules like `self.relu = nn.ReLU`. + + Only submodules that are directly assigned to the model class (`model.feature1`) or at most + one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. + All Sequential containers that are directly assigned to the original model will have their + modules assigned to this module with the name `model.features.1` being changed to `model.features_1` + + Arguments: + model (nn.Module): model from which we will extract the features + out_indices (tuple[int]): model output indices to extract features for + out_map (sequence): list or tuple specifying desired return id for each out index, + otherwise str(index) is used + feature_concat (bool): whether to concatenate intermediate features that are lists or tuples + vs select element [0] + flatten_sequential (bool): whether to flatten sequential modules assigned to model + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureDictNet, self).__init__() + self.feature_info = _get_feature_info(model, out_indices) + self.concat = feature_concat + self.return_layers = {} + return_layers = _get_return_layers(self.feature_info, out_map) + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = set(return_layers.keys()) + layers = OrderedDict() + for new_name, old_name, module in modules: + layers[new_name] = module + if old_name in remaining: + # return id has to be consistently str type for torchscript + self.return_layers[new_name] = str(return_layers[old_name]) + remaining.remove(old_name) + if not remaining: + break + assert not remaining and len(self.return_layers) == len(return_layers), \ + f'Return layers ({remaining}) are not present in model' + self.update(layers) + + def _collect(self, x) -> (Dict[str, torch.Tensor]): + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_id = self.return_layers[name] + if isinstance(x, (tuple, list)): + # If model tap is a tuple or list, concat or select first element + # FIXME this may need to be more generic / flexible for some nets + out[out_id] = torch.cat(x, 1) if self.concat else x[0] + else: + out[out_id] = x + return out + + def forward(self, x) -> Dict[str, torch.Tensor]: + return self._collect(x) + + +class FeatureListNet(FeatureDictNet): + """ Feature extractor with list return + + See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. + In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureListNet, self).__init__( + model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, + flatten_sequential=flatten_sequential) + + def forward(self, x) -> (List[torch.Tensor]): + return list(self._collect(x).values()) + + +class FeatureHookNet(nn.ModuleDict): + """ FeatureHookNet + + Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. + + If `no_rewrite` is True, features are extracted via hooks without modifying the underlying + network in any way. + + If `no_rewrite` is False, the model will be re-written as in the + FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. + + FIXME this does not currently work with Torchscript, see FeatureHooks class + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, + feature_concat=False, flatten_sequential=False, default_hook_type='forward'): + super(FeatureHookNet, self).__init__() + assert not torch.jit.is_scripting() + self.feature_info = _get_feature_info(model, out_indices) + self.out_as_dict = out_as_dict + layers = OrderedDict() + hooks = [] + if no_rewrite: + assert not flatten_sequential + if hasattr(model, 'reset_classifier'): # make sure classifier is removed? + model.reset_classifier(0) + layers['body'] = model + hooks.extend(self.feature_info.get_dicts()) + else: + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type + for f in self.feature_info.get_dicts()} + for new_name, old_name, module in modules: + layers[new_name] = module + for fn, fm in module.named_modules(prefix=old_name): + if fn in remaining: + hooks.append(dict(module=fn, hook_type=remaining[fn])) + del remaining[fn] + if not remaining: + break + assert not remaining, f'Return layers ({remaining}) are not present in model' + self.update(layers) + self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) + + def forward(self, x): + for name, module in self.items(): + x = module(x) + out = self.hooks.get_output(x.device) + return out if self.out_as_dict else list(out.values()) diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd99d04ecf6851d1d70ea0fd87ed8103063978e --- /dev/null +++ b/timm/models/gluon_resnet.py @@ -0,0 +1,245 @@ +"""Pytorch impl of MxNet Gluon ResNet/(SE)ResNeXt variants +This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-ResNeXt additions +and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py) +by Ross Wightman +""" + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import SEModule +from .registry import register_model +from .resnet import ResNet, Bottleneck, BasicBlock + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'gluon_resnet18_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet18_v1b-0757602b.pth'), + 'gluon_resnet34_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth'), + 'gluon_resnet50_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'), + 'gluon_resnet101_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'), + 'gluon_resnet152_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'), + 'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth', + first_conv='conv1.0'), + 'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth', + first_conv='conv1.0'), + 'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth', + first_conv='conv1.0'), + 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'), + 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'), + 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'), + 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'), + 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'), + 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'), + 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth', + first_conv='conv1.0'), +} + + +def _create_resnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) + + +@register_model +def gluon_resnet18_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('gluon_resnet18_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet34_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('gluon_resnet34_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('gluon_resnet50_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('gluon_resnet101_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('gluon_resnet152_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet50_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet101_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet152_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet50_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet101_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet152_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet50_v1s', pretrained, **model_args) + + + +@register_model +def gluon_resnet101_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet101_v1s', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet152_v1s', pretrained, **model_args) + + + +@register_model +def gluon_resnext50_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('gluon_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def gluon_resnext101_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('gluon_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def gluon_resnext101_64x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs) + return _create_resnet('gluon_resnext101_64x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext50_32x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt50-32x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext101_32x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt-101-32x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext101_32x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext101_64x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt-101-64x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext101_64x4d', pretrained, **model_args) + + +@register_model +def gluon_senet154(pretrained=False, **kwargs): + """Constructs an SENet-154 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', + down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_senet154', pretrained, **model_args) diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py new file mode 100644 index 0000000000000000000000000000000000000000..3782c50018e6c9ba82b51e31a191ff950b5d2e65 --- /dev/null +++ b/timm/models/gluon_xception.py @@ -0,0 +1,260 @@ +"""Pytorch impl of Gluon Xception +This is a port of the Gluon Xception code and weights, itself ported from a PyTorch DeepLab impl. + +Gluon model: (https://gluon-cv.mxnet.io/_modules/gluoncv/model_zoo/xception.html) +Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xception + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier, get_padding +from .registry import register_model + +__all__ = ['Xception65'] + +default_cfgs = { + 'gluon_xception65': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth', + 'input_size': (3, 299, 299), + 'crop_pct': 0.903, + 'pool_size': (10, 10), + 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'num_classes': 1000, + 'first_conv': 'conv1', + 'classifier': 'fc' + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + }, +} + +""" PADDING NOTES +The original PyTorch and Gluon impl of these models dutifully reproduced the +aligned padding added to Tensorflow models for Deeplab. This padding was compensating +for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to. +""" + + +class SeparableConv2d(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, + dilation=1, bias=False, norm_layer=None, norm_kwargs=None): + super(SeparableConv2d, self).__init__() + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + self.kernel_size = kernel_size + self.dilation = dilation + + # depthwise convolution + padding = get_padding(kernel_size, stride, dilation) + self.conv_dw = nn.Conv2d( + inplanes, inplanes, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=inplanes, bias=bias) + self.bn = norm_layer(num_features=inplanes, **norm_kwargs) + # pointwise convolution + self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.conv_dw(x) + x = self.bn(x) + x = self.conv_pw(x) + return x + + +class Block(nn.Module): + def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, + norm_layer=None, norm_kwargs=None, ): + super(Block, self).__init__() + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + if isinstance(planes, (list, tuple)): + assert len(planes) == 3 + else: + planes = (planes,) * 3 + outplanes = planes[-1] + + if outplanes != inplanes or stride != 1: + self.skip = nn.Sequential() + self.skip.add_module('conv1', nn.Conv2d( + inplanes, outplanes, 1, stride=stride, bias=False)), + self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs)) + else: + self.skip = None + + rep = OrderedDict() + for i in range(3): + rep['act%d' % (i + 1)] = nn.ReLU(inplace=True) + rep['conv%d' % (i + 1)] = SeparableConv2d( + inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs) + inplanes = planes[i] + + if not start_with_relu: + del rep['act1'] + else: + rep['act1'] = nn.ReLU(inplace=False) + self.rep = nn.Sequential(rep) + + def forward(self, x): + skip = x + if self.skip is not None: + skip = self.skip(skip) + x = self.rep(x) + skip + return x + + +class Xception65(nn.Module): + """Modified Aligned Xception. + + NOTE: only the 65 layer version is included here, the 71 layer variant + was not correct and had no pretrained weights + """ + + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, + norm_kwargs=None, drop_rate=0., global_pool='avg'): + super(Xception65, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + if output_stride == 32: + entry_block3_stride = 2 + exit_block20_stride = 2 + middle_block_dilation = 1 + exit_block_dilations = (1, 1) + elif output_stride == 16: + entry_block3_stride = 2 + exit_block20_stride = 1 + middle_block_dilation = 1 + exit_block_dilations = (1, 2) + elif output_stride == 8: + entry_block3_stride = 1 + exit_block20_stride = 1 + middle_block_dilation = 2 + exit_block_dilations = (2, 4) + else: + raise NotImplementedError + + # Entry flow + self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(num_features=32, **norm_kwargs) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = norm_layer(num_features=64) + self.act2 = nn.ReLU(inplace=True) + + self.block1 = Block( + 64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.block1_act = nn.ReLU(inplace=True) + self.block2 = Block( + 128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.block3 = Block( + 256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + + # Middle flow + self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( + 728, 728, stride=1, dilation=middle_block_dilation, + norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)])) + + # Exit flow + self.block20 = Block( + 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.block20_act = nn.ReLU(inplace=True) + + self.conv3 = SeparableConv2d( + 1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.bn3 = norm_layer(num_features=1536, **norm_kwargs) + self.act3 = nn.ReLU(inplace=True) + + self.conv4 = SeparableConv2d( + 1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.bn4 = norm_layer(num_features=1536, **norm_kwargs) + self.act4 = nn.ReLU(inplace=True) + + self.num_features = 2048 + self.conv5 = SeparableConv2d( + 1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs) + self.act5 = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='act2'), + dict(num_chs=128, reduction=4, module='block1_act'), + dict(num_chs=256, reduction=8, module='block3.rep.act1'), + dict(num_chs=728, reduction=16, module='block20.rep.act1'), + dict(num_chs=2048, reduction=32, module='act5'), + ] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + # Entry flow + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + x = self.block1(x) + x = self.block1_act(x) + # c1 = x + x = self.block2(x) + # c2 = x + x = self.block3(x) + + # Middle flow + x = self.mid(x) + # c3 = x + + # Exit flow + x = self.block20(x) + x = self.block20_act(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.act3(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.act4(x) + + x = self.conv5(x) + x = self.bn5(x) + x = self.act5(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate: + F.dropout(x, self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def _create_gluon_xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + Xception65, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook'), **kwargs) + + +@register_model +def gluon_xception65(pretrained=False, **kwargs): + """ Modified Aligned Xception-65 + """ + return _create_gluon_xception('gluon_xception65', pretrained, **kwargs) diff --git a/timm/models/helpers.py b/timm/models/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..77b98dc6b487c3eaacdaf0fd3032ad616412faf4 --- /dev/null +++ b/timm/models/helpers.py @@ -0,0 +1,310 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import os +import math +from collections import OrderedDict +from copy import deepcopy +from typing import Callable + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +from .features import FeatureListNet, FeatureDictNet, FeatureHookNet +from .layers import Conv2dSame, Linear + + +_logger = logging.getLogger(__name__) + + +def load_state_dict(checkpoint_path, use_ema=False): + if checkpoint_path and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict_key = 'state_dict' + if isinstance(checkpoint, dict): + if use_ema and 'state_dict_ema' in checkpoint: + state_dict_key = 'state_dict_ema' + if state_dict_key and state_dict_key in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint[state_dict_key].items(): + # strip `module.` prefix + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + state_dict = new_state_dict + else: + state_dict = checkpoint + _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + return state_dict + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): + state_dict = load_state_dict(checkpoint_path, use_ema) + model.load_state_dict(state_dict, strict=strict) + + +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + model.load_state_dict(checkpoint) + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True): + if cfg is None: + cfg = getattr(model, 'default_cfg') + if cfg is None or 'url' not in cfg or not cfg['url']: + _logger.warning("Pretrained model URL is invalid, using random initialization.") + return + + state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + if in_chans == 1: + conv1_name = cfg['first_conv'] + _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) + conv1_weight = state_dict[conv1_name + '.weight'] + # Some weights are in torch.half, ensure it's float for sum on CPU + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I > 3: + assert conv1_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) + conv1_weight = conv1_weight.sum(dim=2, keepdim=False) + else: + conv1_weight = conv1_weight.sum(dim=1, keepdim=True) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight + elif in_chans != 3: + conv1_name = cfg['first_conv'] + conv1_weight = state_dict[conv1_name + '.weight'] + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I != 3: + _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name) + del state_dict[conv1_name + '.weight'] + strict = False + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name) + repeat = int(math.ceil(in_chans / 3)) + conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv1_weight *= (3 / float(in_chans)) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + '.weight'] = conv1_weight + + classifier_name = cfg['classifier'] + if num_classes == 1000 and cfg['num_classes'] == 1001: + # special case for imagenet trained models with extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[1:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[1:] + elif num_classes != cfg['num_classes']: + # completely discard fully connected for all other differences between pretrained and created model + del state_dict[classifier_name + '.weight'] + del state_dict[classifier_name + '.bias'] + strict = False + + model.load_state_dict(state_dict, strict=strict) + + +def extract_layer(model, layer): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + if not hasattr(model, 'module') and layer[0] == 'module': + layer = layer[1:] + for l in layer: + if hasattr(module, l): + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + else: + return module + return module + + +def set_layer(model, layer, val): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + lst_index = 0 + module2 = module + for l in layer: + if hasattr(module2, l): + if not l.isdigit(): + module2 = getattr(module2, l) + else: + module2 = module2[int(l)] + lst_index += 1 + lst_index -= 1 + for l in layer[:lst_index]: + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + l = layer[lst_index] + setattr(module, l, val) + + +def adapt_model_from_string(parent_module, model_string): + separator = '***' + state_dict = {} + lst_shape = model_string.split(separator) + for k in lst_shape: + k = k.split(':') + key = k[0] + shape = k[1][1:-1].split(',') + if shape[0] != '': + state_dict[key] = [int(i) for i in shape] + + new_module = deepcopy(parent_module) + for n, m in parent_module.named_modules(): + old_module = extract_layer(parent_module, n) + if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): + if isinstance(old_module, Conv2dSame): + conv = Conv2dSame + else: + conv = nn.Conv2d + s = state_dict[n + '.weight'] + in_channels = s[1] + out_channels = s[0] + g = 1 + if old_module.groups > 1: + in_channels = out_channels + g = in_channels + new_conv = conv( + in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, + bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, + groups=g, stride=old_module.stride) + set_layer(new_module, n, new_conv) + if isinstance(old_module, nn.BatchNorm2d): + new_bn = nn.BatchNorm2d( + num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + set_layer(new_module, n, new_bn) + if isinstance(old_module, nn.Linear): + # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? + num_features = state_dict[n + '.weight'][1] + new_fc = Linear( + in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) + set_layer(new_module, n, new_fc) + if hasattr(new_module, 'num_features'): + new_module.num_features = num_features + new_module.eval() + parent_module.eval() + + return new_module + + +def adapt_model_from_file(parent_module, model_variant): + adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') + with open(adapt_file, 'r') as f: + return adapt_model_from_string(parent_module, f.read().strip()) + + +def default_cfg_for_features(default_cfg): + default_cfg = deepcopy(default_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size? + for tr in to_remove: + default_cfg.pop(tr, None) + return default_cfg + + +def build_model_with_cfg( + model_cls: Callable, + variant: str, + pretrained: bool, + default_cfg: dict, + model_cfg: dict = None, + feature_cfg: dict = None, + pretrained_strict: bool = True, + pretrained_filter_fn: Callable = None, + **kwargs): + pruned = kwargs.pop('pruned', False) + features = False + feature_cfg = feature_cfg or {} + + if kwargs.pop('features_only', False): + features = True + feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) + if 'out_indices' in kwargs: + feature_cfg['out_indices'] = kwargs.pop('out_indices') + + model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) + model.default_cfg = deepcopy(default_cfg) + + if pruned: + model = adapt_model_from_file(model, variant) + + # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) + if pretrained: + load_pretrained( + model, + num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, strict=pretrained_strict) + + if features: + feature_cls = FeatureListNet + if 'feature_cls' in feature_cfg: + feature_cls = feature_cfg.pop('feature_cls') + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if 'hook' in feature_cls: + feature_cls = FeatureHookNet + else: + assert False, f'Unknown feature class {feature_cls}' + model = feature_cls(model, **feature_cfg) + model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg + + return model diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0bc9f0dc3f972b6b663be73184e8a2a4c1fe77 --- /dev/null +++ b/timm/models/hrnet.py @@ -0,0 +1,831 @@ +""" HRNet + +Copied from https://github.com/HRNet/HRNet-Image-Classification + +Original header: + Copyright (c) Microsoft + Licensed under the MIT License. + Written by Bin Xiao (Bin.Xiao@microsoft.com) + Modified by Ke Sun (sunk@mail.ustc.edu.cn) +""" +import logging +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .features import FeatureInfo +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import create_classifier +from .registry import register_model +from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE + +_BN_MOMENTUM = 0.1 +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'hrnet_w18_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v1-f460c6bc.pth'), + 'hrnet_w18_small_v2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v2-4c50a8cb.pth'), + 'hrnet_w18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth'), + 'hrnet_w30': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w30-8d7f8dab.pth'), + 'hrnet_w32': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w32-90d8c5fb.pth'), + 'hrnet_w40': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w40-7cd397a4.pth'), + 'hrnet_w44': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w44-c9ac8c18.pth'), + 'hrnet_w48': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w48-abd2e6ab.pth'), + 'hrnet_w64': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w64-b47cc881.pth'), +} + +cfg_cls = dict( + hrnet_w18_small=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(1,), + NUM_CHANNELS=(32,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(16, 32), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=1, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(16, 32, 64), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=1, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(16, 32, 64, 128), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w18_small_v2=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(2,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=3, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=2, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w18=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w30=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(30, 60), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(30, 60, 120), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(30, 60, 120, 240), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w32=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(32, 64), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(32, 64, 128), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(32, 64, 128, 256), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w40=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(40, 80), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(40, 80, 160), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(40, 80, 160, 320), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w44=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(44, 88), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(44, 88, 176), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(44, 88, 176, 352), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w48=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(48, 96), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(48, 96, 192), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(48, 96, 192, 384), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w64=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(64, 128), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(64, 128, 256), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(64, 128, 256, 512), + FUSE_METHOD='SUM', + ), + ) +) + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.fuse_act = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): + error_msg = '' + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks)) + elif num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels)) + elif num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(num_branches, len(num_inchannels)) + if error_msg: + _logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): + downsample = None + if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM), + ) + + layers = [block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)] + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + for i in range(num_branches): + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return nn.Identity() + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM), + nn.Upsample(scale_factor=2 ** (j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(nn.Identity()) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM), + nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x: List[torch.Tensor]): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i, branch in enumerate(self.branches): + x[i] = branch(x[i]) + + x_fuse = [] + for i, fuse_outer in enumerate(self.fuse_layers): + y = x[0] if i == 0 else fuse_outer[0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + fuse_outer[j](x[j]) + x_fuse.append(self.fuse_act(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, head='classification'): + super(HighResolutionNet, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + + stem_width = cfg['STEM_WIDTH'] + self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM) + self.act2 = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + self.stage2_cfg = cfg['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) + + self.head = head + self.head_channels = None # set if _make_head called + if head == 'classification': + # Classification Head + self.num_features = 2048 + self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + elif head == 'incre': + self.num_features = 2048 + self.incre_modules, _, _ = self._make_head(pre_stage_channels, True) + else: + self.incre_modules = None + self.num_features = 256 + + curr_stride = 2 + # module names aren't actually valid here, hook or FeatureNet based extraction would not work + self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem')] + for i, c in enumerate(self.head_channels if self.head_channels else num_channels): + curr_stride *= 2 + c = c * 4 if self.head_channels else c # head block expansion factor of 4 + self.feature_info += [dict(num_chs=c, reduction=curr_stride, module=f'stage{i + 1}')] + + self.init_weights() + + def _make_head(self, pre_stage_channels, incre_only=False): + head_block = Bottleneck + self.head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_modules.append(self._make_layer(head_block, channels, self.head_channels[i], 1, stride=1)) + incre_modules = nn.ModuleList(incre_modules) + if incre_only: + return incre_modules, None, None + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels) - 1): + in_channels = self.head_channels[i] * head_block.expansion + out_channels = self.head_channels[i + 1] * head_block.expansion + downsamp_module = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=self.head_channels[3] * head_block.expansion, + out_channels=self.num_features, kernel_size=1, stride=1, padding=0 + ), + nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(nn.Identity()) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=_BN_MOMENTUM), + ) + + layers = [block(inplanes, planes, stride, downsample)] + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + reset_multi_scale_output = multi_scale_output or i < num_modules - 1 + modules.append(HighResolutionModule( + num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def stages(self, x) -> List[torch.Tensor]: + x = self.layer1(x) + + xl = [t(x) for i, t in enumerate(self.transition1)] + yl = self.stage2(xl) + + xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition2)] + yl = self.stage3(xl) + + xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition3)] + yl = self.stage4(xl) + return yl + + def forward_features(self, x): + # Stem + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + # Stages + yl = self.stages(x) + + # Classification Head + y = self.incre_modules[0](yl[0]) + for i, down in enumerate(self.downsamp_modules): + y = self.incre_modules[i + 1](yl[i + 1]) + down(y) + y = self.final_layer(y) + return y + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x + + +class HighResolutionNetFeatures(HighResolutionNet): + """HighResolutionNet feature extraction + + The design of HRNet makes it easy to grab feature maps, this class provides a simple wrapper to do so. + It would be more complicated to use the FeatureNet helpers. + + The `feature_location=incre` allows grabbing increased channel count features using part of the + classification head. If `feature_location=''` the default HRNet features are returned. First stem + conv is used for stride 2 features. + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, + feature_location='incre', out_indices=(0, 1, 2, 3, 4)): + assert feature_location in ('incre', '') + super(HighResolutionNetFeatures, self).__init__( + cfg, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool, + drop_rate=drop_rate, head=feature_location) + self.feature_info = FeatureInfo(self.feature_info, out_indices) + self._out_idx = {i for i in out_indices} + + def forward_features(self, x): + assert False, 'Not supported' + + def forward(self, x) -> List[torch.tensor]: + out = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + if 0 in self._out_idx: + out.append(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + x = self.stages(x) + if self.incre_modules is not None: + x = [incre(f) for f, incre in zip(x, self.incre_modules)] + for i, f in enumerate(x): + if i + 1 in self._out_idx: + out.append(f) + return out + + +def _create_hrnet(variant, pretrained, **model_kwargs): + model_cls = HighResolutionNet + features_only = False + if model_kwargs.pop('features_only', False): + model_cls = HighResolutionNetFeatures + model_kwargs['num_classes'] = 0 + features_only = True + model = build_model_with_cfg( + model_cls, variant, pretrained, default_cfg=default_cfgs[variant], + model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +@register_model +def hrnet_w18_small(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w18_small', pretrained, **kwargs) + + +@register_model +def hrnet_w18_small_v2(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w18_small_v2', pretrained, **kwargs) + + +@register_model +def hrnet_w18(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w18', pretrained, **kwargs) + + +@register_model +def hrnet_w30(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w30', pretrained, **kwargs) + + +@register_model +def hrnet_w32(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w32', pretrained, **kwargs) + + +@register_model +def hrnet_w40(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w40', pretrained, **kwargs) + + +@register_model +def hrnet_w44(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w44', pretrained, **kwargs) + + +@register_model +def hrnet_w48(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w48', pretrained, **kwargs) + + +@register_model +def hrnet_w64(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w64', pretrained, **kwargs) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..a5efa33057cb5a81932e7a65083f4668c98f7edd --- /dev/null +++ b/timm/models/inception_resnet_v2.py @@ -0,0 +1,354 @@ +""" Pytorch Inception-Resnet-V2 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['InceptionResnetV2'] + +default_cfgs = { + # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz + 'inception_resnet_v2': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth', + 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + }, + # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz + 'ens_adv_inception_resnet_v2': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth', + 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + } +} + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=.001) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_5b(nn.Module): + def __init__(self): + super(Mixed_5b, self).__init__() + + self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(192, 48, kernel_size=1, stride=1), + BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(192, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(192, 64, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block35(nn.Module): + def __init__(self, scale=1.0): + super(Block35, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), + BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) + ) + + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_6a(nn.Module): + def __init__(self): + super(Mixed_6a, self).__init__() + + self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Block17(nn.Module): + def __init__(self, scale=1.0): + super(Block17, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 128, kernel_size=1, stride=1), + BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_7a(nn.Module): + def __init__(self): + super(Mixed_7a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), + BasicConv2d(288, 320, kernel_size=3, stride=2) + ) + + self.branch3 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block8(nn.Module): + + def __init__(self, scale=1.0, no_relu=False): + super(Block8, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(2080, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), + BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + ) + + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + self.relu = None if no_relu else nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + if self.relu is not None: + out = self.relu(out) + return out + + +class InceptionResnetV2(nn.Module): + def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'): + super(InceptionResnetV2, self).__init__() + self.drop_rate = drop_rate + self.num_classes = num_classes + self.num_features = 1536 + assert output_stride == 32 + + self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')] + + self.maxpool_3a = nn.MaxPool2d(3, stride=2) + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')] + + self.maxpool_5a = nn.MaxPool2d(3, stride=2) + self.mixed_5b = Mixed_5b() + self.repeat = nn.Sequential( + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17) + ) + self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')] + + self.mixed_6a = Mixed_6a() + self.repeat_1 = nn.Sequential( + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10) + ) + self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')] + + self.mixed_7a = Mixed_7a() + self.repeat_2 = nn.Sequential( + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20) + ) + self.block8 = Block8(no_relu=True) + self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) + self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')] + + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.classif + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv2d_1a(x) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + x = self.mixed_6a(x) + x = self.repeat_1(x) + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classif(x) + return x + + +def _create_inception_resnet_v2(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + InceptionResnetV2, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) + + +@register_model +def inception_resnet_v2(pretrained=False, **kwargs): + r"""InceptionResnetV2 model architecture from the + `"InceptionV4, Inception-ResNet..." ` paper. + """ + return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs) + + +@register_model +def ens_adv_inception_resnet_v2(pretrained=False, **kwargs): + r""" Ensemble Adversarially trained InceptionResnetV2 model architecture + As per https://arxiv.org/abs/1705.07204 and + https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models. + """ + return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs) diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae7105feeb107cf74978e917cfdc15f1bea733e --- /dev/null +++ b/timm/models/inception_v3.py @@ -0,0 +1,468 @@ +""" Inception-V3 + +Originally from torchvision Inception3 model +Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .registry import register_model +from .layers import trunc_normal_, create_classifier, Linear + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + # original PyTorch weights, ported from Tensorflow but modified + 'inception_v3': _cfg( + url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', + has_aux=True), # checkpoint has aux logit layer weights + # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) + 'tf_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', + num_classes=1001, has_aux=False), + # my port of Tensorflow adversarially trained Inception V3 from + # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz + 'adv_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', + num_classes=1001, has_aux=False), + # from gluon pretrained models, best performing in terms of accuracy/loss metrics + # https://gluon-cv.mxnet.io/model_zoo/classification.html + 'gluon_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', + mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults + std=IMAGENET_DEFAULT_STD, # also works well with inception defaults + has_aux=False, + ) +} + + +class InceptionA(nn.Module): + + def __init__(self, in_channels, pool_features, conv_block=None): + super(InceptionA, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) + + self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) + self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) + + self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionB(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionB, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionC(nn.Module): + + def __init__(self, in_channels, channels_7x7, conv_block=None): + super(InceptionC, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) + + c7 = channels_7x7 + self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) + + self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionD(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionD, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) + + self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + outputs = [branch3x3, branch7x7x3, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionE(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionE, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) + + self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) + self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) + self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) + self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes, conv_block=None): + super(InceptionAux, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.conv0 = conv_block(in_channels, 128, kernel_size=1) + self.conv1 = conv_block(128, 768, kernel_size=5) + self.conv1.stddev = 0.01 + self.fc = Linear(768, num_classes) + self.fc.stddev = 0.001 + + def forward(self, x): + # N x 768 x 17 x 17 + x = F.avg_pool2d(x, kernel_size=5, stride=3) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + + +class InceptionV3(nn.Module): + """Inception-V3 with no AuxLogits + FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False): + super(InceptionV3, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.aux_logits = aux_logits + + self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = InceptionA(192, pool_features=32) + self.Mixed_5c = InceptionA(256, pool_features=64) + self.Mixed_5d = InceptionA(288, pool_features=64) + self.Mixed_6a = InceptionB(288) + self.Mixed_6b = InceptionC(768, channels_7x7=128) + self.Mixed_6c = InceptionC(768, channels_7x7=160) + self.Mixed_6d = InceptionC(768, channels_7x7=160) + self.Mixed_6e = InceptionC(768, channels_7x7=192) + if aux_logits: + self.AuxLogits = InceptionAux(768, num_classes) + else: + self.AuxLogits = None + self.Mixed_7a = InceptionD(768) + self.Mixed_7b = InceptionE(1280) + self.Mixed_7c = InceptionE(2048) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'), + dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'), + dict(num_chs=288, reduction=8, module='Mixed_5d'), + dict(num_chs=768, reduction=16, module='Mixed_6e'), + dict(num_chs=2048, reduction=32, module='Mixed_7c'), + ] + + self.num_features = 2048 + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + trunc_normal_(m.weight, std=stddev) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward_preaux(self, x): + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = self.Pool1(x) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = self.Pool2(x) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + return x + + def forward_postaux(self, x): + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + return x + + def forward_features(self, x): + x = self.forward_preaux(x) + x = self.forward_postaux(x) + return x + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +class InceptionV3Aux(InceptionV3): + """InceptionV3 with AuxLogits + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True): + super(InceptionV3Aux, self).__init__( + num_classes, in_chans, drop_rate, global_pool, aux_logits) + + def forward_features(self, x): + x = self.forward_preaux(x) + aux = self.AuxLogits(x) if self.training else None + x = self.forward_postaux(x) + return x, aux + + def forward(self, x): + x, aux = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x, aux + + +def _create_inception_v3(variant, pretrained=False, **kwargs): + default_cfg = default_cfgs[variant] + aux_logits = kwargs.pop('aux_logits', False) + if aux_logits: + assert not kwargs.pop('features_only', False) + model_cls = InceptionV3Aux + load_strict = default_cfg['has_aux'] + else: + model_cls = InceptionV3 + load_strict = not default_cfg['has_aux'] + return build_model_with_cfg( + model_cls, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_strict=load_strict, **kwargs) + + +@register_model +def inception_v3(pretrained=False, **kwargs): + # original PyTorch weights, ported from Tensorflow but modified + model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_inception_v3(pretrained=False, **kwargs): + # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) + model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def adv_inception_v3(pretrained=False, **kwargs): + # my port of Tensorflow adversarially trained Inception V3 from + # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz + model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def gluon_inception_v3(pretrained=False, **kwargs): + # from gluon pretrained models, best performing in terms of accuracy/loss metrics + # https://gluon-cv.mxnet.io/model_zoo/classification.html + model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py new file mode 100644 index 0000000000000000000000000000000000000000..40a0f2911e090d7be5a47d5658d48f9fe3485417 --- /dev/null +++ b/timm/models/inception_v4.py @@ -0,0 +1,313 @@ +""" Pytorch Inception-V4 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['InceptionV4'] + +default_cfgs = { + 'inception_v4': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth', + 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'features.0.conv', 'classifier': 'last_linear', + } +} + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=0.001) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed3a(nn.Module): + def __init__(self): + super(Mixed3a, self).__init__() + self.maxpool = nn.MaxPool2d(3, stride=2) + self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) + + def forward(self, x): + x0 = self.maxpool(x) + x1 = self.conv(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed4a(nn.Module): + def __init__(self): + super(Mixed4a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(64, 96, kernel_size=(3, 3), stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed5a(nn.Module): + def __init__(self): + super(Mixed5a, self).__init__() + self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) + self.maxpool = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.conv(x) + x1 = self.maxpool(x) + out = torch.cat((x0, x1), 1) + return out + + +class InceptionA(nn.Module): + def __init__(self): + super(InceptionA, self).__init__() + self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(384, 96, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class ReductionA(nn.Module): + def __init__(self): + super(ReductionA, self).__init__() + self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), + BasicConv2d(224, 256, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class InceptionB(nn.Module): + def __init__(self): + super(InceptionB, self).__init__() + self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1024, 128, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class ReductionB(nn.Module): + def __init__(self): + super(ReductionB, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(320, 320, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class InceptionC(nn.Module): + def __init__(self): + super(InceptionC, self).__init__() + + self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) + + self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + + self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1536, 256, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + + x1_0 = self.branch1_0(x) + x1_1a = self.branch1_1a(x1_0) + x1_1b = self.branch1_1b(x1_0) + x1 = torch.cat((x1_1a, x1_1b), 1) + + x2_0 = self.branch2_0(x) + x2_1 = self.branch2_1(x2_0) + x2_2 = self.branch2_2(x2_1) + x2_3a = self.branch2_3a(x2_2) + x2_3b = self.branch2_3b(x2_2) + x2 = torch.cat((x2_3a, x2_3b), 1) + + x3 = self.branch3(x) + + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class InceptionV4(nn.Module): + def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'): + super(InceptionV4, self).__init__() + assert output_stride == 32 + self.drop_rate = drop_rate + self.num_classes = num_classes + self.num_features = 1536 + + self.features = nn.Sequential( + BasicConv2d(in_chans, 32, kernel_size=3, stride=2), + BasicConv2d(32, 32, kernel_size=3, stride=1), + BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), + Mixed3a(), + Mixed4a(), + Mixed5a(), + InceptionA(), + InceptionA(), + InceptionA(), + InceptionA(), + ReductionA(), # Mixed6a + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + ReductionB(), # Mixed7a + InceptionC(), + InceptionC(), + InceptionC(), + ) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='features.2'), + dict(num_chs=160, reduction=4, module='features.3'), + dict(num_chs=384, reduction=8, module='features.9'), + dict(num_chs=1024, reduction=16, module='features.17'), + dict(num_chs=1536, reduction=32, module='features.21'), + ] + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + return self.features(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + +def _create_inception_v4(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + InceptionV4, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), **kwargs) + + +@register_model +def inception_v4(pretrained=False, **kwargs): + return _create_inception_v4('inception_v4', pretrained, **kwargs) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dac1beb8e6854d7c8a4f4a074eac1b3d44fa549e --- /dev/null +++ b/timm/models/layers/__init__.py @@ -0,0 +1,33 @@ +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .anti_aliasing import AntiAliasDownsampleLayer +from .blur_pool import BlurPool2d +from .classifier import ClassifierHead, create_classifier +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config +from .conv2d_same import Conv2dSame +from .conv_bn_act import ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import create_attn +from .create_conv2d import create_conv2d +from .create_norm_act import create_norm_act, get_norm_act_layer +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule +from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple +from .inplace_abn import InplaceAbn +from .linear import Linear +from .mixed_conv2d import MixedConv2d +from .norm_act import BatchNormAct2d +from .padding import get_padding +from .pool2d_same import AvgPool2dSame, create_pool2d +from .se import SEModule +from .selective_kernel import SelectiveKernelConv +from .separable_conv import SeparableConv2d, SeparableConvBnAct +from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttnConv2d +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .weight_init import trunc_normal_ diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..e16b3bd3a1898365530c1ffc5154a0a4746a136e --- /dev/null +++ b/timm/models/layers/activations.py @@ -0,0 +1,145 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + NOTE: I don't have a working inplace variant + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + + def forward(self, x): + return mish(x) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + +def hard_mish(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + if inplace: + return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) + else: + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_mish(x, self.inplace) + + +class PReLU(nn.PReLU): + """Applies PReLU (w/ dummy inplace arg) + """ + def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: + super(PReLU, self).__init__(num_parameters=num_parameters, init=init) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.prelu(input, self.weight) + + +def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x) + + +class GELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) + """ + def __init__(self, inplace: bool = False): + super(GELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) diff --git a/timm/models/layers/activations_jit.py b/timm/models/layers/activations_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a516530ad0abf41f720ac83d02791179bb7b67 --- /dev/null +++ b/timm/models/layers/activations_jit.py @@ -0,0 +1,90 @@ +""" Activations + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) + + +@torch.jit.script +def hard_mish_jit(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishJit, self).__init__() + + def forward(self, x): + return hard_mish_jit(x) diff --git a/timm/models/layers/activations_me.py b/timm/models/layers/activations_me.py new file mode 100644 index 0000000000000000000000000000000000000000..0441f7c41fa41eabee70ca9f7857c6cc49e4bf47 --- /dev/null +++ b/timm/models/layers/activations_me.py @@ -0,0 +1,208 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_mish_jit_fwd(x): + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +@torch.jit.script +def hard_mish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= -2.) + m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) + return grad_output * m + + +class HardMishJitAutoFn(torch.autograd.Function): + """ A memory efficient, jit scripted variant of Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_mish_jit_bwd(x, grad_output) + + +def hard_mish_me(x, inplace: bool = False): + return HardMishJitAutoFn.apply(x) + + +class HardMishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishMe, self).__init__() + + def forward(self, x): + return HardMishJitAutoFn.apply(x) + + + diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..d2bb9f7216a01209ef1e205c4e127f1b6f593a74 --- /dev/null +++ b/timm/models/layers/adaptive_avgmax_pool.py @@ -0,0 +1,119 @@ +""" PyTorch selectable adaptive pooling +Adaptive pooling with the ability to select the type of pooling from: + * 'avg' - Average pooling + * 'max' - Max pooling + * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 + * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim + +Both a functional and a nn.Module version of the pooling is provided. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def adaptive_pool_feat_mult(pool_type='avg'): + if pool_type == 'catavgmax': + return 2 + else: + return 1 + + +def adaptive_avgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return 0.5 * (x_avg + x_max) + + +def adaptive_catavgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return torch.cat((x_avg, x_max), 1) + + +def select_adaptive_pool2d(x, pool_type='avg', output_size=1): + """Selectable global pooling function with dynamic input kernel size + """ + if pool_type == 'avg': + x = F.adaptive_avg_pool2d(x, output_size) + elif pool_type == 'avgmax': + x = adaptive_avgmax_pool2d(x, output_size) + elif pool_type == 'catavgmax': + x = adaptive_catavgmax_pool2d(x, output_size) + elif pool_type == 'max': + x = F.adaptive_max_pool2d(x, output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + return x + + +class FastAdaptiveAvgPool2d(nn.Module): + def __init__(self, flatten=False): + super(FastAdaptiveAvgPool2d, self).__init__() + self.flatten = flatten + + def forward(self, x): + return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True) + + +class AdaptiveAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_avgmax_pool2d(x, self.output_size) + + +class AdaptiveCatAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveCatAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_catavgmax_pool2d(x, self.output_size) + + +class SelectAdaptivePool2d(nn.Module): + """Selectable global pooling layer with dynamic input kernel size + """ + def __init__(self, output_size=1, pool_type='fast', flatten=False): + super(SelectAdaptivePool2d, self).__init__() + self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing + self.flatten = flatten + if pool_type == '': + self.pool = nn.Identity() # pass through + elif pool_type == 'fast': + assert output_size == 1 + self.pool = FastAdaptiveAvgPool2d(self.flatten) + self.flatten = False + elif pool_type == 'avg': + self.pool = nn.AdaptiveAvgPool2d(output_size) + elif pool_type == 'avgmax': + self.pool = AdaptiveAvgMaxPool2d(output_size) + elif pool_type == 'catavgmax': + self.pool = AdaptiveCatAvgMaxPool2d(output_size) + elif pool_type == 'max': + self.pool = nn.AdaptiveMaxPool2d(output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + + def is_identity(self): + return self.pool_type == '' + + def forward(self, x): + x = self.pool(x) + if self.flatten: + x = x.flatten(1) + return x + + def feat_mult(self): + return adaptive_pool_feat_mult(self.pool_type) + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + 'pool_type=' + self.pool_type \ + + ', flatten=' + str(self.flatten) + ')' + diff --git a/timm/models/layers/anti_aliasing.py b/timm/models/layers/anti_aliasing.py new file mode 100644 index 0000000000000000000000000000000000000000..9d3837e8c6d89169659afb88385a8c1fac49c213 --- /dev/null +++ b/timm/models/layers/anti_aliasing.py @@ -0,0 +1,60 @@ +import torch +import torch.nn.parallel +import torch.nn as nn +import torch.nn.functional as F + + +class AntiAliasDownsampleLayer(nn.Module): + def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False): + super(AntiAliasDownsampleLayer, self).__init__() + if no_jit: + self.op = Downsample(channels, filt_size, stride) + else: + self.op = DownsampleJIT(channels, filt_size, stride) + + # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls + + def forward(self, x): + return self.op(x) + + +@torch.jit.script +class DownsampleJIT(object): + def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2): + self.channels = channels + self.stride = stride + self.filt_size = filt_size + assert self.filt_size == 3 + assert stride == 2 + self.filt = {} # lazy init by device for DataParallel compat + + def _create_filter(self, like: torch.Tensor): + filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) + filt = filt[:, None] * filt[None, :] + filt = filt / torch.sum(filt) + return filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) + + def __call__(self, input: torch.Tensor): + input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') + filt = self.filt.get(str(input.device), self._create_filter(input)) + return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1]) + + +class Downsample(nn.Module): + def __init__(self, channels=None, filt_size=3, stride=2): + super(Downsample, self).__init__() + self.channels = channels + self.filt_size = filt_size + self.stride = stride + + assert self.filt_size == 3 + filt = torch.tensor([1., 2., 1.]) + filt = filt[:, None] * filt[None, :] + filt = filt / torch.sum(filt) + + # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) + self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) + + def forward(self, input): + input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') + return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) diff --git a/timm/models/layers/blur_pool.py b/timm/models/layers/blur_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..399cbe355359c087fafd847f383808ce60c0977c --- /dev/null +++ b/timm/models/layers/blur_pool.py @@ -0,0 +1,58 @@ +""" +BlurPool layer inspired by + - Kornia's Max_BlurPool2d + - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + +FIXME merge this impl with those in `anti_aliasing.py` + +Hacked together by Chris Ha and Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import Dict +from .padding import get_padding + + +class BlurPool2d(nn.Module): + r"""Creates a module that computes blurs and downsample a given feature map. + See :cite:`zhang2019shiftinvar` for more details. + Corresponds to the Downsample class, which does blurring and subsampling + + Args: + channels = Number of input channels + filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. + stride (int): downsampling filter stride + + Returns: + torch.Tensor: the transformed tensor. + """ + filt: Dict[str, torch.Tensor] + + def __init__(self, channels, filt_size=3, stride=2) -> None: + super(BlurPool2d, self).__init__() + assert filt_size > 1 + self.channels = channels + self.filt_size = filt_size + self.stride = stride + pad_size = [get_padding(filt_size, stride, dilation=1)] * 4 + self.padding = nn.ReflectionPad2d(pad_size) + self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat + self.filt = {} # lazy init by device for DataParallel compat + + def _create_filter(self, like: torch.Tensor): + blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device) + return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1) + + def _apply(self, fn): + # override nn.Module _apply, reset filter cache if used + self.filt = {} + super(BlurPool2d, self)._apply(fn) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + C = input_tensor.shape[1] + blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor)) + return F.conv2d( + self.padding(input_tensor), blur_filt, stride=self.stride, groups=C) diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py new file mode 100644 index 0000000000000000000000000000000000000000..44e2fe6da23aae0b884fcb4e6b95c59b19c17fef --- /dev/null +++ b/timm/models/layers/cbam.py @@ -0,0 +1,99 @@ +""" CBAM (sort-of) Attention + +Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 + +WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on +some tasks, especially fine-grained it seems. I may end up removing this impl. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +import torch.nn.functional as F +from .conv_bn_act import ConvBnAct + + +class ChannelAttn(nn.Module): + """ Original CBAM channel attention module, currently avg + max pool variant only. + """ + def __init__(self, channels, reduction=16, act_layer=nn.ReLU): + super(ChannelAttn, self).__init__() + self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) + + def forward(self, x): + x_avg = x.mean((2, 3), keepdim=True) + x_max = F.adaptive_max_pool2d(x, 1) + x_avg = self.fc2(self.act(self.fc1(x_avg))) + x_max = self.fc2(self.act(self.fc1(x_max))) + x_attn = x_avg + x_max + return x * x_attn.sigmoid() + + +class LightChannelAttn(ChannelAttn): + """An experimental 'lightweight' that sums avg + max pool first + """ + def __init__(self, channels, reduction=16): + super(LightChannelAttn, self).__init__(channels, reduction) + + def forward(self, x): + x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1) + x_attn = self.fc2(self.act(self.fc1(x_pool))) + return x * x_attn.sigmoid() + + +class SpatialAttn(nn.Module): + """ Original CBAM spatial attention module + """ + def __init__(self, kernel_size=7): + super(SpatialAttn, self).__init__() + self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) + + def forward(self, x): + x_avg = torch.mean(x, dim=1, keepdim=True) + x_max = torch.max(x, dim=1, keepdim=True)[0] + x_attn = torch.cat([x_avg, x_max], dim=1) + x_attn = self.conv(x_attn) + return x * x_attn.sigmoid() + + +class LightSpatialAttn(nn.Module): + """An experimental 'lightweight' variant that sums avg_pool and max_pool results. + """ + def __init__(self, kernel_size=7): + super(LightSpatialAttn, self).__init__() + self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) + + def forward(self, x): + x_avg = torch.mean(x, dim=1, keepdim=True) + x_max = torch.max(x, dim=1, keepdim=True)[0] + x_attn = 0.5 * x_avg + 0.5 * x_max + x_attn = self.conv(x_attn) + return x * x_attn.sigmoid() + + +class CbamModule(nn.Module): + def __init__(self, channels, spatial_kernel_size=7): + super(CbamModule, self).__init__() + self.channel = ChannelAttn(channels) + self.spatial = SpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + + +class LightCbamModule(nn.Module): + def __init__(self, channels, spatial_kernel_size=7): + super(LightCbamModule, self).__init__() + self.channel = LightChannelAttn(channels) + self.spatial = LightSpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..89fe545819dd19f9aee6fbc12ba59c38e0ca1079 --- /dev/null +++ b/timm/models/layers/classifier.py @@ -0,0 +1,43 @@ +""" Classifier head and layer factory + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .linear import Linear + + +def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): + flatten = not use_conv # flatten when we use a Linear layer after pooling + if not pool_type: + assert num_classes == 0 or use_conv,\ + 'Pooling can only be disabled if classifier is also removed or conv classifier is used' + flatten = False # disable flattening if pooling is pass-through (no pooling) + global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten) + num_pooled_features = num_features * global_pool.feat_mult() + if num_classes <= 0: + fc = nn.Identity() # pass-through (no classifier) + elif use_conv: + fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True) + else: + # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue + fc = Linear(num_pooled_features, num_classes, bias=True) + return global_pool, fc + + +class ClassifierHead(nn.Module): + """Classifier head w/ configurable global pooling and dropout.""" + + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type) + + def forward(self, x): + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + return x diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4bbca84d6f12e0fb875b4edb435b976fc649d6 --- /dev/null +++ b/timm/models/layers/cond_conv2d.py @@ -0,0 +1,122 @@ +""" PyTorch Conditionally Parameterized Convolution (CondConv) + +Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference +(https://arxiv.org/abs/1904.04971) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import math +from functools import partial +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .helpers import to_2tuple +from .conv2d_same import conv2d_same +from .padding import get_padding_value + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditionally Parameterized Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = to_2tuple(padding_val) + self.dilation = to_2tuple(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out diff --git a/timm/models/layers/config.py b/timm/models/layers/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f07b9d782ba0597c174dee81097c28280335fdba --- /dev/null +++ b/timm/models/layers/config.py @@ -0,0 +1,115 @@ +""" Model / Layer Config singleton state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False diff --git a/timm/models/layers/conv2d_same.py b/timm/models/layers/conv2d_same.py new file mode 100644 index 0000000000000000000000000000000000000000..75f0f98d4ec1e3f4a0dc004b977815afaa25e7fc --- /dev/null +++ b/timm/models/layers/conv2d_same.py @@ -0,0 +1,42 @@ +""" Conv2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + +from .padding import pad_same, get_padding_value + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + x = pad_same(x, weight.shape[-2:], stride, dilation) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py new file mode 100644 index 0000000000000000000000000000000000000000..907353574339f43e6c2a80c862b59f0ecf025b44 --- /dev/null +++ b/timm/models/layers/conv_bn_act.py @@ -0,0 +1,40 @@ +""" Conv2d + BN + Act + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import convert_norm_act_type + + +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True, + drop_block=None, aa_layer=None): + super(ConvBnAct, self).__init__() + use_aa = aa_layer is not None + + self.conv = create_conv2d( + in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, + padding=padding, dilation=dilation, groups=groups, bias=False) + + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) + self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None + + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.aa is not None: + x = self.aa(x) + return x diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py new file mode 100644 index 0000000000000000000000000000000000000000..426c36810f3134728ae9e4eddaaaded561fa2c5d --- /dev/null +++ b/timm/models/layers/create_act.py @@ -0,0 +1,133 @@ +""" Activation Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from .activations import * +from .activations_jit import * +from .activations_me import * +from .config import is_exportable, is_scriptable, is_no_jit + +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code +# will use native version if present. Eventually, the custom Swish layers will be removed +# and only native 'silu' will be used. +_has_silu = 'silu' in dir(torch.nn.functional) + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=mish, + relu=F.relu, + relu6=F.relu6, + leaky_relu=F.leaky_relu, + elu=F.elu, + celu=F.celu, + selu=F.selu, + gelu=gelu, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=hard_sigmoid, + hard_swish=hard_swish, + hard_mish=hard_mish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=mish_jit, + hard_sigmoid=hard_sigmoid_jit, + hard_swish=hard_swish_jit, + hard_mish=hard_mish_jit +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=mish_me, + hard_sigmoid=hard_sigmoid_me, + hard_swish=hard_swish_me, + hard_mish=hard_mish_me, +) + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + leaky_relu=nn.LeakyReLU, + elu=nn.ELU, + prelu=PReLU, + celu=nn.CELU, + selu=nn.SELU, + gelu=GELU, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=HardSigmoid, + hard_swish=HardSwish, + hard_mish=HardMish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=MishJit, + hard_sigmoid=HardSigmoidJit, + hard_swish=HardSwishJit, + hard_mish=HardMishJit +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=MishMe, + hard_sigmoid=HardSigmoidMe, + hard_swish=HardSwishMe, + hard_mish=HardMishMe, +) + + +def get_act_fn(name='relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if not (is_no_jit() or is_exportable() or is_scriptable()): + # If not exporting or scripting the model, first look for a memory-efficient version with + # custom autograd, then fallback + if name in _ACT_FN_ME: + return _ACT_FN_ME[name] + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + if not (is_no_jit() or is_exportable()): + if name in _ACT_FN_JIT: + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name='relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if not (is_no_jit() or is_exportable() or is_scriptable()): + if name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + if not (is_no_jit() or is_exportable()): + if name in _ACT_LAYER_JIT: + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + +def create_act_layer(name, inplace=False, **kwargs): + act_layer = get_act_layer(name) + if act_layer is not None: + return act_layer(inplace=inplace, **kwargs) + else: + return None diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..59ecd858e0da49bea53bb114900596bba1d6c1a5 --- /dev/null +++ b/timm/models/layers/create_attn.py @@ -0,0 +1,37 @@ +""" Select AttentionFactory Method + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from .se import SEModule, EffectiveSEModule +from .eca import EcaModule, CecaModule +from .cbam import CbamModule, LightCbamModule + + +def create_attn(attn_type, channels, **kwargs): + module_cls = None + if attn_type is not None: + if isinstance(attn_type, str): + attn_type = attn_type.lower() + if attn_type == 'se': + module_cls = SEModule + elif attn_type == 'ese': + module_cls = EffectiveSEModule + elif attn_type == 'eca': + module_cls = EcaModule + elif attn_type == 'ceca': + module_cls = CecaModule + elif attn_type == 'cbam': + module_cls = CbamModule + elif attn_type == 'lcbam': + module_cls = LightCbamModule + else: + assert False, "Invalid attn module (%s)" % attn_type + elif isinstance(attn_type, bool): + if attn_type: + module_cls = SEModule + else: + module_cls = attn_type + if module_cls is not None: + return module_cls(channels, **kwargs) + return None diff --git a/timm/models/layers/create_conv2d.py b/timm/models/layers/create_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..0134b05c2717ebaeed3dba32d69f7cf983928e86 --- /dev/null +++ b/timm/models/layers/create_conv2d.py @@ -0,0 +1,30 @@ +""" Create Conv2d Factory Method + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d +from .conv2d_same import create_conv2d_pad + + +def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): + """ Select a 2d convolution implementation based on arguments + Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. + + Used extensively by EfficientNet, MobileNetv3 and related networks. + """ + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + assert 'groups' not in kwargs # MixedConv groups are defined by kernel list + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_channels if depthwise else kwargs.pop('groups', 1) + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + return m diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7e529e19642128646090aa2b2b862f92dc9686 --- /dev/null +++ b/timm/models/layers/create_norm_act.py @@ -0,0 +1,74 @@ +""" NormAct (Normalizaiton + Activation Layer) Factory + +Create norm + act combo modules that attempt to be backwards compatible with separate norm + act +isntances in models. Where these are used it will be possible to swap separate BN + act layers with +combined modules like IABN or EvoNorms. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import types +import functools + +import torch +import torch.nn as nn + +from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .norm_act import BatchNormAct2d, GroupNormAct +from .inplace_abn import InplaceAbn + +_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} +_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type + +def get_norm_act_layer(layer_class): + layer_class = layer_class.replace('_', '').lower() + if layer_class.startswith("batchnorm"): + layer = BatchNormAct2d + elif layer_class.startswith("groupnorm"): + layer = GroupNormAct + elif layer_class == "evonormbatch": + layer = EvoNormBatch2d + elif layer_class == "evonormsample": + layer = EvoNormSample2d + elif layer_class == "iabn" or layer_class == "inplaceabn": + layer = InplaceAbn + else: + assert False, "Invalid norm_act layer (%s)" % layer_class + return layer + + +def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): + layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu + assert len(layer_parts) in (1, 2) + layer = get_norm_act_layer(layer_parts[0]) + #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + if jit: + layer_instance = torch.jit.script(layer_instance) + return layer_instance + + +def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) + norm_act_args = norm_kwargs.copy() if norm_kwargs else {} + if isinstance(norm_layer, str): + norm_act_layer = get_norm_act_layer(norm_layer) + elif norm_layer in _NORM_ACT_TYPES: + norm_act_layer = norm_layer + elif isinstance(norm_layer, (types.FunctionType, functools.partial)): + # assuming this is a lambda/fn/bound partial that creates norm_act layer + norm_act_layer = norm_layer + else: + type_name = norm_layer.__name__.lower() + if type_name.startswith('batchnorm'): + norm_act_layer = BatchNormAct2d + elif type_name.startswith('groupnorm'): + norm_act_layer = GroupNormAct + else: + assert False, f"No equivalent norm_act layer for {type_name}" + if norm_act_layer in _NORM_ACT_REQUIRES_ARG: + # Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. + # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types + # It is intended that functions/partial does not trigger this, they should define act. + norm_act_args.update(dict(act_layer=act_layer)) + return norm_act_layer, norm_act_args diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..6de9e3f729f7f1ca29d4511f6c64733d3169fbec --- /dev/null +++ b/timm/models/layers/drop.py @@ -0,0 +1,168 @@ +""" DropBlock, DropPath + +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) + +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) + +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def drop_block_2d( + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + # seed_drop_rate, the gamma parameter + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) + else: + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +def drop_block_fast_2d( + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + if batchwise: + # one mask for whole batch, quite a bit faster + block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma + else: + # mask per batch element + block_mask = torch.rand_like(x) < gamma + block_mask = F.max_pool2d( + block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(1. - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1. - block_mask) + normal_noise * block_mask + else: + block_mask = 1 - block_mask + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + def __init__(self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False, + inplace=False, + batchwise=False, + fast=True): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + else: + return drop_block_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py new file mode 100644 index 0000000000000000000000000000000000000000..3a7f8b8241333c29dbd61dffea561a426e01497f --- /dev/null +++ b/timm/models/layers/eca.py @@ -0,0 +1,107 @@ +""" +ECA module from ECAnet + +paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks +https://arxiv.org/abs/1910.03151 + +Original ECA model borrowed from https://github.com/BangguWu/ECANet + +Modified circular ECA implementation and adaption for use in timm package +by Chris Ha https://github.com/VRandme + +Original License: + +MIT License + +Copyright (c) 2019 BangguWu, Qilong Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import math +from torch import nn +import torch.nn.functional as F + + +class EcaModule(nn.Module): + """Constructs an ECA module. + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + """ + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): + super(EcaModule, self).__init__() + assert kernel_size % 2 == 1 + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv + y = self.conv(y) + y = y.view(x.shape[0], -1, 1, 1).sigmoid() + return x * y.expand_as(x) + + +class CecaModule(nn.Module): + """Constructs a circular ECA module. + + ECA module where the conv uses circular padding rather than zero padding. + Unlike the spatial dimension, the channels do not have inherent ordering nor + locality. Although this module in essence, applies such an assumption, it is unnecessary + to limit the channels on either "edge" from being circularly adapted to each other. + This will fundamentally increase connectivity and possibly increase performance metrics + (accuracy, robustness), without significantly impacting resource metrics + (parameter size, throughput,latency, etc) + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + """ + + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): + super(CecaModule, self).__init__() + assert kernel_size % 2 == 1 + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + + # PyTorch circular padding mode is buggy as of pytorch 1.4 + # see https://github.com/pytorch/pytorch/pull/17240 + # implement manual circular padding + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) + self.padding = (kernel_size - 1) // 2 + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) + # Manually implement circular padding, F.pad does not seemed to be bugged + y = F.pad(y, (self.padding, self.padding), mode='circular') + y = self.conv(y) + y = y.view(x.shape[0], -1, 1, 1).sigmoid() + return x * y.expand_as(x) diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9023afd0e81dc8a76871d03141866217d59f4770 --- /dev/null +++ b/timm/models/layers/evo_norm.py @@ -0,0 +1,83 @@ +"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch + +An attempt at getting decent performing EvoNorms running in PyTorch. +While currently faster than other impl, still quite a ways off the built-in BN +in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). + +Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +import torch.nn as nn + + +class EvoNormBatch2d(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): + super(EvoNormBatch2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + param_shape = (1, num_features, 1, 1) + self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) + if apply_act: + self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.apply_act: + nn.init.ones_(self.v) + + def forward(self, x): + assert x.dim() == 4, 'expected 4D input' + x_type = x.dtype + if self.training: + var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) + n = x.numel() / x.shape[1] + self.running_var.copy_( + var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) + else: + var = self.running_var + + if self.apply_act: + v = self.v.to(dtype=x_type) + d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) + d = d.max((var + self.eps).sqrt().to(dtype=x_type)) + x = x / d + return x * self.weight + self.bias + + +class EvoNormSample2d(nn.Module): + def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): + super(EvoNormSample2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.groups = groups + self.eps = eps + param_shape = (1, num_features, 1, 1) + self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) + if apply_act: + self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.apply_act: + nn.init.ones_(self.v) + + def forward(self, x): + assert x.dim() == 4, 'expected 4D input' + B, C, H, W = x.shape + assert C % self.groups == 0 + if self.apply_act: + n = x * (x * self.v).sigmoid() + x = x.reshape(B, self.groups, -1) + x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() + x = x.reshape(B, C, H, W) + return x * self.weight + self.bias diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7e6e768065a6fac8de5c5982c1b61edc688ffa --- /dev/null +++ b/timm/models/layers/helpers.py @@ -0,0 +1,34 @@ +""" Layer/Module Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +from itertools import repeat +import torch +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +if TORCH_MAJOR == 1 and TORCH_MINOR < 8: + from torch._six import container_abcs +else: + import collections.abc as container_abcs + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + + + + diff --git a/timm/models/layers/inplace_abn.py b/timm/models/layers/inplace_abn.py new file mode 100644 index 0000000000000000000000000000000000000000..c7edac6256fbce2cd47a0c63f2492032553fc5f5 --- /dev/null +++ b/timm/models/layers/inplace_abn.py @@ -0,0 +1,87 @@ +import torch +from torch import nn as nn + +try: + from inplace_abn.functions import inplace_abn, inplace_abn_sync + has_iabn = True +except ImportError: + has_iabn = False + + def inplace_abn(x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): + raise ImportError( + "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'") + + def inplace_abn_sync(**kwargs): + inplace_abn(**kwargs) + + +class InplaceAbn(nn.Module): + """Activated Batch Normalization + + This gathers a BatchNorm and an activation function in a single module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + act_layer : str or nn.Module type + Name or type of the activation functions, one of: `leaky_relu`, `elu` + act_param : float + Negative slope for the `leaky_relu` activation. + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, + act_layer="leaky_relu", act_param=0.01, drop_block=None): + super(InplaceAbn, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.momentum = momentum + if apply_act: + if isinstance(act_layer, str): + assert act_layer in ('leaky_relu', 'elu', 'identity', '') + self.act_name = act_layer if act_layer else 'identity' + else: + # convert act layer passed as type to string + if act_layer == nn.ELU: + self.act_name = 'elu' + elif act_layer == nn.LeakyReLU: + self.act_name = 'leaky_relu' + elif act_layer == nn.Identity: + self.act_name = 'identity' + else: + assert False, f'Invalid act layer {act_layer.__name__} for IABN' + else: + self.act_name = 'identity' + self.act_param = act_param + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.running_mean, 0) + nn.init.constant_(self.running_var, 1) + if self.affine: + nn.init.constant_(self.weight, 1) + nn.init.constant_(self.bias, 0) + + def forward(self, x): + output = inplace_abn( + x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.act_name, self.act_param) + if isinstance(output, tuple): + output = output[0] + return output diff --git a/timm/models/layers/linear.py b/timm/models/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..38fe3380b067ea0b275c45ffd689afdeb4598f3c --- /dev/null +++ b/timm/models/layers/linear.py @@ -0,0 +1,19 @@ +""" Linear layer (alternate definition) +""" +import torch +import torch.nn.functional as F +from torch import nn as nn + + +class Linear(nn.Linear): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` + + Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting + weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None + return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) + else: + return F.linear(input, self.weight, self.bias) diff --git a/timm/models/layers/median_pool.py b/timm/models/layers/median_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..40bd71a7a3840aaebefd2af0a99605b845054cd7 --- /dev/null +++ b/timm/models/layers/median_pool.py @@ -0,0 +1,49 @@ +""" Median Pool +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F +from .helpers import to_2tuple, to_4tuple + + +class MedianPool2d(nn.Module): + """ Median pool (usable as median filter when stride=1) module. + + Args: + kernel_size: size of pooling kernel, int or 2-tuple + stride: pool stride, int or 2-tuple + padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad + same: override padding and enforce same padding, boolean + """ + def __init__(self, kernel_size=3, stride=1, padding=0, same=False): + super(MedianPool2d, self).__init__() + self.k = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + self.padding = to_4tuple(padding) # convert to l, r, t, b + self.same = same + + def _padding(self, x): + if self.same: + ih, iw = x.size()[2:] + if ih % self.stride[0] == 0: + ph = max(self.k[0] - self.stride[0], 0) + else: + ph = max(self.k[0] - (ih % self.stride[0]), 0) + if iw % self.stride[1] == 0: + pw = max(self.k[1] - self.stride[1], 0) + else: + pw = max(self.k[1] - (iw % self.stride[1]), 0) + pl = pw // 2 + pr = pw - pl + pt = ph // 2 + pb = ph - pt + padding = (pl, pr, pt, pb) + else: + padding = self.padding + return padding + + def forward(self, x): + x = F.pad(x, self._padding(x), mode='reflect') + x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] + return x diff --git a/timm/models/layers/mixed_conv2d.py b/timm/models/layers/mixed_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..53d650cdd00c20397928817f48f08dd750d45284 --- /dev/null +++ b/timm/models/layers/mixed_conv2d.py @@ -0,0 +1,51 @@ +""" PyTorch Mixed Convolution + +Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv2d_same import create_conv2d_pad + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = out_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py new file mode 100644 index 0000000000000000000000000000000000000000..bddf9b26d48d92108d6b2831e801bc08598d2bf0 --- /dev/null +++ b/timm/models/layers/norm_act.py @@ -0,0 +1,86 @@ +""" Normalization + Activation Layers +""" +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .create_act import get_act_layer + + +class BatchNormAct2d(nn.BatchNorm2d): + """BatchNorm + Activation + + This module performs BatchNorm + Activation in a manner that will remain backwards + compatible with weights trained with separate bn, act. This is why we inherit from BN + instead of composing it as a .bn member. + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = None + + def _forward_jit(self, x): + """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function + """ + # exponential_average_factor is self.momentum set to + # (when it is available) only so that if gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + x = F.batch_norm( + x, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, self.eps) + return x + + @torch.jit.ignore + def _forward_python(self, x): + return super(BatchNormAct2d, self).forward(x) + + def forward(self, x): + # FIXME cannot call parent forward() and maintain jit.script compatibility? + if torch.jit.is_scripting(): + x = self._forward_jit(x) + else: + x = self._forward_python(x) + if self.act is not None: + x = self.act(x) + return x + + +class GroupNormAct(nn.GroupNorm): + + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + if act_layer is not None and apply_act: + self.act = act_layer(inplace=inplace) + else: + self.act = None + + def forward(self, x): + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + if self.act is not None: + x = self.act(x) + return x diff --git a/timm/models/layers/padding.py b/timm/models/layers/padding.py new file mode 100644 index 0000000000000000000000000000000000000000..34afc37c6c59c8782ad29c7a779f58177011f891 --- /dev/null +++ b/timm/models/layers/padding.py @@ -0,0 +1,56 @@ +""" Padding Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +from typing import List, Tuple + +import torch.nn.functional as F + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) + return x + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic diff --git a/timm/models/layers/pool2d_same.py b/timm/models/layers/pool2d_same.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcd0f1f75c2502816cb931d507d4fbf418002f0 --- /dev/null +++ b/timm/models/layers/pool2d_same.py @@ -0,0 +1,71 @@ +""" AvgPool2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple, Optional + +from .helpers import to_2tuple +from .padding import pad_same, get_padding_value + + +def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + ceil_mode: bool = False, count_include_pad: bool = True): + # FIXME how to deal with count_include_pad vs not for external padding? + x = pad_same(x, kernel_size, stride) + return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + +class AvgPool2dSame(nn.AvgPool2d): + """ Tensorflow like 'SAME' wrapper for 2D average pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + def forward(self, x): + return avg_pool2d_same( + x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) + + +def max_pool2d_same( + x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + dilation: List[int] = (1, 1), ceil_mode: bool = False): + x = pad_same(x, kernel_size, stride, value=-float('inf')) + return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) + + +class MaxPool2dSame(nn.MaxPool2d): + """ Tensorflow like 'SAME' wrapper for 2D max pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) + + def forward(self, x): + return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) + + +def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): + stride = stride or kernel_size + padding = kwargs.pop('padding', '') + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) + if is_dynamic: + if pool_type == 'avg': + return AvgPool2dSame(kernel_size, stride=stride, **kwargs) + elif pool_type == 'max': + return MaxPool2dSame(kernel_size, stride=stride, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' + else: + if pool_type == 'avg': + return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + elif pool_type == 'max': + return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py new file mode 100644 index 0000000000000000000000000000000000000000..a896fb71ba09092f77fd6dd8bf36d015f4aa2463 --- /dev/null +++ b/timm/models/layers/se.py @@ -0,0 +1,36 @@ +from torch import nn as nn +from .create_act import create_act_layer + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None, + gate_layer='sigmoid'): + super(SEModule, self).__init__() + reduction_channels = reduction_channels or max(channels // reduction, min_channels) + self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(x_se) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +class EffectiveSEModule(nn.Module): + """ 'Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + def __init__(self, channels, gate_layer='hard_sigmoid'): + super(EffectiveSEModule, self).__init__() + self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) + self.gate = create_act_layer(gate_layer, inplace=True) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.fc(x_se) + return x * self.gate(x_se) diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..10bfd0e0d4e1c6e4dce1e69305ff990ddf85cb6f --- /dev/null +++ b/timm/models/layers/selective_kernel.py @@ -0,0 +1,118 @@ +""" Selective Kernel Convolution/Attention + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn + +from .conv_bn_act import ConvBnAct + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + """ Selective Kernel Attention Module + + Selective Kernel attention mechanism factored out into its own module. + + """ + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = x.sum(1).mean((2, 3), keepdim=True) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): + """ Selective Kernel Convolution Module + + As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. + + Largest change is the input split, which divides the input channels across each convolution path, this can + be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps + the parameter count from ballooning when the convolutions themselves don't have groups, but still provides + a noteworthy increase in performance over similar param count models without this attention layer. -Ross W + + Args: + in_channels (int): module input (feature) channel count + out_channels (int): module output (feature) channel count + kernel_size (int, list): kernel size for each convolution branch + stride (int): stride for convolutions + dilation (int): dilation for module as a whole, impacts dilation of each branch + groups (int): number of groups for each branch + attn_reduction (int, float): reduction factor for attention features + min_attn_channels (int): minimum attention feature channels + keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations + split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, + can be viewed as grouping by path, output expands to module out_channels count + drop_block (nn.Module): drop block module + act_layer (nn.Module): activation layer to use + norm_layer (nn.Module): batchnorm/norm layer to use + """ + super(SelectiveKernelConv, self).__init__() + kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + self.num_paths = len(kernel_size) + self.in_channels = in_channels + self.out_channels = out_channels + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths + groups = min(out_channels, groups) + + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) + + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + self.drop_block = drop_block + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + x = torch.sum(x, dim=1) + return x diff --git a/timm/models/layers/separable_conv.py b/timm/models/layers/separable_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..e949ea437191dc1ab7e321664226b3212d0b9099 --- /dev/null +++ b/timm/models/layers/separable_conv.py @@ -0,0 +1,74 @@ +""" Depthwise Separable Conv Modules + +Basic DWS convs. Other variations of DWS exist with batch norm or activations between the +DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import convert_norm_act_type + + +class SeparableConvBnAct(nn.Module): + """ Separable Conv w/ trailing Norm and Activation + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + act_layer=nn.ReLU, apply_act=True, drop_block=None): + super(SeparableConvBnAct, self).__init__() + norm_kwargs = norm_kwargs or {} + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + if self.bn is not None: + x = self.bn(x) + return x + + +class SeparableConv2d(nn.Module): + """ Separable Conv + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1): + super(SeparableConv2d, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + return x diff --git a/timm/models/layers/space_to_depth.py b/timm/models/layers/space_to_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e8e0b2a486d51fe3e4ab0472d89b7f1b92e1dc --- /dev/null +++ b/timm/models/layers/space_to_depth.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + + +class SpaceToDepth(nn.Module): + def __init__(self, block_size=4): + super().__init__() + assert block_size == 4 + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) + return x + + +@torch.jit.script +class SpaceToDepthJit(object): + def __call__(self, x: torch.Tensor): + # assuming hard-coded that block_size==4 for acceleration + N, C, H, W = x.size() + x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) + return x + + +class SpaceToDepthModule(nn.Module): + def __init__(self, no_jit=False): + super().__init__() + if not no_jit: + self.op = SpaceToDepthJit() + else: + self.op = SpaceToDepth() + + def forward(self, x): + return self.op(x) + + +class DepthToSpace(nn.Module): + + def __init__(self, block_size): + super().__init__() + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) + x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) + return x diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..5615aa0b391653f2a93b591820deca7d5b17f115 --- /dev/null +++ b/timm/models/layers/split_attn.py @@ -0,0 +1,88 @@ +""" Split Attention Conv2d (for ResNeSt Models) + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, performance, and consistency with timm by Ross Wightman +""" +import torch +import torch.nn.functional as F +from torch import nn + + +class RadixSoftmax(nn.Module): + def __init__(self, radix, cardinality): + super(RadixSoftmax, self).__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttnConv2d(nn.Module): + """Split-Attention Conv2d + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, + act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): + super(SplitAttnConv2d, self).__init__() + self.radix = radix + self.drop_block = drop_block + mid_chs = out_channels * radix + attn_chs = max(in_channels * radix // reduction_factor, 32) + + self.conv = nn.Conv2d( + in_channels, mid_chs, kernel_size, stride, padding, dilation, + groups=groups * radix, bias=bias, **kwargs) + self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None + self.act0 = act_layer(inplace=True) + self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) + self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None + self.act1 = act_layer(inplace=True) + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) + self.rsoftmax = RadixSoftmax(radix, groups) + + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.fc1.out_channels + + def forward(self, x): + x = self.conv(x) + if self.bn0 is not None: + x = self.bn0(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act0(x) + + B, RC, H, W = x.shape + if self.radix > 1: + x = x.reshape((B, self.radix, RC // self.radix, H, W)) + x_gap = x.sum(dim=1) + else: + x_gap = x + x_gap = F.adaptive_avg_pool2d(x_gap, 1) + x_gap = self.fc1(x_gap) + if self.bn1 is not None: + x_gap = self.bn1(x_gap) + x_gap = self.act1(x_gap) + x_attn = self.fc2(x_gap) + + x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) + if self.radix > 1: + out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) + else: + out = x * x_attn + return out.contiguous() diff --git a/timm/models/layers/split_batchnorm.py b/timm/models/layers/split_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..830781b335161f8d6dd74c9458070bb1fa88a918 --- /dev/null +++ b/timm/models/layers/split_batchnorm.py @@ -0,0 +1,75 @@ +""" Split BatchNorm + +A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through +a separate BN layer. The first split is passed through the parent BN layers with weight/bias +keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' +namespace. + +This allows easily removing the auxiliary BN layers after training to efficiently +achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, +'Disentangled Learning via An Auxiliary BN' + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn + + +class SplitBatchNorm2d(torch.nn.BatchNorm2d): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True, num_splits=2): + super().__init__(num_features, eps, momentum, affine, track_running_stats) + assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' + self.num_splits = num_splits + self.aux_bn = nn.ModuleList([ + nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) + + def forward(self, input: torch.Tensor): + if self.training: # aux BN only relevant while training + split_size = input.shape[0] // self.num_splits + assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" + split_input = input.split(split_size) + x = [super().forward(split_input[0])] + for i, a in enumerate(self.aux_bn): + x.append(a(split_input[i + 1])) + return torch.cat(x, dim=0) + else: + return super().forward(input) + + +def convert_splitbn_model(module, num_splits=2): + """ + Recursively traverse module and its children to replace all instances of + ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. + Args: + module (torch.nn.Module): input module + num_splits: number of separate batchnorm layers to split input across + Example:: + >>> # model is an instance of torch.nn.Module + >>> model = timm.models.convert_splitbn_model(model, num_splits=2) + """ + mod = module + if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): + return module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + mod = SplitBatchNorm2d( + module.num_features, module.eps, module.momentum, module.affine, + module.track_running_stats, num_splits=num_splits) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + mod.num_batches_tracked = module.num_batches_tracked + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + for aux in mod.aux_bn: + aux.running_mean = module.running_mean.clone() + aux.running_var = module.running_var.clone() + aux.num_batches_tracked = module.num_batches_tracked.clone() + if module.affine: + aux.weight.data = module.weight.data.clone().detach() + aux.bias.data = module.bias.data.clone().detach() + for name, child in module.named_children(): + mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) + del module + return mod diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..93ece264f5d60d4ebfdb160ae38b06c88fd789bf --- /dev/null +++ b/timm/models/layers/test_time_pool.py @@ -0,0 +1,49 @@ +""" Test Time Pooling (Average-Max Pool) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import logging +from torch import nn +import torch.nn.functional as F + +from .adaptive_avgmax_pool import adaptive_avgmax_pool2d + + +_logger = logging.getLogger(__name__) + + +class TestTimePoolHead(nn.Module): + def __init__(self, base, original_pool=7): + super(TestTimePoolHead, self).__init__() + self.base = base + self.original_pool = original_pool + base_fc = self.base.get_classifier() + if isinstance(base_fc, nn.Conv2d): + self.fc = base_fc + else: + self.fc = nn.Conv2d( + self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) + self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) + self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) + self.base.reset_classifier(0) # delete original fc layer + + def forward(self, x): + x = self.base.forward_features(x) + x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) + x = self.fc(x) + x = adaptive_avgmax_pool2d(x, 1) + return x.view(x.size(0), -1) + + +def apply_test_time_pool(model, config): + test_time_pool = False + if not hasattr(model, 'default_cfg') or not model.default_cfg: + return model, False + if (config['input_size'][-1] > model.default_cfg['input_size'][-1] and + config['input_size'][-2] > model.default_cfg['input_size'][-2]): + _logger.info('Target input size %s > pretrained default %s, using test time pooling' % + (str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:]))) + model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) + test_time_pool = True + return model, test_time_pool diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..d731029ff1708776758d9af60e5b205818940a6d --- /dev/null +++ b/timm/models/layers/weight_init.py @@ -0,0 +1,60 @@ +import torch +import math +import warnings + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..8a48ce728b065033470b297f83378d456554c4f4 --- /dev/null +++ b/timm/models/mobilenetv3.py @@ -0,0 +1,444 @@ + +""" MobileNet V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import List + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights +from .features import FeatureInfo, FeatureHooks +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid +from .registry import register_model + +__all__ = ['MobileNetV3'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'mobilenetv3_large_075': _cfg(url=''), + 'mobilenetv3_large_100': _cfg( + interpolation='bicubic', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'), + 'mobilenetv3_small_075': _cfg(url=''), + 'mobilenetv3_small_100': _cfg(url=''), + 'mobilenetv3_rw': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + interpolation='bicubic'), + 'tf_mobilenetv3_large_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_minimal_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_100': _cfg( + url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_minimal_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), +} + +_DEBUG = False + + +class MobileNetV3(nn.Module): + """ MobiletNet-V3 + + Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific + 'efficient head', where global pooling is done before the head convolution without a final batch-norm + layer before the classifier. + + Paper: https://arxiv.org/abs/1905.02244 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, + channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): + super(MobileNetV3, self).__init__() + + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + + # Stem + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = builder.features + head_chs = builder.in_chs + + # Head + Pooling + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + num_pooled_chs = head_chs * self.global_pool.feat_mult() + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + efficientnet_init_weights(self) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([self.global_pool, self.conv_head, self.act2]) + layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + # cannot meaningfully change pooling of efficient head after creation + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + if not self.global_pool.is_identity(): + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +class MobileNetV3Features(nn.Module): + """ MobileNetV3 Feature Extractor + + A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation + and object detection models. + """ + + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', + in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', + act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(MobileNetV3Features, self).__init__() + norm_kwargs = norm_kwargs or {} + self.drop_rate = drop_rate + + # Stem + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = FeatureInfo(builder.features, out_indices) + self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} + + efficientnet_init_weights(self) + + # Register feature extraction hooks with FeatureHooks helper + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + def forward(self, x) -> List[torch.Tensor]: + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + if self.feature_hooks is None: + features = [] + if 0 in self._stage_out_idx: + features.append(x) # add stem out + for i, b in enumerate(self.blocks): + x = b(x) + if i + 1 in self._stage_out_idx: + features.append(x) + return features + else: + self.blocks(x) + out = self.feature_hooks.get_output(x.device) + return list(out.values()) + + +def _create_mnv3(model_kwargs, variant, pretrained=False): + features_only = False + model_cls = MobileNetV3 + if model_kwargs.pop('features_only', False): + features_only = True + model_kwargs.pop('num_classes', 0) + model_kwargs.pop('num_features', 0) + model_kwargs.pop('head_conv', None) + model_kwargs.pop('head_bias', None) + model_cls = MobileNetV3Features + model = build_model_with_cfg( + model_cls, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1), + **kwargs, + ) + model = _create_mnv3(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = resolve_act_layer(kwargs, 'hard_swish') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = resolve_act_layer(kwargs, 'hard_swish') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=act_layer, + se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), + **kwargs, + ) + model = _create_mnv3(model_kwargs, variant, pretrained) + return model + + +@register_model +def mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_rw(pretrained=False, **kwargs): + """ MobileNet V3 """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..18b3725fc2bd9f879213854dc2a2e6e751ac87f5 --- /dev/null +++ b/timm/models/nasnet.py @@ -0,0 +1,562 @@ +""" + +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier +from .registry import register_model + +__all__ = ['NASNetALarge'] + +default_cfgs = { + 'nasnetalarge': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth', + 'input_size': (3, 331, 331), + 'pool_size': (11, 11), + 'crop_pct': 0.911, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1001, + 'first_conv': 'conv0.conv', + 'classifier': 'last_linear', + }, +} + + +class ActConvBn(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): + super(ActConvBn, self).__init__() + self.act = nn.ReLU() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1) + + def forward(self, x): + x = self.act(x) + x = self.conv(x) + x = self.bn(x) + return x + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = create_conv2d( + in_channels, in_channels, kernel_size=kernel_size, + stride=stride, padding=padding, groups=in_channels) + self.pointwise_conv2d = create_conv2d( + in_channels, out_channels, kernel_size=1, padding=0) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False): + super(BranchSeparables, self).__init__() + middle_channels = out_channels if stem_cell else in_channels + self.act_1 = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1) + self.act_2 = nn.ReLU(inplace=True) + self.separable_2 = SeparableConv2d( + middle_channels, out_channels, kernel_size, stride=1, padding=pad_type) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1) + + def forward(self, x): + x = self.act_1(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.act_2(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class CellStem0(nn.Module): + def __init__(self, stem_size, num_channels=42, pad_type=''): + super(CellStem0, self).__init__() + self.num_channels = num_channels + self.stem_size = stem_size + self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1) + + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x): + x1 = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x1) + x_comb_iter_0_right = self.comb_iter_0_right(x) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x1) + x_comb_iter_1_right = self.comb_iter_1_right(x) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x1) + x_comb_iter_2_right = self.comb_iter_2_right(x) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x1) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class CellStem1(nn.Module): + + def __init__(self, stem_size, num_channels, pad_type=''): + super(CellStem1, self).__init__() + self.num_channels = num_channels + self.stem_size = stem_size + self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1) + + self.act = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) + + self.path_2 = nn.Sequential() + self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1) + + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x_conv0, x_stem_0): + x_left = self.conv_1x1(x_stem_0) + + x_relu = self.act(x_conv0) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2(x_relu) + # final path + x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_right) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_left) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_left) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class FirstCell(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(FirstCell, self).__init__() + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1) + + self.act = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False)) + + self.path_2 = nn.Sequential() + self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + def forward(self, x, x_prev): + x_relu = self.act(x_prev) + x_path1 = self.path_1(x_relu) + x_path2 = self.path_2(x_relu) + x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NormalCell(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(NormalCell, self).__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type) + + self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell0(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(ReductionCell0, self).__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell1(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(ReductionCell1, self).__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NASNetALarge(nn.Module): + """NASNetALarge (6 @ 4032) """ + + def __init__(self, num_classes=1000, in_chans=1, stem_size=96, channel_multiplier=2, + num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'): + super(NASNetALarge, self).__init__() + self.num_classes = num_classes + self.stem_size = stem_size + self.num_features = num_features + self.channel_multiplier = channel_multiplier + self.drop_rate = drop_rate + assert output_stride == 32 + + channels = self.num_features // 24 + # 24 is default value for the architecture + + self.conv0 = ConvBnAct( + in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, + norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None) + + self.cell_stem_0 = CellStem0( + self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type) + self.cell_stem_1 = CellStem1( + self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type) + + self.cell_0 = FirstCell( + in_chs_left=channels, out_chs_left=channels // 2, + in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_1 = NormalCell( + in_chs_left=2 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_2 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_3 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_4 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_5 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + + self.reduction_cell_0 = ReductionCell0( + in_chs_left=6 * channels, out_chs_left=2 * channels, + in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_6 = FirstCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_7 = NormalCell( + in_chs_left=8 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_8 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_9 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_10 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_11 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + + self.reduction_cell_1 = ReductionCell1( + in_chs_left=12 * channels, out_chs_left=4 * channels, + in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_12 = FirstCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_13 = NormalCell( + in_chs_left=16 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_14 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_15 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_16 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_17 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.act = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=96, reduction=2, module='conv0'), + dict(num_chs=168, reduction=4, module='cell_stem_1.conv_1x1.act'), + dict(num_chs=1008, reduction=8, module='reduction_cell_0.conv_1x1.act'), + dict(num_chs=2016, reduction=16, module='reduction_cell_1.conv_1x1.act'), + dict(num_chs=4032, reduction=32, module='act'), + ] + + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x_conv0 = self.conv0(x) + + x_stem_0 = self.cell_stem_0(x_conv0) + x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) + + x_cell_0 = self.cell_0(x_stem_1, x_stem_0) + x_cell_1 = self.cell_1(x_cell_0, x_stem_1) + x_cell_2 = self.cell_2(x_cell_1, x_cell_0) + x_cell_3 = self.cell_3(x_cell_2, x_cell_1) + x_cell_4 = self.cell_4(x_cell_3, x_cell_2) + x_cell_5 = self.cell_5(x_cell_4, x_cell_3) + + x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4) + x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4) + x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0) + x_cell_8 = self.cell_8(x_cell_7, x_cell_6) + x_cell_9 = self.cell_9(x_cell_8, x_cell_7) + x_cell_10 = self.cell_10(x_cell_9, x_cell_8) + x_cell_11 = self.cell_11(x_cell_10, x_cell_9) + + x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10) + x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10) + x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1) + x_cell_14 = self.cell_14(x_cell_13, x_cell_12) + x_cell_15 = self.cell_15(x_cell_14, x_cell_13) + x_cell_16 = self.cell_16(x_cell_15, x_cell_14) + x_cell_17 = self.cell_17(x_cell_16, x_cell_15) + x = self.act(x_cell_17) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + +def _create_nasnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model + **kwargs) + + +@register_model +def nasnetalarge(pretrained=False, **kwargs): + """NASNet-A large model architecture. + """ + model_kwargs = dict(pad_type='same', **kwargs) + return _create_nasnet('nasnetalarge', pretrained, **model_kwargs) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1e177f5a8e31981b681c0293bba274f861fb6f --- /dev/null +++ b/timm/models/pnasnet.py @@ -0,0 +1,347 @@ +""" + pnasnet5large implementation grabbed from Cadene's pretrained models + Additional credit to https://github.com/creafz + + https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py + +""" +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier +from .registry import register_model + +__all__ = ['PNASNet5Large'] + +default_cfgs = { + 'pnasnet5large': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth', + 'input_size': (3, 331, 331), + 'pool_size': (11, 11), + 'crop_pct': 0.911, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1001, + 'first_conv': 'conv_0.conv', + 'classifier': 'last_linear', + }, +} + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = create_conv2d( + in_channels, in_channels, kernel_size=kernel_size, + stride=stride, padding=padding, groups=in_channels) + self.pointwise_conv2d = create_conv2d( + in_channels, out_channels, kernel_size=1, padding=padding) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''): + super(BranchSeparables, self).__init__() + middle_channels = out_channels if stem_cell else in_channels + self.act_1 = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, middle_channels, kernel_size, stride=stride, padding=padding) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) + self.act_2 = nn.ReLU() + self.separable_2 = SeparableConv2d( + middle_channels, out_channels, kernel_size, stride=1, padding=padding) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.act_1(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.act_2(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class ActConvBn(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): + super(ActConvBn, self).__init__() + self.act = nn.ReLU() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.act(x) + x = self.conv(x) + x = self.bn(x) + return x + + +class FactorizedReduction(nn.Module): + + def __init__(self, in_channels, out_channels, padding=''): + super(FactorizedReduction, self).__init__() + self.act = nn.ReLU() + self.path_1 = nn.Sequential(OrderedDict([ + ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), + ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), + ])) + self.path_2 = nn.Sequential(OrderedDict([ + ('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift + ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), + ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), + ])) + self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.act(x) + x_path1 = self.path_1(x) + x_path2 = self.path_2(x) + out = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + return out + + +class CellBase(nn.Module): + + def cell_forward(self, x_left, x_right): + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2) + x_comb_iter_3_right = self.comb_iter_3_right(x_right) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_left) + if self.comb_iter_4_right is not None: + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + else: + x_comb_iter_4_right = x_right + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class CellStem0(CellBase): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(CellStem0, self).__init__() + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables( + in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type) + self.comb_iter_0_right = nn.Sequential(OrderedDict([ + ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)), + ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)), + ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)), + ])) + + self.comb_iter_1_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type) + self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type) + + self.comb_iter_2_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type) + self.comb_iter_2_right = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type) + + self.comb_iter_3_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, padding=pad_type) + self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables( + in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type) + self.comb_iter_4_right = ActConvBn( + out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type) + + def forward(self, x_left): + x_right = self.conv_1x1(x_left) + x_out = self.cell_forward(x_left, x_right) + return x_out + + +class Cell(CellBase): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type='', + is_reduction=False, match_prev_layer_dims=False): + super(Cell, self).__init__() + + # If `is_reduction` is set to `True` stride 2 is used for + # convolution and pooling layers to reduce the spatial size of + # the output of a cell approximately by a factor of 2. + stride = 2 if is_reduction else 1 + + # If `match_prev_layer_dimensions` is set to `True` + # `FactorizedReduction` is used to reduce the spatial size + # of the left input of a cell approximately by a factor of 2. + self.match_prev_layer_dimensions = match_prev_layer_dims + if match_prev_layer_dims: + self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type) + else: + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables( + out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type) + self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type) + + self.comb_iter_1_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type) + self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type) + + self.comb_iter_2_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type) + self.comb_iter_2_right = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type) + + self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3) + self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables( + out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type) + if is_reduction: + self.comb_iter_4_right = ActConvBn( + out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type) + else: + self.comb_iter_4_right = None + + def forward(self, x_left, x_right): + x_left = self.conv_prev_1x1(x_left) + x_right = self.conv_1x1(x_right) + x_out = self.cell_forward(x_left, x_right) + return x_out + + +class PNASNet5Large(nn.Module): + def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''): + super(PNASNet5Large, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.num_features = 4320 + assert output_stride == 32 + + self.conv_0 = ConvBnAct( + in_chans, 96, kernel_size=3, stride=2, padding=0, + norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None) + + self.cell_stem_0 = CellStem0( + in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type) + + self.cell_stem_1 = Cell( + in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type, + match_prev_layer_dims=True, is_reduction=True) + self.cell_0 = Cell( + in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type, + match_prev_layer_dims=True) + self.cell_1 = Cell( + in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + self.cell_2 = Cell( + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + self.cell_3 = Cell( + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + + self.cell_4 = Cell( + in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type, + is_reduction=True) + self.cell_5 = Cell( + in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, + match_prev_layer_dims=True) + self.cell_6 = Cell( + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) + self.cell_7 = Cell( + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) + + self.cell_8 = Cell( + in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type, + is_reduction=True) + self.cell_9 = Cell( + in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, + match_prev_layer_dims=True) + self.cell_10 = Cell( + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) + self.cell_11 = Cell( + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) + self.act = nn.ReLU() + self.feature_info = [ + dict(num_chs=96, reduction=2, module='conv_0'), + dict(num_chs=270, reduction=4, module='cell_stem_1.conv_1x1.act'), + dict(num_chs=1080, reduction=8, module='cell_4.conv_1x1.act'), + dict(num_chs=2160, reduction=16, module='cell_8.conv_1x1.act'), + dict(num_chs=4320, reduction=32, module='act'), + ] + + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x_conv_0 = self.conv_0(x) + x_stem_0 = self.cell_stem_0(x_conv_0) + x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) + x_cell_0 = self.cell_0(x_stem_0, x_stem_1) + x_cell_1 = self.cell_1(x_stem_1, x_cell_0) + x_cell_2 = self.cell_2(x_cell_0, x_cell_1) + x_cell_3 = self.cell_3(x_cell_1, x_cell_2) + x_cell_4 = self.cell_4(x_cell_2, x_cell_3) + x_cell_5 = self.cell_5(x_cell_3, x_cell_4) + x_cell_6 = self.cell_6(x_cell_4, x_cell_5) + x_cell_7 = self.cell_7(x_cell_5, x_cell_6) + x_cell_8 = self.cell_8(x_cell_6, x_cell_7) + x_cell_9 = self.cell_9(x_cell_7, x_cell_8) + x_cell_10 = self.cell_10(x_cell_8, x_cell_9) + x_cell_11 = self.cell_11(x_cell_9, x_cell_10) + x = self.act(x_cell_11) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + +def _create_pnasnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model + **kwargs) + + +@register_model +def pnasnet5large(pretrained=False, **kwargs): + r"""PNASNet-5 model architecture from the + `"Progressive Neural Architecture Search" + `_ paper. + """ + model_kwargs = dict(pad_type='same', **kwargs) + return _create_pnasnet('pnasnet5large', pretrained, **model_kwargs) diff --git a/timm/models/pruned/ecaresnet101d_pruned.txt b/timm/models/pruned/ecaresnet101d_pruned.txt new file mode 100644 index 0000000000000000000000000000000000000000..2589b2f9dd3f0d1e02e1d5ddc1fbcd5c143e02c6 --- /dev/null +++ b/timm/models/pruned/ecaresnet101d_pruned.txt @@ -0,0 +1 @@ +conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[45, 64, 1, 1]***layer1.0.bn1.weight:[45]***layer1.0.conv2.weight:[25, 45, 3, 3]***layer1.0.bn2.weight:[25]***layer1.0.conv3.weight:[26, 25, 1, 1]***layer1.0.bn3.weight:[26]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[26, 64, 1, 1]***layer1.0.downsample.2.weight:[26]***layer1.1.conv1.weight:[53, 26, 1, 1]***layer1.1.bn1.weight:[53]***layer1.1.conv2.weight:[20, 53, 3, 3]***layer1.1.bn2.weight:[20]***layer1.1.conv3.weight:[26, 20, 1, 1]***layer1.1.bn3.weight:[26]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[60, 26, 1, 1]***layer1.2.bn1.weight:[60]***layer1.2.conv2.weight:[27, 60, 3, 3]***layer1.2.bn2.weight:[27]***layer1.2.conv3.weight:[26, 27, 1, 1]***layer1.2.bn3.weight:[26]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[81, 26, 1, 1]***layer2.0.bn1.weight:[81]***layer2.0.conv2.weight:[24, 81, 3, 3]***layer2.0.bn2.weight:[24]***layer2.0.conv3.weight:[142, 24, 1, 1]***layer2.0.bn3.weight:[142]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[142, 26, 1, 1]***layer2.0.downsample.2.weight:[142]***layer2.1.conv1.weight:[93, 142, 1, 1]***layer2.1.bn1.weight:[93]***layer2.1.conv2.weight:[49, 93, 3, 3]***layer2.1.bn2.weight:[49]***layer2.1.conv3.weight:[142, 49, 1, 1]***layer2.1.bn3.weight:[142]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[102, 142, 1, 1]***layer2.2.bn1.weight:[102]***layer2.2.conv2.weight:[54, 102, 3, 3]***layer2.2.bn2.weight:[54]***layer2.2.conv3.weight:[142, 54, 1, 1]***layer2.2.bn3.weight:[142]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[122, 142, 1, 1]***layer2.3.bn1.weight:[122]***layer2.3.conv2.weight:[78, 122, 3, 3]***layer2.3.bn2.weight:[78]***layer2.3.conv3.weight:[142, 78, 1, 1]***layer2.3.bn3.weight:[142]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[101, 142, 1, 1]***layer3.0.bn1.weight:[101]***layer3.0.conv2.weight:[25, 101, 3, 3]***layer3.0.bn2.weight:[25]***layer3.0.conv3.weight:[278, 25, 1, 1]***layer3.0.bn3.weight:[278]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[278, 142, 1, 1]***layer3.0.downsample.2.weight:[278]***layer3.1.conv1.weight:[239, 278, 1, 1]***layer3.1.bn1.weight:[239]***layer3.1.conv2.weight:[160, 239, 3, 3]***layer3.1.bn2.weight:[160]***layer3.1.conv3.weight:[278, 160, 1, 1]***layer3.1.bn3.weight:[278]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[234, 278, 1, 1]***layer3.2.bn1.weight:[234]***layer3.2.conv2.weight:[156, 234, 3, 3]***layer3.2.bn2.weight:[156]***layer3.2.conv3.weight:[278, 156, 1, 1]***layer3.2.bn3.weight:[278]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[250, 278, 1, 1]***layer3.3.bn1.weight:[250]***layer3.3.conv2.weight:[176, 250, 3, 3]***layer3.3.bn2.weight:[176]***layer3.3.conv3.weight:[278, 176, 1, 1]***layer3.3.bn3.weight:[278]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[253, 278, 1, 1]***layer3.4.bn1.weight:[253]***layer3.4.conv2.weight:[191, 253, 3, 3]***layer3.4.bn2.weight:[191]***layer3.4.conv3.weight:[278, 191, 1, 1]***layer3.4.bn3.weight:[278]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[251, 278, 1, 1]***layer3.5.bn1.weight:[251]***layer3.5.conv2.weight:[175, 251, 3, 3]***layer3.5.bn2.weight:[175]***layer3.5.conv3.weight:[278, 175, 1, 1]***layer3.5.bn3.weight:[278]***layer3.5.se.conv.weight:[1, 1, 5]***layer3.6.conv1.weight:[230, 278, 1, 1]***layer3.6.bn1.weight:[230]***layer3.6.conv2.weight:[128, 230, 3, 3]***layer3.6.bn2.weight:[128]***layer3.6.conv3.weight:[278, 128, 1, 1]***layer3.6.bn3.weight:[278]***layer3.6.se.conv.weight:[1, 1, 5]***layer3.7.conv1.weight:[244, 278, 1, 1]***layer3.7.bn1.weight:[244]***layer3.7.conv2.weight:[154, 244, 3, 3]***layer3.7.bn2.weight:[154]***layer3.7.conv3.weight:[278, 154, 1, 1]***layer3.7.bn3.weight:[278]***layer3.7.se.conv.weight:[1, 1, 5]***layer3.8.conv1.weight:[244, 278, 1, 1]***layer3.8.bn1.weight:[244]***layer3.8.conv2.weight:[159, 244, 3, 3]***layer3.8.bn2.weight:[159]***layer3.8.conv3.weight:[278, 159, 1, 1]***layer3.8.bn3.weight:[278]***layer3.8.se.conv.weight:[1, 1, 5]***layer3.9.conv1.weight:[238, 278, 1, 1]***layer3.9.bn1.weight:[238]***layer3.9.conv2.weight:[97, 238, 3, 3]***layer3.9.bn2.weight:[97]***layer3.9.conv3.weight:[278, 97, 1, 1]***layer3.9.bn3.weight:[278]***layer3.9.se.conv.weight:[1, 1, 5]***layer3.10.conv1.weight:[244, 278, 1, 1]***layer3.10.bn1.weight:[244]***layer3.10.conv2.weight:[149, 244, 3, 3]***layer3.10.bn2.weight:[149]***layer3.10.conv3.weight:[278, 149, 1, 1]***layer3.10.bn3.weight:[278]***layer3.10.se.conv.weight:[1, 1, 5]***layer3.11.conv1.weight:[253, 278, 1, 1]***layer3.11.bn1.weight:[253]***layer3.11.conv2.weight:[181, 253, 3, 3]***layer3.11.bn2.weight:[181]***layer3.11.conv3.weight:[278, 181, 1, 1]***layer3.11.bn3.weight:[278]***layer3.11.se.conv.weight:[1, 1, 5]***layer3.12.conv1.weight:[245, 278, 1, 1]***layer3.12.bn1.weight:[245]***layer3.12.conv2.weight:[119, 245, 3, 3]***layer3.12.bn2.weight:[119]***layer3.12.conv3.weight:[278, 119, 1, 1]***layer3.12.bn3.weight:[278]***layer3.12.se.conv.weight:[1, 1, 5]***layer3.13.conv1.weight:[255, 278, 1, 1]***layer3.13.bn1.weight:[255]***layer3.13.conv2.weight:[216, 255, 3, 3]***layer3.13.bn2.weight:[216]***layer3.13.conv3.weight:[278, 216, 1, 1]***layer3.13.bn3.weight:[278]***layer3.13.se.conv.weight:[1, 1, 5]***layer3.14.conv1.weight:[256, 278, 1, 1]***layer3.14.bn1.weight:[256]***layer3.14.conv2.weight:[201, 256, 3, 3]***layer3.14.bn2.weight:[201]***layer3.14.conv3.weight:[278, 201, 1, 1]***layer3.14.bn3.weight:[278]***layer3.14.se.conv.weight:[1, 1, 5]***layer3.15.conv1.weight:[253, 278, 1, 1]***layer3.15.bn1.weight:[253]***layer3.15.conv2.weight:[149, 253, 3, 3]***layer3.15.bn2.weight:[149]***layer3.15.conv3.weight:[278, 149, 1, 1]***layer3.15.bn3.weight:[278]***layer3.15.se.conv.weight:[1, 1, 5]***layer3.16.conv1.weight:[254, 278, 1, 1]***layer3.16.bn1.weight:[254]***layer3.16.conv2.weight:[141, 254, 3, 3]***layer3.16.bn2.weight:[141]***layer3.16.conv3.weight:[278, 141, 1, 1]***layer3.16.bn3.weight:[278]***layer3.16.se.conv.weight:[1, 1, 5]***layer3.17.conv1.weight:[256, 278, 1, 1]***layer3.17.bn1.weight:[256]***layer3.17.conv2.weight:[190, 256, 3, 3]***layer3.17.bn2.weight:[190]***layer3.17.conv3.weight:[278, 190, 1, 1]***layer3.17.bn3.weight:[278]***layer3.17.se.conv.weight:[1, 1, 5]***layer3.18.conv1.weight:[256, 278, 1, 1]***layer3.18.bn1.weight:[256]***layer3.18.conv2.weight:[217, 256, 3, 3]***layer3.18.bn2.weight:[217]***layer3.18.conv3.weight:[278, 217, 1, 1]***layer3.18.bn3.weight:[278]***layer3.18.se.conv.weight:[1, 1, 5]***layer3.19.conv1.weight:[255, 278, 1, 1]***layer3.19.bn1.weight:[255]***layer3.19.conv2.weight:[156, 255, 3, 3]***layer3.19.bn2.weight:[156]***layer3.19.conv3.weight:[278, 156, 1, 1]***layer3.19.bn3.weight:[278]***layer3.19.se.conv.weight:[1, 1, 5]***layer3.20.conv1.weight:[256, 278, 1, 1]***layer3.20.bn1.weight:[256]***layer3.20.conv2.weight:[155, 256, 3, 3]***layer3.20.bn2.weight:[155]***layer3.20.conv3.weight:[278, 155, 1, 1]***layer3.20.bn3.weight:[278]***layer3.20.se.conv.weight:[1, 1, 5]***layer3.21.conv1.weight:[256, 278, 1, 1]***layer3.21.bn1.weight:[256]***layer3.21.conv2.weight:[232, 256, 3, 3]***layer3.21.bn2.weight:[232]***layer3.21.conv3.weight:[278, 232, 1, 1]***layer3.21.bn3.weight:[278]***layer3.21.se.conv.weight:[1, 1, 5]***layer3.22.conv1.weight:[256, 278, 1, 1]***layer3.22.bn1.weight:[256]***layer3.22.conv2.weight:[214, 256, 3, 3]***layer3.22.bn2.weight:[214]***layer3.22.conv3.weight:[278, 214, 1, 1]***layer3.22.bn3.weight:[278]***layer3.22.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[499, 278, 1, 1]***layer4.0.bn1.weight:[499]***layer4.0.conv2.weight:[289, 499, 3, 3]***layer4.0.bn2.weight:[289]***layer4.0.conv3.weight:[2042, 289, 1, 1]***layer4.0.bn3.weight:[2042]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2042, 278, 1, 1]***layer4.0.downsample.2.weight:[2042]***layer4.1.conv1.weight:[512, 2042, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[512, 512, 3, 3]***layer4.1.bn2.weight:[512]***layer4.1.conv3.weight:[2042, 512, 1, 1]***layer4.1.bn3.weight:[2042]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2042, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[502, 512, 3, 3]***layer4.2.bn2.weight:[502]***layer4.2.conv3.weight:[2042, 502, 1, 1]***layer4.2.bn3.weight:[2042]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2042]***layer1_2_conv3_M.weight:[256, 26]***layer2_3_conv3_M.weight:[512, 142]***layer3_22_conv3_M.weight:[1024, 278]***layer4_2_conv3_M.weight:[2048, 2042] \ No newline at end of file diff --git a/timm/models/pruned/ecaresnet50d_pruned.txt b/timm/models/pruned/ecaresnet50d_pruned.txt new file mode 100644 index 0000000000000000000000000000000000000000..9a8b2bf50e0631dce74d66a1a98e26cae10572a7 --- /dev/null +++ b/timm/models/pruned/ecaresnet50d_pruned.txt @@ -0,0 +1 @@ +conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022] \ No newline at end of file diff --git a/timm/models/pruned/efficientnet_b1_pruned.txt b/timm/models/pruned/efficientnet_b1_pruned.txt new file mode 100644 index 0000000000000000000000000000000000000000..0972b527612b283fd242cc5eaeb6e767ea106c66 --- /dev/null +++ b/timm/models/pruned/efficientnet_b1_pruned.txt @@ -0,0 +1 @@ +conv_stem.weight:[32, 3, 3, 3]***bn1.weight:[32]***bn1.bias:[32]***bn1.running_mean:[32]***bn1.running_var:[32]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[32, 1, 3, 3]***blocks.0.0.bn1.weight:[32]***blocks.0.0.bn1.bias:[32]***blocks.0.0.bn1.running_mean:[32]***blocks.0.0.bn1.running_var:[32]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[8, 32, 1, 1]***blocks.0.0.se.conv_reduce.bias:[8]***blocks.0.0.se.conv_expand.weight:[32, 8, 1, 1]***blocks.0.0.se.conv_expand.bias:[32]***blocks.0.0.conv_pw.weight:[16, 32, 1, 1]***blocks.0.0.bn2.weight:[16]***blocks.0.0.bn2.bias:[16]***blocks.0.0.bn2.running_mean:[16]***blocks.0.0.bn2.running_var:[16]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[16, 1, 3, 3]***blocks.0.1.bn1.weight:[16]***blocks.0.1.bn1.bias:[16]***blocks.0.1.bn1.running_mean:[16]***blocks.0.1.bn1.running_var:[16]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[4, 16, 1, 1]***blocks.0.1.se.conv_reduce.bias:[4]***blocks.0.1.se.conv_expand.weight:[16, 4, 1, 1]***blocks.0.1.se.conv_expand.bias:[16]***blocks.0.1.conv_pw.weight:[16, 16, 1, 1]***blocks.0.1.bn2.weight:[16]***blocks.0.1.bn2.bias:[16]***blocks.0.1.bn2.running_mean:[16]***blocks.0.1.bn2.running_var:[16]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[48, 16, 1, 1]***blocks.1.0.bn1.weight:[48]***blocks.1.0.bn1.bias:[48]***blocks.1.0.bn1.running_mean:[48]***blocks.1.0.bn1.running_var:[48]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[48, 1, 3, 3]***blocks.1.0.bn2.weight:[48]***blocks.1.0.bn2.bias:[48]***blocks.1.0.bn2.running_mean:[48]***blocks.1.0.bn2.running_var:[48]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[4, 48, 1, 1]***blocks.1.0.se.conv_reduce.bias:[4]***blocks.1.0.se.conv_expand.weight:[48, 4, 1, 1]***blocks.1.0.se.conv_expand.bias:[48]***blocks.1.0.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.0.bn3.weight:[12]***blocks.1.0.bn3.bias:[12]***blocks.1.0.bn3.running_mean:[12]***blocks.1.0.bn3.running_var:[12]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[62, 12, 1, 1]***blocks.1.1.bn1.weight:[62]***blocks.1.1.bn1.bias:[62]***blocks.1.1.bn1.running_mean:[62]***blocks.1.1.bn1.running_var:[62]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[62, 1, 3, 3]***blocks.1.1.bn2.weight:[62]***blocks.1.1.bn2.bias:[62]***blocks.1.1.bn2.running_mean:[62]***blocks.1.1.bn2.running_var:[62]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[6, 62, 1, 1]***blocks.1.1.se.conv_reduce.bias:[6]***blocks.1.1.se.conv_expand.weight:[62, 6, 1, 1]***blocks.1.1.se.conv_expand.bias:[62]***blocks.1.1.conv_pwl.weight:[12, 62, 1, 1]***blocks.1.1.bn3.weight:[12]***blocks.1.1.bn3.bias:[12]***blocks.1.1.bn3.running_mean:[12]***blocks.1.1.bn3.running_var:[12]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[48, 12, 1, 1]***blocks.1.2.bn1.weight:[48]***blocks.1.2.bn1.bias:[48]***blocks.1.2.bn1.running_mean:[48]***blocks.1.2.bn1.running_var:[48]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[48, 1, 3, 3]***blocks.1.2.bn2.weight:[48]***blocks.1.2.bn2.bias:[48]***blocks.1.2.bn2.running_mean:[48]***blocks.1.2.bn2.running_var:[48]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[6, 48, 1, 1]***blocks.1.2.se.conv_reduce.bias:[6]***blocks.1.2.se.conv_expand.weight:[48, 6, 1, 1]***blocks.1.2.se.conv_expand.bias:[48]***blocks.1.2.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.2.bn3.weight:[12]***blocks.1.2.bn3.bias:[12]***blocks.1.2.bn3.running_mean:[12]***blocks.1.2.bn3.running_var:[12]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[70, 12, 1, 1]***blocks.2.0.bn1.weight:[70]***blocks.2.0.bn1.bias:[70]***blocks.2.0.bn1.running_mean:[70]***blocks.2.0.bn1.running_var:[70]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[70, 1, 5, 5]***blocks.2.0.bn2.weight:[70]***blocks.2.0.bn2.bias:[70]***blocks.2.0.bn2.running_mean:[70]***blocks.2.0.bn2.running_var:[70]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[6, 70, 1, 1]***blocks.2.0.se.conv_reduce.bias:[6]***blocks.2.0.se.conv_expand.weight:[70, 6, 1, 1]***blocks.2.0.se.conv_expand.bias:[70]***blocks.2.0.conv_pwl.weight:[35, 70, 1, 1]***blocks.2.0.bn3.weight:[35]***blocks.2.0.bn3.bias:[35]***blocks.2.0.bn3.running_mean:[35]***blocks.2.0.bn3.running_var:[35]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[61, 35, 1, 1]***blocks.2.1.bn1.weight:[61]***blocks.2.1.bn1.bias:[61]***blocks.2.1.bn1.running_mean:[61]***blocks.2.1.bn1.running_var:[61]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[61, 1, 5, 5]***blocks.2.1.bn2.weight:[61]***blocks.2.1.bn2.bias:[61]***blocks.2.1.bn2.running_mean:[61]***blocks.2.1.bn2.running_var:[61]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[10, 61, 1, 1]***blocks.2.1.se.conv_reduce.bias:[10]***blocks.2.1.se.conv_expand.weight:[61, 10, 1, 1]***blocks.2.1.se.conv_expand.bias:[61]***blocks.2.1.conv_pwl.weight:[35, 61, 1, 1]***blocks.2.1.bn3.weight:[35]***blocks.2.1.bn3.bias:[35]***blocks.2.1.bn3.running_mean:[35]***blocks.2.1.bn3.running_var:[35]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[51, 35, 1, 1]***blocks.2.2.bn1.weight:[51]***blocks.2.2.bn1.bias:[51]***blocks.2.2.bn1.running_mean:[51]***blocks.2.2.bn1.running_var:[51]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[51, 1, 5, 5]***blocks.2.2.bn2.weight:[51]***blocks.2.2.bn2.bias:[51]***blocks.2.2.bn2.running_mean:[51]***blocks.2.2.bn2.running_var:[51]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[10, 51, 1, 1]***blocks.2.2.se.conv_reduce.bias:[10]***blocks.2.2.se.conv_expand.weight:[51, 10, 1, 1]***blocks.2.2.se.conv_expand.bias:[51]***blocks.2.2.conv_pwl.weight:[35, 51, 1, 1]***blocks.2.2.bn3.weight:[35]***blocks.2.2.bn3.bias:[35]***blocks.2.2.bn3.running_mean:[35]***blocks.2.2.bn3.running_var:[35]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[175, 35, 1, 1]***blocks.3.0.bn1.weight:[175]***blocks.3.0.bn1.bias:[175]***blocks.3.0.bn1.running_mean:[175]***blocks.3.0.bn1.running_var:[175]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[175, 1, 3, 3]***blocks.3.0.bn2.weight:[175]***blocks.3.0.bn2.bias:[175]***blocks.3.0.bn2.running_mean:[175]***blocks.3.0.bn2.running_var:[175]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[10, 175, 1, 1]***blocks.3.0.se.conv_reduce.bias:[10]***blocks.3.0.se.conv_expand.weight:[175, 10, 1, 1]***blocks.3.0.se.conv_expand.bias:[175]***blocks.3.0.conv_pwl.weight:[74, 175, 1, 1]***blocks.3.0.bn3.weight:[74]***blocks.3.0.bn3.bias:[74]***blocks.3.0.bn3.running_mean:[74]***blocks.3.0.bn3.running_var:[74]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[188, 74, 1, 1]***blocks.3.1.bn1.weight:[188]***blocks.3.1.bn1.bias:[188]***blocks.3.1.bn1.running_mean:[188]***blocks.3.1.bn1.running_var:[188]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[188, 1, 3, 3]***blocks.3.1.bn2.weight:[188]***blocks.3.1.bn2.bias:[188]***blocks.3.1.bn2.running_mean:[188]***blocks.3.1.bn2.running_var:[188]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[20, 188, 1, 1]***blocks.3.1.se.conv_reduce.bias:[20]***blocks.3.1.se.conv_expand.weight:[188, 20, 1, 1]***blocks.3.1.se.conv_expand.bias:[188]***blocks.3.1.conv_pwl.weight:[74, 188, 1, 1]***blocks.3.1.bn3.weight:[74]***blocks.3.1.bn3.bias:[74]***blocks.3.1.bn3.running_mean:[74]***blocks.3.1.bn3.running_var:[74]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[137, 74, 1, 1]***blocks.3.2.bn1.weight:[137]***blocks.3.2.bn1.bias:[137]***blocks.3.2.bn1.running_mean:[137]***blocks.3.2.bn1.running_var:[137]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[137, 1, 3, 3]***blocks.3.2.bn2.weight:[137]***blocks.3.2.bn2.bias:[137]***blocks.3.2.bn2.running_mean:[137]***blocks.3.2.bn2.running_var:[137]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[20, 137, 1, 1]***blocks.3.2.se.conv_reduce.bias:[20]***blocks.3.2.se.conv_expand.weight:[137, 20, 1, 1]***blocks.3.2.se.conv_expand.bias:[137]***blocks.3.2.conv_pwl.weight:[74, 137, 1, 1]***blocks.3.2.bn3.weight:[74]***blocks.3.2.bn3.bias:[74]***blocks.3.2.bn3.running_mean:[74]***blocks.3.2.bn3.running_var:[74]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[164, 74, 1, 1]***blocks.3.3.bn1.weight:[164]***blocks.3.3.bn1.bias:[164]***blocks.3.3.bn1.running_mean:[164]***blocks.3.3.bn1.running_var:[164]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[164, 1, 3, 3]***blocks.3.3.bn2.weight:[164]***blocks.3.3.bn2.bias:[164]***blocks.3.3.bn2.running_mean:[164]***blocks.3.3.bn2.running_var:[164]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[20, 164, 1, 1]***blocks.3.3.se.conv_reduce.bias:[20]***blocks.3.3.se.conv_expand.weight:[164, 20, 1, 1]***blocks.3.3.se.conv_expand.bias:[164]***blocks.3.3.conv_pwl.weight:[74, 164, 1, 1]***blocks.3.3.bn3.weight:[74]***blocks.3.3.bn3.bias:[74]***blocks.3.3.bn3.running_mean:[74]***blocks.3.3.bn3.running_var:[74]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[399, 74, 1, 1]***blocks.4.0.bn1.weight:[399]***blocks.4.0.bn1.bias:[399]***blocks.4.0.bn1.running_mean:[399]***blocks.4.0.bn1.running_var:[399]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[399, 1, 5, 5]***blocks.4.0.bn2.weight:[399]***blocks.4.0.bn2.bias:[399]***blocks.4.0.bn2.running_mean:[399]***blocks.4.0.bn2.running_var:[399]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[20, 399, 1, 1]***blocks.4.0.se.conv_reduce.bias:[20]***blocks.4.0.se.conv_expand.weight:[399, 20, 1, 1]***blocks.4.0.se.conv_expand.bias:[399]***blocks.4.0.conv_pwl.weight:[67, 399, 1, 1]***blocks.4.0.bn3.weight:[67]***blocks.4.0.bn3.bias:[67]***blocks.4.0.bn3.running_mean:[67]***blocks.4.0.bn3.running_var:[67]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[201, 67, 1, 1]***blocks.4.1.bn1.weight:[201]***blocks.4.1.bn1.bias:[201]***blocks.4.1.bn1.running_mean:[201]***blocks.4.1.bn1.running_var:[201]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[201, 1, 5, 5]***blocks.4.1.bn2.weight:[201]***blocks.4.1.bn2.bias:[201]***blocks.4.1.bn2.running_mean:[201]***blocks.4.1.bn2.running_var:[201]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[28, 201, 1, 1]***blocks.4.1.se.conv_reduce.bias:[28]***blocks.4.1.se.conv_expand.weight:[201, 28, 1, 1]***blocks.4.1.se.conv_expand.bias:[201]***blocks.4.1.conv_pwl.weight:[67, 201, 1, 1]***blocks.4.1.bn3.weight:[67]***blocks.4.1.bn3.bias:[67]***blocks.4.1.bn3.running_mean:[67]***blocks.4.1.bn3.running_var:[67]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[160, 67, 1, 1]***blocks.4.2.bn1.weight:[160]***blocks.4.2.bn1.bias:[160]***blocks.4.2.bn1.running_mean:[160]***blocks.4.2.bn1.running_var:[160]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[160, 1, 5, 5]***blocks.4.2.bn2.weight:[160]***blocks.4.2.bn2.bias:[160]***blocks.4.2.bn2.running_mean:[160]***blocks.4.2.bn2.running_var:[160]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[28, 160, 1, 1]***blocks.4.2.se.conv_reduce.bias:[28]***blocks.4.2.se.conv_expand.weight:[160, 28, 1, 1]***blocks.4.2.se.conv_expand.bias:[160]***blocks.4.2.conv_pwl.weight:[67, 160, 1, 1]***blocks.4.2.bn3.weight:[67]***blocks.4.2.bn3.bias:[67]***blocks.4.2.bn3.running_mean:[67]***blocks.4.2.bn3.running_var:[67]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[213, 67, 1, 1]***blocks.4.3.bn1.weight:[213]***blocks.4.3.bn1.bias:[213]***blocks.4.3.bn1.running_mean:[213]***blocks.4.3.bn1.running_var:[213]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[213, 1, 5, 5]***blocks.4.3.bn2.weight:[213]***blocks.4.3.bn2.bias:[213]***blocks.4.3.bn2.running_mean:[213]***blocks.4.3.bn2.running_var:[213]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[28, 213, 1, 1]***blocks.4.3.se.conv_reduce.bias:[28]***blocks.4.3.se.conv_expand.weight:[213, 28, 1, 1]***blocks.4.3.se.conv_expand.bias:[213]***blocks.4.3.conv_pwl.weight:[67, 213, 1, 1]***blocks.4.3.bn3.weight:[67]***blocks.4.3.bn3.bias:[67]***blocks.4.3.bn3.running_mean:[67]***blocks.4.3.bn3.running_var:[67]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[637, 67, 1, 1]***blocks.5.0.bn1.weight:[637]***blocks.5.0.bn1.bias:[637]***blocks.5.0.bn1.running_mean:[637]***blocks.5.0.bn1.running_var:[637]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[637, 1, 5, 5]***blocks.5.0.bn2.weight:[637]***blocks.5.0.bn2.bias:[637]***blocks.5.0.bn2.running_mean:[637]***blocks.5.0.bn2.running_var:[637]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[27, 637, 1, 1]***blocks.5.0.se.conv_reduce.bias:[27]***blocks.5.0.se.conv_expand.weight:[637, 27, 1, 1]***blocks.5.0.se.conv_expand.bias:[637]***blocks.5.0.conv_pwl.weight:[192, 637, 1, 1]***blocks.5.0.bn3.weight:[192]***blocks.5.0.bn3.bias:[192]***blocks.5.0.bn3.running_mean:[192]***blocks.5.0.bn3.running_var:[192]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[806, 192, 1, 1]***blocks.5.1.bn1.weight:[806]***blocks.5.1.bn1.bias:[806]***blocks.5.1.bn1.running_mean:[806]***blocks.5.1.bn1.running_var:[806]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[806, 1, 5, 5]***blocks.5.1.bn2.weight:[806]***blocks.5.1.bn2.bias:[806]***blocks.5.1.bn2.running_mean:[806]***blocks.5.1.bn2.running_var:[806]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[48, 806, 1, 1]***blocks.5.1.se.conv_reduce.bias:[48]***blocks.5.1.se.conv_expand.weight:[806, 48, 1, 1]***blocks.5.1.se.conv_expand.bias:[806]***blocks.5.1.conv_pwl.weight:[192, 806, 1, 1]***blocks.5.1.bn3.weight:[192]***blocks.5.1.bn3.bias:[192]***blocks.5.1.bn3.running_mean:[192]***blocks.5.1.bn3.running_var:[192]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[798, 192, 1, 1]***blocks.5.2.bn1.weight:[798]***blocks.5.2.bn1.bias:[798]***blocks.5.2.bn1.running_mean:[798]***blocks.5.2.bn1.running_var:[798]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[798, 1, 5, 5]***blocks.5.2.bn2.weight:[798]***blocks.5.2.bn2.bias:[798]***blocks.5.2.bn2.running_mean:[798]***blocks.5.2.bn2.running_var:[798]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[48, 798, 1, 1]***blocks.5.2.se.conv_reduce.bias:[48]***blocks.5.2.se.conv_expand.weight:[798, 48, 1, 1]***blocks.5.2.se.conv_expand.bias:[798]***blocks.5.2.conv_pwl.weight:[192, 798, 1, 1]***blocks.5.2.bn3.weight:[192]***blocks.5.2.bn3.bias:[192]***blocks.5.2.bn3.running_mean:[192]***blocks.5.2.bn3.running_var:[192]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[891, 192, 1, 1]***blocks.5.3.bn1.weight:[891]***blocks.5.3.bn1.bias:[891]***blocks.5.3.bn1.running_mean:[891]***blocks.5.3.bn1.running_var:[891]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[891, 1, 5, 5]***blocks.5.3.bn2.weight:[891]***blocks.5.3.bn2.bias:[891]***blocks.5.3.bn2.running_mean:[891]***blocks.5.3.bn2.running_var:[891]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[48, 891, 1, 1]***blocks.5.3.se.conv_reduce.bias:[48]***blocks.5.3.se.conv_expand.weight:[891, 48, 1, 1]***blocks.5.3.se.conv_expand.bias:[891]***blocks.5.3.conv_pwl.weight:[192, 891, 1, 1]***blocks.5.3.bn3.weight:[192]***blocks.5.3.bn3.bias:[192]***blocks.5.3.bn3.running_mean:[192]***blocks.5.3.bn3.running_var:[192]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[990, 192, 1, 1]***blocks.5.4.bn1.weight:[990]***blocks.5.4.bn1.bias:[990]***blocks.5.4.bn1.running_mean:[990]***blocks.5.4.bn1.running_var:[990]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[990, 1, 5, 5]***blocks.5.4.bn2.weight:[990]***blocks.5.4.bn2.bias:[990]***blocks.5.4.bn2.running_mean:[990]***blocks.5.4.bn2.running_var:[990]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[48, 990, 1, 1]***blocks.5.4.se.conv_reduce.bias:[48]***blocks.5.4.se.conv_expand.weight:[990, 48, 1, 1]***blocks.5.4.se.conv_expand.bias:[990]***blocks.5.4.conv_pwl.weight:[192, 990, 1, 1]***blocks.5.4.bn3.weight:[192]***blocks.5.4.bn3.bias:[192]***blocks.5.4.bn3.running_mean:[192]***blocks.5.4.bn3.running_var:[192]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1152, 192, 1, 1]***blocks.6.0.bn1.weight:[1152]***blocks.6.0.bn1.bias:[1152]***blocks.6.0.bn1.running_mean:[1152]***blocks.6.0.bn1.running_var:[1152]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1152, 1, 3, 3]***blocks.6.0.bn2.weight:[1152]***blocks.6.0.bn2.bias:[1152]***blocks.6.0.bn2.running_mean:[1152]***blocks.6.0.bn2.running_var:[1152]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[48, 1152, 1, 1]***blocks.6.0.se.conv_reduce.bias:[48]***blocks.6.0.se.conv_expand.weight:[1152, 48, 1, 1]***blocks.6.0.se.conv_expand.bias:[1152]***blocks.6.0.conv_pwl.weight:[320, 1152, 1, 1]***blocks.6.0.bn3.weight:[320]***blocks.6.0.bn3.bias:[320]***blocks.6.0.bn3.running_mean:[320]***blocks.6.0.bn3.running_var:[320]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[1912, 320, 1, 1]***blocks.6.1.bn1.weight:[1912]***blocks.6.1.bn1.bias:[1912]***blocks.6.1.bn1.running_mean:[1912]***blocks.6.1.bn1.running_var:[1912]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[1912, 1, 3, 3]***blocks.6.1.bn2.weight:[1912]***blocks.6.1.bn2.bias:[1912]***blocks.6.1.bn2.running_mean:[1912]***blocks.6.1.bn2.running_var:[1912]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[80, 1912, 1, 1]***blocks.6.1.se.conv_reduce.bias:[80]***blocks.6.1.se.conv_expand.weight:[1912, 80, 1, 1]***blocks.6.1.se.conv_expand.bias:[1912]***blocks.6.1.conv_pwl.weight:[320, 1912, 1, 1]***blocks.6.1.bn3.weight:[320]***blocks.6.1.bn3.bias:[320]***blocks.6.1.bn3.running_mean:[320]***blocks.6.1.bn3.running_var:[320]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1280, 320, 1, 1]***bn2.weight:[1280]***bn2.bias:[1280]***bn2.running_mean:[1280]***bn2.running_var:[1280]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1280]***classifier.bias:[1000] \ No newline at end of file diff --git a/timm/models/pruned/efficientnet_b2_pruned.txt b/timm/models/pruned/efficientnet_b2_pruned.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e3fadee3e9f92eaade96afd8691a5e4437551ee --- /dev/null +++ b/timm/models/pruned/efficientnet_b2_pruned.txt @@ -0,0 +1 @@ +conv_stem.weight:[32, 3, 3, 3]***bn1.weight:[32]***bn1.bias:[32]***bn1.running_mean:[32]***bn1.running_var:[32]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[32, 1, 3, 3]***blocks.0.0.bn1.weight:[32]***blocks.0.0.bn1.bias:[32]***blocks.0.0.bn1.running_mean:[32]***blocks.0.0.bn1.running_var:[32]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[8, 32, 1, 1]***blocks.0.0.se.conv_reduce.bias:[8]***blocks.0.0.se.conv_expand.weight:[32, 8, 1, 1]***blocks.0.0.se.conv_expand.bias:[32]***blocks.0.0.conv_pw.weight:[16, 32, 1, 1]***blocks.0.0.bn2.weight:[16]***blocks.0.0.bn2.bias:[16]***blocks.0.0.bn2.running_mean:[16]***blocks.0.0.bn2.running_var:[16]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[16, 1, 3, 3]***blocks.0.1.bn1.weight:[16]***blocks.0.1.bn1.bias:[16]***blocks.0.1.bn1.running_mean:[16]***blocks.0.1.bn1.running_var:[16]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[4, 16, 1, 1]***blocks.0.1.se.conv_reduce.bias:[4]***blocks.0.1.se.conv_expand.weight:[16, 4, 1, 1]***blocks.0.1.se.conv_expand.bias:[16]***blocks.0.1.conv_pw.weight:[16, 16, 1, 1]***blocks.0.1.bn2.weight:[16]***blocks.0.1.bn2.bias:[16]***blocks.0.1.bn2.running_mean:[16]***blocks.0.1.bn2.running_var:[16]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[54, 16, 1, 1]***blocks.1.0.bn1.weight:[54]***blocks.1.0.bn1.bias:[54]***blocks.1.0.bn1.running_mean:[54]***blocks.1.0.bn1.running_var:[54]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[54, 1, 3, 3]***blocks.1.0.bn2.weight:[54]***blocks.1.0.bn2.bias:[54]***blocks.1.0.bn2.running_mean:[54]***blocks.1.0.bn2.running_var:[54]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[4, 54, 1, 1]***blocks.1.0.se.conv_reduce.bias:[4]***blocks.1.0.se.conv_expand.weight:[54, 4, 1, 1]***blocks.1.0.se.conv_expand.bias:[54]***blocks.1.0.conv_pwl.weight:[17, 54, 1, 1]***blocks.1.0.bn3.weight:[17]***blocks.1.0.bn3.bias:[17]***blocks.1.0.bn3.running_mean:[17]***blocks.1.0.bn3.running_var:[17]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[69, 17, 1, 1]***blocks.1.1.bn1.weight:[69]***blocks.1.1.bn1.bias:[69]***blocks.1.1.bn1.running_mean:[69]***blocks.1.1.bn1.running_var:[69]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[69, 1, 3, 3]***blocks.1.1.bn2.weight:[69]***blocks.1.1.bn2.bias:[69]***blocks.1.1.bn2.running_mean:[69]***blocks.1.1.bn2.running_var:[69]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[6, 69, 1, 1]***blocks.1.1.se.conv_reduce.bias:[6]***blocks.1.1.se.conv_expand.weight:[69, 6, 1, 1]***blocks.1.1.se.conv_expand.bias:[69]***blocks.1.1.conv_pwl.weight:[17, 69, 1, 1]***blocks.1.1.bn3.weight:[17]***blocks.1.1.bn3.bias:[17]***blocks.1.1.bn3.running_mean:[17]***blocks.1.1.bn3.running_var:[17]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[61, 17, 1, 1]***blocks.1.2.bn1.weight:[61]***blocks.1.2.bn1.bias:[61]***blocks.1.2.bn1.running_mean:[61]***blocks.1.2.bn1.running_var:[61]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[61, 1, 3, 3]***blocks.1.2.bn2.weight:[61]***blocks.1.2.bn2.bias:[61]***blocks.1.2.bn2.running_mean:[61]***blocks.1.2.bn2.running_var:[61]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[6, 61, 1, 1]***blocks.1.2.se.conv_reduce.bias:[6]***blocks.1.2.se.conv_expand.weight:[61, 6, 1, 1]***blocks.1.2.se.conv_expand.bias:[61]***blocks.1.2.conv_pwl.weight:[17, 61, 1, 1]***blocks.1.2.bn3.weight:[17]***blocks.1.2.bn3.bias:[17]***blocks.1.2.bn3.running_mean:[17]***blocks.1.2.bn3.running_var:[17]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[86, 17, 1, 1]***blocks.2.0.bn1.weight:[86]***blocks.2.0.bn1.bias:[86]***blocks.2.0.bn1.running_mean:[86]***blocks.2.0.bn1.running_var:[86]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[86, 1, 5, 5]***blocks.2.0.bn2.weight:[86]***blocks.2.0.bn2.bias:[86]***blocks.2.0.bn2.running_mean:[86]***blocks.2.0.bn2.running_var:[86]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[6, 86, 1, 1]***blocks.2.0.se.conv_reduce.bias:[6]***blocks.2.0.se.conv_expand.weight:[86, 6, 1, 1]***blocks.2.0.se.conv_expand.bias:[86]***blocks.2.0.conv_pwl.weight:[42, 86, 1, 1]***blocks.2.0.bn3.weight:[42]***blocks.2.0.bn3.bias:[42]***blocks.2.0.bn3.running_mean:[42]***blocks.2.0.bn3.running_var:[42]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[72, 42, 1, 1]***blocks.2.1.bn1.weight:[72]***blocks.2.1.bn1.bias:[72]***blocks.2.1.bn1.running_mean:[72]***blocks.2.1.bn1.running_var:[72]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[72, 1, 5, 5]***blocks.2.1.bn2.weight:[72]***blocks.2.1.bn2.bias:[72]***blocks.2.1.bn2.running_mean:[72]***blocks.2.1.bn2.running_var:[72]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[12, 72, 1, 1]***blocks.2.1.se.conv_reduce.bias:[12]***blocks.2.1.se.conv_expand.weight:[72, 12, 1, 1]***blocks.2.1.se.conv_expand.bias:[72]***blocks.2.1.conv_pwl.weight:[42, 72, 1, 1]***blocks.2.1.bn3.weight:[42]***blocks.2.1.bn3.bias:[42]***blocks.2.1.bn3.running_mean:[42]***blocks.2.1.bn3.running_var:[42]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[98, 42, 1, 1]***blocks.2.2.bn1.weight:[98]***blocks.2.2.bn1.bias:[98]***blocks.2.2.bn1.running_mean:[98]***blocks.2.2.bn1.running_var:[98]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[98, 1, 5, 5]***blocks.2.2.bn2.weight:[98]***blocks.2.2.bn2.bias:[98]***blocks.2.2.bn2.running_mean:[98]***blocks.2.2.bn2.running_var:[98]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[12, 98, 1, 1]***blocks.2.2.se.conv_reduce.bias:[12]***blocks.2.2.se.conv_expand.weight:[98, 12, 1, 1]***blocks.2.2.se.conv_expand.bias:[98]***blocks.2.2.conv_pwl.weight:[42, 98, 1, 1]***blocks.2.2.bn3.weight:[42]***blocks.2.2.bn3.bias:[42]***blocks.2.2.bn3.running_mean:[42]***blocks.2.2.bn3.running_var:[42]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[245, 42, 1, 1]***blocks.3.0.bn1.weight:[245]***blocks.3.0.bn1.bias:[245]***blocks.3.0.bn1.running_mean:[245]***blocks.3.0.bn1.running_var:[245]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[245, 1, 3, 3]***blocks.3.0.bn2.weight:[245]***blocks.3.0.bn2.bias:[245]***blocks.3.0.bn2.running_mean:[245]***blocks.3.0.bn2.running_var:[245]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[12, 245, 1, 1]***blocks.3.0.se.conv_reduce.bias:[12]***blocks.3.0.se.conv_expand.weight:[245, 12, 1, 1]***blocks.3.0.se.conv_expand.bias:[245]***blocks.3.0.conv_pwl.weight:[85, 245, 1, 1]***blocks.3.0.bn3.weight:[85]***blocks.3.0.bn3.bias:[85]***blocks.3.0.bn3.running_mean:[85]***blocks.3.0.bn3.running_var:[85]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[274, 85, 1, 1]***blocks.3.1.bn1.weight:[274]***blocks.3.1.bn1.bias:[274]***blocks.3.1.bn1.running_mean:[274]***blocks.3.1.bn1.running_var:[274]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[274, 1, 3, 3]***blocks.3.1.bn2.weight:[274]***blocks.3.1.bn2.bias:[274]***blocks.3.1.bn2.running_mean:[274]***blocks.3.1.bn2.running_var:[274]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[22, 274, 1, 1]***blocks.3.1.se.conv_reduce.bias:[22]***blocks.3.1.se.conv_expand.weight:[274, 22, 1, 1]***blocks.3.1.se.conv_expand.bias:[274]***blocks.3.1.conv_pwl.weight:[85, 274, 1, 1]***blocks.3.1.bn3.weight:[85]***blocks.3.1.bn3.bias:[85]***blocks.3.1.bn3.running_mean:[85]***blocks.3.1.bn3.running_var:[85]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[254, 85, 1, 1]***blocks.3.2.bn1.weight:[254]***blocks.3.2.bn1.bias:[254]***blocks.3.2.bn1.running_mean:[254]***blocks.3.2.bn1.running_var:[254]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[254, 1, 3, 3]***blocks.3.2.bn2.weight:[254]***blocks.3.2.bn2.bias:[254]***blocks.3.2.bn2.running_mean:[254]***blocks.3.2.bn2.running_var:[254]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[22, 254, 1, 1]***blocks.3.2.se.conv_reduce.bias:[22]***blocks.3.2.se.conv_expand.weight:[254, 22, 1, 1]***blocks.3.2.se.conv_expand.bias:[254]***blocks.3.2.conv_pwl.weight:[85, 254, 1, 1]***blocks.3.2.bn3.weight:[85]***blocks.3.2.bn3.bias:[85]***blocks.3.2.bn3.running_mean:[85]***blocks.3.2.bn3.running_var:[85]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[292, 85, 1, 1]***blocks.3.3.bn1.weight:[292]***blocks.3.3.bn1.bias:[292]***blocks.3.3.bn1.running_mean:[292]***blocks.3.3.bn1.running_var:[292]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[292, 1, 3, 3]***blocks.3.3.bn2.weight:[292]***blocks.3.3.bn2.bias:[292]***blocks.3.3.bn2.running_mean:[292]***blocks.3.3.bn2.running_var:[292]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[22, 292, 1, 1]***blocks.3.3.se.conv_reduce.bias:[22]***blocks.3.3.se.conv_expand.weight:[292, 22, 1, 1]***blocks.3.3.se.conv_expand.bias:[292]***blocks.3.3.conv_pwl.weight:[85, 292, 1, 1]***blocks.3.3.bn3.weight:[85]***blocks.3.3.bn3.bias:[85]***blocks.3.3.bn3.running_mean:[85]***blocks.3.3.bn3.running_var:[85]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[502, 85, 1, 1]***blocks.4.0.bn1.weight:[502]***blocks.4.0.bn1.bias:[502]***blocks.4.0.bn1.running_mean:[502]***blocks.4.0.bn1.running_var:[502]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[502, 1, 5, 5]***blocks.4.0.bn2.weight:[502]***blocks.4.0.bn2.bias:[502]***blocks.4.0.bn2.running_mean:[502]***blocks.4.0.bn2.running_var:[502]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[22, 502, 1, 1]***blocks.4.0.se.conv_reduce.bias:[22]***blocks.4.0.se.conv_expand.weight:[502, 22, 1, 1]***blocks.4.0.se.conv_expand.bias:[502]***blocks.4.0.conv_pwl.weight:[116, 502, 1, 1]***blocks.4.0.bn3.weight:[116]***blocks.4.0.bn3.bias:[116]***blocks.4.0.bn3.running_mean:[116]***blocks.4.0.bn3.running_var:[116]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[315, 116, 1, 1]***blocks.4.1.bn1.weight:[315]***blocks.4.1.bn1.bias:[315]***blocks.4.1.bn1.running_mean:[315]***blocks.4.1.bn1.running_var:[315]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[315, 1, 5, 5]***blocks.4.1.bn2.weight:[315]***blocks.4.1.bn2.bias:[315]***blocks.4.1.bn2.running_mean:[315]***blocks.4.1.bn2.running_var:[315]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[30, 315, 1, 1]***blocks.4.1.se.conv_reduce.bias:[30]***blocks.4.1.se.conv_expand.weight:[315, 30, 1, 1]***blocks.4.1.se.conv_expand.bias:[315]***blocks.4.1.conv_pwl.weight:[116, 315, 1, 1]***blocks.4.1.bn3.weight:[116]***blocks.4.1.bn3.bias:[116]***blocks.4.1.bn3.running_mean:[116]***blocks.4.1.bn3.running_var:[116]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[354, 116, 1, 1]***blocks.4.2.bn1.weight:[354]***blocks.4.2.bn1.bias:[354]***blocks.4.2.bn1.running_mean:[354]***blocks.4.2.bn1.running_var:[354]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[354, 1, 5, 5]***blocks.4.2.bn2.weight:[354]***blocks.4.2.bn2.bias:[354]***blocks.4.2.bn2.running_mean:[354]***blocks.4.2.bn2.running_var:[354]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[30, 354, 1, 1]***blocks.4.2.se.conv_reduce.bias:[30]***blocks.4.2.se.conv_expand.weight:[354, 30, 1, 1]***blocks.4.2.se.conv_expand.bias:[354]***blocks.4.2.conv_pwl.weight:[116, 354, 1, 1]***blocks.4.2.bn3.weight:[116]***blocks.4.2.bn3.bias:[116]***blocks.4.2.bn3.running_mean:[116]***blocks.4.2.bn3.running_var:[116]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[443, 116, 1, 1]***blocks.4.3.bn1.weight:[443]***blocks.4.3.bn1.bias:[443]***blocks.4.3.bn1.running_mean:[443]***blocks.4.3.bn1.running_var:[443]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[443, 1, 5, 5]***blocks.4.3.bn2.weight:[443]***blocks.4.3.bn2.bias:[443]***blocks.4.3.bn2.running_mean:[443]***blocks.4.3.bn2.running_var:[443]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[30, 443, 1, 1]***blocks.4.3.se.conv_reduce.bias:[30]***blocks.4.3.se.conv_expand.weight:[443, 30, 1, 1]***blocks.4.3.se.conv_expand.bias:[443]***blocks.4.3.conv_pwl.weight:[116, 443, 1, 1]***blocks.4.3.bn3.weight:[116]***blocks.4.3.bn3.bias:[116]***blocks.4.3.bn3.running_mean:[116]***blocks.4.3.bn3.running_var:[116]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[719, 116, 1, 1]***blocks.5.0.bn1.weight:[719]***blocks.5.0.bn1.bias:[719]***blocks.5.0.bn1.running_mean:[719]***blocks.5.0.bn1.running_var:[719]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[719, 1, 5, 5]***blocks.5.0.bn2.weight:[719]***blocks.5.0.bn2.bias:[719]***blocks.5.0.bn2.running_mean:[719]***blocks.5.0.bn2.running_var:[719]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[30, 719, 1, 1]***blocks.5.0.se.conv_reduce.bias:[30]***blocks.5.0.se.conv_expand.weight:[719, 30, 1, 1]***blocks.5.0.se.conv_expand.bias:[719]***blocks.5.0.conv_pwl.weight:[208, 719, 1, 1]***blocks.5.0.bn3.weight:[208]***blocks.5.0.bn3.bias:[208]***blocks.5.0.bn3.running_mean:[208]***blocks.5.0.bn3.running_var:[208]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[1148, 208, 1, 1]***blocks.5.1.bn1.weight:[1148]***blocks.5.1.bn1.bias:[1148]***blocks.5.1.bn1.running_mean:[1148]***blocks.5.1.bn1.running_var:[1148]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[1148, 1, 5, 5]***blocks.5.1.bn2.weight:[1148]***blocks.5.1.bn2.bias:[1148]***blocks.5.1.bn2.running_mean:[1148]***blocks.5.1.bn2.running_var:[1148]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[52, 1148, 1, 1]***blocks.5.1.se.conv_reduce.bias:[52]***blocks.5.1.se.conv_expand.weight:[1148, 52, 1, 1]***blocks.5.1.se.conv_expand.bias:[1148]***blocks.5.1.conv_pwl.weight:[208, 1148, 1, 1]***blocks.5.1.bn3.weight:[208]***blocks.5.1.bn3.bias:[208]***blocks.5.1.bn3.running_mean:[208]***blocks.5.1.bn3.running_var:[208]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[1160, 208, 1, 1]***blocks.5.2.bn1.weight:[1160]***blocks.5.2.bn1.bias:[1160]***blocks.5.2.bn1.running_mean:[1160]***blocks.5.2.bn1.running_var:[1160]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[1160, 1, 5, 5]***blocks.5.2.bn2.weight:[1160]***blocks.5.2.bn2.bias:[1160]***blocks.5.2.bn2.running_mean:[1160]***blocks.5.2.bn2.running_var:[1160]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[52, 1160, 1, 1]***blocks.5.2.se.conv_reduce.bias:[52]***blocks.5.2.se.conv_expand.weight:[1160, 52, 1, 1]***blocks.5.2.se.conv_expand.bias:[1160]***blocks.5.2.conv_pwl.weight:[208, 1160, 1, 1]***blocks.5.2.bn3.weight:[208]***blocks.5.2.bn3.bias:[208]***blocks.5.2.bn3.running_mean:[208]***blocks.5.2.bn3.running_var:[208]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[1182, 208, 1, 1]***blocks.5.3.bn1.weight:[1182]***blocks.5.3.bn1.bias:[1182]***blocks.5.3.bn1.running_mean:[1182]***blocks.5.3.bn1.running_var:[1182]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[1182, 1, 5, 5]***blocks.5.3.bn2.weight:[1182]***blocks.5.3.bn2.bias:[1182]***blocks.5.3.bn2.running_mean:[1182]***blocks.5.3.bn2.running_var:[1182]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[52, 1182, 1, 1]***blocks.5.3.se.conv_reduce.bias:[52]***blocks.5.3.se.conv_expand.weight:[1182, 52, 1, 1]***blocks.5.3.se.conv_expand.bias:[1182]***blocks.5.3.conv_pwl.weight:[208, 1182, 1, 1]***blocks.5.3.bn3.weight:[208]***blocks.5.3.bn3.bias:[208]***blocks.5.3.bn3.running_mean:[208]***blocks.5.3.bn3.running_var:[208]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[1228, 208, 1, 1]***blocks.5.4.bn1.weight:[1228]***blocks.5.4.bn1.bias:[1228]***blocks.5.4.bn1.running_mean:[1228]***blocks.5.4.bn1.running_var:[1228]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[1228, 1, 5, 5]***blocks.5.4.bn2.weight:[1228]***blocks.5.4.bn2.bias:[1228]***blocks.5.4.bn2.running_mean:[1228]***blocks.5.4.bn2.running_var:[1228]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[52, 1228, 1, 1]***blocks.5.4.se.conv_reduce.bias:[52]***blocks.5.4.se.conv_expand.weight:[1228, 52, 1, 1]***blocks.5.4.se.conv_expand.bias:[1228]***blocks.5.4.conv_pwl.weight:[208, 1228, 1, 1]***blocks.5.4.bn3.weight:[208]***blocks.5.4.bn3.bias:[208]***blocks.5.4.bn3.running_mean:[208]***blocks.5.4.bn3.running_var:[208]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1248, 208, 1, 1]***blocks.6.0.bn1.weight:[1248]***blocks.6.0.bn1.bias:[1248]***blocks.6.0.bn1.running_mean:[1248]***blocks.6.0.bn1.running_var:[1248]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1248, 1, 3, 3]***blocks.6.0.bn2.weight:[1248]***blocks.6.0.bn2.bias:[1248]***blocks.6.0.bn2.running_mean:[1248]***blocks.6.0.bn2.running_var:[1248]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[52, 1248, 1, 1]***blocks.6.0.se.conv_reduce.bias:[52]***blocks.6.0.se.conv_expand.weight:[1248, 52, 1, 1]***blocks.6.0.se.conv_expand.bias:[1248]***blocks.6.0.conv_pwl.weight:[352, 1248, 1, 1]***blocks.6.0.bn3.weight:[352]***blocks.6.0.bn3.bias:[352]***blocks.6.0.bn3.running_mean:[352]***blocks.6.0.bn3.running_var:[352]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[2112, 352, 1, 1]***blocks.6.1.bn1.weight:[2112]***blocks.6.1.bn1.bias:[2112]***blocks.6.1.bn1.running_mean:[2112]***blocks.6.1.bn1.running_var:[2112]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[2112, 1, 3, 3]***blocks.6.1.bn2.weight:[2112]***blocks.6.1.bn2.bias:[2112]***blocks.6.1.bn2.running_mean:[2112]***blocks.6.1.bn2.running_var:[2112]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[88, 2112, 1, 1]***blocks.6.1.se.conv_reduce.bias:[88]***blocks.6.1.se.conv_expand.weight:[2112, 88, 1, 1]***blocks.6.1.se.conv_expand.bias:[2112]***blocks.6.1.conv_pwl.weight:[352, 2112, 1, 1]***blocks.6.1.bn3.weight:[352]***blocks.6.1.bn3.bias:[352]***blocks.6.1.bn3.running_mean:[352]***blocks.6.1.bn3.running_var:[352]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1408, 352, 1, 1]***bn2.weight:[1408]***bn2.bias:[1408]***bn2.running_mean:[1408]***bn2.running_var:[1408]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1408]***classifier.bias:[1000] \ No newline at end of file diff --git a/timm/models/pruned/efficientnet_b3_pruned.txt b/timm/models/pruned/efficientnet_b3_pruned.txt new file mode 100644 index 0000000000000000000000000000000000000000..489781736de08e5cf40bf76528a735fff4a3f61c --- /dev/null +++ b/timm/models/pruned/efficientnet_b3_pruned.txt @@ -0,0 +1 @@ +conv_stem.weight:[40, 3, 3, 3]***bn1.weight:[40]***bn1.bias:[40]***bn1.running_mean:[40]***bn1.running_var:[40]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[40, 1, 3, 3]***blocks.0.0.bn1.weight:[40]***blocks.0.0.bn1.bias:[40]***blocks.0.0.bn1.running_mean:[40]***blocks.0.0.bn1.running_var:[40]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[10, 40, 1, 1]***blocks.0.0.se.conv_reduce.bias:[10]***blocks.0.0.se.conv_expand.weight:[40, 10, 1, 1]***blocks.0.0.se.conv_expand.bias:[40]***blocks.0.0.conv_pw.weight:[24, 40, 1, 1]***blocks.0.0.bn2.weight:[24]***blocks.0.0.bn2.bias:[24]***blocks.0.0.bn2.running_mean:[24]***blocks.0.0.bn2.running_var:[24]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[24, 1, 3, 3]***blocks.0.1.bn1.weight:[24]***blocks.0.1.bn1.bias:[24]***blocks.0.1.bn1.running_mean:[24]***blocks.0.1.bn1.running_var:[24]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[6, 24, 1, 1]***blocks.0.1.se.conv_reduce.bias:[6]***blocks.0.1.se.conv_expand.weight:[24, 6, 1, 1]***blocks.0.1.se.conv_expand.bias:[24]***blocks.0.1.conv_pw.weight:[24, 24, 1, 1]***blocks.0.1.bn2.weight:[24]***blocks.0.1.bn2.bias:[24]***blocks.0.1.bn2.running_mean:[24]***blocks.0.1.bn2.running_var:[24]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[27, 24, 1, 1]***blocks.1.0.bn1.weight:[27]***blocks.1.0.bn1.bias:[27]***blocks.1.0.bn1.running_mean:[27]***blocks.1.0.bn1.running_var:[27]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[27, 1, 3, 3]***blocks.1.0.bn2.weight:[27]***blocks.1.0.bn2.bias:[27]***blocks.1.0.bn2.running_mean:[27]***blocks.1.0.bn2.running_var:[27]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[6, 27, 1, 1]***blocks.1.0.se.conv_reduce.bias:[6]***blocks.1.0.se.conv_expand.weight:[27, 6, 1, 1]***blocks.1.0.se.conv_expand.bias:[27]***blocks.1.0.conv_pwl.weight:[12, 27, 1, 1]***blocks.1.0.bn3.weight:[12]***blocks.1.0.bn3.bias:[12]***blocks.1.0.bn3.running_mean:[12]***blocks.1.0.bn3.running_var:[12]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[49, 12, 1, 1]***blocks.1.1.bn1.weight:[49]***blocks.1.1.bn1.bias:[49]***blocks.1.1.bn1.running_mean:[49]***blocks.1.1.bn1.running_var:[49]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[49, 1, 3, 3]***blocks.1.1.bn2.weight:[49]***blocks.1.1.bn2.bias:[49]***blocks.1.1.bn2.running_mean:[49]***blocks.1.1.bn2.running_var:[49]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[8, 49, 1, 1]***blocks.1.1.se.conv_reduce.bias:[8]***blocks.1.1.se.conv_expand.weight:[49, 8, 1, 1]***blocks.1.1.se.conv_expand.bias:[49]***blocks.1.1.conv_pwl.weight:[12, 49, 1, 1]***blocks.1.1.bn3.weight:[12]***blocks.1.1.bn3.bias:[12]***blocks.1.1.bn3.running_mean:[12]***blocks.1.1.bn3.running_var:[12]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[48, 12, 1, 1]***blocks.1.2.bn1.weight:[48]***blocks.1.2.bn1.bias:[48]***blocks.1.2.bn1.running_mean:[48]***blocks.1.2.bn1.running_var:[48]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[48, 1, 3, 3]***blocks.1.2.bn2.weight:[48]***blocks.1.2.bn2.bias:[48]***blocks.1.2.bn2.running_mean:[48]***blocks.1.2.bn2.running_var:[48]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[8, 48, 1, 1]***blocks.1.2.se.conv_reduce.bias:[8]***blocks.1.2.se.conv_expand.weight:[48, 8, 1, 1]***blocks.1.2.se.conv_expand.bias:[48]***blocks.1.2.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.2.bn3.weight:[12]***blocks.1.2.bn3.bias:[12]***blocks.1.2.bn3.running_mean:[12]***blocks.1.2.bn3.running_var:[12]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[83, 12, 1, 1]***blocks.2.0.bn1.weight:[83]***blocks.2.0.bn1.bias:[83]***blocks.2.0.bn1.running_mean:[83]***blocks.2.0.bn1.running_var:[83]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[83, 1, 5, 5]***blocks.2.0.bn2.weight:[83]***blocks.2.0.bn2.bias:[83]***blocks.2.0.bn2.running_mean:[83]***blocks.2.0.bn2.running_var:[83]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[8, 83, 1, 1]***blocks.2.0.se.conv_reduce.bias:[8]***blocks.2.0.se.conv_expand.weight:[83, 8, 1, 1]***blocks.2.0.se.conv_expand.bias:[83]***blocks.2.0.conv_pwl.weight:[40, 83, 1, 1]***blocks.2.0.bn3.weight:[40]***blocks.2.0.bn3.bias:[40]***blocks.2.0.bn3.running_mean:[40]***blocks.2.0.bn3.running_var:[40]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[90, 40, 1, 1]***blocks.2.1.bn1.weight:[90]***blocks.2.1.bn1.bias:[90]***blocks.2.1.bn1.running_mean:[90]***blocks.2.1.bn1.running_var:[90]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[90, 1, 5, 5]***blocks.2.1.bn2.weight:[90]***blocks.2.1.bn2.bias:[90]***blocks.2.1.bn2.running_mean:[90]***blocks.2.1.bn2.running_var:[90]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[12, 90, 1, 1]***blocks.2.1.se.conv_reduce.bias:[12]***blocks.2.1.se.conv_expand.weight:[90, 12, 1, 1]***blocks.2.1.se.conv_expand.bias:[90]***blocks.2.1.conv_pwl.weight:[40, 90, 1, 1]***blocks.2.1.bn3.weight:[40]***blocks.2.1.bn3.bias:[40]***blocks.2.1.bn3.running_mean:[40]***blocks.2.1.bn3.running_var:[40]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[85, 40, 1, 1]***blocks.2.2.bn1.weight:[85]***blocks.2.2.bn1.bias:[85]***blocks.2.2.bn1.running_mean:[85]***blocks.2.2.bn1.running_var:[85]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[85, 1, 5, 5]***blocks.2.2.bn2.weight:[85]***blocks.2.2.bn2.bias:[85]***blocks.2.2.bn2.running_mean:[85]***blocks.2.2.bn2.running_var:[85]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[12, 85, 1, 1]***blocks.2.2.se.conv_reduce.bias:[12]***blocks.2.2.se.conv_expand.weight:[85, 12, 1, 1]***blocks.2.2.se.conv_expand.bias:[85]***blocks.2.2.conv_pwl.weight:[40, 85, 1, 1]***blocks.2.2.bn3.weight:[40]***blocks.2.2.bn3.bias:[40]***blocks.2.2.bn3.running_mean:[40]***blocks.2.2.bn3.running_var:[40]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[215, 40, 1, 1]***blocks.3.0.bn1.weight:[215]***blocks.3.0.bn1.bias:[215]***blocks.3.0.bn1.running_mean:[215]***blocks.3.0.bn1.running_var:[215]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[215, 1, 3, 3]***blocks.3.0.bn2.weight:[215]***blocks.3.0.bn2.bias:[215]***blocks.3.0.bn2.running_mean:[215]***blocks.3.0.bn2.running_var:[215]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[12, 215, 1, 1]***blocks.3.0.se.conv_reduce.bias:[12]***blocks.3.0.se.conv_expand.weight:[215, 12, 1, 1]***blocks.3.0.se.conv_expand.bias:[215]***blocks.3.0.conv_pwl.weight:[93, 215, 1, 1]***blocks.3.0.bn3.weight:[93]***blocks.3.0.bn3.bias:[93]***blocks.3.0.bn3.running_mean:[93]***blocks.3.0.bn3.running_var:[93]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[261, 93, 1, 1]***blocks.3.1.bn1.weight:[261]***blocks.3.1.bn1.bias:[261]***blocks.3.1.bn1.running_mean:[261]***blocks.3.1.bn1.running_var:[261]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[261, 1, 3, 3]***blocks.3.1.bn2.weight:[261]***blocks.3.1.bn2.bias:[261]***blocks.3.1.bn2.running_mean:[261]***blocks.3.1.bn2.running_var:[261]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[24, 261, 1, 1]***blocks.3.1.se.conv_reduce.bias:[24]***blocks.3.1.se.conv_expand.weight:[261, 24, 1, 1]***blocks.3.1.se.conv_expand.bias:[261]***blocks.3.1.conv_pwl.weight:[93, 261, 1, 1]***blocks.3.1.bn3.weight:[93]***blocks.3.1.bn3.bias:[93]***blocks.3.1.bn3.running_mean:[93]***blocks.3.1.bn3.running_var:[93]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[219, 93, 1, 1]***blocks.3.2.bn1.weight:[219]***blocks.3.2.bn1.bias:[219]***blocks.3.2.bn1.running_mean:[219]***blocks.3.2.bn1.running_var:[219]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[219, 1, 3, 3]***blocks.3.2.bn2.weight:[219]***blocks.3.2.bn2.bias:[219]***blocks.3.2.bn2.running_mean:[219]***blocks.3.2.bn2.running_var:[219]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[24, 219, 1, 1]***blocks.3.2.se.conv_reduce.bias:[24]***blocks.3.2.se.conv_expand.weight:[219, 24, 1, 1]***blocks.3.2.se.conv_expand.bias:[219]***blocks.3.2.conv_pwl.weight:[93, 219, 1, 1]***blocks.3.2.bn3.weight:[93]***blocks.3.2.bn3.bias:[93]***blocks.3.2.bn3.running_mean:[93]***blocks.3.2.bn3.running_var:[93]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[254, 93, 1, 1]***blocks.3.3.bn1.weight:[254]***blocks.3.3.bn1.bias:[254]***blocks.3.3.bn1.running_mean:[254]***blocks.3.3.bn1.running_var:[254]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[254, 1, 3, 3]***blocks.3.3.bn2.weight:[254]***blocks.3.3.bn2.bias:[254]***blocks.3.3.bn2.running_mean:[254]***blocks.3.3.bn2.running_var:[254]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[24, 254, 1, 1]***blocks.3.3.se.conv_reduce.bias:[24]***blocks.3.3.se.conv_expand.weight:[254, 24, 1, 1]***blocks.3.3.se.conv_expand.bias:[254]***blocks.3.3.conv_pwl.weight:[93, 254, 1, 1]***blocks.3.3.bn3.weight:[93]***blocks.3.3.bn3.bias:[93]***blocks.3.3.bn3.running_mean:[93]***blocks.3.3.bn3.running_var:[93]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.3.4.conv_pw.weight:[236, 93, 1, 1]***blocks.3.4.bn1.weight:[236]***blocks.3.4.bn1.bias:[236]***blocks.3.4.bn1.running_mean:[236]***blocks.3.4.bn1.running_var:[236]***blocks.3.4.bn1.num_batches_tracked:[]***blocks.3.4.conv_dw.weight:[236, 1, 3, 3]***blocks.3.4.bn2.weight:[236]***blocks.3.4.bn2.bias:[236]***blocks.3.4.bn2.running_mean:[236]***blocks.3.4.bn2.running_var:[236]***blocks.3.4.bn2.num_batches_tracked:[]***blocks.3.4.se.conv_reduce.weight:[24, 236, 1, 1]***blocks.3.4.se.conv_reduce.bias:[24]***blocks.3.4.se.conv_expand.weight:[236, 24, 1, 1]***blocks.3.4.se.conv_expand.bias:[236]***blocks.3.4.conv_pwl.weight:[93, 236, 1, 1]***blocks.3.4.bn3.weight:[93]***blocks.3.4.bn3.bias:[93]***blocks.3.4.bn3.running_mean:[93]***blocks.3.4.bn3.running_var:[93]***blocks.3.4.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[480, 93, 1, 1]***blocks.4.0.bn1.weight:[480]***blocks.4.0.bn1.bias:[480]***blocks.4.0.bn1.running_mean:[480]***blocks.4.0.bn1.running_var:[480]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[480, 1, 5, 5]***blocks.4.0.bn2.weight:[480]***blocks.4.0.bn2.bias:[480]***blocks.4.0.bn2.running_mean:[480]***blocks.4.0.bn2.running_var:[480]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[24, 480, 1, 1]***blocks.4.0.se.conv_reduce.bias:[24]***blocks.4.0.se.conv_expand.weight:[480, 24, 1, 1]***blocks.4.0.se.conv_expand.bias:[480]***blocks.4.0.conv_pwl.weight:[120, 480, 1, 1]***blocks.4.0.bn3.weight:[120]***blocks.4.0.bn3.bias:[120]***blocks.4.0.bn3.running_mean:[120]***blocks.4.0.bn3.running_var:[120]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[235, 120, 1, 1]***blocks.4.1.bn1.weight:[235]***blocks.4.1.bn1.bias:[235]***blocks.4.1.bn1.running_mean:[235]***blocks.4.1.bn1.running_var:[235]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[235, 1, 5, 5]***blocks.4.1.bn2.weight:[235]***blocks.4.1.bn2.bias:[235]***blocks.4.1.bn2.running_mean:[235]***blocks.4.1.bn2.running_var:[235]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[34, 235, 1, 1]***blocks.4.1.se.conv_reduce.bias:[34]***blocks.4.1.se.conv_expand.weight:[235, 34, 1, 1]***blocks.4.1.se.conv_expand.bias:[235]***blocks.4.1.conv_pwl.weight:[120, 235, 1, 1]***blocks.4.1.bn3.weight:[120]***blocks.4.1.bn3.bias:[120]***blocks.4.1.bn3.running_mean:[120]***blocks.4.1.bn3.running_var:[120]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[217, 120, 1, 1]***blocks.4.2.bn1.weight:[217]***blocks.4.2.bn1.bias:[217]***blocks.4.2.bn1.running_mean:[217]***blocks.4.2.bn1.running_var:[217]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[217, 1, 5, 5]***blocks.4.2.bn2.weight:[217]***blocks.4.2.bn2.bias:[217]***blocks.4.2.bn2.running_mean:[217]***blocks.4.2.bn2.running_var:[217]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[34, 217, 1, 1]***blocks.4.2.se.conv_reduce.bias:[34]***blocks.4.2.se.conv_expand.weight:[217, 34, 1, 1]***blocks.4.2.se.conv_expand.bias:[217]***blocks.4.2.conv_pwl.weight:[120, 217, 1, 1]***blocks.4.2.bn3.weight:[120]***blocks.4.2.bn3.bias:[120]***blocks.4.2.bn3.running_mean:[120]***blocks.4.2.bn3.running_var:[120]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[226, 120, 1, 1]***blocks.4.3.bn1.weight:[226]***blocks.4.3.bn1.bias:[226]***blocks.4.3.bn1.running_mean:[226]***blocks.4.3.bn1.running_var:[226]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[226, 1, 5, 5]***blocks.4.3.bn2.weight:[226]***blocks.4.3.bn2.bias:[226]***blocks.4.3.bn2.running_mean:[226]***blocks.4.3.bn2.running_var:[226]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[33, 226, 1, 1]***blocks.4.3.se.conv_reduce.bias:[33]***blocks.4.3.se.conv_expand.weight:[226, 33, 1, 1]***blocks.4.3.se.conv_expand.bias:[226]***blocks.4.3.conv_pwl.weight:[120, 226, 1, 1]***blocks.4.3.bn3.weight:[120]***blocks.4.3.bn3.bias:[120]***blocks.4.3.bn3.running_mean:[120]***blocks.4.3.bn3.running_var:[120]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.4.4.conv_pw.weight:[340, 120, 1, 1]***blocks.4.4.bn1.weight:[340]***blocks.4.4.bn1.bias:[340]***blocks.4.4.bn1.running_mean:[340]***blocks.4.4.bn1.running_var:[340]***blocks.4.4.bn1.num_batches_tracked:[]***blocks.4.4.conv_dw.weight:[340, 1, 5, 5]***blocks.4.4.bn2.weight:[340]***blocks.4.4.bn2.bias:[340]***blocks.4.4.bn2.running_mean:[340]***blocks.4.4.bn2.running_var:[340]***blocks.4.4.bn2.num_batches_tracked:[]***blocks.4.4.se.conv_reduce.weight:[34, 340, 1, 1]***blocks.4.4.se.conv_reduce.bias:[34]***blocks.4.4.se.conv_expand.weight:[340, 34, 1, 1]***blocks.4.4.se.conv_expand.bias:[340]***blocks.4.4.conv_pwl.weight:[120, 340, 1, 1]***blocks.4.4.bn3.weight:[120]***blocks.4.4.bn3.bias:[120]***blocks.4.4.bn3.running_mean:[120]***blocks.4.4.bn3.running_var:[120]***blocks.4.4.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[802, 120, 1, 1]***blocks.5.0.bn1.weight:[802]***blocks.5.0.bn1.bias:[802]***blocks.5.0.bn1.running_mean:[802]***blocks.5.0.bn1.running_var:[802]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[802, 1, 5, 5]***blocks.5.0.bn2.weight:[802]***blocks.5.0.bn2.bias:[802]***blocks.5.0.bn2.running_mean:[802]***blocks.5.0.bn2.running_var:[802]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[34, 802, 1, 1]***blocks.5.0.se.conv_reduce.bias:[34]***blocks.5.0.se.conv_expand.weight:[802, 34, 1, 1]***blocks.5.0.se.conv_expand.bias:[802]***blocks.5.0.conv_pwl.weight:[232, 802, 1, 1]***blocks.5.0.bn3.weight:[232]***blocks.5.0.bn3.bias:[232]***blocks.5.0.bn3.running_mean:[232]***blocks.5.0.bn3.running_var:[232]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[1030, 232, 1, 1]***blocks.5.1.bn1.weight:[1030]***blocks.5.1.bn1.bias:[1030]***blocks.5.1.bn1.running_mean:[1030]***blocks.5.1.bn1.running_var:[1030]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[1030, 1, 5, 5]***blocks.5.1.bn2.weight:[1030]***blocks.5.1.bn2.bias:[1030]***blocks.5.1.bn2.running_mean:[1030]***blocks.5.1.bn2.running_var:[1030]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[58, 1030, 1, 1]***blocks.5.1.se.conv_reduce.bias:[58]***blocks.5.1.se.conv_expand.weight:[1030, 58, 1, 1]***blocks.5.1.se.conv_expand.bias:[1030]***blocks.5.1.conv_pwl.weight:[232, 1030, 1, 1]***blocks.5.1.bn3.weight:[232]***blocks.5.1.bn3.bias:[232]***blocks.5.1.bn3.running_mean:[232]***blocks.5.1.bn3.running_var:[232]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[924, 232, 1, 1]***blocks.5.2.bn1.weight:[924]***blocks.5.2.bn1.bias:[924]***blocks.5.2.bn1.running_mean:[924]***blocks.5.2.bn1.running_var:[924]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[924, 1, 5, 5]***blocks.5.2.bn2.weight:[924]***blocks.5.2.bn2.bias:[924]***blocks.5.2.bn2.running_mean:[924]***blocks.5.2.bn2.running_var:[924]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[58, 924, 1, 1]***blocks.5.2.se.conv_reduce.bias:[58]***blocks.5.2.se.conv_expand.weight:[924, 58, 1, 1]***blocks.5.2.se.conv_expand.bias:[924]***blocks.5.2.conv_pwl.weight:[232, 924, 1, 1]***blocks.5.2.bn3.weight:[232]***blocks.5.2.bn3.bias:[232]***blocks.5.2.bn3.running_mean:[232]***blocks.5.2.bn3.running_var:[232]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[1016, 232, 1, 1]***blocks.5.3.bn1.weight:[1016]***blocks.5.3.bn1.bias:[1016]***blocks.5.3.bn1.running_mean:[1016]***blocks.5.3.bn1.running_var:[1016]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[1016, 1, 5, 5]***blocks.5.3.bn2.weight:[1016]***blocks.5.3.bn2.bias:[1016]***blocks.5.3.bn2.running_mean:[1016]***blocks.5.3.bn2.running_var:[1016]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[58, 1016, 1, 1]***blocks.5.3.se.conv_reduce.bias:[58]***blocks.5.3.se.conv_expand.weight:[1016, 58, 1, 1]***blocks.5.3.se.conv_expand.bias:[1016]***blocks.5.3.conv_pwl.weight:[232, 1016, 1, 1]***blocks.5.3.bn3.weight:[232]***blocks.5.3.bn3.bias:[232]***blocks.5.3.bn3.running_mean:[232]***blocks.5.3.bn3.running_var:[232]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[1130, 232, 1, 1]***blocks.5.4.bn1.weight:[1130]***blocks.5.4.bn1.bias:[1130]***blocks.5.4.bn1.running_mean:[1130]***blocks.5.4.bn1.running_var:[1130]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[1130, 1, 5, 5]***blocks.5.4.bn2.weight:[1130]***blocks.5.4.bn2.bias:[1130]***blocks.5.4.bn2.running_mean:[1130]***blocks.5.4.bn2.running_var:[1130]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[58, 1130, 1, 1]***blocks.5.4.se.conv_reduce.bias:[58]***blocks.5.4.se.conv_expand.weight:[1130, 58, 1, 1]***blocks.5.4.se.conv_expand.bias:[1130]***blocks.5.4.conv_pwl.weight:[232, 1130, 1, 1]***blocks.5.4.bn3.weight:[232]***blocks.5.4.bn3.bias:[232]***blocks.5.4.bn3.running_mean:[232]***blocks.5.4.bn3.running_var:[232]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.5.5.conv_pw.weight:[1266, 232, 1, 1]***blocks.5.5.bn1.weight:[1266]***blocks.5.5.bn1.bias:[1266]***blocks.5.5.bn1.running_mean:[1266]***blocks.5.5.bn1.running_var:[1266]***blocks.5.5.bn1.num_batches_tracked:[]***blocks.5.5.conv_dw.weight:[1266, 1, 5, 5]***blocks.5.5.bn2.weight:[1266]***blocks.5.5.bn2.bias:[1266]***blocks.5.5.bn2.running_mean:[1266]***blocks.5.5.bn2.running_var:[1266]***blocks.5.5.bn2.num_batches_tracked:[]***blocks.5.5.se.conv_reduce.weight:[58, 1266, 1, 1]***blocks.5.5.se.conv_reduce.bias:[58]***blocks.5.5.se.conv_expand.weight:[1266, 58, 1, 1]***blocks.5.5.se.conv_expand.bias:[1266]***blocks.5.5.conv_pwl.weight:[232, 1266, 1, 1]***blocks.5.5.bn3.weight:[232]***blocks.5.5.bn3.bias:[232]***blocks.5.5.bn3.running_mean:[232]***blocks.5.5.bn3.running_var:[232]***blocks.5.5.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1392, 232, 1, 1]***blocks.6.0.bn1.weight:[1392]***blocks.6.0.bn1.bias:[1392]***blocks.6.0.bn1.running_mean:[1392]***blocks.6.0.bn1.running_var:[1392]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1392, 1, 3, 3]***blocks.6.0.bn2.weight:[1392]***blocks.6.0.bn2.bias:[1392]***blocks.6.0.bn2.running_mean:[1392]***blocks.6.0.bn2.running_var:[1392]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[58, 1392, 1, 1]***blocks.6.0.se.conv_reduce.bias:[58]***blocks.6.0.se.conv_expand.weight:[1392, 58, 1, 1]***blocks.6.0.se.conv_expand.bias:[1392]***blocks.6.0.conv_pwl.weight:[384, 1392, 1, 1]***blocks.6.0.bn3.weight:[384]***blocks.6.0.bn3.bias:[384]***blocks.6.0.bn3.running_mean:[384]***blocks.6.0.bn3.running_var:[384]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[2301, 384, 1, 1]***blocks.6.1.bn1.weight:[2301]***blocks.6.1.bn1.bias:[2301]***blocks.6.1.bn1.running_mean:[2301]***blocks.6.1.bn1.running_var:[2301]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[2301, 1, 3, 3]***blocks.6.1.bn2.weight:[2301]***blocks.6.1.bn2.bias:[2301]***blocks.6.1.bn2.running_mean:[2301]***blocks.6.1.bn2.running_var:[2301]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[96, 2301, 1, 1]***blocks.6.1.se.conv_reduce.bias:[96]***blocks.6.1.se.conv_expand.weight:[2301, 96, 1, 1]***blocks.6.1.se.conv_expand.bias:[2301]***blocks.6.1.conv_pwl.weight:[384, 2301, 1, 1]***blocks.6.1.bn3.weight:[384]***blocks.6.1.bn3.bias:[384]***blocks.6.1.bn3.running_mean:[384]***blocks.6.1.bn3.running_var:[384]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1536, 384, 1, 1]***bn2.weight:[1536]***bn2.bias:[1536]***bn2.running_mean:[1536]***bn2.running_var:[1536]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1536]***classifier.bias:[1000] \ No newline at end of file diff --git a/timm/models/registry.py b/timm/models/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..3317eecee3a00312207cdf398f636faf566736dc --- /dev/null +++ b/timm/models/registry.py @@ -0,0 +1,107 @@ +""" Model Registry +Hacked together by / Copyright 2020 Ross Wightman +""" + +import sys +import re +import fnmatch +from collections import defaultdict + +__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] + +_module_to_models = defaultdict(set) # dict of sets to check membership of model in module +_model_to_module = {} # mapping of model names to module names +_model_entrypoints = {} # mapping of model names to entrypoint fns +_model_has_pretrained = set() # set of model names that have pretrained weight url present + + +def register_model(fn): + # lookup containing module + mod = sys.modules[fn.__module__] + module_name_split = fn.__module__.split('.') + module_name = module_name_split[-1] if len(module_name_split) else '' + + # add model to __all__ in module + model_name = fn.__name__ + if hasattr(mod, '__all__'): + mod.__all__.append(model_name) + else: + mod.__all__ = [model_name] + + # add entries to registry dict/sets + _model_entrypoints[model_name] = fn + _model_to_module[model_name] = module_name + _module_to_models[module_name].add(model_name) + has_pretrained = False # check if model has a pretrained url to allow filtering on this + if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: + # this will catch all models that have entrypoint matching cfg key, but miss any aliasing + # entrypoints or non-matching combos + has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] + if has_pretrained: + _model_has_pretrained.add(model_name) + return fn + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def list_models(filter='', module='', pretrained=False, exclude_filters=''): + """ Return list of available model names, sorted alphabetically + + Args: + filter (str) - Wildcard filter string that works with fnmatch + module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') + pretrained (bool) - Include only models with pretrained weights if True + exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + + Example: + model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' + model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module + """ + if module: + models = list(_module_to_models[module]) + else: + models = _model_entrypoints.keys() + if filter: + models = fnmatch.filter(models, filter) # include these models + if exclude_filters: + if not isinstance(exclude_filters, list): + exclude_filters = [exclude_filters] + for xf in exclude_filters: + exclude_models = fnmatch.filter(models, xf) # exclude these models + if len(exclude_models): + models = set(models).difference(exclude_models) + if pretrained: + models = _model_has_pretrained.intersection(models) + return list(sorted(models, key=_natural_key)) + + +def is_model(model_name): + """ Check if a model name exists + """ + return model_name in _model_entrypoints + + +def model_entrypoint(model_name): + """Fetch a model entrypoint for specified model name + """ + return _model_entrypoints[model_name] + + +def list_modules(): + """ Return list of module names that contain models / model entrypoints + """ + modules = _module_to_models.keys() + return list(sorted(modules)) + + +def is_model_in_modules(model_name, module_names): + """Check if a model exists within a subset of modules + Args: + model_name (str) - name of model to check + module_names (tuple, list, set) - names of modules to search in + """ + assert isinstance(module_names, (tuple, list, set)) + return any(model_name in _module_to_models[n] for n in module_names) + diff --git a/timm/models/regnet.py b/timm/models/regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..68bb817cbff2487cff9172c78c8de6af62d3bc53 --- /dev/null +++ b/timm/models/regnet.py @@ -0,0 +1,477 @@ +"""RegNet + +Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678 +Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py + +Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here) +and cleaned up with more descriptive variable names. + +Weights from original impl have been modified +* first layer from BGR -> RGB as most PyTorch models are +* removed training specific dict entries from checkpoints and keep model state_dict only +* remap names to match the ones here + +Hacked together by / Copyright 2020 Ross Wightman +""" +import numpy as np +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath +from .registry import register_model + + +def _mcfg(**kwargs): + cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32) + cfg.update(**kwargs) + return cfg + + +# Model FLOPS = three trailing digits * 10^8 +model_cfgs = dict( + regnetx_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13), + regnetx_004=_mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22), + regnetx_006=_mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16), + regnetx_008=_mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16), + regnetx_016=_mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18), + regnetx_032=_mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25), + regnetx_040=_mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23), + regnetx_064=_mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17), + regnetx_080=_mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23), + regnetx_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19), + regnetx_160=_mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22), + regnetx_320=_mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23), + regnety_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25), + regnety_004=_mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25), + regnety_006=_mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25), + regnety_008=_mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25), + regnety_016=_mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25), + regnety_032=_mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25), + regnety_040=_mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25), + regnety_064=_mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25), + regnety_080=_mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25), + regnety_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25), + regnety_160=_mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25), + regnety_320=_mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25), +) + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + } + + +default_cfgs = dict( + regnetx_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth'), + regnetx_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth'), + regnetx_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth'), + regnetx_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth'), + regnetx_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth'), + regnetx_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth'), + regnetx_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth'), + regnetx_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth'), + regnetx_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth'), + regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'), + regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'), + regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'), + regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'), + regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'), + regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'), + regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'), + regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'), + regnety_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'), + regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'), + regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'), + regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'), + regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'), + regnety_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth'), + regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), +) + + +def quantize_float(f, q): + """Converts a float to closest non-zero int divisible by q.""" + return int(round(f / q) * q) + + +def adjust_widths_groups_comp(widths, bottle_ratios, groups): + """Adjusts the compatibility of widths and groups.""" + bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)] + bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)] + widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)] + return widths, groups + + +def generate_regnet(width_slope, width_initial, width_mult, depth, q=8): + """Generates per block widths from RegNet parameters.""" + assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % q == 0 + widths_cont = np.arange(depth) * width_slope + width_initial + width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult)) + widths = width_initial * np.power(width_mult, width_exps) + widths = np.round(np.divide(widths, q)) * q + num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1 + widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist() + return widths, num_stages, max_stage, widths_cont + + +class Bottleneck(nn.Module): + """ RegNet Bottleneck + + This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from + after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. + """ + + def __init__(self, in_chs, out_chs, stride=1, dilation=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25, + downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, + drop_block=None, drop_path=None): + super(Bottleneck, self).__init__() + bottleneck_chs = int(round(out_chs * bottleneck_ratio)) + groups = bottleneck_chs // group_width + + cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) + self.conv2 = ConvBnAct( + bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation, + groups=groups, **cargs) + if se_ratio: + se_channels = int(round(in_chs * se_ratio)) + self.se = SEModule(bottleneck_chs, reduction_channels=se_channels) + else: + self.se = None + cargs['act_layer'] = None + self.conv3 = ConvBnAct(bottleneck_chs, out_chs, kernel_size=1, **cargs) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.se is not None: + x = self.se(x) + x = self.conv3(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act3(x) + return x + + +def downsample_conv( + in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + dilation = dilation if kernel_size > 1 else 1 + return ConvBnAct( + in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, act_layer=None) + + +def downsample_avg( + in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None): + """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + pool = nn.Identity() + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + return nn.Sequential(*[ + pool, ConvBnAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, act_layer=None)]) + + +class RegStage(nn.Module): + """Stage (sequence of blocks w/ the same output shape).""" + + def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width, + block_fn=Bottleneck, se_ratio=0., drop_path_rates=None, drop_block=None): + super(RegStage, self).__init__() + block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args + first_dilation = 1 if dilation in (1, 2) else 2 + for i in range(depth): + block_stride = stride if i == 0 else 1 + block_in_chs = in_chs if i == 0 else out_chs + block_dilation = first_dilation if i == 0 else dilation + if drop_path_rates is not None and drop_path_rates[i] > 0.: + drop_path = DropPath(drop_path_rates[i]) + else: + drop_path = None + if (block_in_chs != out_chs) or (block_stride != 1): + proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation) + else: + proj_block = None + + name = "b{}".format(i + 1) + self.add_module( + name, block_fn( + block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio, + downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs) + ) + + def forward(self, x): + for block in self.children(): + x = block(x) + return x + + +class RegNet(nn.Module): + """RegNet model. + + Paper: https://arxiv.org/abs/2003.13678 + Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., + drop_path_rate=0., zero_init_last_bn=True): + super().__init__() + # TODO add drop block, drop path, anti-aliasing, custom bn/act args + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + + # Construct the stem + stem_width = cfg['stem_width'] + self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2) + self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')] + + # Construct the stages + prev_width = stem_width + curr_stride = 2 + stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) + se_ratio = cfg['se_ratio'] + for i, stage_args in enumerate(stage_params): + stage_name = "s{}".format(i + 1) + self.add_module(stage_name, RegStage(prev_width, **stage_args, se_ratio=se_ratio)) + prev_width = stage_args['out_chs'] + curr_stride *= stage_args['stride'] + self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] + + # Construct the head + self.num_features = prev_width + self.head = ClassifierHead( + in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.): + # Generate RegNet ws per block + w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth'] + widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d) + + # Convert to per stage format + stage_widths, stage_depths = np.unique(widths, return_counts=True) + + # Use the same group width, bottleneck mult and stride for each stage + stage_groups = [cfg['group_w'] for _ in range(num_stages)] + stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)] + stage_strides = [] + stage_dilations = [] + net_stride = 2 + dilation = 1 + for _ in range(num_stages): + if net_stride >= output_stride: + dilation *= default_stride + stride = 1 + else: + stride = default_stride + net_stride *= stride + stage_strides.append(stride) + stage_dilations.append(dilation) + stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1])) + + # Adjust the compatibility of ws and gws + stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) + param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rates'] + stage_params = [ + dict(zip(param_names, params)) for params in + zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups, + stage_dpr)] + return stage_params + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + for block in list(self.children())[:-1]: + x = block(x) + return x + + def forward(self, x): + for block in self.children(): + x = block(x) + return x + + +def _create_regnet(variant, pretrained, **kwargs): + return build_model_with_cfg( + RegNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], **kwargs) + + +@register_model +def regnetx_002(pretrained=False, **kwargs): + """RegNetX-200MF""" + return _create_regnet('regnetx_002', pretrained, **kwargs) + + +@register_model +def regnetx_004(pretrained=False, **kwargs): + """RegNetX-400MF""" + return _create_regnet('regnetx_004', pretrained, **kwargs) + + +@register_model +def regnetx_006(pretrained=False, **kwargs): + """RegNetX-600MF""" + return _create_regnet('regnetx_006', pretrained, **kwargs) + + +@register_model +def regnetx_008(pretrained=False, **kwargs): + """RegNetX-800MF""" + return _create_regnet('regnetx_008', pretrained, **kwargs) + + +@register_model +def regnetx_016(pretrained=False, **kwargs): + """RegNetX-1.6GF""" + return _create_regnet('regnetx_016', pretrained, **kwargs) + + +@register_model +def regnetx_032(pretrained=False, **kwargs): + """RegNetX-3.2GF""" + return _create_regnet('regnetx_032', pretrained, **kwargs) + + +@register_model +def regnetx_040(pretrained=False, **kwargs): + """RegNetX-4.0GF""" + return _create_regnet('regnetx_040', pretrained, **kwargs) + + +@register_model +def regnetx_064(pretrained=False, **kwargs): + """RegNetX-6.4GF""" + return _create_regnet('regnetx_064', pretrained, **kwargs) + + +@register_model +def regnetx_080(pretrained=False, **kwargs): + """RegNetX-8.0GF""" + return _create_regnet('regnetx_080', pretrained, **kwargs) + + +@register_model +def regnetx_120(pretrained=False, **kwargs): + """RegNetX-12GF""" + return _create_regnet('regnetx_120', pretrained, **kwargs) + + +@register_model +def regnetx_160(pretrained=False, **kwargs): + """RegNetX-16GF""" + return _create_regnet('regnetx_160', pretrained, **kwargs) + + +@register_model +def regnetx_320(pretrained=False, **kwargs): + """RegNetX-32GF""" + return _create_regnet('regnetx_320', pretrained, **kwargs) + + +@register_model +def regnety_002(pretrained=False, **kwargs): + """RegNetY-200MF""" + return _create_regnet('regnety_002', pretrained, **kwargs) + + +@register_model +def regnety_004(pretrained=False, **kwargs): + """RegNetY-400MF""" + return _create_regnet('regnety_004', pretrained, **kwargs) + + +@register_model +def regnety_006(pretrained=False, **kwargs): + """RegNetY-600MF""" + return _create_regnet('regnety_006', pretrained, **kwargs) + + +@register_model +def regnety_008(pretrained=False, **kwargs): + """RegNetY-800MF""" + return _create_regnet('regnety_008', pretrained, **kwargs) + + +@register_model +def regnety_016(pretrained=False, **kwargs): + """RegNetY-1.6GF""" + return _create_regnet('regnety_016', pretrained, **kwargs) + + +@register_model +def regnety_032(pretrained=False, **kwargs): + """RegNetY-3.2GF""" + return _create_regnet('regnety_032', pretrained, **kwargs) + + +@register_model +def regnety_040(pretrained=False, **kwargs): + """RegNetY-4.0GF""" + return _create_regnet('regnety_040', pretrained, **kwargs) + + +@register_model +def regnety_064(pretrained=False, **kwargs): + """RegNetY-6.4GF""" + return _create_regnet('regnety_064', pretrained, **kwargs) + + +@register_model +def regnety_080(pretrained=False, **kwargs): + """RegNetY-8.0GF""" + return _create_regnet('regnety_080', pretrained, **kwargs) + + +@register_model +def regnety_120(pretrained=False, **kwargs): + """RegNetY-12GF""" + return _create_regnet('regnety_120', pretrained, **kwargs) + + +@register_model +def regnety_160(pretrained=False, **kwargs): + """RegNetY-16GF""" + return _create_regnet('regnety_160', pretrained, **kwargs) + + +@register_model +def regnety_320(pretrained=False, **kwargs): + """RegNetY-32GF""" + return _create_regnet('regnety_320', pretrained, **kwargs) diff --git a/timm/models/res2net.py b/timm/models/res2net.py new file mode 100644 index 0000000000000000000000000000000000000000..6e51d49129af867774c51f6642460b45490ac903 --- /dev/null +++ b/timm/models/res2net.py @@ -0,0 +1,214 @@ +""" Res2Net and Res2NeXt +Adapted from Official Pytorch impl at: https://github.com/gasvn/Res2Net/ +Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169 +""" +import math + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .registry import register_model +from .resnet import ResNet + +__all__ = [] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'res2net50_26w_4s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'), + 'res2net50_48w_2s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'), + 'res2net50_14w_8s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth'), + 'res2net50_26w_6s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth'), + 'res2net50_26w_8s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth'), + 'res2net101_26w_4s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth'), + 'res2next50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth'), +} + + +class Bottle2neck(nn.Module): + """ Res2Net/Res2NeXT Bottleneck + Adapted from https://github.com/gasvn/Res2Net/blob/master/res2net.py + """ + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None, + act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_): + super(Bottle2neck, self).__init__() + self.scale = scale + self.is_first = stride > 1 or downsample is not None + self.num_scales = max(1, scale - 1) + width = int(math.floor(planes * (base_width / 64.0))) * cardinality + self.width = width + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) + self.bn1 = norm_layer(width * scale) + + convs = [] + bns = [] + for i in range(self.num_scales): + convs.append(nn.Conv2d( + width, width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False)) + bns.append(norm_layer(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + if self.is_first: + # FIXME this should probably have count_include_pad=False, but hurts original weights + self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) + else: + self.pool = None + + self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.se = attn_layer(outplanes) if attn_layer is not None else None + + self.relu = act_layer(inplace=True) + self.downsample = downsample + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + spx = torch.split(out, self.width, 1) + spo = [] + sp = spx[0] # redundant, for torchscript + for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): + if i == 0 or self.is_first: + sp = spx[i] + else: + sp = sp + spx[i] + sp = conv(sp) + sp = bn(sp) + sp = self.relu(sp) + spo.append(sp) + if self.scale > 1: + if self.pool is not None: + # self.is_first == True, None check for torchscript + spo.append(self.pool(spx[-1])) + else: + spo.append(spx[-1]) + out = torch.cat(spo, 1) + + out = self.conv3(out) + out = self.bn3(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +def _create_res2net(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) + + +@register_model +def res2net50_26w_4s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w4s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2net50_26w_4s', pretrained, **model_args) + + +@register_model +def res2net101_26w_4s(pretrained=False, **kwargs): + """Constructs a Res2Net-101 26w4s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2net101_26w_4s', pretrained, **model_args) + + +@register_model +def res2net50_26w_6s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w6s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs) + return _create_res2net('res2net50_26w_6s', pretrained, **model_args) + + +@register_model +def res2net50_26w_8s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w8s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs) + return _create_res2net('res2net50_26w_8s', pretrained, **model_args) + + +@register_model +def res2net50_48w_2s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 48w2s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs) + return _create_res2net('res2net50_48w_2s', pretrained, **model_args) + + +@register_model +def res2net50_14w_8s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 14w8s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs) + return _create_res2net('res2net50_14w_8s', pretrained, **model_args) + + +@register_model +def res2next50(pretrained=False, **kwargs): + """Construct Res2NeXt-50 4s + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2next50', pretrained, **model_args) diff --git a/timm/models/resnest.py b/timm/models/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8bb348302956a1578facdef39368c04a376641 --- /dev/null +++ b/timm/models/resnest.py @@ -0,0 +1,236 @@ +""" ResNeSt Models + +Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang + +Modified for torchscript compat, and consistency with timm by Ross Wightman +""" +import torch +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import SplitAttnConv2d +from .registry import register_model +from .resnet import ResNet + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1.0', 'classifier': 'fc', + **kwargs + } + +default_cfgs = { + 'resnest14d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'), + 'resnest26d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'), + 'resnest50d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth'), + 'resnest101e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'resnest200e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.909, interpolation='bicubic'), + 'resnest269e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth', + input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.928, interpolation='bicubic'), + 'resnest50d_4s2x40d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth', + interpolation='bicubic'), + 'resnest50d_1s4x24d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth', + interpolation='bicubic') +} + + +class ResNestBottleneck(nn.Module): + """ResNet Bottleneck + """ + # pylint: disable=unused-argument + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(ResNestBottleneck, self).__init__() + assert reduce_first == 1 # not supported + assert attn_layer is None # not supported + assert aa_layer is None # TODO not yet supported + assert drop_path is None # TODO not yet supported + + group_width = int(planes * (base_width / 64.)) * cardinality + first_dilation = first_dilation or dilation + if avd and (stride > 1 or is_first): + avd_stride = stride + stride = 1 + else: + avd_stride = 0 + self.radix = radix + self.drop_block = drop_block + + self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) + self.bn1 = norm_layer(group_width) + self.act1 = act_layer(inplace=True) + self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None + + if self.radix >= 1: + self.conv2 = SplitAttnConv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) + self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness + self.act2 = None + else: + self.conv2 = nn.Conv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(group_width) + self.act2 = act_layer(inplace=True) + self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None + + self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes*4) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + if self.drop_block is not None: + out = self.drop_block(out) + out = self.act1(out) + + if self.avd_first is not None: + out = self.avd_first(out) + + out = self.conv2(out) + if self.bn2 is not None: + out = self.bn2(out) + if self.drop_block is not None: + out = self.drop_block(out) + out = self.act2(out) + + if self.avd_last is not None: + out = self.avd_last(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.drop_block is not None: + out = self.drop_block(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + return out + + +def _create_resnest(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) + + +@register_model +def resnest14d(pretrained=False, **kwargs): + """ ResNeSt-14d model. Weights ported from GluonCV. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[1, 1, 1, 1], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest26d(pretrained=False, **kwargs): + """ ResNeSt-26d model. Weights ported from GluonCV. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[2, 2, 2, 2], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest50d(pretrained=False, **kwargs): + """ ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest101e(pretrained=False, **kwargs): + """ ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 23, 3], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest200e(pretrained=False, **kwargs): + """ ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 24, 36, 3], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest269e(pretrained=False, **kwargs): + """ ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 30, 48, 8], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest50d_4s2x40d(pretrained=False, **kwargs): + """ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2, + block_args=dict(radix=4, avd=True, avd_first=True), **kwargs) + return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest50d_1s4x24d(pretrained=False, **kwargs): + """ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4, + block_args=dict(radix=1, avd=True, avd_first=True), **kwargs) + return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/resnet.py b/timm/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c2cc55fd842436f4dde8c26980342fc8e85f79b5 --- /dev/null +++ b/timm/models/resnet.py @@ -0,0 +1,1227 @@ +"""PyTorch ResNet + +This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with +additional dropout and dynamic global avg/max pool. + +ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman +Copyright 2020 Ross Wightman +""" +import math +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, create_attn, create_classifier +from .registry import register_model + +__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + # ResNet and Wide ResNet + 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), + 'resnet18d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet34': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), + 'resnet34d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34d_ra2-f8dcfcaf.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet26': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth', + interpolation='bicubic'), + 'resnet26d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', + interpolation='bicubic'), + 'resnet50d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet66d': _cfg(url='', interpolation='bicubic', first_conv='conv1.0'), + 'resnet101': _cfg(url='', interpolation='bicubic'), + 'resnet101d': _cfg(url='', interpolation='bicubic', first_conv='conv1.0'), + 'resnet152': _cfg(url='', interpolation='bicubic'), + 'resnet152d': _cfg(url='', interpolation='bicubic', first_conv='conv1.0'), + 'resnet200': _cfg(url='', interpolation='bicubic'), + 'resnet200d': _cfg(url='', interpolation='bicubic', first_conv='conv1.0'), + 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'), + 'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), + 'tv_resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), + 'tv_resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), + 'wide_resnet50_2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/wide_resnet50_racm-8234f177.pth', + interpolation='bicubic'), + 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'), + + # ResNeXt + 'resnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d_ra-d733960d.pth', + interpolation='bicubic'), + 'resnext50d_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'resnext101_32x4d': _cfg(url=''), + 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), + 'resnext101_64x4d': _cfg(url=''), + 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'), + + # ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags + # from https://github.com/facebookresearch/WSL-Images + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'), + 'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'), + 'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'), + 'ig_resnext101_32x48d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth'), + + # Semi-Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'ssl_resnet18': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth'), + 'ssl_resnet50': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth'), + 'ssl_resnext50_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth'), + 'ssl_resnext101_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth'), + 'ssl_resnext101_32x8d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth'), + 'ssl_resnext101_32x16d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth'), + + # Semi-Weakly Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'swsl_resnet18': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth'), + 'swsl_resnet50': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth'), + 'swsl_resnext50_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth'), + 'swsl_resnext101_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth'), + 'swsl_resnext101_32x8d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth'), + 'swsl_resnext101_32x16d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'), + + # Squeeze-Excitation ResNets, to eventually replace the models in senet.py + 'seresnet18': _cfg( + url='', + interpolation='bicubic'), + 'seresnet34': _cfg( + url='', + interpolation='bicubic'), + 'seresnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet50_ra_224-8efdb4bb.pth', + interpolation='bicubic'), + 'seresnet50tn': _cfg( + url='', + interpolation='bicubic', + first_conv='conv1.0'), + 'seresnet101': _cfg( + url='', + interpolation='bicubic'), + 'seresnet152': _cfg( + url='', + interpolation='bicubic'), + + # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py + 'seresnext26_32x4d': _cfg( + url='', + interpolation='bicubic'), + 'seresnext26d_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'seresnext26t_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26t_32x4d-361bc1c4.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'seresnext26tn_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'seresnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext50_32x4d_racm-a304a460.pth', + interpolation='bicubic'), + 'seresnext101_32x4d': _cfg( + url='', + interpolation='bicubic'), + 'seresnext101_32x8d': _cfg( + url='', + interpolation='bicubic'), + 'senet154': _cfg( + url='', + interpolation='bicubic', + first_conv='conv1.0'), + + # Efficient Channel Attention ResNets + 'ecaresnet18': _cfg(), + 'ecaresnet50': _cfg(), + 'ecaresnetlight': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth', + interpolation='bicubic'), + 'ecaresnet50d': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet50d_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet101d': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet101d_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', + interpolation='bicubic', + first_conv='conv1.0'), + + # Efficient Channel Attention ResNeXts + 'ecaresnext26tn_32x4d': _cfg( + url='', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnext50_32x4d': _cfg( + url='', + interpolation='bicubic'), + + # ResNets with anti-aliasing blur pool + 'resnetblur18': _cfg( + interpolation='bicubic'), + 'resnetblur50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth', + interpolation='bicubic') +} + + +def get_padding(kernel_size, stride, dilation=1): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(BasicBlock, self).__init__() + + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock does not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) + + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, + dilation=first_dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None + + self.conv2 = nn.Conv2d( + first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(outplanes) + + self.se = create_attn(attn_layer, outplanes) + + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) + + def forward(self, x): + residual = x + + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + if self.aa is not None: + x = self.aa(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + + if self.se is not None: + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act2(x) + + return x + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(Bottleneck, self).__init__() + + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=1 if use_aa else stride, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(width) + self.act2 = act_layer(inplace=True) + self.aa = aa_layer(channels=width, stride=stride) if use_aa else None + + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + + self.se = create_attn(attn_layer, outplanes) + + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + residual = x + + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act2(x) + if self.aa is not None: + x = self.aa(x) + + x = self.conv3(x) + x = self.bn3(x) + if self.drop_block is not None: + x = self.drop_block(x) + + if self.se is not None: + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act3(x) + + return x + + +def downsample_conv( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 + p = get_padding(kernel_size, stride, first_dilation) + + return nn.Sequential(*[ + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False), + norm_layer(out_channels) + ]) + + +def downsample_avg( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + if stride == 1 and dilation == 1: + pool = nn.Identity() + else: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + + return nn.Sequential(*[ + pool, + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False), + norm_layer(out_channels) + ]) + + +def drop_blocks(drop_block_rate=0.): + return [ + None, None, + DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, + DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] + + +def make_blocks( + block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32, + down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs): + stages = [] + feature_info = [] + net_num_blocks = sum(block_repeats) + net_block_idx = 0 + net_stride = 4 + dilation = prev_dilation = 1 + for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))): + stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it + stride = 1 if stage_idx == 0 else 2 + if net_stride >= output_stride: + dilation *= stride + stride = 1 + else: + net_stride *= stride + + downsample = None + if stride != 1 or inplanes != planes * block_fn.expansion: + down_kwargs = dict( + in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size, + stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer')) + downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs) + + block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs) + blocks = [] + for block_idx in range(num_blocks): + downsample = downsample if block_idx == 0 else None + stride = stride if block_idx == 0 else 1 + block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule + blocks.append(block_fn( + inplanes, planes, stride, downsample, first_dilation=prev_dilation, + drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs)) + prev_dilation = dilation + inplanes = planes * block_fn.expansion + net_block_idx += 1 + + stages.append((stage_name, nn.Sequential(*blocks))) + feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name)) + + return stages, feature_info + + +class ResNet(nn.Module): + """ResNet / ResNeXt / SE-ResNeXt / SE-Net + + This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that + * have > 1 stride in the 3x3 conv layer of bottleneck + * have conv-bn-act ordering + + This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s + variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the + 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default. + + ResNet variants (the same modifications can be used in SE/ResNeXt models as well): + * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b + * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64) + * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample + * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample + * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128) + * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample + * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample + + ResNeXt + * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths + * same c,d, e, s variants as ResNet can be enabled + + SE-ResNeXt + * normal - 7x7 stem, stem_width = 64 + * same c, d, e, s variants as ResNet can be enabled + + SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, + reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockGl, BottleneckGl. + layers : list of int + Numbers of layers in each block + num_classes : int, default 1000 + Number of classification classes. + in_chans : int, default 3 + Number of input (color) channels. + cardinality : int, default 1 + Number of convolution groups for 3x3 conv in Bottleneck. + base_width : int, default 64 + Factor determining bottleneck channels. `planes * base_width / 64 * cardinality` + stem_width : int, default 64 + Number of channels in stem convolutions + stem_type : str, default '' + The type of stem: + * '', default - a single 7x7 conv with a width of stem_width + * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 + * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width//4 * 6, stem_width * 2 + * 'deep_tiered_narrow' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 + block_reduce_first: int, default 1 + Reduction factor for first convolution output width of residual blocks, + 1 for all archs except senets, where 2 + down_kernel_size: int, default 1 + Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + output_stride : int, default 32 + Set the output stride of the network, 32, 16, or 8. Typically used in segmentation. + act_layer : nn.Module, activation layer + norm_layer : nn.Module, normalization layer + aa_layer : nn.Module, anti-aliasing layer + drop_rate : float, default 0. + Dropout probability before classifier, for training + global_pool : str, default 'avg' + Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + """ + + def __init__(self, block, layers, num_classes=1000, in_chans=3, + cardinality=1, base_width=64, stem_width=64, stem_type='', + output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): + block_args = block_args or dict() + assert output_stride in (8, 16, 32) + self.num_classes = num_classes + self.drop_rate = drop_rate + super(ResNet, self).__init__() + + # Stem + deep_stem = 'deep' in stem_type + inplanes = stem_width * 2 if deep_stem else 64 + if deep_stem: + stem_chs_1 = stem_chs_2 = stem_width + if 'tiered' in stem_type: + stem_chs_1 = 3 * (stem_width // 4) + stem_chs_2 = stem_width if 'narrow' in stem_type else 6 * (stem_width // 4) + self.conv1 = nn.Sequential(*[ + nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False), + norm_layer(stem_chs_1), + act_layer(inplace=True), + nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False), + norm_layer(stem_chs_2), + act_layer(inplace=True), + nn.Conv2d(stem_chs_2, inplanes, 3, stride=1, padding=1, bias=False)]) + else: + self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(inplanes) + self.act1 = act_layer(inplace=True) + self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] + + # Stem Pooling + if aa_layer is not None: + self.maxpool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=inplanes, stride=2)]) + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # Feature Blocks + channels = [64, 128, 256, 512] + stage_modules, stage_feature_info = make_blocks( + block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width, + output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down, + down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, + drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args) + for stage in stage_modules: + self.add_module(*stage) # layer1, layer2, etc + self.feature_info.extend(stage_feature_info) + + # Head (Pooling and Classifier) + self.num_features = 512 * block.expansion + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + return x + + +def _create_resnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) + + +@register_model +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('resnet18', pretrained, **model_args) + + +@register_model +def resnet18d(pretrained=False, **kwargs): + """Constructs a ResNet-18-D model. + """ + model_args = dict( + block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet18d', pretrained, **model_args) + + +@register_model +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet34', pretrained, **model_args) + + +@register_model +def resnet34d(pretrained=False, **kwargs): + """Constructs a ResNet-34-D model. + """ + model_args = dict( + block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet34d', pretrained, **model_args) + + +@register_model +def resnet26(pretrained=False, **kwargs): + """Constructs a ResNet-26 model. + """ + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('resnet26', pretrained, **model_args) + + +@register_model +def resnet26d(pretrained=False, **kwargs): + """Constructs a ResNet-26-D model. + """ + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet26d', pretrained, **model_args) + + +@register_model +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet50', pretrained, **model_args) + + +@register_model +def resnet50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet50d', pretrained, **model_args) + + +@register_model +def resnet66d(pretrained=False, **kwargs): + """Constructs a ResNet-66-D model. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet66d', pretrained, **model_args) + + +@register_model +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('resnet101', pretrained, **model_args) + + +@register_model +def resnet101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet101d', pretrained, **model_args) + + +@register_model +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('resnet152', pretrained, **model_args) + + +@register_model +def resnet152d(pretrained=False, **kwargs): + """Constructs a ResNet-152-D model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet152d', pretrained, **model_args) + + +@register_model +def resnet200(pretrained=False, **kwargs): + """Constructs a ResNet-200 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], **kwargs) + return _create_resnet('resnet200', pretrained, **model_args) + + +@register_model +def resnet200d(pretrained=False, **kwargs): + """Constructs a ResNet-200-D model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet200d', pretrained, **model_args) + + +@register_model +def tv_resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model with original Torchvision weights. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('tv_resnet34', pretrained, **model_args) + + +@register_model +def tv_resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with original Torchvision weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('tv_resnet50', pretrained, **model_args) + + +@register_model +def tv_resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model w/ Torchvision pretrained weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('tv_resnet101', pretrained, **model_args) + + +@register_model +def tv_resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model w/ Torchvision pretrained weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('tv_resnet152', pretrained, **model_args) + + +@register_model +def wide_resnet50_2(pretrained=False, **kwargs): + """Constructs a Wide ResNet-50-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs) + return _create_resnet('wide_resnet50_2', pretrained, **model_args) + + +@register_model +def wide_resnet101_2(pretrained=False, **kwargs): + """Constructs a Wide ResNet-101-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs) + return _create_resnet('wide_resnet101_2', pretrained, **model_args) + + +@register_model +def resnext50_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('resnext50_32x4d', pretrained, **model_args) + + +@register_model +def resnext50d_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnext50d_32x4d', pretrained, **model_args) + + +@register_model +def resnext101_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('resnext101_32x4d', pretrained, **model_args) + + +@register_model +def resnext101_32x8d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 32x8d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('resnext101_32x8d', pretrained, **model_args) + + +@register_model +def resnext101_64x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt101-64x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs) + return _create_resnet('resnext101_64x4d', pretrained, **model_args) + + +@register_model +def tv_resnext50_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model with original Torchvision weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x16d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x32d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs) + return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x48d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs) + return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args) + + +@register_model +def ssl_resnet18(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNet-18 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('ssl_resnet18', pretrained, **model_args) + + +@register_model +def ssl_resnet50(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNet-50 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('ssl_resnet50', pretrained, **model_args) + + +@register_model +def ssl_resnext50_32x4d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-50 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def ssl_resnext101_32x4d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-101 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def ssl_resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-101 32x8 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args) + + +@register_model +def ssl_resnext101_32x16d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-101 32x16 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args) + + +@register_model +def swsl_resnet18(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised Resnet-18 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('swsl_resnet18', pretrained, **model_args) + + +@register_model +def swsl_resnet50(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNet-50 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('swsl_resnet50', pretrained, **model_args) + + +@register_model +def swsl_resnext50_32x4d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-50 32x4 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def swsl_resnext101_32x4d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-101 32x4 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def swsl_resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-101 32x8 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args) + + +@register_model +def swsl_resnext101_32x16d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-101 32x16 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args) + + +@register_model +def ecaresnet18(pretrained=False, **kwargs): + """ Constructs an ECA-ResNet-18 model. + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet18', pretrained, **model_args) + + +@register_model +def ecaresnet50(pretrained=False, **kwargs): + """Constructs an ECA-ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50', pretrained, **model_args) + + +@register_model +def ecaresnet50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50d', pretrained, **model_args) + + +@register_model +def ecaresnet50d_pruned(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model pruned with eca. + The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args) + + +@register_model +def ecaresnetlight(pretrained=False, **kwargs): + """Constructs a ResNet-50-D light model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnetlight', pretrained, **model_args) + + +@register_model +def ecaresnet101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet101d', pretrained, **model_args) + + +@register_model +def ecaresnet101d_pruned(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model pruned with eca. + The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args) + + +@register_model +def ecaresnext26tn_32x4d(pretrained=False, **kwargs): + """Constructs an ECA-ResNeXt-26-TN model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. + this model replaces SE module with the ECA module + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args) + + +@register_model +def resnetblur18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model with blur anti-aliasing + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur18', pretrained, **model_args) + + +@register_model +def resnetblur50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with blur anti-aliasing + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur50', pretrained, **model_args) + + +@register_model +def seresnet18(pretrained=False, **kwargs): + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet18', pretrained, **model_args) + + +@register_model +def seresnet34(pretrained=False, **kwargs): + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet34', pretrained, **model_args) + + +@register_model +def seresnet50(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet50', pretrained, **model_args) + + +@register_model +def seresnet50tn(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet50tn', pretrained, **model_args) + + +@register_model +def seresnet101(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet101', pretrained, **model_args) + + +@register_model +def seresnet152(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet152', pretrained, **model_args) + + +@register_model +def seresnext26_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26_32x4d', pretrained, **model_args) + + +@register_model +def seresnext26d_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNeXt-26-D model.` + This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for + combination of deep stem and avg_pool in downsample. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26d_32x4d', pretrained, **model_args) + + +@register_model +def seresnext26t_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNet-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels + in the deep stem. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26t_32x4d', pretrained, **model_args) + + +@register_model +def seresnext26tn_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNeXt-26-TN model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args) + + +@register_model +def seresnext50_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def seresnext101_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext101_32x4d', pretrained, **model_args) + + +@register_model +def seresnext101_32x8d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext101_32x8d', pretrained, **model_args) + + +@register_model +def senet154(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', + down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('senet154', pretrained, **model_args) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6444b3c8c415f31ef67d7864c7906be9c3b3d154 --- /dev/null +++ b/timm/models/rexnet.py @@ -0,0 +1,262 @@ +""" ReXNet + +A PyTorch impl of `ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network` - +https://arxiv.org/abs/2007.00992 + +Adapted from original impl at https://github.com/clovaai/rexnet +Copyright (c) 2020-present NAVER Corp. MIT license + +Changes for timm, feature extraction, and rounded channel variant hacked together by Ross Wightman +Copyright 2020 Ross Wightman +""" + +import torch.nn as nn +from math import ceil + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath +from .registry import register_model +from .efficientnet_builder import efficientnet_init_weights + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + } + + +default_cfgs = dict( + rexnet_100=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth'), + rexnet_130=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_130-590d768e.pth'), + rexnet_150=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_150-bd1a6aa8.pth'), + rexnet_200=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth'), + rexnetr_100=_cfg( + url=''), + rexnetr_130=_cfg( + url=''), + rexnetr_150=_cfg( + url=''), + rexnetr_200=_cfg( + url=''), +) + + +def make_divisible(v, divisor=8, min_value=None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + return new_v + + +class SEWithNorm(nn.Module): + + def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None, + gate_layer='sigmoid'): + super(SEWithNorm, self).__init__() + reduction_channels = reduction_channels or make_divisible(int(channels * se_ratio), divisor=divisor) + self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) + self.bn = nn.BatchNorm2d(reduction_channels) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.bn(x_se) + x_se = self.act(x_se) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +class LinearBottleneck(nn.Module): + def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, drop_path=None): + super(LinearBottleneck, self).__init__() + self.use_shortcut = stride == 1 and in_chs <= out_chs + self.in_channels = in_chs + self.out_channels = out_chs + + if exp_ratio != 1.: + dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div) + self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer="swish") + else: + dw_chs = in_chs + self.conv_exp = None + + self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) + self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None + self.act_dw = nn.ReLU6() + + self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False) + self.drop_path = drop_path + + def feat_channels(self, exp=False): + return self.conv_dw.out_channels if exp else self.out_channels + + def forward(self, x): + shortcut = x + if self.conv_exp is not None: + x = self.conv_exp(x) + x = self.conv_dw(x) + if self.se is not None: + x = self.se(x) + x = self.act_dw(x) + x = self.conv_pwl(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.use_shortcut: + x[:, 0:self.in_channels] += shortcut + return x + + +def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se_ratio=0., ch_div=1): + layers = [1, 2, 2, 3, 3, 5] + strides = [1, 2, 2, 2, 1, 2] + layers = [ceil(element * depth_mult) for element in layers] + strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], []) + exp_ratios = [1] * layers[0] + [6] * sum(layers[1:]) + depth = sum(layers[:]) * 3 + base_chs = initial_chs / width_mult if width_mult < 1.0 else initial_chs + + # The following channel configuration is a simple instance to make each layer become an expand layer. + out_chs_list = [] + for i in range(depth // 3): + out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div)) + base_chs += final_chs / (depth // 3 * 1.0) + + se_ratios = [0.] * (layers[0] + layers[1]) + [se_ratio] * sum(layers[2:]) + + return list(zip(out_chs_list, exp_ratios, strides, se_ratios)) + + +def _build_blocks( + block_cfg, prev_chs, width_mult, ch_div=1, drop_path_rate=0., feature_location='bottleneck'): + feat_exp = feature_location == 'expansion' + feat_chs = [prev_chs] + feature_info = [] + curr_stride = 2 + features = [] + num_blocks = len(block_cfg) + for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg): + if stride > 1: + fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}' + if block_idx > 0 and feat_exp: + fname += '.act_dw' + feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)] + curr_stride *= stride + block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule + drop_path = DropPath(block_dpr) if block_dpr > 0. else None + features.append(LinearBottleneck( + in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio, + ch_div=ch_div, drop_path=drop_path)) + prev_chs = chs + feat_chs += [features[-1].feat_channels(feat_exp)] + pen_chs = make_divisible(1280 * width_mult, divisor=ch_div) + feature_info += [dict( + num_chs=pen_chs if feat_exp else feat_chs[-1], reduction=curr_stride, + module=f'features.{len(features) - int(not feat_exp)}')] + features.append(ConvBnAct(prev_chs, pen_chs, act_layer="swish")) + return features, feature_info + + +class ReXNetV1(nn.Module): + def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, + initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12., + ch_div=1, drop_rate=0.2, drop_path_rate=0., feature_location='bottleneck'): + super(ReXNetV1, self).__init__() + self.drop_rate = drop_rate + self.num_classes = num_classes + + assert output_stride == 32 # FIXME support dilation + stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32 + stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div) + self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer='swish') + + block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div) + features, self.feature_info = _build_blocks( + block_cfg, stem_chs, width_mult, ch_div, drop_path_rate, feature_location) + self.num_features = features[-1].out_channels + self.features = nn.Sequential(*features) + + self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate) + + efficientnet_init_weights(self) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.features(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_rexnet(variant, pretrained, **kwargs): + feature_cfg = dict(flatten_sequential=True) + if kwargs.get('feature_location', '') == 'expansion': + feature_cfg['feature_cls'] = 'hook' + return build_model_with_cfg( + ReXNetV1, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs) + + +@register_model +def rexnet_100(pretrained=False, **kwargs): + """ReXNet V1 1.0x""" + return _create_rexnet('rexnet_100', pretrained, **kwargs) + + +@register_model +def rexnet_130(pretrained=False, **kwargs): + """ReXNet V1 1.3x""" + return _create_rexnet('rexnet_130', pretrained, width_mult=1.3, **kwargs) + + +@register_model +def rexnet_150(pretrained=False, **kwargs): + """ReXNet V1 1.5x""" + return _create_rexnet('rexnet_150', pretrained, width_mult=1.5, **kwargs) + + +@register_model +def rexnet_200(pretrained=False, **kwargs): + """ReXNet V1 2.0x""" + return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs) + + +@register_model +def rexnetr_100(pretrained=False, **kwargs): + """ReXNet V1 1.0x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_100', pretrained, ch_div=8, **kwargs) + + +@register_model +def rexnetr_130(pretrained=False, **kwargs): + """ReXNet V1 1.3x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_130', pretrained, width_mult=1.3, ch_div=8, **kwargs) + + +@register_model +def rexnetr_150(pretrained=False, **kwargs): + """ReXNet V1 1.5x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_150', pretrained, width_mult=1.5, ch_div=8, **kwargs) + + +@register_model +def rexnetr_200(pretrained=False, **kwargs): + """ReXNet V1 2.0x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py new file mode 100644 index 0000000000000000000000000000000000000000..73bc7732833c1d373f7f32ab5da695852be63c5f --- /dev/null +++ b/timm/models/selecsls.py @@ -0,0 +1,359 @@ +"""PyTorch SelecSLS Net example for ImageNet Classification +License: CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/legalcode) +Author: Dushyant Mehta (@mehtadushy) + +SelecSLS (core) Network Architecture as proposed in "XNect: Real-time Multi-person 3D +Human Pose Estimation with a Single RGB Camera, Mehta et al." +https://arxiv.org/abs/1907.00837 + +Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models +and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch +""" +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'selecsls42': _cfg( + url='', + interpolation='bicubic'), + 'selecsls42b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls42b-8af30141.pth', + interpolation='bicubic'), + 'selecsls60': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60-bbf87526.pth', + interpolation='bicubic'), + 'selecsls60b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60b-94e619b5.pth', + interpolation='bicubic'), + 'selecsls84': _cfg( + url='', + interpolation='bicubic'), +} + + +class SequentialList(nn.Sequential): + + def __init__(self, *args): + super(SequentialList, self).__init__(*args) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (List[torch.Tensor]) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (List[torch.Tensor]) + pass + + def forward(self, x) -> List[torch.Tensor]: + for module in self: + x = module(x) + return x + + +class SelectSeq(nn.Module): + def __init__(self, mode='index', index=0): + super(SelectSeq, self).__init__() + self.mode = mode + self.index = index + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (Tuple[torch.Tensor]) -> (torch.Tensor) + pass + + def forward(self, x) -> torch.Tensor: + if self.mode == 'index': + return x[self.index] + else: + return torch.cat(x, dim=1) + + +def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1): + if padding is None: + padding = ((stride - 1) + dilation * (k - 1)) // 2 + return nn.Sequential( + nn.Conv2d(in_chs, out_chs, k, stride, padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(out_chs), + nn.ReLU(inplace=True) + ) + + +class SelecSLSBlock(nn.Module): + def __init__(self, in_chs, skip_chs, mid_chs, out_chs, is_first, stride, dilation=1): + super(SelecSLSBlock, self).__init__() + self.stride = stride + self.is_first = is_first + assert stride in [1, 2] + + # Process input with 4 conv blocks with the same number of input and output channels + self.conv1 = conv_bn(in_chs, mid_chs, 3, stride, dilation=dilation) + self.conv2 = conv_bn(mid_chs, mid_chs, 1) + self.conv3 = conv_bn(mid_chs, mid_chs // 2, 3) + self.conv4 = conv_bn(mid_chs // 2, mid_chs, 1) + self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3) + self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + if not isinstance(x, list): + x = [x] + assert len(x) in [1, 2] + + d1 = self.conv1(x[0]) + d2 = self.conv3(self.conv2(d1)) + d3 = self.conv5(self.conv4(d2)) + if self.is_first: + out = self.conv6(torch.cat([d1, d2, d3], 1)) + return [out, out] + else: + return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)), x[1]] + + +class SelecSLS(nn.Module): + """SelecSLS42 / SelecSLS60 / SelecSLS84 + + Parameters + ---------- + cfg : network config dictionary specifying block type, feature, and head args + num_classes : int, default 1000 + Number of classification classes. + in_chans : int, default 3 + Number of input (color) channels. + drop_rate : float, default 0. + Dropout probability before classifier, for training + global_pool : str, default 'avg' + Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + """ + + def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'): + self.num_classes = num_classes + self.drop_rate = drop_rate + super(SelecSLS, self).__init__() + + self.stem = conv_bn(in_chans, 32, stride=2) + self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']]) + self.from_seq = SelectSeq() # from List[tensor] -> Tensor in module compatible way + self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']]) + self.num_features = cfg['num_features'] + self.feature_info = cfg['feature_info'] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.stem(x) + x = self.features(x) + x = self.head(self.from_seq(x)) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def _create_selecsls(variant, pretrained, model_kwargs): + cfg = {} + feature_info = [dict(num_chs=32, reduction=2, module='stem.2')] + if variant.startswith('selecsls42'): + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 128, False, 1), + (128, 0, 144, 144, True, 2), + (144, 144, 144, 288, False, 1), + (288, 0, 304, 304, True, 2), + (304, 304, 304, 480, False, 1), + ] + feature_info.extend([ + dict(num_chs=128, reduction=4, module='features.1'), + dict(num_chs=288, reduction=8, module='features.3'), + dict(num_chs=480, reduction=16, module='features.5'), + ]) + # Head can be replaced with alternative configurations depending on the problem + feature_info.append(dict(num_chs=1024, reduction=32, module='head.1')) + if variant == 'selecsls42b': + cfg['head'] = [ + (480, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1280, 3, 2), + (1280, 1024, 1, 1), + ] + feature_info.append(dict(num_chs=1024, reduction=64, module='head.3')) + cfg['num_features'] = 1024 + else: + cfg['head'] = [ + (480, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 1, 1), + ] + feature_info.append(dict(num_chs=1280, reduction=64, module='head.3')) + cfg['num_features'] = 1280 + + elif variant.startswith('selecsls60'): + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 128, False, 1), + (128, 0, 128, 128, True, 2), + (128, 128, 128, 128, False, 1), + (128, 128, 128, 288, False, 1), + (288, 0, 288, 288, True, 2), + (288, 288, 288, 288, False, 1), + (288, 288, 288, 288, False, 1), + (288, 288, 288, 416, False, 1), + ] + feature_info.extend([ + dict(num_chs=128, reduction=4, module='features.1'), + dict(num_chs=288, reduction=8, module='features.4'), + dict(num_chs=416, reduction=16, module='features.8'), + ]) + # Head can be replaced with alternative configurations depending on the problem + feature_info.append(dict(num_chs=1024, reduction=32, module='head.1')) + if variant == 'selecsls60b': + cfg['head'] = [ + (416, 756, 3, 2), + (756, 1024, 3, 1), + (1024, 1280, 3, 2), + (1280, 1024, 1, 1), + ] + feature_info.append(dict(num_chs=1024, reduction=64, module='head.3')) + cfg['num_features'] = 1024 + else: + cfg['head'] = [ + (416, 756, 3, 2), + (756, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 1, 1), + ] + feature_info.append(dict(num_chs=1280, reduction=64, module='head.3')) + cfg['num_features'] = 1280 + + elif variant == 'selecsls84': + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 144, False, 1), + (144, 0, 144, 144, True, 2), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 304, False, 1), + (304, 0, 304, 304, True, 2), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 512, False, 1), + ] + feature_info.extend([ + dict(num_chs=144, reduction=4, module='features.1'), + dict(num_chs=304, reduction=8, module='features.6'), + dict(num_chs=512, reduction=16, module='features.12'), + ]) + # Head can be replaced with alternative configurations depending on the problem + cfg['head'] = [ + (512, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 3, 1), + ] + cfg['num_features'] = 1280 + feature_info.extend([ + dict(num_chs=1024, reduction=32, module='head.1'), + dict(num_chs=1280, reduction=64, module='head.3') + ]) + else: + raise ValueError('Invalid net configuration ' + variant + ' !!!') + cfg['feature_info'] = feature_info + + # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? + return build_model_with_cfg( + SelecSLS, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfg, + feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), **model_kwargs) + + +@register_model +def selecsls42(pretrained=False, **kwargs): + """Constructs a SelecSLS42 model. + """ + return _create_selecsls('selecsls42', pretrained, kwargs) + + +@register_model +def selecsls42b(pretrained=False, **kwargs): + """Constructs a SelecSLS42_B model. + """ + return _create_selecsls('selecsls42b', pretrained, kwargs) + + +@register_model +def selecsls60(pretrained=False, **kwargs): + """Constructs a SelecSLS60 model. + """ + return _create_selecsls('selecsls60', pretrained, kwargs) + + +@register_model +def selecsls60b(pretrained=False, **kwargs): + """Constructs a SelecSLS60_B model. + """ + return _create_selecsls('selecsls60b', pretrained, kwargs) + + +@register_model +def selecsls84(pretrained=False, **kwargs): + """Constructs a SelecSLS84 model. + """ + return _create_selecsls('selecsls84', pretrained, kwargs) diff --git a/timm/models/senet.py b/timm/models/senet.py new file mode 100644 index 0000000000000000000000000000000000000000..8073229a721930e279f07a46590eda71d5756904 --- /dev/null +++ b/timm/models/senet.py @@ -0,0 +1,465 @@ +""" +SEResNet implementation from Cadene's pretrained models +https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py +Additional credit to https://github.com/creafz + +Original model: https://github.com/hujie-frank/SENet + +ResNet code gently borrowed from +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + +FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate +support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here. +""" +import math +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['SENet'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'layer0.conv1', 'classifier': 'last_linear', + **kwargs + } + + +default_cfgs = { + 'legacy_senet154': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'), + 'legacy_seresnet18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth', + interpolation='bicubic'), + 'legacy_seresnet34': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'), + 'legacy_seresnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'), + 'legacy_seresnet101': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'), + 'legacy_seresnet152': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'), + 'legacy_seresnext26_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth', + interpolation='bicubic'), + 'legacy_seresnext50_32x4d': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'), + 'legacy_seresnext101_32x4d': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'), +} + + +def _weight_init(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = x.mean((2, 3), keepdim=True) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class Bottleneck(nn.Module): + """ + Base class for bottlenecks that implements `forward()` method. + """ + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_module(out) + residual + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d( + planes * 2, planes * 4, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d( + planes * 4, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBottleneck(Bottleneck): + """ + ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe + implementation and uses `stride=stride` in `conv1` and not in `conv2` + (the latter is used in the torchvision implementation of ResNet). + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEResNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=1, bias=False, stride=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNeXtBottleneck(Bottleneck): + """ + ResNeXt bottleneck type C with a Squeeze-and-Excitation module. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None, base_width=4): + super(SEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d( + inplanes, width, kernel_size=1, bias=False, stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d( + width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): + super(SEResNetBlock, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes, reduction=reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_module(out) + residual + out = self.relu(out) + + return out + + +class SENet(nn.Module): + + def __init__(self, block, layers, groups, reduction, drop_rate=0.2, + in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1, + downsample_padding=0, num_classes=1000, global_pool='avg'): + """ + Parameters + ---------- + block (nn.Module): Bottleneck class. + - For SENet154: SEBottleneck + - For SE-ResNet models: SEResNetBottleneck + - For SE-ResNeXt models: SEResNeXtBottleneck + layers (list of ints): Number of residual blocks for 4 layers of the + network (layer1...layer4). + groups (int): Number of groups for the 3x3 convolution in each + bottleneck block. + - For SENet154: 64 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 32 + reduction (int): Reduction ratio for Squeeze-and-Excitation modules. + - For all models: 16 + dropout_p (float or None): Drop probability for the Dropout layer. + If `None` the Dropout layer is not used. + - For SENet154: 0.2 + - For SE-ResNet models: None + - For SE-ResNeXt models: None + inplanes (int): Number of input channels for layer1. + - For SENet154: 128 + - For SE-ResNet models: 64 + - For SE-ResNeXt models: 64 + input_3x3 (bool): If `True`, use three 3x3 convolutions instead of + a single 7x7 convolution in layer0. + - For SENet154: True + - For SE-ResNet models: False + - For SE-ResNeXt models: False + downsample_kernel_size (int): Kernel size for downsampling convolutions + in layer2, layer3 and layer4. + - For SENet154: 3 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 1 + downsample_padding (int): Padding for downsampling convolutions in + layer2, layer3 and layer4. + - For SENet154: 1 + - For SE-ResNet models: 0 + - For SE-ResNeXt models: 0 + num_classes (int): Number of outputs in `last_linear` layer. + - For all models: 1000 + """ + super(SENet, self).__init__() + self.inplanes = inplanes + self.num_classes = num_classes + self.drop_rate = drop_rate + if input_3x3: + layer0_modules = [ + ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)), + ('bn1', nn.BatchNorm2d(64)), + ('relu1', nn.ReLU(inplace=True)), + ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), + ('bn2', nn.BatchNorm2d(64)), + ('relu2', nn.ReLU(inplace=True)), + ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)), + ('bn3', nn.BatchNorm2d(inplanes)), + ('relu3', nn.ReLU(inplace=True)), + ] + else: + layer0_modules = [ + ('conv1', nn.Conv2d( + in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), + ('bn1', nn.BatchNorm2d(inplanes)), + ('relu1', nn.ReLU(inplace=True)), + ] + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + # To preserve compatibility with Caffe weights `ceil_mode=True` is used instead of `padding=1`. + self.pool0 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.feature_info = [dict(num_chs=inplanes, reduction=2, module='layer0')] + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0 + ) + self.feature_info += [dict(num_chs=64 * block.expansion, reduction=4, module='layer1')] + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=128 * block.expansion, reduction=8, module='layer2')] + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=256 * block.expansion, reduction=16, module='layer3')] + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')] + self.num_features = 512 * block.expansion + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + for m in self.modules(): + _weight_init(m) + + def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, + downsample_kernel_size=1, downsample_padding=0): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size, + stride=stride, padding=downsample_padding, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, groups, reduction, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.layer0(x) + x = self.pool0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def logits(self, x): + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.logits(x) + return x + + +def _create_senet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + SENet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) + + +@register_model +def legacy_seresnet18(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBlock, layers=[2, 2, 2, 2], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet18', pretrained, **model_args) + + +@register_model +def legacy_seresnet34(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBlock, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet34', pretrained, **model_args) + + +@register_model +def legacy_seresnet50(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet50', pretrained, **model_args) + + +@register_model +def legacy_seresnet101(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet101', pretrained, **model_args) + + +@register_model +def legacy_seresnet152(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet152', pretrained, **model_args) + + +@register_model +def legacy_senet154(pretrained=False, **kwargs): + model_args = dict( + block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, reduction=16, + downsample_kernel_size=3, downsample_padding=1, inplanes=128, input_3x3=True, **kwargs) + return _create_senet('legacy_senet154', pretrained, **model_args) + + +@register_model +def legacy_seresnext26_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=SEResNeXtBottleneck, layers=[2, 2, 2, 2], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext26_32x4d', pretrained, **model_args) + + +@register_model +def legacy_seresnext50_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def legacy_seresnext101_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext101_32x4d', pretrained, **model_args) diff --git a/timm/models/sknet.py b/timm/models/sknet.py new file mode 100644 index 0000000000000000000000000000000000000000..6c654922e109658c04de78e2d433d4ebc16645cc --- /dev/null +++ b/timm/models/sknet.py @@ -0,0 +1,218 @@ +""" Selective Kernel Networks (ResNet base) + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) + +This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268) +and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer +to the original paper with some modifications of my own to better balance param count vs accuracy. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math + +from torch import nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import SelectiveKernelConv, ConvBnAct, create_attn +from .registry import register_model +from .resnet import ResNet + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'skresnet18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'), + 'skresnet34': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'), + 'skresnet50': _cfg(), + 'skresnet50d': _cfg( + first_conv='conv1.0'), + 'skresnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'), +} + + +class SelectiveKernelBasic(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(SelectiveKernelBasic, self).__init__() + + sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = SelectiveKernelConv( + inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = ConvBnAct( + first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs) + self.se = create_attn(attn_layer, outplanes) + self.act = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.conv2(x) + if self.se is not None: + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, + drop_block=None, drop_path=None): + super(SelectiveKernelBottleneck, self).__init__() + + sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) + self.conv2 = SelectiveKernelConv( + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, + **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs) + self.se = create_attn(attn_layer, outplanes) + self.act = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + if self.se is not None: + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x + + +def _create_skresnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) + + +@register_model +def skresnet18(pretrained=False, **kwargs): + """Constructs a Selective Kernel ResNet-18 model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict( + min_attn_channels=16, + attn_reduction=8, + split_input=True) + model_args = dict( + block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet18', pretrained, **model_args) + + +@register_model +def skresnet34(pretrained=False, **kwargs): + """Constructs a Selective Kernel ResNet-34 model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict( + min_attn_channels=16, + attn_reduction=8, + split_input=True) + model_args = dict( + block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet34', pretrained, **model_args) + + +@register_model +def skresnet50(pretrained=False, **kwargs): + """Constructs a Select Kernel ResNet-50 model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict(split_input=True) + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet50', pretrained, **model_args) + + +@register_model +def skresnet50d(pretrained=False, **kwargs): + """Constructs a Select Kernel ResNet-50-D model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict(split_input=True) + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet50d', pretrained, **model_args) + + +@register_model +def skresnext50_32x4d(pretrained=False, **kwargs): + """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to + the SKNet-50 model in the Select Kernel Paper + """ + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnext50_32x4d', pretrained, **model_args) + diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e371292f7ded9af1a18c0ff91b6b2aff886ce77f --- /dev/null +++ b/timm/models/tresnet.py @@ -0,0 +1,293 @@ +""" +TResNet: High Performance GPU-Dedicated Architecture +https://arxiv.org/pdf/2003.13630.pdf + +Original model: https://github.com/mrT23/TResNet + +""" +import copy +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule +from .registry import register_model + +__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': (0, 0, 0), 'std': (1, 1, 1), + 'first_conv': 'body.conv1.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + 'tresnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_80_8-dbc13962.pth'), + 'tresnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'), + 'tresnet_xl': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'), + 'tresnet_m_448': _cfg( + input_size=(3, 448, 448), pool_size=(14, 14), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'), + 'tresnet_l_448': _cfg( + input_size=(3, 448, 448), pool_size=(14, 14), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'), + 'tresnet_xl_448': _cfg( + input_size=(3, 448, 448), pool_size=(14, 14), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth') +} + + +def IABN2Float(module: nn.Module) -> nn.Module: + """If `module` is IABN don't use half precision.""" + if isinstance(module, InplaceAbn): + module.float() + for child in module.children(): + IABN2Float(child) + return module + + +def conv2d_iabn(ni, nf, stride, kernel_size=3, groups=1, act_layer="leaky_relu", act_param=1e-2): + return nn.Sequential( + nn.Conv2d( + ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, bias=False), + InplaceAbn(nf, act_layer=act_layer, act_param=act_param) + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_layer=None): + super(BasicBlock, self).__init__() + if stride == 1: + self.conv1 = conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3) + else: + if aa_layer is None: + self.conv1 = conv2d_iabn(inplanes, planes, stride=2, act_param=1e-3) + else: + self.conv1 = nn.Sequential( + conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3), + aa_layer(channels=planes, filt_size=3, stride=2)) + + self.conv2 = conv2d_iabn(planes, planes, stride=1, act_layer="identity") + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + reduction_chs = max(planes * self.expansion // 4, 64) + self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None + + def forward(self, x): + if self.downsample is not None: + residual = self.downsample(x) + else: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + + if self.se is not None: + out = self.se(out) + + out += residual + out = self.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, + act_layer="leaky_relu", aa_layer=None): + super(Bottleneck, self).__init__() + self.conv1 = conv2d_iabn( + inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, act_param=1e-3) + if stride == 1: + self.conv2 = conv2d_iabn( + planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3) + else: + if aa_layer is None: + self.conv2 = conv2d_iabn( + planes, planes, kernel_size=3, stride=2, act_layer=act_layer, act_param=1e-3) + else: + self.conv2 = nn.Sequential( + conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3), + aa_layer(channels=planes, filt_size=3, stride=2)) + + reduction_chs = max(planes * self.expansion // 8, 64) + self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None + + self.conv3 = conv2d_iabn( + planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + if self.downsample is not None: + residual = self.downsample(x) + else: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + if self.se is not None: + out = self.se(out) + + out = self.conv3(out) + out = out + residual # no inplace + out = self.relu(out) + + return out + + +class TResNet(nn.Module): + def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False, + global_pool='fast', drop_rate=0.): + self.num_classes = num_classes + self.drop_rate = drop_rate + super(TResNet, self).__init__() + + # JIT layers + space_to_depth = SpaceToDepthModule() + aa_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit) + + # TResnet stages + self.inplanes = int(64 * width_factor) + self.planes = int(64 * width_factor) + conv1 = conv2d_iabn(in_chans * 16, self.planes, stride=1, kernel_size=3) + layer1 = self._make_layer( + BasicBlock, self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer) # 56x56 + layer2 = self._make_layer( + BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer) # 28x28 + layer3 = self._make_layer( + Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer) # 14x14 + layer4 = self._make_layer( + Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer) # 7x7 + + # body + self.body = nn.Sequential(OrderedDict([ + ('SpaceToDepth', space_to_depth), + ('conv1', conv1), + ('layer1', layer1), + ('layer2', layer2), + ('layer3', layer3), + ('layer4', layer4)])) + + self.feature_info = [ + dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D? + dict(num_chs=self.planes, reduction=4, module='body.layer1'), + dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'), + dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'), + dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'), + ] + + # head + self.num_features = (self.planes * 8) * Bottleneck.expansion + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + # model initilization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InplaceAbn): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # residual connections special initialization + for m in self.modules(): + if isinstance(m, BasicBlock): + m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero + if isinstance(m, Bottleneck): + m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero + if isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + + def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + layers = [] + if stride == 2: + # avg pooling before 1x1 conv + layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) + layers += [conv2d_iabn( + self.inplanes, planes * block.expansion, kernel_size=1, stride=1, act_layer="identity")] + downsample = nn.Sequential(*layers) + + layers = [] + layers.append(block( + self.inplanes, planes, stride, downsample, use_se=use_se, aa_layer=aa_layer)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block(self.inplanes, planes, use_se=use_se, aa_layer=aa_layer)) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='fast'): + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + return self.body(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_tresnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + TResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, + feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), **kwargs) + + +@register_model +def tresnet_m(pretrained=False, **kwargs): + model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) + return _create_tresnet('tresnet_m', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_l(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs) + return _create_tresnet('tresnet_l', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_xl(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs) + return _create_tresnet('tresnet_xl', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_m_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) + return _create_tresnet('tresnet_m_448', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_l_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs) + return _create_tresnet('tresnet_l_448', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_xl_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs) + return _create_tresnet('tresnet_xl_448', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f05a87f2135ca5d279f4ed96fc149b3b1fe0cc6b --- /dev/null +++ b/timm/models/vision_transformer.py @@ -0,0 +1,431 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +Status/TODO: +* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. +* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. +* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. +* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import load_pretrained +from .layers import DropPath, to_2tuple, trunc_normal_ +from .resnet import resnet26d, resnet50d +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models + 'vit_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + ), + 'vit_base_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), + 'vit_base_patch16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + 'vit_base_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + 'vit_huge_patch16_224': _cfg(), + 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)), + # hybrid models + 'vit_small_resnet26d_224': _cfg(), + 'vit_small_resnet50d_s3_224': _cfg(), + 'vit_base_resnet26d_224': _cfg(), + 'vit_base_resnet50d_224': _cfg(), +} + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, input_shape='bchw'): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.input_shape = input_shape + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + if self.input_shape == 'bhwc': + x = x.permute(0, 3, 1, 2) + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here + #self.repr = nn.Linear(embed_dim, representation_size) + #self.repr_act = nn.Tanh() + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _conv_filter(state_dict, patch_size=16): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + + +@register_model +def vit_small_patch16_224(pretrained=False, **kwargs): + if pretrained: + # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model + kwargs.setdefault('qk_scale', 768 ** -0.5) + model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['vit_small_patch16_224'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + return model + + +@register_model +def vit_base_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_base_patch16_224'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + return model + + +@register_model +def vit_base_patch16_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_base_patch16_384'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_base_patch32_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_base_patch32_384'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_large_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_large_patch16_224'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_large_patch16_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_large_patch16_384'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_large_patch32_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_large_patch32_384'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_huge_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) + model.default_cfg = default_cfgs['vit_huge_patch16_224'] + return model + + +@register_model +def vit_huge_patch32_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs) + model.default_cfg = default_cfgs['vit_huge_patch32_384'] + return model + + +@register_model +def vit_small_resnet26d_224(pretrained=False, **kwargs): + pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing + backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) + model = VisionTransformer( + img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) + model.default_cfg = default_cfgs['vit_small_resnet26d_224'] + return model + + +@register_model +def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): + pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing + backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[3]) + model = VisionTransformer( + img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) + model.default_cfg = default_cfgs['vit_small_resnet50d_s3_224'] + return model + + +@register_model +def vit_base_resnet26d_224(pretrained=False, **kwargs): + pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing + backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) + model = VisionTransformer( + img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) + model.default_cfg = default_cfgs['vit_base_resnet26d_224'] + return model + + +@register_model +def vit_base_resnet50d_224(pretrained=False, **kwargs): + pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing + backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) + model = VisionTransformer( + img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) + model.default_cfg = default_cfgs['vit_base_resnet50d_224'] + return model diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f544433c4596b0d593b0633bebae965d4aa0094a --- /dev/null +++ b/timm/models/vovnet.py @@ -0,0 +1,403 @@ +""" VoVNet (V1 & V2) + +Papers: +* `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730 +* `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Looked at https://github.com/youngwanLEE/vovnet-detectron2 & +https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py +for some reference, rewrote most of the code. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model +from .helpers import build_model_with_cfg +from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, DropPath,\ + create_attn, create_norm_act, get_norm_act_layer + + +# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 & +# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py +model_cfgs = dict( + vovnet39a=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=False, + depthwise=False, + attn='', + ), + vovnet57a=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 4, 3], + residual=False, + depthwise=False, + attn='', + + ), + ese_vovnet19b_slim_dw=dict( + stem_chs=[64, 64, 64], + stage_conv_chs=[64, 80, 96, 112], + stage_out_chs=[112, 256, 384, 512], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=True, + attn='ese', + + ), + ese_vovnet19b_dw=dict( + stem_chs=[64, 64, 64], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=True, + attn='ese', + ), + ese_vovnet19b_slim=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[64, 80, 96, 112], + stage_out_chs=[112, 256, 384, 512], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=False, + attn='ese', + ), + ese_vovnet19b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=False, + attn='ese', + + ), + ese_vovnet39b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=True, + depthwise=False, + attn='ese', + ), + ese_vovnet57b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 4, 3], + residual=True, + depthwise=False, + attn='ese', + + ), + ese_vovnet99b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 3, 9, 3], + residual=True, + depthwise=False, + attn='ese', + ), + eca_vovnet39b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=True, + depthwise=False, + attn='eca', + ), +) +model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b'] +model_cfgs['ese_vovnet99b_iabn'] = model_cfgs['ese_vovnet99b'] + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', + } + + +default_cfgs = dict( + vovnet39a=_cfg(url=''), + vovnet57a=_cfg(url=''), + ese_vovnet19b_slim_dw=_cfg(url=''), + ese_vovnet19b_dw=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth'), + ese_vovnet19b_slim=_cfg(url=''), + ese_vovnet39b=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'), + ese_vovnet57b=_cfg(url=''), + ese_vovnet99b=_cfg(url=''), + eca_vovnet39b=_cfg(url=''), + ese_vovnet39b_evos=_cfg(url=''), + ese_vovnet99b_iabn=_cfg(url=''), +) + + +class SequentialAppendList(nn.Sequential): + def __init__(self, *args): + super(SequentialAppendList, self).__init__(*args) + + def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor: + for i, module in enumerate(self): + if i == 0: + concat_list.append(module(x)) + else: + concat_list.append(module(concat_list[-1])) + x = torch.cat(concat_list, dim=1) + return x + + +class OsaBlock(nn.Module): + + def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False, + depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None): + super(OsaBlock, self).__init__() + + self.residual = residual + self.depthwise = depthwise + conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer) + + next_in_chs = in_chs + if self.depthwise and next_in_chs != mid_chs: + assert not residual + self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, **conv_kwargs) + else: + self.conv_reduction = None + + mid_convs = [] + for i in range(layer_per_block): + if self.depthwise: + conv = SeparableConvBnAct(mid_chs, mid_chs, **conv_kwargs) + else: + conv = ConvBnAct(next_in_chs, mid_chs, 3, **conv_kwargs) + next_in_chs = mid_chs + mid_convs.append(conv) + self.conv_mid = SequentialAppendList(*mid_convs) + + # feature aggregation + next_in_chs = in_chs + layer_per_block * mid_chs + self.conv_concat = ConvBnAct(next_in_chs, out_chs, **conv_kwargs) + + if attn: + self.attn = create_attn(attn, out_chs) + else: + self.attn = None + + self.drop_path = drop_path + + def forward(self, x): + output = [x] + if self.conv_reduction is not None: + x = self.conv_reduction(x) + x = self.conv_mid(x, output) + x = self.conv_concat(x) + if self.attn is not None: + x = self.attn(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.residual: + x = x + output[0] + return x + + +class OsaStage(nn.Module): + + def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True, + residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, + drop_path_rates=None): + super(OsaStage, self).__init__() + + if downsample: + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) + else: + self.pool = None + + blocks = [] + for i in range(block_per_stage): + last_block = i == block_per_stage - 1 + if drop_path_rates is not None and drop_path_rates[i] > 0.: + drop_path = DropPath(drop_path_rates[i]) + else: + drop_path = None + blocks += [OsaBlock( + in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise, + attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path) + ] + in_chs = out_chs + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + if self.pool is not None: + x = self.pool(x) + x = self.blocks(x) + return x + + +class VovNet(nn.Module): + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, + output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.): + """ VovNet (v2) + """ + super(VovNet, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert stem_stride in (4, 2) + assert output_stride == 32 # FIXME support dilation + + stem_chs = cfg["stem_chs"] + stage_conv_chs = cfg["stage_conv_chs"] + stage_out_chs = cfg["stage_out_chs"] + block_per_stage = cfg["block_per_stage"] + layer_per_block = cfg["layer_per_block"] + conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer) + + # Stem module + last_stem_stride = stem_stride // 2 + conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct + self.stem = nn.Sequential(*[ + ConvBnAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs), + conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs), + conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs), + ]) + self.feature_info = [dict( + num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')] + current_stride = stem_stride + + # OSA stages + stage_dpr = torch.split(torch.linspace(0, drop_path_rate, sum(block_per_stage)), block_per_stage) + in_ch_list = stem_chs[-1:] + stage_out_chs[:-1] + stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs) + stages = [] + for i in range(4): # num_stages + downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4 + stages += [OsaStage( + in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block, + downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args) + ] + self.num_features = stage_out_chs[i] + current_stride *= 2 if downsample else 1 + self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] + + self.stages = nn.Sequential(*stages) + + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + return self.stages(x) + + def forward(self, x): + x = self.forward_features(x) + return self.head(x) + + +def _create_vovnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + VovNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), **kwargs) + + +@register_model +def vovnet39a(pretrained=False, **kwargs): + return _create_vovnet('vovnet39a', pretrained=pretrained, **kwargs) + + +@register_model +def vovnet57a(pretrained=False, **kwargs): + return _create_vovnet('vovnet57a', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_slim_dw(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_dw(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_slim(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet39b(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet57b(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet99b(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs) + + +@register_model +def eca_vovnet39b(pretrained=False, **kwargs): + return _create_vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs) + + +# Experimental Models + +@register_model +def ese_vovnet39b_evos(pretrained=False, **kwargs): + def norm_act_fn(num_features, **nkwargs): + return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs) + return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) + + +@register_model +def ese_vovnet99b_iabn(pretrained=False, **kwargs): + norm_layer = get_norm_act_layer('iabn') + return _create_vovnet( + 'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs) diff --git a/timm/models/xception.py b/timm/models/xception.py new file mode 100644 index 0000000000000000000000000000000000000000..a61548dc5f8ddba37cb183aaa3345960ef7b5a24 --- /dev/null +++ b/timm/models/xception.py @@ -0,0 +1,230 @@ +""" +Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) + +@author: tstandley +Adapted by cadene + +Creates an Xception Model as defined in: + +Francois Chollet +Xception: Deep Learning with Depthwise Separable Convolutions +https://arxiv.org/pdf/1610.02357.pdf + +This weights ported from the Keras implementation. Achieves the following performance on the validation set: + +Loss:0.9173 Prec@1:78.892 Prec@5:94.292 + +REMEMBER to set your image size to 3x299x299 for both test and validation + +normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + +The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 +""" + +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['Xception'] + +default_cfgs = { + 'xception': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth', + 'input_size': (3, 299, 299), + 'pool_size': (10, 10), + 'crop_pct': 0.8975, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'first_conv': 'conv1', + 'classifier': 'fc' + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } +} + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1): + super(SeparableConv2d, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False) + self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class Block(nn.Module): + def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=True, grow_first=True): + super(Block, self).__init__() + + if out_channels != in_channels or strides != 1: + self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_channels) + else: + self.skip = None + + rep = [] + for i in range(reps): + if grow_first: + inc = in_channels if i == 0 else out_channels + outc = out_channels + else: + inc = in_channels + outc = in_channels if i < (reps - 1) else out_channels + rep.append(nn.ReLU(inplace=True)) + rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1)) + rep.append(nn.BatchNorm2d(outc)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip + return x + + +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'): + """ Constructor + Args: + num_classes: number of classes + """ + super(Xception, self).__init__() + self.drop_rate = drop_rate + self.global_pool = global_pool + self.num_classes = num_classes + self.num_features = 2048 + + self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = nn.BatchNorm2d(64) + self.act2 = nn.ReLU(inplace=True) + + self.block1 = Block(64, 128, 2, 2, start_with_relu=False) + self.block2 = Block(128, 256, 2, 2) + self.block3 = Block(256, 728, 2, 2) + + self.block4 = Block(728, 728, 3, 1) + self.block5 = Block(728, 728, 3, 1) + self.block6 = Block(728, 728, 3, 1) + self.block7 = Block(728, 728, 3, 1) + + self.block8 = Block(728, 728, 3, 1) + self.block9 = Block(728, 728, 3, 1) + self.block10 = Block(728, 728, 3, 1) + self.block11 = Block(728, 728, 3, 1) + + self.block12 = Block(728, 1024, 2, 2, grow_first=False) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = nn.BatchNorm2d(1536) + self.act3 = nn.ReLU(inplace=True) + + self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(self.num_features) + self.act4 = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='act2'), + dict(num_chs=128, reduction=4, module='block2.rep.0'), + dict(num_chs=256, reduction=8, module='block3.rep.0'), + dict(num_chs=728, reduction=16, module='block12.rep.0'), + dict(num_chs=2048, reduction=32, module='act4'), + ] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + # #------- init weights -------- + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + x = self.conv3(x) + x = self.bn3(x) + x = self.act3(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.act4(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate: + F.dropout(x, self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def _xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + Xception, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook'), **kwargs) + + +@register_model +def xception(pretrained=False, **kwargs): + return _xception('xception', pretrained=pretrained, **kwargs) diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b21576288153d1a8494f1f389bf2469b820a76 --- /dev/null +++ b/timm/models/xception_aligned.py @@ -0,0 +1,240 @@ +"""Pytorch impl of Aligned Xception 41, 65, 71 + +This is a correct, from scratch impl of Aligned Xception (Deeplab) models compatible with TF weights at +https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct, create_conv2d +from .layers.helpers import to_3tuple +from .registry import register_model + +__all__ = ['XceptionAligned'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10), + 'crop_pct': 0.903, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + xception41=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'), + xception65=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'), + xception71=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'), +) + + +class SeparableConv2d(nn.Module): + def __init__( + self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='', + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(SeparableConv2d, self).__init__() + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + self.kernel_size = kernel_size + self.dilation = dilation + + # depthwise convolution + self.conv_dw = create_conv2d( + inplanes, inplanes, kernel_size, stride=stride, + padding=padding, dilation=dilation, depthwise=True) + self.bn_dw = norm_layer(inplanes, **norm_kwargs) + if act_layer is not None: + self.act_dw = act_layer(inplace=True) + else: + self.act_dw = None + + # pointwise convolution + self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1) + self.bn_pw = norm_layer(planes, **norm_kwargs) + if act_layer is not None: + self.act_pw = act_layer(inplace=True) + else: + self.act_pw = None + + def forward(self, x): + x = self.conv_dw(x) + x = self.bn_dw(x) + if self.act_dw is not None: + x = self.act_dw(x) + x = self.conv_pw(x) + x = self.bn_pw(x) + if self.act_pw is not None: + x = self.act_pw(x) + return x + + +class XceptionModule(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, pad_type='', + start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None): + super(XceptionModule, self).__init__() + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + out_chs = to_3tuple(out_chs) + self.in_channels = in_chs + self.out_channels = out_chs[-1] + self.no_skip = no_skip + if not no_skip and (self.out_channels != self.in_channels or stride != 1): + self.shortcut = ConvBnAct( + in_chs, self.out_channels, 1, stride=stride, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, act_layer=None) + else: + self.shortcut = None + + separable_act_layer = None if start_with_relu else act_layer + self.stack = nn.Sequential() + for i in range(3): + if start_with_relu: + self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0)) + self.stack.add_module(f'conv{i + 1}', SeparableConv2d( + in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, + act_layer=separable_act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs)) + in_chs = out_chs[i] + + def forward(self, x): + skip = x + x = self.stack(x) + if self.shortcut is not None: + skip = self.shortcut(skip) + if not self.no_skip: + x = x + skip + return x + + +class XceptionAligned(nn.Module): + """Modified Aligned Xception + """ + + def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_rate=0., global_pool='avg'): + super(XceptionAligned, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + + layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.stem = nn.Sequential(*[ + ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), + ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args) + ]) + + curr_dilation = 1 + curr_stride = 2 + self.feature_info = [] + self.blocks = nn.Sequential() + for i, b in enumerate(block_cfg): + b['dilation'] = curr_dilation + if b['stride'] > 1: + self.feature_info += [dict( + num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')] + next_stride = curr_stride * b['stride'] + if next_stride > output_stride: + curr_dilation *= b['stride'] + b['stride'] = 1 + else: + curr_stride = next_stride + self.blocks.add_module(str(i), XceptionModule(**b, **layer_args)) + self.num_features = self.blocks[-1].out_channels + + self.feature_info += [dict( + num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] + + self.head = ClassifierHead( + in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), **kwargs) + + +@register_model +def xception41(pretrained=False, **kwargs): + """ Modified Aligned Xception-41 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 8), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) + return _xception('xception41', pretrained=pretrained, **model_args) + + +@register_model +def xception65(pretrained=False, **kwargs): + """ Modified Aligned Xception-65 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) + return _xception('xception65', pretrained=pretrained, **model_args) + + +@register_model +def xception71(pretrained=False, **kwargs): + """ Modified Aligned Xception-71 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=1), + dict(in_chs=256, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=1), + dict(in_chs=728, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_kwargs=dict(eps=.001, momentum=.1), **kwargs) + return _xception('xception71', pretrained=pretrained, **model_args) diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33e4907f99a74bebdaadb0cd5435d734e2fed684 --- /dev/null +++ b/timm/optim/__init__.py @@ -0,0 +1,13 @@ +from .adamp import AdamP +from .adamw import AdamW +from .adafactor import Adafactor +from .adahessian import Adahessian +from .lookahead import Lookahead +from .nadam import Nadam +from .novograd import NovoGrad +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP + +from .optim_factory import create_optimizer \ No newline at end of file diff --git a/timm/optim/adafactor.py b/timm/optim/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..088ce3acd82e2be1b393afafa05f48435e538a1a --- /dev/null +++ b/timm/optim/adafactor.py @@ -0,0 +1,174 @@ +""" Adafactor Optimizer + +Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + +Original header/copyright below. + +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import math + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate depending on the + *scale_parameter*, *relative_step* and *warmup_init* options. + + To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) + relative_step (bool): if True, time-dependent learning rate is computed + instead of external learning rate (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, + decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): + relative_step = lr is None + if warmup_init and not relative_step: + raise ValueError('warmup_init requires relative_step=True') + + beta1 = None if betas is None else betas[0] # make it compat with standard betas arg + defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, + beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, + relative_step=relative_step, warmup_init=warmup_init) + super(Adafactor, self).__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + if param_group['relative_step']: + min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 + lr_t = min(min_step, 1.0 / math.sqrt(param_state['step'])) + param_scale = 1.0 + if param_group['scale_parameter']: + param_scale = max(param_group['eps_scale'], param_state['RMS']) + param_group['lr'] = lr_t * param_scale + return param_group['lr'] + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group['beta1'] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError('Adafactor does not support sparse gradients.') + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state['step'] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + if factored: + state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state['exp_avg_sq'] = torch.zeros_like(grad) + + state['RMS'] = 0 + else: + if use_first_moment: + state['exp_avg'] = state['exp_avg'].to(grad) + if factored: + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) + else: + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state['step'] += 1 + state['RMS'] = self._rms(p_data_fp32) + lr_t = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) + update = grad ** 2 + group['eps'] + if factored: + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] + + exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1)) + exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2)) + #exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+ + #exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state['exp_avg_sq'] + + exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update) + #exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+ + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0)) + update.mul_(lr_t) + + if use_first_moment: + exp_avg = state['exp_avg'] + exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update) + #exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+ + update = exp_avg + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32) + #p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+ + + p_data_fp32.add_(-update) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss \ No newline at end of file diff --git a/timm/optim/adahessian.py b/timm/optim/adahessian.py new file mode 100644 index 0000000000000000000000000000000000000000..985c67ca686a65f61f5c5b1a7db3e5bba815a19b --- /dev/null +++ b/timm/optim/adahessian.py @@ -0,0 +1,156 @@ +""" AdaHessian Optimizer + +Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py +Originally licensed MIT, Copyright 2020, David Samuel +""" +import torch + + +class Adahessian(torch.optim.Optimizer): + """ + Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning" + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): learning rate (default: 0.1) + betas ((float, float), optional): coefficients used for computing running averages of gradient and the + squared hessian trace (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) + hessian_power (float, optional): exponent of the hessian trace (default: 1.0) + update_each (int, optional): compute the hessian trace approximation only after *this* number of steps + (to save time) (default: 1) + n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) + """ + + def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, + hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= hessian_power <= 1.0: + raise ValueError(f"Invalid Hessian power value: {hessian_power}") + + self.n_samples = n_samples + self.update_each = update_each + self.avg_conv_kernel = avg_conv_kernel + + # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training + self.seed = 2147483647 + self.generator = torch.Generator().manual_seed(self.seed) + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) + super(Adahessian, self).__init__(params, defaults) + + for p in self.get_params(): + p.hess = 0.0 + self.state[p]["hessian step"] = 0 + + @property + def is_second_order(self): + return True + + def get_params(self): + """ + Gets all parameters in all param_groups with gradients + """ + + return (p for group in self.param_groups for p in group['params'] if p.requires_grad) + + def zero_hessian(self): + """ + Zeros out the accumalated hessian traces. + """ + + for p in self.get_params(): + if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0: + p.hess.zero_() + + @torch.no_grad() + def set_hessian(self): + """ + Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. + """ + + params = [] + for p in filter(lambda p: p.grad is not None, self.get_params()): + if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step + params.append(p) + self.state[p]["hessian step"] += 1 + + if len(params) == 0: + return + + if self.generator.device != params[0].device: # hackish way of casting the generator to the right device + self.generator = torch.Generator(params[0].device).manual_seed(self.seed) + + grads = [p.grad for p in params] + + for i in range(self.n_samples): + # Rademacher distribution {-1.0, 1.0} + zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] + h_zs = torch.autograd.grad( + grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1) + for h_z, z, p in zip(h_zs, zs, params): + p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step. + Arguments: + closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) + """ + + loss = None + if closure is not None: + loss = closure() + + self.zero_hessian() + self.set_hessian() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None or p.hess is None: + continue + + if self.avg_conv_kernel and p.dim() == 4: + p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone() + + # Perform correct stepweight decay as in AdamW + p.mul_(1 - group['lr'] * group['weight_decay']) + + state = self.state[p] + + # State initialization + if len(state) == 1: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of Hessian diagonal square values + state['exp_hessian_diag_sq'] = torch.zeros_like(p) + + exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq'] + beta1, beta2 = group['betas'] + state['step'] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) + exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + k = group['hessian_power'] + denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps']) + + # make update + step_size = group['lr'] / bias_correction1 + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss diff --git a/timm/optim/adamp.py b/timm/optim/adamp.py new file mode 100644 index 0000000000000000000000000000000000000000..468c3e865e0ceb6fb2bf22f9388237a783314f07 --- /dev/null +++ b/timm/optim/adamp.py @@ -0,0 +1,107 @@ +""" +AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn as nn +from torch.optim.optimizer import Optimizer, required +import math + +class AdamP(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) + super(AdamP, self).__init__(params, defaults) + + def _channel_view(self, x): + return x.view(x.size(0), -1) + + def _layer_view(self, x): + return x.view(1, -1) + + def _cosine_similarity(self, x, y, eps, view_func): + x = view_func(x) + y = view_func(y) + + x_norm = x.norm(dim=1).add_(eps) + y_norm = y.norm(dim=1).add_(eps) + dot = (x * y).sum(dim=1) + + return dot.abs() / x_norm / y_norm + + def _projection(self, p, grad, perturb, delta, wd_ratio, eps): + wd = 1 + expand_size = [-1] + [1] * (len(p.shape) - 1) + for view_func in [self._channel_view, self._layer_view]: + + cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) + + if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): + p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) + wd = wd_ratio + + return perturb, wd + + return perturb, wd + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad.data + beta1, beta2 = group['betas'] + nesterov = group['nesterov'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg_sq'] = torch.zeros_like(p.data) + + # Adam + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + step_size = group['lr'] / bias_correction1 + + if nesterov: + perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom + else: + perturb = exp_avg / denom + + # Projection + wd_ratio = 1 + if len(p.shape) > 1: + perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if group['weight_decay'] > 0: + p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio) + + # Step + p.data.add_(-step_size, perturb) + + return loss diff --git a/timm/optim/adamw.py b/timm/optim/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..66f9a959de586356a29ace2f9c57d3fee8d1057a --- /dev/null +++ b/timm/optim/adamw.py @@ -0,0 +1,117 @@ +""" AdamW Optimizer +Impl copied from PyTorch master +""" +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.data.addcdiv_(-step_size, exp_avg, denom) + + return loss diff --git a/timm/optim/lookahead.py b/timm/optim/lookahead.py new file mode 100644 index 0000000000000000000000000000000000000000..6b5b7f38ec8cb6594e3986b66223fa2881daeca3 --- /dev/null +++ b/timm/optim/lookahead.py @@ -0,0 +1,92 @@ +""" Lookahead Optimizer Wrapper. +Implementation modified from: https://github.com/alphadl/lookahead.pytorch +Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch.optim.optimizer import Optimizer +from collections import defaultdict + + +class Lookahead(Optimizer): + def __init__(self, base_optimizer, alpha=0.5, k=6): + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) + self.base_optimizer = base_optimizer + self.param_groups = self.base_optimizer.param_groups + self.defaults = base_optimizer.defaults + self.defaults.update(defaults) + self.state = defaultdict(dict) + # manually add our defaults to the param groups + for name, default in defaults.items(): + for group in self.param_groups: + group.setdefault(name, default) + + def update_slow(self, group): + for fast_p in group["params"]: + if fast_p.grad is None: + continue + param_state = self.state[fast_p] + if 'slow_buffer' not in param_state: + param_state['slow_buffer'] = torch.empty_like(fast_p.data) + param_state['slow_buffer'].copy_(fast_p.data) + slow = param_state['slow_buffer'] + slow.add_(group['lookahead_alpha'], fast_p.data - slow) + fast_p.data.copy_(slow) + + def sync_lookahead(self): + for group in self.param_groups: + self.update_slow(group) + + def step(self, closure=None): + #assert id(self.param_groups) == id(self.base_optimizer.param_groups) + loss = self.base_optimizer.step(closure) + for group in self.param_groups: + group['lookahead_step'] += 1 + if group['lookahead_step'] % group['lookahead_k'] == 0: + self.update_slow(group) + return loss + + def state_dict(self): + fast_state_dict = self.base_optimizer.state_dict() + slow_state = { + (id(k) if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items() + } + fast_state = fast_state_dict['state'] + param_groups = fast_state_dict['param_groups'] + return { + 'state': fast_state, + 'slow_state': slow_state, + 'param_groups': param_groups, + } + + def load_state_dict(self, state_dict): + fast_state_dict = { + 'state': state_dict['state'], + 'param_groups': state_dict['param_groups'], + } + self.base_optimizer.load_state_dict(fast_state_dict) + + # We want to restore the slow state, but share param_groups reference + # with base_optimizer. This is a bit redundant but least code + slow_state_new = False + if 'slow_state' not in state_dict: + print('Loading state_dict from optimizer without Lookahead applied.') + state_dict['slow_state'] = defaultdict(dict) + slow_state_new = True + slow_state_dict = { + 'state': state_dict['slow_state'], + 'param_groups': state_dict['param_groups'], # this is pointless but saves code + } + super(Lookahead, self).load_state_dict(slow_state_dict) + self.param_groups = self.base_optimizer.param_groups # make both ref same container + if slow_state_new: + # reapply defaults to catch missing lookahead specific ones + for name, default in self.defaults.items(): + for group in self.param_groups: + group.setdefault(name, default) diff --git a/timm/optim/nadam.py b/timm/optim/nadam.py new file mode 100644 index 0000000000000000000000000000000000000000..d994d1b83485c9b068de73f5f3cf2efb1e5bec39 --- /dev/null +++ b/timm/optim/nadam.py @@ -0,0 +1,88 @@ +import torch +from torch.optim import Optimizer + + +class Nadam(Optimizer): + """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). + + It has been proposed in `Incorporating Nesterov Momentum into Adam`__. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + schedule_decay (float, optional): momentum schedule decay (default: 4e-3) + + __ http://cs229.stanford.edu/proj2015/054_report.pdf + __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf + + Originally taken from: https://github.com/pytorch/pytorch/pull/1408 + NOTE: Has potential issues but does work well on some problems. + """ + + def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, schedule_decay=4e-3): + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, schedule_decay=schedule_decay) + super(Nadam, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['m_schedule'] = 1. + state['exp_avg'] = grad.new().resize_as_(grad).zero_() + state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() + + # Warming momentum schedule + m_schedule = state['m_schedule'] + schedule_decay = group['schedule_decay'] + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + eps = group['eps'] + state['step'] += 1 + t = state['step'] + + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + momentum_cache_t = beta1 * \ + (1. - 0.5 * (0.96 ** (t * schedule_decay))) + momentum_cache_t_1 = beta1 * \ + (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) + m_schedule_new = m_schedule * momentum_cache_t + m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 + state['m_schedule'] = m_schedule_new + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1. - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) + exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) + denom = exp_avg_sq_prime.sqrt_().add_(eps) + + p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) + p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) + + return loss diff --git a/timm/optim/novograd.py b/timm/optim/novograd.py new file mode 100644 index 0000000000000000000000000000000000000000..4137c6aa9406360d29f5f7234ebbdef294404d0e --- /dev/null +++ b/timm/optim/novograd.py @@ -0,0 +1,77 @@ +"""NovoGrad Optimizer. +Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NovoGrad(Optimizer): + def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(NovoGrad, self).__init__(params, defaults) + self._lr = lr + self._beta1 = betas[0] + self._beta2 = betas[1] + self._eps = eps + self._wd = weight_decay + self._grad_averaging = grad_averaging + + self._momentum_initialized = False + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + if not self._momentum_initialized: + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('NovoGrad does not support sparse gradients') + + v = torch.norm(grad)**2 + m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data + state['step'] = 0 + state['v'] = v + state['m'] = m + state['grad_ema'] = None + self._momentum_initialized = True + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + state['step'] += 1 + + step, v, m = state['step'], state['v'], state['m'] + grad_ema = state['grad_ema'] + + grad = p.grad.data + g2 = torch.norm(grad)**2 + grad_ema = g2 if grad_ema is None else grad_ema * \ + self._beta2 + g2 * (1. - self._beta2) + grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) + + if self._grad_averaging: + grad *= (1. - self._beta1) + + g2 = torch.norm(grad)**2 + v = self._beta2*v + (1. - self._beta2)*g2 + m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) + bias_correction1 = 1 - self._beta1 ** step + bias_correction2 = 1 - self._beta2 ** step + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + state['v'], state['m'] = v, m + state['grad_ema'] = grad_ema + p.data.add_(-step_size, m) + return loss diff --git a/timm/optim/nvnovograd.py b/timm/optim/nvnovograd.py new file mode 100644 index 0000000000000000000000000000000000000000..323312d2fc36d028124f7a7ec604d248e71503cd --- /dev/null +++ b/timm/optim/nvnovograd.py @@ -0,0 +1,118 @@ +""" Nvidia NovoGrad Optimizer. +Original impl by Nvidia from Jasper example: + - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NvNovoGrad(Optimizer): + """ + Implements Novograd algorithm. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.95, 0.98)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging: gradient averaging + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + """ + + def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, + weight_decay=0, grad_averaging=False, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + amsgrad=amsgrad) + + super(NvNovoGrad, self).__init__(params, defaults) + + def __setstate__(self, state): + super(NvNovoGrad, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Sparse gradients are not supported.') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + norm = torch.sum(torch.pow(grad, 2)) + + if exp_avg_sq == 0: + exp_avg_sq.copy_(norm) + else: + exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + grad.div_(denom) + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + if group['grad_averaging']: + grad.mul_(1 - beta1) + exp_avg.mul_(beta1).add_(grad) + + p.data.add_(-group['lr'], exp_avg) + + return loss diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc61c5f51469470df508abc158141d2bd1ff9a4 --- /dev/null +++ b/timm/optim/optim_factory.py @@ -0,0 +1,120 @@ +""" Optimizer Factory w/ Custom Weight Decay +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import optim as optim + +from .adafactor import Adafactor +from .adahessian import Adahessian +from .adamp import AdamP +from .lookahead import Lookahead +from .nadam import Nadam +from .novograd import NovoGrad +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP + +try: + from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD + has_apex = True +except ImportError: + has_apex = False + + +def add_weight_decay(model, weight_decay=1e-5, skip_list=()): + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + no_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': weight_decay}] + + +def create_optimizer(args, model, filter_bias_and_bn=True): + opt_lower = args.opt.lower() + weight_decay = args.weight_decay + if weight_decay and filter_bias_and_bn: + skip = {} + if hasattr(model, 'no_weight_decay'): + skip = model.no_weight_decay() + parameters = add_weight_decay(model, weight_decay, skip) + weight_decay = 0. + else: + parameters = model.parameters() + + if 'fused' in opt_lower: + assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' + + opt_args = dict(lr=args.lr, weight_decay=weight_decay) + if hasattr(args, 'opt_eps') and args.opt_eps is not None: + opt_args['eps'] = args.opt_eps + if hasattr(args, 'opt_betas') and args.opt_betas is not None: + opt_args['betas'] = args.opt_betas + + opt_split = opt_lower.split('_') + opt_lower = opt_split[-1] + if opt_lower == 'sgd' or opt_lower == 'nesterov': + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) + elif opt_lower == 'momentum': + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) + elif opt_lower == 'adam': + optimizer = optim.Adam(parameters, **opt_args) + elif opt_lower == 'adamw': + optimizer = optim.AdamW(parameters, **opt_args) + elif opt_lower == 'nadam': + optimizer = Nadam(parameters, **opt_args) + elif opt_lower == 'radam': + optimizer = RAdam(parameters, **opt_args) + elif opt_lower == 'adamp': + optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) + elif opt_lower == 'sgdp': + optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) + elif opt_lower == 'adadelta': + optimizer = optim.Adadelta(parameters, **opt_args) + elif opt_lower == 'adafactor': + if not args.lr: + opt_args['lr'] = None + optimizer = Adafactor(parameters, **opt_args) + elif opt_lower == 'adahessian': + optimizer = Adahessian(parameters, **opt_args) + elif opt_lower == 'rmsprop': + optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) + elif opt_lower == 'rmsproptf': + optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) + elif opt_lower == 'novograd': + optimizer = NovoGrad(parameters, **opt_args) + elif opt_lower == 'nvnovograd': + optimizer = NvNovoGrad(parameters, **opt_args) + elif opt_lower == 'fusedsgd': + opt_args.pop('eps', None) + optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) + elif opt_lower == 'fusedmomentum': + opt_args.pop('eps', None) + optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) + elif opt_lower == 'fusedadam': + optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) + elif opt_lower == 'fusedadamw': + optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) + elif opt_lower == 'fusedlamb': + optimizer = FusedLAMB(parameters, **opt_args) + elif opt_lower == 'fusednovograd': + opt_args.setdefault('betas', (0.95, 0.98)) + optimizer = FusedNovoGrad(parameters, **opt_args) + else: + assert False and "Invalid optimizer" + raise ValueError + + if len(opt_split) > 1: + if opt_split[0] == 'lookahead': + optimizer = Lookahead(optimizer) + + return optimizer diff --git a/timm/optim/radam.py b/timm/optim/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..9987a334460286b1a6c8ec6d57ee023596a74219 --- /dev/null +++ b/timm/optim/radam.py @@ -0,0 +1,152 @@ +"""RAdam Optimizer. +Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam +Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 +""" +import math +import torch +from torch.optim.optimizer import Optimizer, required + + +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.buffer = [[None, None, None] for ind in range(10)] + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + buffered = self.buffer[int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss + + +class PlainRAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + + super(PlainRAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(PlainRAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss diff --git a/timm/optim/rmsprop_tf.py b/timm/optim/rmsprop_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..5115555cd26040e3af297a6e79e7bd5e4d202623 --- /dev/null +++ b/timm/optim/rmsprop_tf.py @@ -0,0 +1,136 @@ +""" RMSProp modified to behave like Tensorflow impl + +Originally cut & paste from PyTorch RMSProp +https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py +Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE + +Modifications Copyright 2020 Ross Wightman +""" + +import torch +from torch.optim import Optimizer + + +class RMSpropTF(Optimizer): + """Implements RMSprop algorithm (TensorFlow style epsilon) + + NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt + and a few other modifications to closer match Tensorflow for matching hyper-params. + + Noteworthy changes include: + 1. Epsilon applied inside square-root + 2. square_avg initialized to ones + 3. LR scaling of update accumulated in momentum buffer + + Proposed by G. Hinton in his + `course `_. + + The centered version first appears in `Generating Sequences + With Recurrent Neural Networks `_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing (decay) constant (default: 0.9) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-10) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 + lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer + update as per defaults in Tensorflow + + """ + + def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, + decoupled_decay=False, lr_in_momentum=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, + decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + super(RMSpropTF, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RMSpropTF, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('momentum', 0) + group.setdefault('centered', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p.data) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p.data) + + square_avg = state['square_avg'] + one_minus_alpha = 1. - group['alpha'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + if 'decoupled_decay' in group and group['decoupled_decay']: + p.data.add_(-group['weight_decay'], p.data) + else: + grad = grad.add(group['weight_decay'], p.data) + + # Tensorflow order of ops for updating squared avg + square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) + # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original + + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.add_(one_minus_alpha, grad - grad_avg) + # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original + avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt + else: + avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt + + if group['momentum'] > 0: + buf = state['momentum_buffer'] + # Tensorflow accumulates the LR scaling in the momentum buffer + if 'lr_in_momentum' in group and group['lr_in_momentum']: + buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) + p.data.add_(-buf) + else: + # PyTorch scales the param update by LR + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.data.add_(-group['lr'], buf) + else: + p.data.addcdiv_(-group['lr'], grad, avg) + + return loss diff --git a/timm/optim/sgdp.py b/timm/optim/sgdp.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a94aa332d7030a70e888342eb6cc4623d69836 --- /dev/null +++ b/timm/optim/sgdp.py @@ -0,0 +1,96 @@ +""" +SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn as nn +from torch.optim.optimizer import Optimizer, required +import math + +class SGDP(Optimizer): + def __init__(self, params, lr=required, momentum=0, dampening=0, + weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, + nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) + super(SGDP, self).__init__(params, defaults) + + def _channel_view(self, x): + return x.view(x.size(0), -1) + + def _layer_view(self, x): + return x.view(1, -1) + + def _cosine_similarity(self, x, y, eps, view_func): + x = view_func(x) + y = view_func(y) + + x_norm = x.norm(dim=1).add_(eps) + y_norm = y.norm(dim=1).add_(eps) + dot = (x * y).sum(dim=1) + + return dot.abs() / x_norm / y_norm + + def _projection(self, p, grad, perturb, delta, wd_ratio, eps): + wd = 1 + expand_size = [-1] + [1] * (len(p.shape) - 1) + for view_func in [self._channel_view, self._layer_view]: + + cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) + + if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): + p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) + wd = wd_ratio + + return perturb, wd + + return perturb, wd + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + # State initialization + if len(state) == 0: + state['momentum'] = torch.zeros_like(p.data) + + # SGD + buf = state['momentum'] + buf.mul_(momentum).add_(1 - dampening, grad) + if nesterov: + d_p = grad + momentum * buf + else: + d_p = buf + + # Projection + wd_ratio = 1 + if len(p.shape) > 1: + d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if weight_decay != 0: + p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) + + # Step + p.data.add_(-group['lr'], d_p) + + return loss diff --git a/timm/scheduler/__init__.py b/timm/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a7789826229f66e1220cb6149902ba9c411b537 --- /dev/null +++ b/timm/scheduler/__init__.py @@ -0,0 +1,5 @@ +from .cosine_lr import CosineLRScheduler +from .plateau_lr import PlateauLRScheduler +from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler +from .scheduler_factory import create_scheduler diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..1532f092b5cc8c0af5125967cfb84b32ce03ca4a --- /dev/null +++ b/timm/scheduler/cosine_lr.py @@ -0,0 +1,116 @@ +""" Cosine Scheduler + +Cosine LR schedule with warmup, cycle/restarts, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class CosineLRScheduler(Scheduler): + """ + Cosine decay with restarts. + This is described in the paper https://arxiv.org/abs/1608.03983. + + Inspiration from + https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + t_mul: float = 1., + lr_min: float = 0., + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + cycle_limit=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and t_mul == 1 and decay_rate == 1: + _logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.t_mul = t_mul + self.lr_min = lr_min + self.decay_rate = decay_rate + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.t_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) + t_i = self.t_mul ** i * self.t_initial + t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.decay_rate ** i + lr_min = self.lr_min * gamma + lr_max_values = [v * gamma for v in self.base_values] + + if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): + lrs = [ + lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + if not cycles: + cycles = self.cycle_limit + cycles = max(1, cycles) + if self.t_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2cacb65a1bf23d10aa6fd296f74579571043cf --- /dev/null +++ b/timm/scheduler/plateau_lr.py @@ -0,0 +1,113 @@ +""" Plateau Scheduler + +Adapts PyTorch plateau scheduler and allows application of noise, warmup. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +from .scheduler import Scheduler + + +class PlateauLRScheduler(Scheduler): + """Decay the LR by a factor every time the validation loss plateaus.""" + + def __init__(self, + optimizer, + decay_rate=0.1, + patience_t=10, + verbose=True, + threshold=1e-4, + cooldown_t=0, + warmup_t=0, + warmup_lr_init=0, + lr_min=0, + mode='max', + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize=True, + ): + super().__init__(optimizer, 'lr', initialize=initialize) + + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + patience=patience_t, + factor=decay_rate, + verbose=verbose, + threshold=threshold, + cooldown=cooldown_t, + mode=mode, + min_lr=lr_min + ) + + self.noise_range = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + self.restore_lr = None + + def state_dict(self): + return { + 'best': self.lr_scheduler.best, + 'last_epoch': self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + self.lr_scheduler.best = state_dict['best'] + if 'last_epoch' in state_dict: + self.lr_scheduler.last_epoch = state_dict['last_epoch'] + + # override the base class step fn completely + def step(self, epoch, metric=None): + if epoch <= self.warmup_t: + lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] + super().update_groups(lrs) + else: + if self.restore_lr is not None: + # restore actual LR from before our last noise perturbation before stepping base + for i, param_group in enumerate(self.optimizer.param_groups): + param_group['lr'] = self.restore_lr[i] + self.restore_lr = None + + self.lr_scheduler.step(metric, epoch) # step the base scheduler + + if self.noise_range is not None: + if isinstance(self.noise_range, (list, tuple)): + apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] + else: + apply_noise = epoch >= self.noise_range + if apply_noise: + self._apply_noise(epoch) + + def _apply_noise(self, epoch): + g = torch.Generator() + g.manual_seed(self.noise_seed + epoch) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + break + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + + # apply the noise on top of previous LR, cache the old value so we can restore for normal + # stepping of base scheduler + restore_lr = [] + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group['lr']) + restore_lr.append(old_lr) + new_lr = old_lr + old_lr * noise + param_group['lr'] = new_lr + self.restore_lr = restore_lr diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..21d51509c87a0783c6b61986c574a3ed5366e165 --- /dev/null +++ b/timm/scheduler/scheduler.py @@ -0,0 +1,105 @@ +from typing import Dict, Any + +import torch + + +class Scheduler: + """ Parameter Scheduler Base Class + A scheduler base class that can be used to schedule any optimizer parameter groups. + + Unlike the builtin PyTorch schedulers, this is intended to be consistently called + * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value + * At the END of each optimizer update, after incrementing the update count, to calculate next update's value + + The schedulers built on this should try to remain as stateless as possible (for simplicity). + + This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' + and -1 values for special behaviour. All epoch and update counts must be tracked in the training + code and explicitly passed in to the schedulers on the corresponding step or step_update call. + + Based on ideas from: + * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler + * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + param_group_field: str, + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize: bool = True) -> None: + self.optimizer = optimizer + self.param_group_field = param_group_field + self._initial_param_group_field = f"initial_{param_group_field}" + if initialize: + for i, group in enumerate(self.optimizer.param_groups): + if param_group_field not in group: + raise KeyError(f"{param_group_field} missing from param_groups[{i}]") + group.setdefault(self._initial_param_group_field, group[param_group_field]) + else: + for i, group in enumerate(self.optimizer.param_groups): + if self._initial_param_group_field not in group: + raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") + self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] + self.metric = None # any point to having this for all? + self.noise_range_t = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 + self.update_groups(self.base_values) + + def state_dict(self) -> Dict[str, Any]: + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.__dict__.update(state_dict) + + def get_epoch_values(self, epoch: int): + return None + + def get_update_values(self, num_updates: int): + return None + + def step(self, epoch: int, metric: float = None) -> None: + self.metric = metric + values = self.get_epoch_values(epoch) + if values is not None: + values = self._add_noise(values, epoch) + self.update_groups(values) + + def step_update(self, num_updates: int, metric: float = None): + self.metric = metric + values = self.get_update_values(num_updates) + if values is not None: + values = self._add_noise(values, num_updates) + self.update_groups(values) + + def update_groups(self, values): + if not isinstance(values, (list, tuple)): + values = [values] * len(self.optimizer.param_groups) + for param_group, value in zip(self.optimizer.param_groups, values): + param_group[self.param_group_field] = value + + def _add_noise(self, lrs, t): + if self.noise_range_t is not None: + if isinstance(self.noise_range_t, (list, tuple)): + apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] + else: + apply_noise = t >= self.noise_range_t + if apply_noise: + g = torch.Generator() + g.manual_seed(self.noise_seed + t) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + break + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + lrs = [v + v * noise for v in lrs] + return lrs diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7748f42280b846ab159fb18d7cda09d1890123 --- /dev/null +++ b/timm/scheduler/scheduler_factory.py @@ -0,0 +1,87 @@ +""" Scheduler Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from .cosine_lr import CosineLRScheduler +from .tanh_lr import TanhLRScheduler +from .step_lr import StepLRScheduler +from .plateau_lr import PlateauLRScheduler + + +def create_scheduler(args, optimizer): + num_epochs = args.epochs + + if getattr(args, 'lr_noise', None) is not None: + lr_noise = getattr(args, 'lr_noise') + if isinstance(lr_noise, (list, tuple)): + noise_range = [n * num_epochs for n in lr_noise] + if len(noise_range) == 1: + noise_range = noise_range[0] + else: + noise_range = lr_noise * num_epochs + else: + noise_range = None + + lr_scheduler = None + if args.sched == 'cosine': + lr_scheduler = CosineLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=getattr(args, 'lr_cycle_mul', 1.), + lr_min=args.min_lr, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cycle_limit=getattr(args, 'lr_cycle_limit', 1), + t_in_epochs=True, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + elif args.sched == 'tanh': + lr_scheduler = TanhLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=getattr(args, 'lr_cycle_mul', 1.), + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cycle_limit=getattr(args, 'lr_cycle_limit', 1), + t_in_epochs=True, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + elif args.sched == 'step': + lr_scheduler = StepLRScheduler( + optimizer, + decay_t=args.decay_epochs, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + elif args.sched == 'plateau': + mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' + lr_scheduler = PlateauLRScheduler( + optimizer, + decay_rate=args.decay_rate, + patience_t=args.patience_epochs, + lr_min=args.min_lr, + mode=mode, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cooldown_t=0, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + + return lr_scheduler, num_epochs diff --git a/timm/scheduler/step_lr.py b/timm/scheduler/step_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..f797e1a8cf35999531dd5f1ccbbe09a9d0cf30a9 --- /dev/null +++ b/timm/scheduler/step_lr.py @@ -0,0 +1,63 @@ +""" Step Scheduler + +Basic step LR schedule with warmup, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import torch + +from .scheduler import Scheduler + + +class StepLRScheduler(Scheduler): + """ + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + decay_t: float, + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc338bb1df7a564d9207b32ab0f59cdf1ef4c59 --- /dev/null +++ b/timm/scheduler/tanh_lr.py @@ -0,0 +1,120 @@ +""" TanH Scheduler + +TanH schedule with warmup, cycle/restarts, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class TanhLRScheduler(Scheduler): + """ + Hyberbolic-Tangent decay with restarts. + This is described in the paper https://arxiv.org/abs/1806.01593 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lb: float = -6., + ub: float = 4., + t_mul: float = 1., + lr_min: float = 0., + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + cycle_limit=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + assert lb < ub + assert cycle_limit >= 0 + assert warmup_t >= 0 + assert warmup_lr_init >= 0 + self.lb = lb + self.ub = ub + self.t_initial = t_initial + self.t_mul = t_mul + self.lr_min = lr_min + self.decay_rate = decay_rate + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + if self.warmup_t: + t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.t_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) + t_i = self.t_mul ** i * self.t_initial + t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): + gamma = self.decay_rate ** i + lr_min = self.lr_min * gamma + lr_max_values = [v * gamma for v in self.base_values] + + tr = t_curr / t_i + lrs = [ + lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + if not cycles: + cycles = self.cycle_limit + cycles = max(1, cycles) + if self.t_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f7c4b055537e2d7d0aeb22e2abd9fb09c141099 --- /dev/null +++ b/timm/utils/__init__.py @@ -0,0 +1,10 @@ +from .checkpoint_saver import CheckpointSaver +from .cuda import ApexScaler, NativeScaler +from .distributed import distribute_bn, reduce_tensor +from .jit import set_jit_legacy +from .log import setup_default_logging, FormatterNoInfo +from .metrics import AverageMeter, accuracy +from .misc import natural_key, add_bool_arg +from .model import unwrap_model, get_state_dict +from .model_ema import ModelEma, ModelEmaV2 +from .summary import update_summary, get_outdir diff --git a/timm/utils/checkpoint_saver.py b/timm/utils/checkpoint_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..51896e782dc9c2c1263c2fe5f9901c9bd54a0e54 --- /dev/null +++ b/timm/utils/checkpoint_saver.py @@ -0,0 +1,153 @@ +""" Checkpoint Saver + +Track top-n training checkpoints and maintain recovery checkpoints on specified intervals. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import glob +import operator +import os +import logging + +import torch + +from .model import unwrap_model, get_state_dict + + +_logger = logging.getLogger(__name__) + + +class CheckpointSaver: + def __init__( + self, + model, + optimizer, + args=None, + model_ema=None, + amp_scaler=None, + checkpoint_prefix='checkpoint', + recovery_prefix='recovery', + checkpoint_dir='', + recovery_dir='', + decreasing=False, + max_history=10, + unwrap_fn=unwrap_model): + + # objects to save state_dicts of + self.model = model + self.optimizer = optimizer + self.args = args + self.model_ema = model_ema + self.amp_scaler = amp_scaler + + # state + self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness + self.best_epoch = None + self.best_metric = None + self.curr_recovery_file = '' + self.last_recovery_file = '' + + # config + self.checkpoint_dir = checkpoint_dir + self.recovery_dir = recovery_dir + self.save_prefix = checkpoint_prefix + self.recovery_prefix = recovery_prefix + self.extension = '.pth.tar' + self.decreasing = decreasing # a lower metric is better if True + self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs + self.max_history = max_history + self.unwrap_fn = unwrap_fn + assert self.max_history >= 1 + + def save_checkpoint(self, epoch, metric=None): + assert epoch >= 0 + tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) + last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) + self._save(tmp_save_path, epoch, metric) + if os.path.exists(last_save_path): + os.unlink(last_save_path) # required for Windows support. + os.rename(tmp_save_path, last_save_path) + worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None + if (len(self.checkpoint_files) < self.max_history + or metric is None or self.cmp(metric, worst_file[1])): + if len(self.checkpoint_files) >= self.max_history: + self._cleanup_checkpoints(1) + filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension + save_path = os.path.join(self.checkpoint_dir, filename) + os.link(last_save_path, save_path) + self.checkpoint_files.append((save_path, metric)) + self.checkpoint_files = sorted( + self.checkpoint_files, key=lambda x: x[1], + reverse=not self.decreasing) # sort in descending order if a lower metric is not better + + checkpoints_str = "Current checkpoints:\n" + for c in self.checkpoint_files: + checkpoints_str += ' {}\n'.format(c) + _logger.info(checkpoints_str) + + if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): + self.best_epoch = epoch + self.best_metric = metric + best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) + if os.path.exists(best_save_path): + os.unlink(best_save_path) + os.link(last_save_path, best_save_path) + + return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) + + def _save(self, save_path, epoch, metric=None): + save_state = { + 'epoch': epoch, + 'arch': type(self.model).__name__.lower(), + 'state_dict': get_state_dict(self.model, self.unwrap_fn), + 'optimizer': self.optimizer.state_dict(), + 'version': 2, # version < 2 increments epoch before save + } + if self.args is not None: + save_state['arch'] = self.args.model + save_state['args'] = self.args + if self.amp_scaler is not None: + save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() + if self.model_ema is not None: + save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) + if metric is not None: + save_state['metric'] = metric + torch.save(save_state, save_path) + + def _cleanup_checkpoints(self, trim=0): + trim = min(len(self.checkpoint_files), trim) + delete_index = self.max_history - trim + if delete_index <= 0 or len(self.checkpoint_files) <= delete_index: + return + to_delete = self.checkpoint_files[delete_index:] + for d in to_delete: + try: + _logger.debug("Cleaning checkpoint: {}".format(d)) + os.remove(d[0]) + except Exception as e: + _logger.error("Exception '{}' while deleting checkpoint".format(e)) + self.checkpoint_files = self.checkpoint_files[:delete_index] + + def save_recovery(self, epoch, batch_idx=0): + assert epoch >= 0 + filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension + save_path = os.path.join(self.recovery_dir, filename) + self._save(save_path, epoch) + if os.path.exists(self.last_recovery_file): + try: + _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) + os.remove(self.last_recovery_file) + except Exception as e: + _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) + self.last_recovery_file = self.curr_recovery_file + self.curr_recovery_file = save_path + + def find_recovery(self): + recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) + files = glob.glob(recovery_path + '*' + self.extension) + files = sorted(files) + if len(files): + return files[0] + else: + return '' diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..bcd29f5801864472e0e91a803bccf531d55aac8b --- /dev/null +++ b/timm/utils/cuda.py @@ -0,0 +1,53 @@ +""" CUDA / AMP utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +try: + from apex import amp + has_apex = True +except ImportError: + amp = None + has_apex = False + + +class ApexScaler: + state_dict_key = "amp" + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward(create_graph=create_graph) + if clip_grad is not None: + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) + optimizer.step() + + def state_dict(self): + if 'state_dict' in amp.__dict__: + return amp.state_dict() + + def load_state_dict(self, state_dict): + if 'load_state_dict' in amp.__dict__: + amp.load_state_dict(state_dict) + + +class NativeScaler: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): + self._scaler.scale(loss).backward(create_graph=create_graph) + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + self._scaler.step(optimizer) + self._scaler.update() + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5dba8c1de5a6ff53638207521377fdfbc4f239 --- /dev/null +++ b/timm/utils/distributed.py @@ -0,0 +1,28 @@ +""" Distributed training/validation utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import distributed as dist + +from .model import unwrap_model + + +def reduce_tensor(tensor, n): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= n + return rt + + +def distribute_bn(model, world_size, reduce=False): + # ensure every node has the same running bn stats + for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): + if ('running_mean' in bn_name) or ('running_var' in bn_name): + if reduce: + # average bn stats across whole group + torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) + bn_buf /= float(world_size) + else: + # broadcast bn stats from rank 0 to whole group + torch.distributed.broadcast(bn_buf, 0) diff --git a/timm/utils/jit.py b/timm/utils/jit.py new file mode 100644 index 0000000000000000000000000000000000000000..185ab7a0d852b9a1c469cfbfff108dbafbb02466 --- /dev/null +++ b/timm/utils/jit.py @@ -0,0 +1,18 @@ +""" JIT scripting/tracing utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + + +def set_jit_legacy(): + """ Set JIT executor to legacy w/ support for op fusion + This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes + in the JIT exectutor. These API are not supported so could change. + """ + # + assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_override_can_fuse_on_gpu(True) + #torch._C._jit_set_texpr_fuser_enabled(True) diff --git a/timm/utils/log.py b/timm/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..c99469e0884f3e45905ef7c7f0d1e491092697ad --- /dev/null +++ b/timm/utils/log.py @@ -0,0 +1,28 @@ +""" Logging helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import logging.handlers + + +class FormatterNoInfo(logging.Formatter): + def __init__(self, fmt='%(levelname)s: %(message)s'): + logging.Formatter.__init__(self, fmt) + + def format(self, record): + if record.levelno == logging.INFO: + return str(record.getMessage()) + return logging.Formatter.format(self, record) + + +def setup_default_logging(default_level=logging.INFO, log_path=''): + console_handler = logging.StreamHandler() + console_handler.setFormatter(FormatterNoInfo()) + logging.root.addHandler(console_handler) + logging.root.setLevel(default_level) + if log_path: + file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) + file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") + file_handler.setFormatter(file_formatter) + logging.root.addHandler(file_handler) diff --git a/timm/utils/metrics.py b/timm/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..8e0b1f9989a9dc95708a0dbb42e747f9a8565378 --- /dev/null +++ b/timm/utils/metrics.py @@ -0,0 +1,32 @@ +""" Eval metrics and related + +Hacked together by / Copyright 2020 Ross Wightman +""" + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] diff --git a/timm/utils/misc.py b/timm/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..39c0097c60ed602547f832f1f8dafbe37f156064 --- /dev/null +++ b/timm/utils/misc.py @@ -0,0 +1,18 @@ +""" Misc utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import re + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def add_bool_arg(parser, name, default=False, help=''): + dest_name = name.replace('-', '_') + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) + group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) + parser.set_defaults(**{dest_name: default}) diff --git a/timm/utils/model.py b/timm/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd42806c37e62bd1c8741c5a0cb934e813b2682 --- /dev/null +++ b/timm/utils/model.py @@ -0,0 +1,16 @@ +""" Model / state_dict utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +from .model_ema import ModelEma + + +def unwrap_model(model): + if isinstance(model, ModelEma): + return unwrap_model(model.ema) + else: + return model.module if hasattr(model, 'module') else model + + +def get_state_dict(model, unwrap_fn=unwrap_model): + return unwrap_fn(model).state_dict() diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py new file mode 100644 index 0000000000000000000000000000000000000000..073d5c5ea1a4afc5aa3817b6354b2566f8cc2cf5 --- /dev/null +++ b/timm/utils/model_ema.py @@ -0,0 +1,126 @@ +""" Exponential Moving Average (EMA) of model updates + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn + +_logger = logging.getLogger(__name__) + + +class ModelEma: + """ Model Exponential Moving Average (DEPRECATED) + + Keep a moving average of everything in the model state_dict (parameters and buffers). + This version is deprecated, it does not work with scripted models. Will be removed eventually. + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device='', resume=''): + # make a copy of the model for accumulating moving average of weights + self.ema = deepcopy(model) + self.ema.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if device: + self.ema.to(device=device) + self.ema_has_module = hasattr(self.ema, 'module') + if resume: + self._load_checkpoint(resume) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def _load_checkpoint(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + assert isinstance(checkpoint, dict) + if 'state_dict_ema' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict_ema'].items(): + # ema model may have been wrapped by DataParallel, and need module prefix + if self.ema_has_module: + name = 'module.' + k if not k.startswith('module') else k + else: + name = k + new_state_dict[name] = v + self.ema.load_state_dict(new_state_dict) + _logger.info("Loaded state_dict_ema") + else: + _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") + + def update(self, model): + # correct a mismatch in state dict keys + needs_module = hasattr(model, 'module') and not self.ema_has_module + with torch.no_grad(): + msd = model.state_dict() + for k, ema_v in self.ema.state_dict().items(): + if needs_module: + k = 'module.' + k + model_v = msd[k].detach() + if self.device: + model_v = model_v.to(device=self.device) + ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + + +class ModelEmaV2(nn.Module): + """ Model Exponential Moving Average V2 + + Keep a moving average of everything in the model state_dict (parameters and buffers). + V2 of this module is simpler, it does not match params/buffers based on name but simply + iterates in order. It works with torchscript (JIT of full model). + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device=None): + super(ModelEmaV2, self).__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) diff --git a/timm/utils/summary.py b/timm/utils/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..a0801eaace9098851a38345162d0636d7447bb9f --- /dev/null +++ b/timm/utils/summary.py @@ -0,0 +1,34 @@ +""" Summary utilities + +Hacked together by / Copyright 2020 Ross Wightman +""" +import csv +import os +from collections import OrderedDict + + +def get_outdir(path, *paths, inc=False): + outdir = os.path.join(path, *paths) + if not os.path.exists(outdir): + os.makedirs(outdir) + elif inc: + count = 1 + outdir_inc = outdir + '-' + str(count) + while os.path.exists(outdir_inc): + count = count + 1 + outdir_inc = outdir + '-' + str(count) + assert count < 100 + outdir = outdir_inc + os.makedirs(outdir) + return outdir + + +def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False): + rowd = OrderedDict(epoch=epoch) + rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) + rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) + with open(filename, mode='a') as cf: + dw = csv.DictWriter(cf, fieldnames=rowd.keys()) + if write_header: # first iteration (epoch == 1 can't be used) + dw.writeheader() + dw.writerow(rowd) diff --git a/timm/version.py b/timm/version.py new file mode 100644 index 0000000000000000000000000000000000000000..73e3bb4f38bada34ca9359469effd05bb6767ac7 --- /dev/null +++ b/timm/version.py @@ -0,0 +1 @@ +__version__ = '0.3.2' diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/fid_score.py b/tools/fid_score.py new file mode 100644 index 0000000000000000000000000000000000000000..d07585349edfa52829ab471aa8c85848ad99ee72 --- /dev/null +++ b/tools/fid_score.py @@ -0,0 +1,260 @@ +"""Calculates the Frechet Inception Distance (FID) to evalulate GANs + +The FID metric calculates the distance between two distributions of images. +Typically, we have summary statistics (mean & covariance matrix) of one +of these distributions, while the 2nd distribution is given by a GAN. + +When run as a stand-alone program, it compares the distribution of +images that are stored as PNG/JPEG at a specified location with a +distribution given by summary statistics (in pickle format). + +The FID is calculated by assuming that X_1 and X_2 are the activations of +the pool_3 layer of the inception net for generated samples and real world +samples respectively. + +See --help to see further details. + +Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead +of Tensorflow + +Copyright 2018 Institute of Bioinformatics, JKU Linz + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import pathlib + +import numpy as np +import torch +import torchvision.transforms as TF +from PIL import Image +from scipy import linalg +from torch.nn.functional import adaptive_avg_pool2d + +try: + from tqdm import tqdm +except ImportError: + # If tqdm is not available, provide a mock version of it + def tqdm(x): + return x + +from .inception import InceptionV3 + + +IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', + 'tif', 'tiff', 'webp'} + + +class ImagePathDataset(torch.utils.data.Dataset): + def __init__(self, files, transforms=None): + self.files = files + self.transforms = transforms + + def __len__(self): + return len(self.files) + + def __getitem__(self, i): + path = self.files[i] + img = Image.open(path).convert('RGB') + if self.transforms is not None: + img = self.transforms(img) + return img + + +def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8): + """Calculates the activations of the pool_3 layer for all images. + + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : Batch size of images for the model to process at once. + Make sure that the number of samples is a multiple of + the batch size, otherwise some samples are ignored. This + behavior is retained to match the original FID score + implementation. + -- dims : Dimensionality of features returned by Inception + -- device : Device to run calculations + -- num_workers : Number of parallel dataloader workers + + Returns: + -- A numpy array of dimension (num images, dims) that contains the + activations of the given tensor when feeding inception with the + query tensor. + """ + model.eval() + + if batch_size > len(files): + print(('Warning: batch size is bigger than the data size. ' + 'Setting batch size to data size')) + batch_size = len(files) + + dataset = ImagePathDataset(files, transforms=TF.ToTensor()) + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers) + + pred_arr = np.empty((len(files), dims)) + + start_idx = 0 + + for batch in tqdm(dataloader): + batch = batch.to(device) + + with torch.no_grad(): + pred = model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred = pred.squeeze(3).squeeze(2).cpu().numpy() + + pred_arr[start_idx:start_idx + pred.shape[0]] = pred + + start_idx = start_idx + pred.shape[0] + + return pred_arr + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +def calculate_activation_statistics(files, model, batch_size=50, dims=2048, + device='cpu', num_workers=8): + """Calculation of the statistics used by the FID. + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : The images numpy array is split into batches with + batch size batch_size. A reasonable batch size + depends on the hardware. + -- dims : Dimensionality of features returned by Inception + -- device : Device to run calculations + -- num_workers : Number of parallel dataloader workers + + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the inception model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the inception model. + """ + act = get_activations(files, model, batch_size, dims, device, num_workers) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8): + if path.endswith('.npz'): + with np.load(path) as f: + m, s = f['mu'][:], f['sigma'][:] + else: + path = pathlib.Path(path) + files = sorted([file for ext in IMAGE_EXTENSIONS + for file in path.glob('*.{}'.format(ext))]) + m, s = calculate_activation_statistics(files, model, batch_size, + dims, device, num_workers) + + return m, s + + +def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8): + if device is None: + device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + else: + device = torch.device(device) + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + model = InceptionV3([block_idx]).to(device) + m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers) + np.savez(out_path, mu=m1, sigma=s1) + + +def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8): + """Calculates the FID of two paths""" + if device is None: + device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + else: + device = torch.device(device) + + for p in paths: + if not os.path.exists(p): + raise RuntimeError('Invalid path: %s' % p) + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + + model = InceptionV3([block_idx]).to(device) + + m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, + dims, device, num_workers) + m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, + dims, device, num_workers) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + + return fid_value diff --git a/tools/inception.py b/tools/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..34bdb2152f96f511978df1b44bc290e8984c5ce3 --- /dev/null +++ b/tools/inception.py @@ -0,0 +1,328 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=(DEFAULT_BLOCK_INDEX,), + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = _inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate(x, + size=(299, 299), + mode='bilinear', + align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def _inception_v3(*args, **kwargs): + """Wraps `torchvision.models.inception_v3` + + Skips default weight inititialization if supported by torchvision version. + See https://github.com/mseitzer/pytorch-fid/issues/28. + """ + try: + version = tuple(map(int, torchvision.__version__.split('.')[:2])) + except ValueError: + # Just a caution against weird version strings + version = (0,) + + if version >= (0, 6): + kwargs['init_weights'] = False + + return torchvision.models.inception_v3(*args, **kwargs) + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + inception = _inception_v3(num_classes=1008, + aux_logits=False, + pretrained=False) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(torchvision.models.inception.InceptionA): + """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(torchvision.models.inception.InceptionC): + """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(torchvision.models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(torchvision.models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..437fec395dc29d9f660c7370fb3ba7e7dadebb8f --- /dev/null +++ b/utils.py @@ -0,0 +1,317 @@ +import pickle + +import torch +import torch.nn as nn +import numpy as np +import os +from tqdm import tqdm +from torchvision.utils import save_image +from torch import distributed as dist +from loguru import logger +logging = logger + + +def set_logger(log_level='info', fname=None): + import logging as _logging + handler = logging.get_absl_handler() + formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s') + handler.setFormatter(formatter) + logging.set_verbosity(log_level) + if fname is not None: + handler = _logging.FileHandler(fname) + handler.setFormatter(formatter) + logging.get_absl_logger().addHandler(handler) + + +def dct2str(dct): + return str({k: f'{v:.6g}' for k, v in dct.items()}) + + +def get_nnet(name, **kwargs): + if name == 'uvit_t2i_vq': + from libs.uvit_t2i_vq import UViT + return UViT(**kwargs) + elif name == 'uvit_vq': + from libs.uvit_vq import UViT + return UViT(**kwargs) + else: + raise NotImplementedError(name) + + +def set_seed(seed: int): + if seed is not None: + torch.manual_seed(seed) + np.random.seed(seed) + + +def get_optimizer(params, name, **kwargs): + if name == 'adam': + from torch.optim import Adam + return Adam(params, **kwargs) + elif name == 'adamw': + from torch.optim import AdamW + return AdamW(params, **kwargs) + else: + raise NotImplementedError(name) + + +def customized_lr_scheduler(optimizer, warmup_steps=-1): + from torch.optim.lr_scheduler import LambdaLR + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'customized': + return customized_lr_scheduler(optimizer, **kwargs) + else: + raise NotImplementedError(name) + + +def ema(model_dest: nn.Module, model_src: nn.Module, rate): + param_dict_src = dict(model_src.named_parameters()) + for p_name, p_dest in model_dest.named_parameters(): + p_src = param_dict_src[p_name] + assert p_src is not p_dest + if 'adapter' not in p_name: + p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) + else: + p_dest.data = p_src.detach().clone() + + +class TrainState(object): + def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None): + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.step = step + self.nnet = nnet + self.nnet_ema = nnet_ema + + def ema_update(self, rate=0.9999): + if self.nnet_ema is not None: + ema(self.nnet_ema, self.nnet, rate) + + def save(self, path, adapter_only=False,name=""): + os.makedirs(path, exist_ok=True) + torch.save(self.step, os.path.join(path, 'step.pth')) + if adapter_only: + torch.save(self.nnet.adapter.state_dict(), os.path.join(path, name+'adapter.pth')) + else: + + for key, val in self.__dict__.items(): + if key != 'step' and val is not None: + torch.save(val.state_dict(), os.path.join(path, f'{key}.pth')) + + def make_dict(self,model,state_dict): + state = {} + for k in model.state_dict().keys(): + if k in state_dict: + state[k] = state_dict[k].clone() + else: + state[k] = model.state_dict()[k].clone() + return state + + def load(self, path): + logging.info(f'load from {path}') + self.step = torch.load(os.path.join(path, 'step.pth'), map_location='cpu') + for key, val in self.__dict__.items(): + if key != 'step' and val is not None and key != 'optimizer' and key != 'lr_scheduler': + if key == 'nnet' or key == 'nnet_ema': + val.load_state_dict(self.make_dict(val,torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))) + else: + val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu')) + + def load_adapter(self,path): + logging.info('load adapter from {}'.format(path)) + adapter = torch.load(path,map_location='cpu') + keys=['nnet','nnet_ema'] + for key in keys: + if key in self.__dict__: + self.__dict__[key].adapter.load_state_dict(adapter) + else: + logging.info('adapter not in state_dict') + + def resume(self, ckpt_root,adapter_path=None, step=None): + if not os.path.exists(ckpt_root): + return + if ckpt_root.endswith('.ckpt'): + ckpt_path = ckpt_root + else: + if step is None: + ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root))) + if not ckpts: + return + steps = map(lambda x: int(x.split(".")[0]), ckpts) + step = max(steps) + ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt') + logging.info(f'resume from {ckpt_path}') + self.load(ckpt_path) + if adapter_path is not None: + self.load_adapter(adapter_path) + + def to(self, device): + for key, val in self.__dict__.items(): + if isinstance(val, nn.Module): + val.to(device) + def freeze(self): + self.nnet.requires_grad_(False) + for name, p in self.nnet.named_parameters(): + if 'adapter' in name: + p.requires_grad_(True) + + +def cnt_params(model): + return sum(param.numel() for param in model.parameters()) + + +def initialize_train_state(config, device): + params = [] + + nnet = get_nnet(**config.nnet) + params += nnet.adapter.parameters() + nnet_ema = get_nnet(**config.nnet) + nnet_ema.eval() + logging.info(f'nnet has {cnt_params(nnet)} parameters') + + optimizer = get_optimizer(params, **config.optimizer) + lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler) + + train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0, + nnet=nnet, nnet_ema=nnet_ema) + train_state.ema_update(0) + train_state.to(device) + return train_state + + +def amortize(n_samples, batch_size): + k = n_samples // batch_size + r = n_samples % batch_size + return k * [batch_size] if r == 0 else k * [batch_size] + [r] + + +def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, dist=True): + if path: + os.makedirs(path, exist_ok=True) + idx = 0 + batch_size = mini_batch_size * accelerator.num_processes if dist else mini_batch_size + + for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): + samples = unpreprocess_fn(sample_fn(mini_batch_size)) + if dist: + samples = accelerator.gather(samples.contiguous())[:_batch_size] + if accelerator.is_main_process: + for sample in samples: + save_image(sample, os.path.join(path, f"{idx}.png")) + idx += 1 + + +def grad_norm(model): + total_norm = 0. + for p in model.parameters(): + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** (1. / 2) + return total_norm + +from collections import defaultdict, deque +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter=" "): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def add_meter(self, name, meter): + self.meters[name] = meter + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + from torch._six import inf + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm +