File size: 2,807 Bytes
9db3d14
 
 
 
 
 
 
31140a5
d28444d
 
9db3d14
 
 
 
 
 
 
734dc0e
9db3d14
734dc0e
d28444d
9866292
9db3d14
734dc0e
9db3d14
9866292
9db3d14
b63fee5
9db3d14
 
 
 
 
 
734dc0e
9db3d14
 
 
 
 
 
d28444d
9db3d14
 
 
 
 
 
 
706784e
9db3d14
 
 
 
 
 
 
 
 
 
734dc0e
9db3d14
 
 
 
 
734dc0e
9db3d14
 
734dc0e
9db3d14
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import  ViTModel, ViTImageProcessor
from PIL import Image, ImageOps
import gradio as gr
import torch
from datasets import Dataset
from torch.nn import CosineSimilarity

image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval()
scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval()

candidates: Dataset = None

cosinesimilarity = CosineSimilarity()



def load_candidates(candidate_dir, progress=gr.Progress()):
    def preprocess(examples):
        images = [image for image in examples["image"]]
        examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"]
        progress.update(len(images))
        return examples
    dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in progress.tqdm(candidate_dir)]
    dataset = Dataset.from_list(dataset)
    progress.tqdm(dataset)
    with torch.no_grad():
        dataset = dataset.map(preprocess, batched=True, batch_size=1)
    return dataset


def load_candidates_in_cache(candidate_files):
    global candidates
    candidates = load_candidates(candidate_files)
    return [f.name for f in candidate_files]


def scribble_matching(input_img: Image):
    input_img = ImageOps.invert(input_img)

    scribble = input_img
    scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"])["pooler_output"].to("cpu")
    image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)


    sim = cosinesimilarity(scribble_embedding, image_embeddings)

    predicts = torch.topk(sim, k=15)

    output_imgs = candidates[predicts.indices.tolist()]["image"]
    labels = predicts.values.tolist()
    labels = [f"{label:.3f}" for label in labels]

    return list(zip([input_img] + output_imgs, ["preview"] + labels))


def main():
    with gr.Blocks() as demo:
        with gr.Row():
            input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10)
            prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True)

        with gr.Row():
            candidate_dir = gr.File(file_count="directory", min_width=300, height=300)
            load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
        btn = gr.Button("Scribble Matching", variant="primary")
        load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir], outputs=candidate_dir)
        btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery])

    demo.queue().launch()

if __name__ == "__main__":
    main()