import streamlit as st import os import torch from transformers import AutoTokenizer from jax import numpy as jnp import json import requests import zipfile import io import natsort from PIL import Image as PilImage from torchvision import datasets, transforms from torchvision.transforms import CenterCrop, Normalize, Resize, ToTensor from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm from modeling_hybrid_clip import FlaxHybridCLIP import utils @st.cache def get_model(): return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian") @st.cache def download_images(): # from sentence_transformers import SentenceTransformer, util img_folder = "photos/" if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0: os.makedirs(img_folder, exist_ok=True) photo_filename = "unsplash-25k-photos.zip" if not os.path.exists(photo_filename): # Download dataset if does not exist print(f"Downloading {photo_filename}...") r = requests.get("http://sbert.net/datasets/" + photo_filename, stream=True) z = zipfile.ZipFile(io.BytesIO(r.content)) print("Extracting the dataset...") z.extractall(path=img_folder) print("Done.") @st.cache() def get_image_features(): return jnp.load("static/features/features.npy") """ # CLIP Italian Demo (Flax Community Week) """ os.environ["TOKENIZERS_PARALLELISM"] = "false" query = st.text_input("Insert a query text") if query: with st.spinner("Computing in progress..."): model = get_model() download_images() image_features = get_image_features() model = get_model() tokenizer = AutoTokenizer.from_pretrained( "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True ) image_size = model.config.vision_config.image_size val_preprocess = transforms.Compose( [ Resize([image_size], interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ToTensor(), Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ] ) dataset = utils.CustomDataSet("photos/", transform=val_preprocess) image_paths = utils.find_image( query, model, dataset, tokenizer, image_features, n=2 ) st.image(image_paths)