Spaces:
Running
on
Zero
Running
on
Zero
| ''' | |
| * Adapted from BLIP (https://github.com/salesforce/BLIP) | |
| ''' | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import torch | |
| import os | |
| from urllib.parse import urlparse | |
| from timm.models.hub import download_cached_file | |
| from transformers import BertTokenizer | |
| from .vit import VisionTransformer, interpolate_pos_embed | |
| def default_bert(): | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| project_root = os.path.abspath(os.path.join(current_dir, '../../../../')) | |
| model_path = os.path.join(project_root, 'models', 'QualityMetric') | |
| return os.path.join(model_path, "bert-base-uncased") | |
| def init_tokenizer(bert_model_path): | |
| tokenizer = BertTokenizer.from_pretrained(bert_model_path) | |
| tokenizer.add_special_tokens({'bos_token':'[DEC]'}) | |
| tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) | |
| tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] | |
| return tokenizer | |
| def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): | |
| assert vit in ['base', 'large'], "vit parameter must be base or large" | |
| if vit=='base': | |
| vision_width = 768 | |
| visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, | |
| num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, | |
| drop_path_rate=0 or drop_path_rate | |
| ) | |
| elif vit=='large': | |
| vision_width = 1024 | |
| visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, | |
| num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, | |
| drop_path_rate=0.1 or drop_path_rate | |
| ) | |
| return visual_encoder, vision_width | |
| def is_url(url_or_filename): | |
| parsed = urlparse(url_or_filename) | |
| return parsed.scheme in ("http", "https") | |
| def load_checkpoint(model,url_or_filename): | |
| if is_url(url_or_filename): | |
| cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) | |
| checkpoint = torch.load(cached_file, map_location='cpu') | |
| elif os.path.isfile(url_or_filename): | |
| checkpoint = torch.load(url_or_filename, map_location='cpu') | |
| else: | |
| raise RuntimeError('checkpoint url or path is invalid') | |
| state_dict = checkpoint['model'] | |
| state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) | |
| if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): | |
| state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], | |
| model.visual_encoder_m) | |
| for key in model.state_dict().keys(): | |
| if key in state_dict.keys(): | |
| if state_dict[key].shape!=model.state_dict()[key].shape: | |
| print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape) | |
| del state_dict[key] | |
| msg = model.load_state_dict(state_dict,strict=False) | |
| print('load checkpoint from %s'%url_or_filename) | |
| return model,msg | |