seedmanc commited on
Commit
2c8a8d3
1 Parent(s): b04e41a

Update opp.py

Browse files
Files changed (1) hide show
  1. opp.py +434 -66
opp.py CHANGED
@@ -1,11 +1,31 @@
1
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import torch
4
  import pytorch_lightning as pl
5
  import torch.nn as nn
6
  import clip
7
- from PIL import Image, ImageFile
8
- import gradio as gr
 
9
 
10
  # if you changed the MLP architecture during training, change it also here:
11
  class MLP(pl.LightningModule):
@@ -24,45 +44,26 @@ class MLP(pl.LightningModule):
24
  nn.Linear(128, 64),
25
  #nn.ReLU(),
26
  nn.Dropout(0.1),
27
-
28
  nn.Linear(64, 16),
29
  #nn.ReLU(),
30
-
31
  nn.Linear(16, 1)
32
  )
33
 
34
  def forward(self, x):
35
  return self.layers(x)
36
 
37
- def training_step(self, batch, batch_idx):
38
- x = batch[self.xcol]
39
- y = batch[self.ycol].reshape(-1, 1)
40
- x_hat = self.layers(x)
41
- loss = F.mse_loss(x_hat, y)
42
- return loss
43
-
44
- def validation_step(self, batch, batch_idx):
45
- x = batch[self.xcol]
46
- y = batch[self.ycol].reshape(-1, 1)
47
- x_hat = self.layers(x)
48
- loss = F.mse_loss(x_hat, y)
49
- return loss
50
-
51
- def configure_optimizers(self):
52
- optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
53
- return optimizer
54
-
55
  def normalized(a, axis=-1, order=2):
56
- import numpy as np # pylint: disable=import-outside-toplevel
57
-
58
  l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
59
  l2[l2 == 0] = 1
60
  return a / np.expand_dims(l2, axis)
61
 
 
62
  def load_models():
63
  model = MLP(768)
64
-
65
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
66
 
67
  s = torch.load("sac+logos+ava1-l14-linearMSE.pth", map_location=device)
68
 
@@ -70,7 +71,7 @@ def load_models():
70
  model.to(device)
71
  model.eval()
72
 
73
- model2, preprocess = clip.load("ViT-L/14", device=device)
74
 
75
  model_dict = {}
76
  model_dict['classifier'] = model
@@ -80,54 +81,421 @@ def load_models():
80
 
81
  return model_dict
82
 
83
- def predict(image):
84
- image_input = model_dict['clip_preprocess'](image).unsqueeze(0).to(model_dict['device'])
85
- with torch.no_grad():
86
- image_features = model_dict['clip_model'].encode_image(image_input)
87
- if model_dict['device'] == 'cuda':
88
- im_emb_arr = normalized(image_features.detach().cpu().numpy())
89
- im_emb = torch.from_numpy(im_emb_arr).to(model_dict['device']).type(torch.cuda.FloatTensor)
90
- else:
91
- im_emb_arr = normalized(image_features.detach().numpy())
92
- im_emb = torch.from_numpy(im_emb_arr).to(model_dict['device']).type(torch.FloatTensor)
93
-
94
- prediction = model_dict['classifier'](im_emb)
95
- score = prediction.item()
96
-
97
- return {'aesthetic score': score}
98
-
99
  if __name__ == '__main__':
100
  print('\tinit models')
101
 
102
  global model_dict
 
103
 
104
  model_dict = load_models()
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- inputs = [gr.inputs.Image(type='pil', label='Image')]
 
 
 
 
107
 
108
- outputs = gr.outputs.JSON()
 
 
109
 
110
- title = 'image aesthetic predictor'
111
 
112
- examples = ['example1.jpg', 'example2.jpg', 'example3.jpg']
 
113
 
