tongyx361 commited on
Commit
935b23c
β€’
0 Parent(s):

Initialize the demo.

Browse files
Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +13 -0
  3. app.py +241 -0
  4. requirements.txt +13 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ImageReward Demo
3
+ emoji: πŸ‘©β€πŸŽ¨
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.28.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from PIL import Image
7
+
8
+ import ImageReward as RM
9
+
10
+ # initialize
11
+
12
+ model_id = "runwayml/stable-diffusion-v1-5"
13
+ pipe = StableDiffusionPipeline.from_pretrained(
14
+ model_id,
15
+ torch_dtype=torch.float16,
16
+ )
17
+
18
+ model = RM.load("ImageReward-v1.0")
19
+
20
+ images_in_gallery = []
21
+ rewards_in_gallery = []
22
+
23
+ # event functions
24
+
25
+
26
+ def generate_images(
27
+ prompt, magic_words, num, height, width, num_inference_steps, guidance_scale
28
+ ):
29
+ global images_in_gallery, rewards_in_gallery
30
+
31
+ if magic_words is not None:
32
+ prompt += ", ".join(magic_words)
33
+
34
+ images_in_gallery = pipe(
35
+ prompt,
36
+ height=height,
37
+ width=width,
38
+ num_inference_steps=num_inference_steps,
39
+ guidance_scale=guidance_scale,
40
+ num_images_per_prompt=num,
41
+ ).images
42
+ rewards_in_gallery = [None] * len(images_in_gallery)
43
+ return list(zip(images_in_gallery, rewards_in_gallery))
44
+
45
+
46
+ def score_and_rank(prompt):
47
+ global rewards_in_gallery, images_in_gallery
48
+
49
+ num_not_scored = rewards_in_gallery.count(None)
50
+
51
+ if num_not_scored > 0:
52
+ images_to_score = images_in_gallery[-num_not_scored:]
53
+ with torch.no_grad():
54
+ ranks, rewards = model.inference_rank(prompt, images_to_score)
55
+
56
+ if not isinstance(rewards, list):
57
+ rewards = [rewards]
58
+ rewards_in_gallery = rewards_in_gallery[:-num_not_scored] + rewards
59
+
60
+ outputs = sorted(
61
+ zip(images_in_gallery, rewards_in_gallery), key=lambda x: x[1], reverse=True
62
+ )
63
+
64
+ images_in_gallery = [image for image, _ in outputs]
65
+ rewards_in_gallery = [reward for _, reward in outputs]
66
+
67
+ return outputs, [
68
+ [idx + 1, reward] for idx, reward in enumerate(rewards_in_gallery)
69
+ ]
70
+ else:
71
+ return list(zip(images_in_gallery, rewards_in_gallery)), [
72
+ [idx + 1, reward] for idx, reward in enumerate(rewards_in_gallery)
73
+ ]
74
+
75
+
76
+ def upload_images_to_gallery(uploaded_image_files):
77
+ global images_in_gallery, rewards_in_gallery
78
+
79
+ uploaded_image_file_paths = [file.name for file in uploaded_image_files]
80
+ uploaded_images = [Image.open(path) for path in uploaded_image_file_paths]
81
+ for path in uploaded_image_file_paths:
82
+ os.remove(path)
83
+ images_in_gallery = images_in_gallery + uploaded_images
84
+ rewards_in_gallery = rewards_in_gallery + [None] * len(uploaded_images)
85
+
86
+ return list(zip(images_in_gallery, rewards_in_gallery))
87
+
88
+
89
+ def clear_images():
90
+ global images_in_gallery, rewards_in_gallery
91
+ images_in_gallery = []
92
+ rewards_in_gallery = []
93
+ return None
94
+
95
+
96
+ if __name__ == "__main__":
97
+ # UI
98
+ with gr.Blocks(
99
+ theme=gr.themes.Monochrome(),
100
+ css=r".caption-label { color: black; }",
101
+ ) as demo:
102
+ gr.HTML(
103
+ """
104
+ <h1 align="center">ImageReward Demo</h1>
105
+ <p align="center"><a href="https://github.com/THUDM/ImageReward">GitHub Repo</a> β€’ πŸ€— <a href="https://huggingface.co/THUDM/ImageReward" target="_blank">HF Repo</a> β€’ 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> β€’ πŸ“ƒ <a href="https://arxiv.org/abs/2304.05977" target="_blank">Paper</a><br></p>
106
+ <br>
107
+ <p dir="auto">ImageReward is the first general-purpose text-to-image <strong>human preference RM</strong>, which is trained on in total <strong>137k pairs of expert comparisons</strong>!</p>
108
+ <p dir="auto">The calculation of ImageRewards is based on <strong>both the prompt and images</strong>.</p>
109
+ """
110
+ )
111
+ with gr.Row():
112
+ with gr.Column():
113
+ gr.HTML(
114
+ """
115
+ <p dir="auto">Try ImageReward with only 2 steps:</p>
116
+ <ol dir="auto">
117
+ <li>Click the <strong>"Generate"</strong> button <strong>in the middle of the bottom</strong>.</li>
118
+ <li>Click the <strong>"Score&Rank"</strong> button <strong>below the gallery</strong>.</li>
119
+ </ol>
120
+ <p dir="auto">Finally, just check ImageRewards <strong>along with images or on the right of the gallery</strong>.</p>
121
+ <br>
122
+ <p dir="auto">This demo uses <code>runwayml/stable-diffusion-v1-5</code> as image generation model.</p>
123
+ """
124
+ )
125
+ with gr.Column():
126
+ gr.HTML(
127
+ """
128
+ <p dir="auto">Besides generating images, you can also <strong>upload</strong> images to score:</p>
129
+ <ol dir="auto">
130
+ <li>Upload images <strong>in the bottom right corner</strong>.</li>
131
+ <li>Change the <strong>"Prompt"</strong> to correspond to the images.</li>
132
+ <li>Click the <strong>"Score&Rank"</strong> button <strong>below the gallery</strong>.</li>
133
+ </ol>
134
+ <br>
135
+ <p dir="auto">For more details about using ImageReward in your own program, check <a href="https://github.com/THUDM/ImageReward">the README.md in our Github Repo</a>.</p>
136
+ """
137
+ )
138
+
139
+ with gr.Row(elem_id="outputs_row"):
140
+ with gr.Column(elem_id="gallery_column", scale=4):
141
+ gallery = gr.Gallery(
142
+ label="Images (scored ones sorted)",
143
+ show_label=False,
144
+ elem_id="gallery",
145
+ ).style(columns=4, object_fit="contain", full_width=True)
146
+ with gr.Column(elem_id="rewards_column"):
147
+ rewards = gr.Matrix(
148
+ value=[[None, None]],
149
+ headers=["Rank", "ImageReward"],
150
+ datatype="number",
151
+ )
152
+ with gr.Row():
153
+ score_and_rank_button = gr.Button("Score&Rank")
154
+ clear_button = gr.Button("Clear Gallery")
155
+ with gr.Row().style(equal_height=True):
156
+ with gr.Column():
157
+ prompt = gr.Textbox(
158
+ label="Prompt",
159
+ value="A painting of an ocean with clouds and birds, day time, low depth field effect, oil painting, impressionism",
160
+ )
161
+
162
+ examples = [
163
+ "A painting of an ocean with clouds and birds, day time, low depth field effect, oil painting, impressionism",
164
+ "A painting of a girl walking in a hallway and suddenly finds a giant sunflower on the floor blocking her way",
165
+ "Coronation of the sun emperor, digital art, illustration,4k resolution,intricate extremely detailed, depth,vivid colors",
166
+ "Symmetry!! Product render poster vivid colors divine proportion owl,glowing fog intricate,elegant, highly detailed",
167
+ "A unicorn in a clearing.it has a single shining horn. volumetric light.by emmanuel shiu, harry potter, eragon",
168
+ "Highly detailed portrait of a woman with long hairs,stephen bliss. unreal engine, fantasy art by greg rutkowski",
169
+ "Sculpture made of flame,portrait, female,future, torch,fire,harper's bazaar,vogue, fashion magazine, intricate",
170
+ ]
171
+ prompt_examples = gr.Examples(
172
+ examples=examples,
173
+ label="Prompt Examples",
174
+ inputs=[prompt],
175
+ elem_id="prompt_examples",
176
+ )
177
+
178
+ with gr.Column():
179
+ choices = [
180
+ "HDR, UHD, 4K, 8K, 64K",
181
+ "highly detailed",
182
+ "studio lighting",
183
+ "professional",
184
+ "trending on artstation",
185
+ "unreal engine",
186
+ "vivid colors",
187
+ ]
188
+ magic_words = gr.CheckboxGroup(
189
+ choices=choices,
190
+ value=choices,
191
+ type="value",
192
+ label="Magic Words to Append to Prompt",
193
+ )
194
+
195
+ num = gr.Slider(1, 16, step=1, label="Number of images", value=8)
196
+ height = gr.Slider(256, 2048, step=256, label="Height", value=512)
197
+ width = gr.Slider(256, 2048, step=256, label="Width", value=512)
198
+ num_inference_steps = gr.Slider(
199
+ 0, 200, step=10, label="Number of inference steps", value=50
200
+ )
201
+ guidance_scale = gr.Slider(
202
+ 0, 25, step=0.1, label="Guidance scale", value=7.5
203
+ )
204
+
205
+ generate_button = gr.Button("Generate")
206
+ with gr.Column():
207
+ gr.Markdown(
208
+ """
209
+ - To clear all uploaded images, click the **"Clear Gallery"** button above.
210
+ - To clear the upload list and add additional images, click the **`x` in the upper right corner of the uploading window**.
211
+ - Additional images will be appended to the gallery, instead of replacing the existing ones.
212
+ """
213
+ )
214
+ uploaded_image_files = gr.File(
215
+ file_count="multiple",
216
+ file_types=["image"],
217
+ type="file",
218
+ label="Upload Images",
219
+ show_label=True,
220
+ )
221
+
222
+ generate_button.click(
223
+ generate_images,
224
+ [
225
+ prompt,
226
+ magic_words,
227
+ num,
228
+ height,
229
+ width,
230
+ num_inference_steps,
231
+ guidance_scale,
232
+ ],
233
+ [gallery],
234
+ )
235
+ score_and_rank_button.click(score_and_rank, [prompt], [gallery, rewards])
236
+ uploaded_image_files.upload(
237
+ upload_images_to_gallery, [uploaded_image_files], [gallery]
238
+ )
239
+ clear_button.click(clear_images, None, [gallery])
240
+
241
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image-reward
2
+ timm==0.6.13
3
+ transformers==4.27.4
4
+ fairscale==0.4.13
5
+ huggingface_hub==0.13.4
6
+ clip @ git+https://github.com/openai/CLIP.git
7
+
8
+ torch
9
+ diffusers
10
+ accelerate
11
+
12
+ Pillow
13
+ gradio