import streamlit as st import os import torch from transformers import AutoTokenizer from jax import numpy as jnp import json import requests import zipfile import io import natsort from PIL import Image as PilImage from torchvision import datasets, transforms from torchvision.transforms import CenterCrop, Normalize, Resize, ToTensor from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm from modeling_hybrid_clip import FlaxHybridCLIP @st.cache def get_model(): return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian") @st.cache def download_images(): # from sentence_transformers import SentenceTransformer, util img_folder = "photos/" if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0: os.makedirs(img_folder, exist_ok=True) photo_filename = "unsplash-25k-photos.zip" if not os.path.exists(photo_filename): # Download dataset if does not exist print(f"Downloading {photo_filename}...") r = requests.get("http://sbert.net/datasets/" + photo_filename, stream=True) z = zipfile.ZipFile(io.BytesIO(r.content)) print("Extracting the dataset...") z.extractall(path=img_folder) print("Done.") @st.cache(allow_output_mutation=True) def get_image_features(model, image_dir): image_size = model.config.vision_config.image_size val_preprocess = transforms.Compose( [ Resize([image_size], interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ToTensor(), Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ] ) dataset = CustomDataSet(image_dir, transform=val_preprocess) loader = torch.utils.data.DataLoader( dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=False, ) return precompute_image_features(loader), dataset class CustomDataSet(torch.utils.data.Dataset): def __init__(self, main_dir, transform): self.main_dir = main_dir self.transform = transform all_imgs = os.listdir(main_dir) self.total_imgs = natsort.natsorted(all_imgs) def __len__(self): return len(self.total_imgs) def get_image_name(self, idx): return self.total_imgs[idx] def __getitem__(self, idx): img_loc = os.path.join(self.main_dir, self.total_imgs[idx]) image = PilImage.open(img_loc).convert("RGB") tensor_image = self.transform(image) return tensor_image def text_encoder(text, tokenizer): inputs = tokenizer( [text], max_length=96, truncation=True, padding="max_length", return_tensors="np", ) embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[ 0 ] embedding /= jnp.linalg.norm(embedding) return jnp.expand_dims(embedding, axis=0) @st.cache def precompute_image_features(model, loader): image_features = [] for i, (images) in enumerate(tqdm(loader)): images = images.permute(0, 2, 3, 1).numpy() features = model.get_image_features( images, ) features /= jnp.linalg.norm(features, axis=-1, keepdims=True) image_features.extend(features) return jnp.array(image_features) def find_image(text_query, dataset, tokenizer, image_features, n=1): zeroshot_weights = text_encoder(text_query, tokenizer) zeroshot_weights /= jnp.linalg.norm(zeroshot_weights) distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1)) file_paths = [] for i in range(1, n + 1): idx = jnp.argsort(distances, axis=0)[-i, 0] file_paths.append("photos/" + dataset.get_image_name(idx)) return file_paths """ # CLIP Italian Demo (Flax Community Week) """ os.environ["TOKENIZERS_PARALLELISM"] = "false" query = st.text_input("Insert a query text") if query: with st.spinner("Computing in progress..."): model = get_model() download_images() tokenizer = AutoTokenizer.from_pretrained( "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True ) image_size = model.config.vision_config.image_size val_preprocess = transforms.Compose( [ Resize([image_size], interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ToTensor(), Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ] ) dataset = CustomDataSet("photos/", transform=val_preprocess) loader = torch.utils.data.DataLoader( dataset, batch_size=16, shuffle=False, num_workers=2, drop_last=False, ) image_features = precompute_image_features(model, loader) image_paths = find_image(query, dataset, tokenizer, image_features, n=2) st.image(image_paths)