Spaces:
Sleeping
Sleeping
File size: 2,807 Bytes
9db3d14 31140a5 d28444d 9db3d14 734dc0e 9db3d14 734dc0e d28444d 9866292 9db3d14 734dc0e 9db3d14 9866292 9db3d14 b63fee5 9db3d14 734dc0e 9db3d14 d28444d 9db3d14 706784e 9db3d14 734dc0e 9db3d14 734dc0e 9db3d14 734dc0e 9db3d14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from transformers import ViTModel, ViTImageProcessor
from PIL import Image, ImageOps
import gradio as gr
import torch
from datasets import Dataset
from torch.nn import CosineSimilarity
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval()
scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval()
candidates: Dataset = None
cosinesimilarity = CosineSimilarity()
def load_candidates(candidate_dir, progress=gr.Progress()):
def preprocess(examples):
images = [image for image in examples["image"]]
examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"]
progress.update(len(images))
return examples
dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in progress.tqdm(candidate_dir)]
dataset = Dataset.from_list(dataset)
progress.tqdm(dataset)
with torch.no_grad():
dataset = dataset.map(preprocess, batched=True, batch_size=1)
return dataset
def load_candidates_in_cache(candidate_files):
global candidates
candidates = load_candidates(candidate_files)
return [f.name for f in candidate_files]
def scribble_matching(input_img: Image):
input_img = ImageOps.invert(input_img)
scribble = input_img
scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"])["pooler_output"].to("cpu")
image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
sim = cosinesimilarity(scribble_embedding, image_embeddings)
predicts = torch.topk(sim, k=15)
output_imgs = candidates[predicts.indices.tolist()]["image"]
labels = predicts.values.tolist()
labels = [f"{label:.3f}" for label in labels]
return list(zip([input_img] + output_imgs, ["preview"] + labels))
def main():
with gr.Blocks() as demo:
with gr.Row():
input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10)
prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True)
with gr.Row():
candidate_dir = gr.File(file_count="directory", min_width=300, height=300)
load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
btn = gr.Button("Scribble Matching", variant="primary")
load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir], outputs=candidate_dir)
btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery])
demo.queue().launch()
if __name__ == "__main__":
main() |