Spaces:
Sleeping
Sleeping
File size: 3,127 Bytes
fe5a3f6 5be4493 959ff93 48fd732 6e6ffc3 442902b 9a86ed3 442902b f39b433 92aacef 442902b 35c3293 845398f 111f3e4 5be4493 442902b eef1967 5be4493 111f3e4 442902b a39ba1e 5be4493 f8920ef e101ad7 c73a44e 442902b 48fd732 f8920ef e101ad7 f8920ef 442902b e101ad7 af5e985 111f3e4 442902b 7eaea7a f6987b8 1ff8081 4226398 1ff8081 5073f14 f6987b8 e68afbc f6987b8 5073f14 1ff8081 4226398 845398f 6760611 5073f14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import sentence_transformers
from sentence_transformers import SentenceTransformer, util
import pickle
from PIL import Image
import os
from datasets import load_dataset
from huggingface_hub.hf_api import HfFolder
import numpy as np
import torch
import os
from PIL import Image
import io
def convert_to_image(byte_data):
"""Convert byte strings to images
"""
return Image.open(io.BytesIO(byte_data))
# Load the model and dataset
model = SentenceTransformer('clip-ViT-B-32')
ds_with_embeddings = load_dataset("kvriza8/clip_microscopy_image_text_embeddings")
# Initialize FAISS index once
ds_with_embeddings['train'].add_faiss_index(column='img_embeddings')
def get_image_from_text(text_prompt, number_to_retrieve=1):
prompt = model.encode(text_prompt)
scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('img_embeddings', prompt, k=number_to_retrieve)
# Convert byte images to PIL images
images = [convert_to_image(img) for img in retrieved_examples['images']]
captions = retrieved_examples['caption_summary']
return images, captions
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def get_image_from_image(query_image, number_to_retrieve=1):
if number_to_retrieve <= 0:
raise ValueError("Number to retrieve must be a positive integer")
image = Image.fromarray((query_image * 255).astype(np.uint8))
inputs = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
image_features = clip_model.get_image_features(**inputs)
image_features_numpy = image_features.cpu().detach().numpy()
scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('img_embeddings', image_features_numpy, k=number_to_retrieve)
images = [convert_to_image(img) for img in retrieved_examples['images']]
captions = retrieved_examples['caption_summary']
return images, captions
def plot_images(text_prompt="", number_to_retrieve=1, query_image=None):
if query_image is not None:
# Handle image input
sample_images, sample_titles = get_image_from_image(query_image, number_to_retrieve)
elif text_prompt:
# Handle text input
sample_images, sample_titles = get_image_from_text(text_prompt, number_to_retrieve)
else:
# Handle empty input
return [], "No input provided"
concatenated_captions = "\n".join(sample_titles)
return sample_images, concatenated_captions
iface = gr.Interface(
title="Microscopy image retrieval",
fn=plot_images,
inputs=[
gr.Textbox(lines=4, label="Insert your prompt", placeholder="Text Here..."),
gr.Slider(0, 8, step=1),
gr.Image(label="Or Upload an Image")
],
outputs=[gr.Gallery(label="Retrieved Images"), gr.Textbox(label="Image Captions")],
examples=[["TEM image", 2], ["Nanoparticles", 1], ["ZnSe-ZnTe core-shell nanowire", 2]]
).launch(debug=True)
|