|
"""
|
|
# Copyright (c) 2022, salesforce.com, inc.
|
|
# All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
from PIL import Image
|
|
import requests
|
|
import torch
|
|
|
|
import os
|
|
|
|
from lavis.common.registry import registry
|
|
from lavis.processors import *
|
|
from lavis.models import *
|
|
from lavis.common.utils import build_default_model
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
def load_demo_image():
|
|
img_url = (
|
|
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
|
|
)
|
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
|
|
|
return raw_image
|
|
|
|
|
|
def read_img(filepath):
|
|
raw_image = Image.open(filepath).convert("RGB")
|
|
|
|
return raw_image
|
|
|
|
|
|
|
|
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
|
|
feature_extractor = BlipFeatureExtractor(pretrained=model_url)
|
|
|
|
feature_extractor.eval()
|
|
feature_extractor = feature_extractor.to(device)
|
|
|
|
|
|
vis_processor = BlipImageEvalProcessor(image_size=224)
|
|
text_processor = BlipCaptionProcessor()
|
|
|
|
|
|
|
|
file_root = "/export/home/.cache/lavis/coco/images/train2014"
|
|
filepaths = os.listdir(file_root)
|
|
|
|
print(len(filepaths))
|
|
|
|
caption = "dummy"
|
|
|
|
path2feat = dict()
|
|
bsz = 256
|
|
|
|
images_in_batch = []
|
|
filepaths_in_batch = []
|
|
|
|
for i, filename in enumerate(filepaths):
|
|
if i % bsz == 0 and i > 0:
|
|
images_in_batch = torch.cat(images_in_batch, dim=0).to(device)
|
|
with torch.no_grad():
|
|
image_features = feature_extractor(
|
|
images_in_batch, caption, mode="image", normalized=True
|
|
)[:, 0]
|
|
|
|
for filepath, image_feat in zip(filepaths_in_batch, image_features):
|
|
path2feat[os.path.basename(filepath)] = image_feat.detach().cpu()
|
|
|
|
images_in_batch = []
|
|
filepaths_in_batch = []
|
|
|
|
print(len(path2feat), image_features.shape)
|
|
else:
|
|
filepath = os.path.join(file_root, filename)
|
|
|
|
image = read_img(filepath)
|
|
image = vis_processor(image).unsqueeze(0)
|
|
|
|
images_in_batch.append(image)
|
|
filepaths_in_batch.append(filepath)
|
|
|
|
torch.save(path2feat, "path2feat_coco_train2014.pth")
|
|
|