yuvalkirstain commited on
Commit
46535f9
1 Parent(s): dec59bb

first commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import gradio as gr
4
+ from glob import glob
5
+ import torch
6
+ from transformers import AutoModel, AutoProcessor
7
+
8
+ DEFAULT_EXAMPLE_PATH = f'{os.path.dirname(__file__)}/examples/example_0'
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ weight_dtype = torch.bfloat16 if device == "cuda" else torch.float32
12
+ print(f"Using device: {device} ({weight_dtype})")
13
+ print("Loading model...")
14
+ model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"
15
+ processor = AutoProcessor.from_pretrained(model_pretrained_name_or_path)
16
+ model = AutoModel.from_pretrained(model_pretrained_name_or_path, torch_dtype=weight_dtype).eval().to(device)
17
+ print("Model loaded.")
18
+
19
+
20
+ def calc_probs(prompt, images):
21
+ print("Processing inputs...")
22
+ image_inputs = processor(
23
+ images=images,
24
+ padding=True,
25
+ truncation=True,
26
+ max_length=77,
27
+ return_tensors="pt",
28
+ ).to(device)
29
+
30
+ image_inputs = {k: v.to(weight_dtype) for k, v in image_inputs.items()}
31
+
32
+ text_inputs = processor(
33
+ text=prompt,
34
+ padding=True,
35
+ truncation=True,
36
+ max_length=77,
37
+ return_tensors="pt",
38
+ ).to(device)
39
+
40
+ with torch.no_grad():
41
+ print("Embedding images and text...")
42
+ image_embs = model.get_image_features(**image_inputs)
43
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
44
+
45
+ text_embs = model.get_text_features(**text_inputs)
46
+ text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
47
+
48
+ print("Calculating scores...")
49
+ scores = model.logit_scale.exp() * (text_embs.float() @ image_embs.float().T)[0]
50
+
51
+ print("Calculating probabilities...")
52
+ probs = torch.softmax(scores, dim=-1)
53
+
54
+ return probs.cpu().tolist()
55
+
56
+
57
+ def predict(prompt, image_1, image_2):
58
+ print(f"Starting prediction for prompt: {prompt}")
59
+ probs = calc_probs(prompt, [image_1, image_2])
60
+ print(f"Prediction: {probs}")
61
+ return str(round(probs[0], 3)), str(round(probs[1], 3))
62
+
63
+
64
+ with gr.Blocks(title="PickScore v1") as demo:
65
+ gr.Markdown("# PickScore v1")
66
+ gr.Markdown(
67
+ "This is a demo for the PickScore model - see [paper](https://arxiv.org/abs/2305.01569), [code](https://github.com/yuvalkirstain/PickScore), [dataset](https://huggingface.co/datasets/pickapic-anonymous/pickapic_v1), and [model](https://huggingface.co/yuvalkirstain/PickScore_v1).")
68
+ gr.Markdown("## Instructions")
69
+ gr.Markdown("Write a prompt, place two images, and press run to get their PickScore!")
70
+ with gr.Row():
71
+ prompt = gr.inputs.Textbox(lines=1, label="Prompt",
72
+ default=open(f'{DEFAULT_EXAMPLE_PATH}/prompt.txt').readline())
73
+ with gr.Row():
74
+ image_1 = gr.components.Image(type="pil", label="image 1",
75
+ value=Image.open(f'{DEFAULT_EXAMPLE_PATH}/image_1.png'))
76
+ image_2 = gr.components.Image(type="pil", label="image 2",
77
+ value=Image.open(f'{DEFAULT_EXAMPLE_PATH}/image_2.png'))
78
+ with gr.Row():
79
+ pred_1 = gr.outputs.Textbox(label="Probability 1")
80
+ pred_2 = gr.outputs.Textbox(label="Probability 2")
81
+
82
+ btn = gr.Button("Run")
83
+ btn.click(fn=predict, inputs=[prompt, image_1, image_2], outputs=[pred_1, pred_2])
84
+
85
+ gr.Examples(
86
+ [[open(f'{path}/prompt.txt').readline(), f'{path}/image_1.png', f'{path}/image_2.png'] for path in
87
+ glob(f'{os.path.dirname(__file__)}/examples/*')],
88
+ [prompt, image_1, image_2],
89
+ [pred_1, pred_2],
90
+ predict
91
+ )
92
+
93
+ demo.launch(share=True)
examples/example_0/image_1.png ADDED

Git LFS Details

  • SHA256: b447ee5abf70a31b20433aff393add09f9770bd6963dacc0f497f54ff1003f13
  • Pointer size: 131 Bytes
  • Size of remote file: 427 kB
examples/example_0/image_2.png ADDED

Git LFS Details

  • SHA256: c497ec4b927d7a38df6f59ab7dab0ea907ca2eb90ac9fb08b811acc6bb78e04a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
examples/example_0/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A sign that says PICK A PIC
examples/example_1/image_1.png ADDED

Git LFS Details

  • SHA256: 6afec84a65f687c76704f6dabd73718d346f0d093f655b625a590f4229ee6be5
  • Pointer size: 131 Bytes
  • Size of remote file: 841 kB
