Spaces:
Build error
Build error
import jax | |
import flax | |
import matplotlib.pyplot as plt | |
import nmslib | |
import numpy as np | |
import os | |
import streamlit as st | |
from tempfile import NamedTemporaryFile | |
from torchvision.transforms import Compose, Resize, ToPILImage | |
from transformers import CLIPProcessor, FlaxCLIPModel | |
from PIL import Image | |
import utils | |
BASELINE_MODEL = "openai/clip-vit-base-patch32" | |
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1" | |
MODEL_PATH = "flax-community/clip-rsicd-v2" | |
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv" | |
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" | |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" | |
# IMAGES_DIR = "/home/shared/data/rsicd_images" | |
IMAGES_DIR = "./images" | |
def split_image(X): | |
num_rows = X.shape[0] // 224 | |
num_cols = X.shape[1] // 224 | |
Xc = X[0 : num_rows * 224, 0 : num_cols * 224, :] | |
patches = [] | |
for j in range(num_rows): | |
for i in range(num_cols): | |
patches.append(Xc[j * 224 : (j + 1) * 224, | |
i * 224 : (i + 1) * 224, | |
:]) | |
return num_rows, num_cols, patches | |
def get_patch_probabilities(patches, searched_feature, | |
image_preprocesor, | |
model, processor): | |
images = [image_preprocesor(patch) for patch in patches] | |
text = "An aerial image of {:s}".format(searched_feature) | |
inputs = processor(images=images, | |
text=text, | |
return_tensors="jax", | |
padding=True) | |
outputs = model(**inputs) | |
probs = jax.nn.softmax(outputs.logits_per_text, axis=-1) | |
probs_np = np.asarray(probs)[0] | |
return probs_np | |
def get_image_ranks(probs): | |
temp = np.argsort(-probs) | |
ranks = np.empty_like(temp) | |
ranks[temp] = np.arange(len(probs)) | |
return ranks | |
def app(): | |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL) | |
st.title("Find Features in Images") | |
st.markdown(""" | |
The CLIP model from OpenAI is trained in a self-supervised manner using | |
contrastive learning to project images and caption text onto a common | |
embedding space. We have fine-tuned the model using the RSICD dataset | |
(10k images and ~50k captions from the remote sensing domain). | |
This demo shows the ability of the model to find specific features | |
(specified as text queries) in the image. As an example, say you wish to | |
find the parts of the following image that contain a `beach`, `houses`, | |
or `ships`. We partition the image into tiles of (224, 224) and report | |
how likely each of them are to contain each text features. | |
""") | |
st.image("demo-images/st_tropez_1.png") | |
st.image("demo-images/st_tropez_2.png") | |
st.markdown(""" | |
For this image and the queries listed above, our model reports that the | |
two left tiles are most likely to contain a `beach`, the two top right | |
tiles are most likely to contain `houses`, and the two bottom right tiles | |
are likely to contain `boats`. | |
You can try it yourself with your own photographs. | |
[Unsplash](https://unsplash.com/s/photos/aerial-view) has some good | |
aerial photographs. You will need to download from Unsplash to your | |
computer and upload it to the demo app. | |
""") | |
buf = st.file_uploader("Upload Image for Analysis") | |
searched_feature = st.text_input("Feature to find") | |
if st.button("Find"): | |
ftmp = NamedTemporaryFile() | |
ftmp.write(buf.getvalue()) | |
image = plt.imread(ftmp.name) | |
if len(image.shape) != 3 and image.shape[2] != 3: | |
st.error("Image should be an RGB image") | |
if image.shape[0] < 224 or image.shape[1] < 224: | |
st.error("Image should be at least (224 x 224") | |
st.image(image, caption="Input Image") | |
st.markdown("---") | |
num_rows, num_cols, patches = split_image(image) | |
image_preprocessor = Compose([ | |
ToPILImage(), | |
Resize(224) | |
]) | |
num_rows, num_cols, patches = split_image(image) | |
patch_probs = get_patch_probabilities( | |
patches, | |
searched_feature, | |
image_preprocessor, | |
model, | |
processor) | |
patch_ranks = get_image_ranks(patch_probs) | |
for i in range(num_rows): | |
row_patches = patches[i * num_cols : (i + 1) * num_cols] | |
row_probs = patch_probs[i * num_cols : (i + 1) * num_cols] | |
row_ranks = patch_ranks[i * num_cols : (i + 1) * num_cols] | |
captions = ["p({:s})={:.3f}, rank={:d}".format(searched_feature, p, r + 1) | |
for p, r in zip(row_probs, row_ranks)] | |
st.image(row_patches, caption=captions) | |