File size: 3,127 Bytes
fe5a3f6
5be4493
 
 
 
 
 
 
959ff93
48fd732
6e6ffc3
442902b
 
 
 
 
9a86ed3
 
442902b
 
f39b433
92aacef
442902b
 
35c3293
845398f
111f3e4
5be4493
442902b
eef1967
5be4493
111f3e4
442902b
 
 
 
a39ba1e
5be4493
f8920ef
 
e101ad7
c73a44e
 
 
442902b
48fd732
f8920ef
e101ad7
f8920ef
442902b
e101ad7
af5e985
 
111f3e4
442902b
 
 
7eaea7a
 
f6987b8
 
 
 
 
 
 
 
 
 
 
1ff8081
 
4226398
1ff8081
 
 
5073f14
f6987b8
e68afbc
f6987b8
5073f14
1ff8081
 
 
4226398
845398f
6760611
5073f14
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import sentence_transformers
from sentence_transformers import SentenceTransformer, util
import pickle
from PIL import Image
import os
from datasets import load_dataset
from huggingface_hub.hf_api import HfFolder
import numpy as np 
import torch
import os
from PIL import Image
import io

def convert_to_image(byte_data):
    """Convert byte strings to images
    """
    return Image.open(io.BytesIO(byte_data))

# Load the model and dataset
model = SentenceTransformer('clip-ViT-B-32')
ds_with_embeddings = load_dataset("kvriza8/clip_microscopy_image_text_embeddings")


# Initialize FAISS index once
ds_with_embeddings['train'].add_faiss_index(column='img_embeddings')


def get_image_from_text(text_prompt, number_to_retrieve=1):
    prompt = model.encode(text_prompt)
    scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('img_embeddings', prompt, k=number_to_retrieve)
    
    # Convert byte images to PIL images
    images = [convert_to_image(img) for img in retrieved_examples['images']]
    captions = retrieved_examples['caption_summary']
    return images, captions

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def get_image_from_image(query_image, number_to_retrieve=1):
    if number_to_retrieve <= 0:
        raise ValueError("Number to retrieve must be a positive integer")

    image = Image.fromarray((query_image * 255).astype(np.uint8))
    inputs = clip_processor(images=image, return_tensors="pt")

    with torch.no_grad():
        image_features = clip_model.get_image_features(**inputs)

    image_features_numpy = image_features.cpu().detach().numpy()

    scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('img_embeddings', image_features_numpy, k=number_to_retrieve)
    
    images = [convert_to_image(img) for img in retrieved_examples['images']]
    captions = retrieved_examples['caption_summary']
    return images, captions

def plot_images(text_prompt="", number_to_retrieve=1, query_image=None):
    if query_image is not None:
        # Handle image input
        sample_images, sample_titles = get_image_from_image(query_image, number_to_retrieve)
    elif text_prompt:
        # Handle text input
        sample_images, sample_titles = get_image_from_text(text_prompt, number_to_retrieve)
    else:
        # Handle empty input
        return [], "No input provided"

    concatenated_captions = "\n".join(sample_titles)
    return sample_images, concatenated_captions

iface = gr.Interface(
    title="Microscopy image retrieval",
    fn=plot_images,
    inputs=[
        gr.Textbox(lines=4, label="Insert your prompt", placeholder="Text Here..."),
        gr.Slider(0, 8, step=1),
        gr.Image(label="Or Upload an Image")
    ],
    outputs=[gr.Gallery(label="Retrieved Images"), gr.Textbox(label="Image Captions")],
    examples=[["TEM image", 2], ["Nanoparticles", 1], ["ZnSe-ZnTe core-shell nanowire", 2]]
).launch(debug=True)