examples/example_1/image_2.png ADDED

Git LFS Details

  • SHA256: aca819d9f57e35eddd3cc971002719869b709ea88e02e19db0323d8d59a46a6b
  • Pointer size: 131 Bytes
  • Size of remote file: 934 kB
examples/example_1/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A bee devouring the world
examples/example_2/image_1.png ADDED

Git LFS Details

  • SHA256: ab1ef7d55c7d04827e8b0e01f1c197c7a82644b6ba672bf71c4e7e26f6a36c5b
  • Pointer size: 131 Bytes
  • Size of remote file: 992 kB
examples/example_2/image_2.png ADDED

Git LFS Details

  • SHA256: 73b3764e591b45eb4f19cf24b0c0d92b1d60644c711652ba3ec9bf2737253d42
  • Pointer size: 131 Bytes
  • Size of remote file: 857 kB
examples/example_2/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Crazy frog on one wheel
examples/example_3/image_1.png ADDED

Git LFS Details

  • SHA256: 27a539671b2b928e6a57c11452991150bea83ce572d500db68dbf92826ebd733
  • Pointer size: 131 Bytes
  • Size of remote file: 740 kB
examples/example_3/image_2.png ADDED

Git LFS Details

  • SHA256: 37bfc68bebe837e432d64725439519cf85a06c80ba8ecd31f3738102fa80badc
  • Pointer size: 131 Bytes
  • Size of remote file: 760 kB
examples/example_3/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A cute fluffy easter bunny singing
examples/example_4/image_1.png ADDED

Git LFS Details

  • SHA256: d5d00769883a61b32c9589913c956acddb81181ebd2cc56ffc0de3a0dc72fe84
  • Pointer size: 131 Bytes
  • Size of remote file: 971 kB
examples/example_4/image_2.png ADDED

Git LFS Details

  • SHA256: b38d23aaf87595ac3aa6f74258fbdfc5b7342cb92f8edf38e96eda0376726d36
  • Pointer size: 131 Bytes
  • Size of remote file: 928 kB
examples/example_4/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Silver furred Lion man hybrid with few leather clothes
examples/example_5/image_1.png ADDED

Git LFS Details

  • SHA256: f38c1605bccdd08783c0b362a1948998fdb0dfb7f269ee3f4a9a782782264b75
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
examples/example_5/image_2.png ADDED

Git LFS Details

  • SHA256: b3fadbd6c75ee681588c6155a159770b9cfcc12ac6b7217c52548dd1ead8e12f
  • Pointer size: 131 Bytes
  • Size of remote file: 999 kB
examples/example_5/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Superman's corpse found
examples/example_6/image_1.png ADDED

Git LFS Details

  • SHA256: 73ddbe2000310f1bb6f8712a5f08c28c60788cbf107b29e4c2f549fc206dc520
  • Pointer size: 131 Bytes
  • Size of remote file: 877 kB
examples/example_6/image_2.png ADDED

Git LFS Details

  • SHA256: ec2bb1b16de70f8ff009deb043f1ea9835497c6d3421fa9d6a2179808711b8d2
  • Pointer size: 131 Bytes
  • Size of remote file: 717 kB
examples/example_6/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ a girl
examples/example_7/image_1.png ADDED

Git LFS Details

  • SHA256: cb8290da1816eb9ea701154677a83afbfcb977978e2d7473dd6bef1944852fcf
  • Pointer size: 131 Bytes
  • Size of remote file: 846 kB
examples/example_7/image_2.png ADDED

Git LFS Details

  • SHA256: 07f8b3657fee644b8ef129f60d5c0d7d2caad3a97992f075e55c675b6df70b1a
  • Pointer size: 131 Bytes
  • Size of remote file: 385 kB
examples/example_7/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A gloomy rabbit drinks wine
examples/example_8/image_1.png ADDED

Git LFS Details

  • SHA256: 83bf47ebdff85b39fe67cb34ecd413832ec1d492eef823c4395a73b04412fd55
  • Pointer size: 132 Bytes
  • Size of remote file: 1.76 MB
examples/example_8/image_2.png ADDED

Git LFS Details

  • SHA256: 9952429fcac76f9a116ac043615668528817b51174f5bb2eea4e8bc9529dad12
  • Pointer size: 132 Bytes
  • Size of remote file: 1.77 MB
examples/example_8/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ a beautiful landscape
examples/example_9/image_1.png ADDED

Git LFS Details

  • SHA256: cb8290da1816eb9ea701154677a83afbfcb977978e2d7473dd6bef1944852fcf
  • Pointer size: 131 Bytes
  • Size of remote file: 846 kB
examples/example_9/image_2.png ADDED

Git LFS Details

  • SHA256: 07f8b3657fee644b8ef129f60d5c0d7d2caad3a97992f075e55c675b6df70b1a
  • Pointer size: 131 Bytes
  • Size of remote file: 385 kB
examples/example_9/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A gloomy rabbit drinks wine
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ transformers