""" # Copyright (c) 2022, salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from PIL import Image import requests import torch import os from lavis.common.registry import registry from lavis.processors import * from lavis.models import * from lavis.common.utils import build_default_model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_demo_image(): img_url = ( "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg" ) raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") return raw_image def read_img(filepath): raw_image = Image.open(filepath).convert("RGB") return raw_image # model model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth" feature_extractor = BlipFeatureExtractor(pretrained=model_url) feature_extractor.eval() feature_extractor = feature_extractor.to(device) # preprocessors vis_processor = BlipImageEvalProcessor(image_size=224) text_processor = BlipCaptionProcessor() # files to process # file_root = "/export/home/.cache/lavis/coco/images/val2014" file_root = "/export/home/.cache/lavis/coco/images/train2014" filepaths = os.listdir(file_root) print(len(filepaths)) caption = "dummy" path2feat = dict() bsz = 256 images_in_batch = [] filepaths_in_batch = [] for i, filename in enumerate(filepaths): if i % bsz == 0 and i > 0: images_in_batch = torch.cat(images_in_batch, dim=0).to(device) with torch.no_grad(): image_features = feature_extractor( images_in_batch, caption, mode="image", normalized=True )[:, 0] for filepath, image_feat in zip(filepaths_in_batch, image_features): path2feat[os.path.basename(filepath)] = image_feat.detach().cpu() images_in_batch = [] filepaths_in_batch = [] print(len(path2feat), image_features.shape) else: filepath = os.path.join(file_root, filename) image = read_img(filepath) image = vis_processor(image).unsqueeze(0) images_in_batch.append(image) filepaths_in_batch.append(filepath) torch.save(path2feat, "path2feat_coco_train2014.pth")