tommymarto commited on
Commit
fea6c63
·
1 Parent(s): 8870098
Files changed (3) hide show
  1. README.md +8 -6
  2. demo.py +338 -0
  3. requirements.txt +1 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Unsupervised Image Editing
3
- emoji: 🌍
4
- colorFrom: indigo
5
- colorTo: blue
 
6
  sdk: gradio
7
  sdk_version: 4.19.2
8
- app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Image Editing Human Evaluation
3
+ emoji: 🖼️
4
+ python_version: 3.10
5
+ colorFrom: red
6
+ colorTo: indigo
7
  sdk: gradio
8
  sdk_version: 4.19.2
9
+ app_file: demo.py
10
  pinned: false
11
+ license: mit
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
demo.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import datetime
3
+ from functools import partial
4
+ import os
5
+ import io
6
+ import json
7
+ import queue
8
+ import gradio as gr
9
+ import random
10
+ import firebase_admin
11
+ from firebase_admin import db, credentials
12
+ from google.oauth2 import service_account
13
+ from googleapiclient.discovery import build
14
+ import google_auth_httplib2
15
+ import googleapiclient
16
+ from googleapiclient.http import MediaIoBaseDownload
17
+ from PIL import Image
18
+ import httplib2
19
+
20
+
21
+
22
+ #################################################################################################################################################
23
+ # Authentication
24
+ #################################################################################################################################################
25
+
26
+
27
+
28
+
29
+ # read secret api key
30
+ FIREBASE_API_KEY = os.environ['FirebaseSecret']
31
+ FIREBASE_URL = os.environ['FirebaseURL']
32
+ DRIVE_API_KEY = os.environ['DriveSecret']
33
+
34
+ SCOPES = ["https://www.googleapis.com/auth/drive"]
35
+
36
+
37
+ #################################################################################################################################################
38
+ # Types
39
+ #################################################################################################################################################
40
+
41
+
42
+ class Experiment():
43
+ def __init__(self, name, corrupted, options, selected_image=None, initialized=False):
44
+ self.name = name
45
+ self.corrupted = corrupted
46
+ self.options = options
47
+ self.selected_image = selected_image
48
+ self.initialized = initialized
49
+
50
+ def to_dict(self):
51
+ return {
52
+ "experiment_name": self.name,
53
+ "corrupted": self.corrupted["name"],
54
+ "options": [img["name"] for img in self.options],
55
+ "selected_image": self.selected_image,
56
+ "algo": self.options[self.selected_image]["algo"]
57
+ }
58
+
59
+ def to_pil(self):
60
+ return self.corrupted["pil"], [img["pil"] for img in self.options]
61
+
62
+ @staticmethod
63
+ def from_dict(source):
64
+ return Experiment(source["name"], source["corrupted"], source["options"])
65
+
66
+ def __repr__(self):
67
+ return f"Experiment(name={self.name}, corrupted={self.corrupted}, options={self.options})"
68
+
69
+ def __str__(self):
70
+ return f"Experiment(name={self.name}, corrupted={self.corrupted}, options={self.options})"
71
+
72
+ def __eq__(self, other):
73
+ return self.name == other.name and self.corrupted == other.corrupted and self.options == other.options
74
+
75
+
76
+ class App():
77
+
78
+ NUM_THREADS = 8
79
+ NUM_TO_SCHEDULE = 8
80
+
81
+ def __init__(self):
82
+ self.init_remote()
83
+ self.init_download_thread()
84
+
85
+ for _ in range(App.NUM_TO_SCHEDULE):
86
+ self.q_requested.put({})
87
+
88
+ self.next_experiment()
89
+ self.build_components_from_experiment()
90
+
91
+
92
+ def lifespan(self, fastapi_app):
93
+ yield
94
+ # cancel thredpool
95
+ self.executor.shutdown(wait=False)
96
+ # cancel download threads
97
+ for _ in range(App.NUM_THREADS):
98
+ self.q_to_download.put(None)
99
+ self.q_requested.put(None)
100
+
101
+ def init_remote(self):
102
+
103
+ def build_request(http, *args, **kwargs):
104
+ new_http = google_auth_httplib2.AuthorizedHttp(self.drive_creds, http=httplib2.Http())
105
+ return googleapiclient.http.HttpRequest(new_http, *args, **kwargs)
106
+
107
+ # init drive service
108
+ self.drive_creds = service_account.Credentials.from_service_account_info(json.loads(DRIVE_API_KEY), scopes=SCOPES)
109
+ authorized_http = google_auth_httplib2.AuthorizedHttp(self.drive_creds, http=httplib2.Http())
110
+ self.drive_service = build("drive", "v3", requestBuilder=build_request, http=authorized_http)
111
+
112
+ # init firebase service
113
+ self.firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY))
114
+ self.firebase_app = firebase_admin.initialize_app(self.firebase_creds, {'databaseURL': FIREBASE_URL})
115
+ self.firebase_data_ref = db.reference("data")
116
+
117
+ def init_download_thread(self):
118
+ # init download thread and queue
119
+ self.q_requested = queue.Queue()
120
+ self.q_to_download = queue.Queue()
121
+ self.q_processed = queue.Queue()
122
+ self.executor = ThreadPoolExecutor(max_workers=2*App.NUM_THREADS)
123
+ for _ in range(App.NUM_THREADS):
124
+ self.executor.submit(download_thread, self.drive_service, self.q_to_download, self.q_processed)
125
+ self.executor.submit(schedule_downloads, self.drive_service, self.q_to_download, self.q_requested)
126
+
127
+ def next_experiment(self):
128
+ self.q_requested.put({})
129
+ self.current_experiment : Experiment = self.q_processed.get()
130
+
131
+ def build_components_from_experiment(self):
132
+ corrupted = self.current_experiment.corrupted
133
+ images = self.current_experiment.options
134
+
135
+ self.corrupted_component = gr.Image(value=corrupted["pil"], label="corr", show_label=True, show_download_button=False, elem_id="padded")
136
+ self.img_components = [
137
+ gr.Image(value=img["pil"], label=f"{i}", show_label=True, show_download_button=False, elem_id="unsel")
138
+ for i, img in enumerate(images)
139
+ ]
140
+
141
+ selected_index = self.current_experiment.selected_image
142
+ if selected_index is not None:
143
+ self.img_components[selected_index] = (
144
+ gr.Image(value=images[selected_index]["pil"], label=f"{selected_index}", show_label=True, show_download_button=False, elem_id="sel")
145
+ )
146
+
147
+ return [*self.img_components, self.corrupted_component]
148
+
149
+ def on_select(self, evt: gr.SelectData): # SelectData is a subclass of EventData
150
+ self.current_experiment.selected_image = int(evt.target.label)
151
+ return self.build_components_from_experiment()
152
+
153
+ def save(self):
154
+ if save_to_firebase(self.current_experiment, self.firebase_data_ref):
155
+ self.next_experiment()
156
+ self.build_components_from_experiment()
157
+ return [*self.img_components, self.corrupted_component]
158
+
159
+
160
+ #################################################################################################################################################
161
+ # API calls
162
+ #################################################################################################################################################
163
+
164
+ def save_to_firebase(experiment, firebase_data_ref):
165
+ if experiment is None or experiment.selected_image is None:
166
+ gr.Warning("You must select an image before submitting")
167
+ return False
168
+
169
+ firebase_data_ref.push({
170
+ **experiment.to_dict(),
171
+ "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
172
+ })
173
+
174
+ gr.Info("Your choice has been saved to Firebase")
175
+ return True
176
+
177
+ def list_folders(service):
178
+ folders = []
179
+
180
+ results = (
181
+ service.files()
182
+ .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q="mimeType='application/vnd.google-apps.folder' and name contains 'Experiment'")
183
+ .execute()
184
+ )
185
+ folders.extend(results.get("files", []))
186
+
187
+ while "nextPageToken" in results:
188
+ page_token = results["nextPageToken"]
189
+ results = (
190
+ service.files()
191
+ .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q="mimeType='application/vnd.google-apps.folder' and name contains 'Experiment'", pageToken=page_token)
192
+ .execute()
193
+ )
194
+ folders.extend(results.get("files", []))
195
+
196
+ return folders
197
+
198
+ def list_files_in_folder(service, folder, filter_=""):
199
+ files = []
200
+
201
+ results = (
202
+ service.files()
203
+ .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"'{folder['id']}' in parents {'and ' + filter_ if filter_ else ''}")
204
+ .execute()
205
+ )
206
+ files.extend(results.get("files", []))
207
+
208
+ while "nextPageToken" in results:
209
+ page_token = results["nextPageToken"]
210
+ results = (
211
+ service.files()
212
+ .list(pageSize=50, fields="nextPageToken, files(id, name)", orderBy="name", q=f"'{folder['id']}' in parents {'and ' + filter_ if filter_ else ''}", pageToken=page_token)
213
+ .execute()
214
+ )
215
+ files.extend(results.get("files", []))
216
+
217
+ return files
218
+
219
+ def download_file(service, file):
220
+ request = service.files().get_media(fileId=file['id'])
221
+ fh = io.BytesIO()
222
+ downloader = MediaIoBaseDownload(fh, request)
223
+ done = False
224
+ while done is False:
225
+ status, done = downloader.next_chunk(10)
226
+ print(f"Download image: '{file['name']}' - progress: {int(status.progress() * 100)}.")
227
+
228
+ return Image.open(fh)
229
+
230
+
231
+ def schedule_downloads(service, q_to_download, q_requested):
232
+ while True:
233
+ print("Waiting for new experiment to schedule")
234
+ if q_requested.get() is None:
235
+ break
236
+
237
+ print("Scheduling new experiment")
238
+ folders = list_folders(service)
239
+
240
+ # sample a random folder from folders
241
+ folder = random.choice(folders)
242
+
243
+ # list subfolders in the folder
244
+ subfolders = list_files_in_folder(service, folder, filter_="mimeType='application/vnd.google-apps.folder'")
245
+
246
+ # the results should be 2 subfolders: SDEdit and ODEdit
247
+ odedit_subfolder = [subfolder for subfolder in subfolders if "ODEdit" in subfolder["name"]][0]
248
+ sdedit_subfolder = [subfolder for subfolder in subfolders if "SDEdit" in subfolder["name"]][0]
249
+
250
+ odedit_files = list_files_in_folder(service, odedit_subfolder)
251
+ sdedit_files = list_files_in_folder(service, sdedit_subfolder)
252
+
253
+ selected_odedit_files = random.sample(odedit_files, k=5)
254
+ selected_odedit_files = [{**file, "algo": "ODEdit"} for file in selected_odedit_files]
255
+
256
+ selected_sdedit_files = random.sample(sdedit_files, k=5)
257
+ selected_sdedit_files = [{**file, "algo": "SDEdit"} for file in selected_sdedit_files]
258
+
259
+ corrupted_file = list_files_in_folder(service, folder, filter_="mimeType contains 'image/'")[0]
260
+
261
+ selected_files = [*selected_odedit_files, *selected_sdedit_files]
262
+
263
+ experiment = Experiment(folder["name"], corrupted_file, selected_files)
264
+
265
+ q_to_download.put(experiment)
266
+ q_requested.task_done()
267
+ print("Experiment scheduled")
268
+
269
+ def download_thread(service, q_to_download, q_processed):
270
+ while True:
271
+ print("Waiting for experiment to download")
272
+ experiment : Experiment = q_to_download.get()
273
+ if experiment is None:
274
+ break
275
+
276
+ corrupted = experiment.corrupted
277
+ print(f"Downloading file {corrupted['name']}")
278
+ corrupted_pil = download_file(service, corrupted)
279
+ print(f"File {corrupted['name']} downloaded")
280
+ experiment.corrupted["pil"] = corrupted_pil
281
+
282
+ for file in experiment.options:
283
+ print(f"Downloading file {file['name']}")
284
+ pil = download_file(service, file)
285
+ print(f"File {file['name']} downloaded")
286
+ file["pil"] = pil
287
+
288
+ q_processed.put(experiment)
289
+ q_to_download.task_done()
290
+ print("Experiment downloaded")
291
+
292
+
293
+ #################################################################################################################################################
294
+ # UI
295
+ #################################################################################################################################################
296
+
297
+
298
+ css = """
299
+ #unsel {border: solid 5px transparent !important; border-radius: 15px !important}
300
+ #sel {border: solid 5px #00c0ff !important; border-radius: 15px !important}
301
+ #padded {margin-left: 25% !important; margin-right: 5% !important}
302
+ #paddedRight {margin-right: 5% !important}
303
+ """
304
+
305
+ def build_demo():
306
+ app = App()
307
+
308
+ with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo:
309
+
310
+ with gr.Row():
311
+ corrupted_component = gr.Image(label="corr", elem_id="padded")
312
+ with gr.Column(scale=3):
313
+ gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>", elem_id="paddedRight")
314
+ gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best</h3></div>", elem_id="paddedRight")
315
+ btn = gr.Button("Submit")
316
+ gr.Markdown("<hr>")
317
+
318
+ img_components = []
319
+ with gr.Row():
320
+ for i, img in enumerate(app.img_components[:5]):
321
+ img_components.append(gr.Image(label=f"{i}", elem_id="unsel"))
322
+
323
+ with gr.Row():
324
+ for i, img in enumerate(app.img_components[5:]):
325
+ img_components.append(gr.Image(label=f"{i+5}", elem_id="unsel"))
326
+
327
+ btn.click(app.save, None, [*img_components, corrupted_component])
328
+ for img in img_components:
329
+ img.select(app.on_select, None, img_components, show_progress="hidden")
330
+
331
+ demo.load(app.build_components_from_experiment, inputs=None, outputs=[*img_components, corrupted_component])
332
+
333
+ return demo, app
334
+
335
+
336
+ if __name__ == "__main__":
337
+ demo, app = build_demo()
338
+ demo.launch(share=False, show_api=False, app_kwargs={"lifespan": app.lifespan})
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ firebase_admin