Spaces:
Running
Running
import streamlit as st | |
import os | |
import torch | |
from transformers import AutoTokenizer | |
from jax import numpy as jnp | |
import json | |
import requests | |
import zipfile | |
import io | |
import natsort | |
from PIL import Image as PilImage | |
from torchvision import datasets, transforms | |
from torchvision.transforms import CenterCrop, Normalize, Resize, ToTensor | |
from torchvision.transforms.functional import InterpolationMode | |
from tqdm import tqdm | |
from modeling_hybrid_clip import FlaxHybridCLIP | |
def get_model(): | |
return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian") | |
def download_images(): | |
# from sentence_transformers import SentenceTransformer, util | |
img_folder = "photos/" | |
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0: | |
os.makedirs(img_folder, exist_ok=True) | |
photo_filename = "unsplash-25k-photos.zip" | |
if not os.path.exists(photo_filename): # Download dataset if does not exist | |
print(f"Downloading {photo_filename}...") | |
r = requests.get("http://sbert.net/datasets/" + photo_filename, stream=True) | |
z = zipfile.ZipFile(io.BytesIO(r.content)) | |
print("Extracting the dataset...") | |
z.extractall(path=img_folder) | |
print("Done.") | |
def get_image_features(model, image_dir): | |
image_size = model.config.vision_config.image_size | |
val_preprocess = transforms.Compose( | |
[ | |
Resize([image_size], interpolation=InterpolationMode.BICUBIC), | |
CenterCrop(image_size), | |
ToTensor(), | |
Normalize( | |
(0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
dataset = CustomDataSet(image_dir, transform=val_preprocess) | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=256, | |
shuffle=False, | |
num_workers=2, | |
persistent_workers=True, | |
drop_last=False, | |
) | |
return precompute_image_features(loader), dataset | |
class CustomDataSet(torch.utils.data.Dataset): | |
def __init__(self, main_dir, transform): | |
self.main_dir = main_dir | |
self.transform = transform | |
all_imgs = os.listdir(main_dir) | |
self.total_imgs = natsort.natsorted(all_imgs) | |
def __len__(self): | |
return len(self.total_imgs) | |
def get_image_name(self, idx): | |
return self.total_imgs[idx] | |
def __getitem__(self, idx): | |
img_loc = os.path.join(self.main_dir, self.total_imgs[idx]) | |
image = PilImage.open(img_loc).convert("RGB") | |
tensor_image = self.transform(image) | |
return tensor_image | |
def text_encoder(text, tokenizer): | |
inputs = tokenizer( | |
[text], | |
max_length=96, | |
truncation=True, | |
padding="max_length", | |
return_tensors="np", | |
) | |
embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[ | |
0 | |
] | |
embedding /= jnp.linalg.norm(embedding) | |
return jnp.expand_dims(embedding, axis=0) | |
def precompute_image_features(loader): | |
image_features = [] | |
for i, (images) in enumerate(tqdm(loader)): | |
images = images.permute(0, 2, 3, 1).numpy() | |
features = model.get_image_features( | |
images, | |
) | |
features /= jnp.linalg.norm(features, axis=-1, keepdims=True) | |
image_features.extend(features) | |
return jnp.array(image_features) | |
def find_image(text_query, dataset, tokenizer, image_features, n=1): | |
zeroshot_weights = text_encoder(text_query, tokenizer) | |
zeroshot_weights /= jnp.linalg.norm(zeroshot_weights) | |
distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1)) | |
file_paths = [] | |
for i in range(1, n + 1): | |
idx = jnp.argsort(distances, axis=0)[-i, 0] | |
file_paths.append("photos/" + dataset.get_image_name(idx)) | |
return file_paths | |
""" | |
# CLIP Italian Demo (Flax Community Week) | |
""" | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
query = st.text_input("Insert a query text") | |
if query: | |
with st.spinner("Computing in progress..."): | |
model = get_model() | |
download_images() | |
tokenizer = AutoTokenizer.from_pretrained( | |
"dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True | |
) | |
image_features, dataset = get_image_features(model, "photos") | |
image_paths = find_image(query, dataset, tokenizer, image_features, n=3) | |
st.image(image_paths) | |