114
- description = """
115
- # Image Aesthetic Predictor Demo
116
- This model (Image Aesthetic Predictor) is trained by LAION Team. See [https://github.com/christophschuhmann/improved-aesthetic-predictor](https://github.com/christophschuhmann/improved-aesthetic-predictor)
117
- 1. This model is desgined by adding five MLP layers on top of (frozen) CLIP ViT-L/14 and only the MLP layers are fine-tuned with a lot of images by a regression loss term such as MSE and MAE.
118
- 2. Output is bounded from 0 to 10. The higher the better.
119
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- article = "<p style='text-align: center'><a href='https://laion.ai/blog/laion-aesthetics/'>LAION aesthetics blog post</a></p>"
122
-
123
- with gr.Blocks() as demo:
124
- gr.Markdown(description)
125
- with gr.Row():
126
- with gr.Column():
127
- image_input = gr.Image(type='pil', label='Input image')
128
- submit_button = gr.Button('Submit')
129
- json_output = gr.JSON(label='Output')
130
- submit_button.click(predict, inputs=image_input, outputs=json_output)
131
- gr.Examples(examples=examples, inputs=image_input)
132
- gr.HTML(article)
133
- demo.launch()
 
1
+ # -*- coding: utf-8 -*-
2
+ """batch aesthetics predictor v2 - release.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1zTrHop7pStcCwPAUP_nekK1rp6lcppYx
8
+ """
9
+
10
+ # Commented out IPython magic to ensure Python compatibility.
11
+ # %%capture
12
+ # #@title Install environment & dl MLP { form-width: "100%", display-mode: "form" }
13
+ # !pip install git+https://github.com/openai/CLIP.git
14
+ # !pip install gradio~=3.18.0
15
+ # #!pip install torch==1.13.1#+cu116
16
+ # !pip install pytorch-lightning~=2.0.1
17
+ # !wget -nc https://huggingface.co/spaces/Seedmanc/batch-laion-aesthetic-predictor/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth
18
+
19
+ #@title CLIP dl & init { run: "auto", vertical-output: true, form-width: "25%", display-mode: "form" }
20
+ checkpoint = "ViT-L/14" #@param ["ViT-L/14", "ViT-L/14@336px"]
21
  import numpy as np
22
  import torch
23
  import pytorch_lightning as pl
24
  import torch.nn as nn
25
  import clip
26
+ import time
27
+ global prev_time
28
+ global isCpu
29
 
30
  # if you changed the MLP architecture during training, change it also here:
31
  class MLP(pl.LightningModule):
 
44
  nn.Linear(128, 64),
45
  #nn.ReLU(),
46
  nn.Dropout(0.1),
 
47
  nn.Linear(64, 16),
48
  #nn.ReLU(),
 
49
  nn.Linear(16, 1)
50
  )
51
 
52
  def forward(self, x):
53
  return self.layers(x)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def normalized(a, axis=-1, order=2):
 
 
56
  l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
57
  l2[l2 == 0] = 1
58
  return a / np.expand_dims(l2, axis)
59
 
60
+
61
  def load_models():
62
  model = MLP(768)
63
+ global device
64
  device = "cuda" if torch.cuda.is_available() else "cpu"
65
+ global isCpu
66
+ isCpu = device == "cpu"
67
 
68
  s = torch.load("sac+logos+ava1-l14-linearMSE.pth", map_location=device)
69
 
 
71
  model.to(device)
72
  model.eval()
73
 
74
+ model2, preprocess = clip.load(checkpoint, device=device)
75
 
76
  model_dict = {}
77
  model_dict['classifier'] = model
 
81
 
82
  return model_dict
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if __name__ == '__main__':
85
  print('\tinit models')
86
 
87
  global model_dict
88
+ prev_time = time.time()
89
 
90
  model_dict = load_models()
91
+ print('model load', time.time() - prev_time)
92
+ description = f"""
93
+ ## Batch Image Aesthetic Predictor
94
+ 0. Based on https://huggingface.co/spaces/Geonmo/laion-aesthetic-predictor, I just expanded the GUI & added stats.
95
+ 1. This model is designed by adding five MLP layers on top of (frozen) CLIP <u>**{checkpoint}**</u> checkpoint and only the MLP layers are fine-tuned with a lot of images by a regression loss term such as MSE and MAE.
96
+ 2. Output is bounded from 0 to 10. The higher the better.
97
+ 3. The MLP being used currently is: **sac+logos+ava1-l14-linearMSE.pth** trained on 224x224 images.
98
+ 4. Running on **{device}**{', be patient. Progressive output & immediate stats are available.' if isCpu else '. Batch mode enabled, results after completion.'}
99
+ 5. Please don't click 'Submit' again during the processing, it'll mess things up. To stop processing, clear the file input. If the results are missing from the stats or export areas at the end, sort the table by any header & wait.
100
+ {'6. The MLP was not retrained for this CLIP checkpoint, correct results are not guaranteed. It is also 2x slower.' if checkpoint != "ViT-L/14" else ''}
101
+ """
102
 
103
+ #@title 👁️⃤ { run: "auto", form-width: "15%" }
104
+ global predict#or
105
+ writeClip = False #param {type:"boolean"}
106
+ import os
107
+ from PIL import Image
108
 
109
+ if writeClip: #disabled in v1
110
+ import torchvision
111
+ os.makedirs('CLIPped', exist_ok=True)
112
 
 
113
 
114
+ def predict(image):
115
+ img_input = model_dict['clip_preprocess'](Image.open(image))
116
 
117
+ clipped = None
118
+ if writeClip:
119
+ clipped = img_input
120
+
121
+ image_input = img_input.unsqueeze(0).to(model_dict['device']) #try batch
122
+
123
+ with torch.no_grad():
124
+ image_features = model_dict['clip_model'].encode_image(image_input)
125
+
126
+ if model_dict['device'] == 'cuda': # add TPU support?
127
+ im_emb_arr = normalized(image_features.detach().cpu().numpy())
128
+ im_emb = torch.from_numpy(im_emb_arr).to(model_dict['device']).type(torch.cuda.FloatTensor)
129
+ else:
130
+ im_emb_arr = normalized(image_features.detach().numpy())
131
+ im_emb = torch.from_numpy(im_emb_arr).to(model_dict['device']).type(torch.FloatTensor)
132
+
133
+ prediction = model_dict['classifier'](im_emb)
134
+ score = prediction.item() #optimize?
135
+
136
+ return score, clipped
137
+
138
+ #@title Wrapper & stats { form-width: "10%" }
139
+ DEBUG = True #@param {type:"boolean"}
140
+ autoclearLog = True #@param {type:"boolean"}
141
+ import csv
142
+ import sys
143
+ import gradio as gr
144
+ from google.colab import output
145
+ if DEBUG: print(gr.__version__) #
146
+
147
+ def defStats():
148
+ return {'Max':{}, 'Max - min': {}}
149
+
150
+ global Ready
151
+ global avgScore
152
+ global eta
153
+ global speed
154
+ global canPoll
155
+ canPoll=Ready =False
156
+ eta = avgScore = None
157
+ speed = 0
158
+
159
+ Stats = defStats()
160
+ global default_mode
161
+ default_mode = list(Stats.keys())[1]
162
+
163
+ def log(x = '', y = None): # debug only
164
+ if not DEBUG:
165
+ return x
166
+
167
+ global prev_time
168
+ with output.use_tags('debug', append=True):
169
+ print(f"<\033[94m{sys._getframe( 1).f_back.f_code.co_name}\033[96m => \033[97m{sys._getframe().f_back.f_code.co_name}\033[96m>:")
170
+ if x:
171
+ print(time.strftime('%M:%S '), x, round(time.time() - prev_time, 3), '\033[0m')
172
+ if y:
173
+ print(' extra: ', y, '\033[0m')
174
+ prev_time = time.time()
175
+
176
+ return x
177
+
178
+
179
+ def pollStatus(table=[]): ### TODO idk what to do
180
+ time.sleep(1)
181
+ spd = speed and (f'{round(speed,1)} s/f' if speed >= 1 else f'{round(1/speed,1)} f/s')
182
+ stext = f' Avg speed: {spd}.' if speed else ''
183
+ etext = f' ETA: {eta} {"s." if type(eta) == int else ""}' if eta else ''
184
+ atext = f'Running average: {avgScore}.' if avgScore else ''
185
+
186
+ return f"[{time.strftime('%M:%S')}] {' Ready.' if not atext else ''} {atext} {etext} {stext}" if canPoll else 'idle'
187
+
188
+
189
+ def switch_stats(mode):
190
+ global default_mode
191
+ default_mode = mode if mode else 'Max'
192
+
193
+ return Stats[default_mode]
194
+
195
+
196
+ def writeStats(labels):
197
+ with open('stats.csv', 'w', newline='') as f:
198
+ writer = csv.writer(f)
199
+ log('actual write stats', labels and labels.values()) #
200
+ writer.writerow(gr.utils.sanitize_list_for_csv(labels.keys()))
201
+ writer.writerow(gr.utils.sanitize_list_for_csv(labels.values()))
202
+
203
+ # MAIN ################################################################
204
+ def batch_predict(files=None, progress=gr.Progress()): #=> stats_toggle, stats_output, table_output, submit_btn
205
+ run_time = time.time()
206
+ if files and len(files) > 1:
207
+ global eta
208
+ eta = 'calculating...'
209
+
210
+ results = list()
211
+ log('has file(s)?', files and files[0])
212
+ global Stats
213
+ global Ready
214
+ Stats = defStats()
215
+
216
+ if files is None:
217
+ log('empty load')
218
+ yield gr.update(), None, None, gr.update()
219
+ log('ABORT')
220
+ return
221
+ else:
222
+ maxSteps = min(len(files), 3) if isCpu else len(files)
223
+ log('good2go')
224
+ yield gr.update(visible=False), gr.update(visible=False), None, gr.update(variant="secondary")
225
+
226
+ progress((1, maxSteps), unit='', desc='Importing...')
227
+ clearStats()
228
+
229
+ log('start the main loop')
230
+
231
+ times=list()
232
+ clips=list()
233
+
234
+ for idx,file in enumerate(files, start=1):
235
+ prev_time = time.time()
236
+ score, clipped = predict(file)
237
+
238
+ if not Ready: # the solution to the interruption bug, do not remove #
239
+ return
240
+
241
+ results.append([file.orig_name, round(score, 5), None])
242
+
243
+ if writeClip: #disabled in v1
244
+ clips.append((clipped, file.orig_name))
245
+
246
+ times.append(time.time() - prev_time) #simplify
247
+ asyncThreshold = 1 if isCpu else len(files)-1
248
+
249
+ if (idx <= asyncThreshold):
250
+ progress((idx+1, maxSteps), unit='', desc='Starting...' if isCpu else 'Working...')
251
+
252
+ if (idx > asyncThreshold) and (idx < len(files)): # === False if not isCpu
253
+ global avgScore
254
+ global speed
255
+ speed = np.mean(times)
256
+ avgScore = statistics(results, False)
257
+ eta = round(speed*(len(files)-idx+1)) # +1 or [1::]?
258
+ log(idx)
259
+ yield gr.update(), None, results, gr.update()
260
+
261
+ table_data = results
262
+ if DEBUG: print('RUN time', time.time() - run_time, 'avg', np.mean(times)) #
263
+
264
+ if len(results) > 1:
265
+ eta = 'finishing...'
266
+ log('finishing')
267
+ stats = statistics(results)
268
+
269
+ for i, row in enumerate(table_data):
270
+ table_data[i][2] = round((row[1] - stats['AVG'])**2, 4) # pylint: disable=report-general-type-issues
271
+
272
+ writeStats(stats)
273
+ log('|2|', table_data) #
274
+ yield gr.update(visible=True), gr.update(value=switch_stats(default_mode), visible=True), table_data, gr.update(variant="primary")
275
+ else:
276
+ log('I', table_data) #
277
+ yield gr.update(visible=False), gr.update(value=None, visible=False), table_data, gr.update(variant="primary") #
278
+
279
+ avgScore = None
280
+ if writeClip: #supposedly runs async w/o delaying the results? disabled in v1 anyway
281
+ log('beforeWrite')
282
+ for c,f in clips:
283
+ torchvision.utils.save_image(c, 'CLIPped/'+f+'.png', normalize=True)
284
+ log('afterWrite')
285
+
286
+ log('Exit main loop')
287
+ speed = (time.time() - run_time)/len(files)
288
+ # /main #####################################################################
289
+
290
+ def statistics(results, full=True):
291
+ array = np.array(results).T[1].astype(float)
292
+
293
+ max = np.max(array)
294
+ avg = round(array.mean(), 3)
295
+ if (not full): return avg
296
+ med = round(np.median(array), 3)
297
+ min = array.min()
298
+ std = round(array.std(), 4)
299
+ cov = round(std/avg*100, 2)
300
+ rng = round(max-min, 3)
301
+ range = max-min
302
+
303
+ Stats['Max'][f'MAX: {round(max, 3)}'] = 1
304
+ Stats['Max'][f'min: {round(min, 3)}'] = min/max
305
+ Stats['Max'][f"CoV: {cov}%"] = std/max
306
+ Stats['Max'][f'AVG: {avg}'] = avg/max
307
+ Stats['Max'][f'Med: {med}'] = med/max
308
+ Stats['Max'][f'M-m: {rng}'] = range/max
309
+ # TODO can this be shortened?
310
+ if (range == 0):
311
+ range = 1
312
+ Stats['Max - min'][f'MAX: {round(max, 3)}'] = 1
313
+ Stats['Max - min'][f'min: {round(min, 3)}'] = 0
314
+ Stats['Max - min'][f"CoV: {cov}%"] = std/range
315
+ Stats['Max - min'][f'AVG: {avg}'] = (avg-min)/range
316
+ Stats['Max - min'][f'Med: {med}'] = (med-min)/range
317
+ Stats['Max - min'][f'M-m: {rng}'] = rng/max
318
+
319
+ return dict(zip(('AVG','CoV','M-m','min','Med','MAX'), (avg, cov, rng, round(min,3), med, round(max,3))))
320
+
321
+
322
+ def clearStats():
323
+ log('clst too many calls?') #
324
+ for root, dirs, files in os.walk('.'):
325
+ for file in files:
326
+ if (file.startswith(('scores','stats'))): # TODO separate folder, names?
327
+ os.remove(file)
328
+
329
+
330
+ def scan():
331
+ r = ['scores.csv', 'stats.csv']
332
+ return [x for x in r if os.path.isfile(x)]
333
+
334
+ # buggy as fuck
335
+ def writeScores(table, files): # => csv_output, stats_output, stats_toggle
336
+ statsVisible = False
337
+ rows = table and table['data']
338
+ log('Entering the scores writer', 'from table change' if files and table else None)
339
+ showStats = (gr.update(visible=statsVisible) for x in range(0,2)) # add full return statement?
340
+
341
+ if files is None:
342
+ log('No files, exiting writer')#
343
+ resetStatus('from table') # refactor
344
+ return [gr.update(value=scan()), *list(showStats)]
345
+ ######
346
+ def writes(tbl):
347
+ with open('scores.csv', 'w', newline='') as f: #try tsv, json
348
+ writer = csv.writer(f)
349
+ log('Actual saving scores', len(tbl['data'])) #
350
+ writer.writerow(gr.utils.sanitize_list_for_csv(tbl['headers']))
351
+ writer.writerows(gr.utils.sanitize_list_for_csv(tbl['data']))
352
+ ######
353
+ if table and any([x for x in rows[0]]):
354
+ if (len(rows) > 1):
355
+ statsVisible = len(rows) >= len(files)
356
+
357
+ if statsVisible:
358
+ writes(table)
359
+ log('Updating two', 'finished') #
360
+ global eta
361
+ eta = 0
362
+ return [gr.update(value=scan()), *list(showStats)]
363
+
364
+ else:
365
+ statsVisible = False
366
+ if (len(files) == 1):
367
+ writes(table)
368
+ log('updating 1') #
369
+ return [gr.update(value=scan()), *list(showStats)]
370
+
371
+ log('Not ready for writing yet, exiting.', f'total files: {files and len(files)}, but ready rows: {rows and len(rows)}')
372
+ return [gr.update(value=scan()), *list(showStats)]
373
+
374
+ #@title GUI { vertical-output: true, form-width: "50%", display-mode: "both" }
375
+ tableQueued_False = False #@param {type:"boolean"}
376
+ queueConcurrency_2 = 2 #@param {type:"integer", min:1}
377
+ queueUpdateInterval_0 = 0 #@param {type:"slider", min:0, max:10, step:0.2}
378
+ prevent_thread_lock = False #@param {type:"boolean"}
379
+ #@markdown tableQueued == True + queueConcurrency == 1 guarantees stalling on CPU
380
+ #@markdown
381
+ #@markdown tableQueued - unknown effect on speed or stability
382
+ #@markdown
383
+ #@markdown queueConcurrency > 1 - technically should improve speed?
384
+ #@markdown
385
+ #@markdown queueUpdateInterval - in (0, 1] slows down processing, otherwise seems useless.
386
+
387
+ #@markdown prevent_thread_lock - keep the "busy cell" behavior of debug mode without it to avoid multiple instances running in parallel;
388
+ #@markdown effects on speed & stability unknown
389
+ if DEBUG:
390
+ import shutil #i doshutilsya
391
+
392
+ if writeClip: # disabled in v1
393
+ for root, dirs, files in os.walk('CLIPped'):
394
+ for file in files:
395
+ os.remove('CLIPped/'+file)
396
+
397
+ if DEBUG:
398
+ for root, dirs, files in os.walk('../tmp'): #debug only
399
+ for dir in dirs:
400
+ shutil.rmtree('../tmp/'+dir)
401
+ for file in files:
402
+ os.remove('../tmp/'+file) #/debug
403
+
404
+ def resetStatus(msg = 'clear'):
405
+ global avgScore
406
+ global eta
407
+ global speed
408
+ avgScore = None
409
+ eta = None
410
+ speed = 0
411
+ log(msg)
412
+ if msg != 'clear':
413
+ clearStats()
414
+ print('\n')
415
+
416
+
417
+ Css = '''
418
+ #lbl .output-class {
419
+ background-color: transparent;
420
+ max-height: 0;
421
+ color: transparent;
422
+ padding: var(--size-3);
423
+ }
424
+ #add_img .file-preview .file td:first-child {
425
+ overflow-wrap: anywhere;
426
+ }
427
+ #csv_out .file-preview {
428
+ margin-bottom: var(--size-4);
429
+ overflow-x: visible;
430
+ }
431
+ #tbl_out tbody .cell-wrap:first-child {
432
+ overflow-wrap: anywhere;
433
+ }
434
+ button#sbmt:focus:not(:active) {
435
+ opacity: 0.75;
436
+ pointer-events: none;
437
+ }
438
+ #mid_col :not(#csv_out) .wrap.default {
439
+ opacity: 0!important;
440
+ }
441
+ '''
442
+
443
+ def toggleRun(files): # => submit, dataframe, status
444
+ global Ready
445
+ Ready = files is not None
446
+ log('Toggle', Ready)
447
+ global canPoll
448
+ canPoll = Ready
449
+
450
+ if not Ready:
451
+ if eta:
452
+ log('INTERRUPTED at ss remaining (extra)', eta)
453
+ resetStatus()
454
+ if (DEBUG and autoclearLog):
455
+ output.clear(output_tags='debug')
456
+ print('\r')
457
+
458
+ clearStats()
459
+ return gr.Button.update(variant='primary' if Ready else 'secondary'), None, pollStatus()
460
+ # ''', interactive=True''')
461
+
462
+ log('GUI start')
463
+
464
+ blks = gr.Blocks(analytics_enabled=False, title="Batch Image Aesthetic Predictor", css=Css)
465
+ with blks as demo:
466
+ with gr.Accordion('README', open=False):
467
+ gr.Markdown(description)
468
+ if DEBUG and not autoclearLog:
469
+ gr.Button('Clear logs').click(lambda: output.clear(output_tags='debug')) #debug
470
+ with gr.Row().style(equal_height=False):
471
+ with gr.Column(scale=2):
472
+ imageinput = gr.Files(file_types=["image"], label="Add images", elem_id="addimg")
473
+ submit_button = gr.Button('Submit', variant="secondary", elem_id='sbmt') #TODO interactive
474
+ with gr.Column(variant="compact", min_width=256, elem_id="mid_col"):
475
+ stats_toggle = gr.Radio(list(Stats.keys()), show_label=True, label='Stats relative to:', value=default_mode, visible=False)
476
+ stats_output = gr.Label(label='Stats', visible=False, elem_id="lbl")
477
+ csv_output = gr.File( label="Export", elem_id='csv_out' )
478
+ with gr.Column(scale=2):
479
+ table_output = gr.Dataframe(headers=['Image', 'Score', 'MSE'], max_rows=15, overflow_row_behaviour="paginate", interactive=False, wrap=True, elem_id="tbl_out")
480
+
481
+ status = gr.Textbox(pollStatus(), max_lines=1, show_label=False, placeholder='Status bar').style(container=False)
482
+ status.change(pollStatus, None, status, show_progress= False, queue=False)
483
+
484
+ tch = table_output.change(writeScores, [table_output, imageinput], [csv_output, stats_output, stats_toggle], preprocess=False, queue= tableQueued_False, show_progress=not isCpu)
485
+ stats_toggle.change(switch_stats, [stats_toggle], [stats_output], queue=False, show_progress=False)
486
+
487
+ run = submit_button.click(batch_predict, imageinput, [stats_toggle, stats_output, table_output, submit_button], queue=True, scroll_to_output=True)
488
+ #imageinput.clear(reset, [imageinput], [table_output], queue=False, show_progress=True, preprocess=False)
489
+ imageinput.change(toggleRun, imageinput, [submit_button, table_output, status], queue= False, cancels=[run], show_progress=False) #
490
+ # try .then()
491
+ if DEBUG:
492
+ demo.load(lambda: log('load'), queue=not True, show_progress=False)
493
+
494
+ demo.queue(api_open= not DEBUG, status_update_rate='auto' if queueUpdateInterval_0 == 0 else queueUpdateInterval_0 , concurrency_count=max(queueConcurrency_2, 1))
495
+ log('Prelaunch')
496
+
497
+ #demo.dev_mode = DEBUG
498
+ demo.launch(debug=DEBUG, quiet= not DEBUG, show_error= True, prevent_thread_lock=prevent_thread_lock, height=768)
499
+ if (prevent_thread_lock and not DEBUG): demo.block_thread()
500
 
501
+ #demo.close()