File size: 16,966 Bytes
2c8a8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12d2a52
1d50f06
 
 
 
 
2c8a8d3
 
 
1d50f06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c8a8d3
1d50f06
 
2c8a8d3
1d50f06
2c8a8d3
 
1d50f06
 
 
 
 
 
 
ba32bfd
1d50f06
 
 
 
 
 
 
 
 
 
 
 
 
2c8a8d3
1d50f06
 
2c8a8d3
 
 
 
 
 
 
 
 
 
 
1d50f06
2c8a8d3
 
 
 
 
1d50f06
2c8a8d3
 
 
1d50f06
 
2c8a8d3
 
1d50f06
2c8a8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26557fc
2898bf9
2c8a8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90e53a6
38b31c0
 
 
 
 
2c8a8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0af732e
be75d7a
e56f00c
2c8a8d3
 
 
 
 
 
 
 
 
 
 
 
2898bf9
2c8a8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e384182
be75d7a
2c8a8d3
 
 
 
 
 
 
 
 
 
 
e56f00c
2c8a8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0af732e
2c8a8d3
 
e56f00c
1d50f06
2c8a8d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
# -*- coding: utf-8 -*-
"""batch aesthetics predictor v2 - release.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1zTrHop7pStcCwPAUP_nekK1rp6lcppYx
"""

# Commented out IPython magic to ensure Python compatibility.
# %%capture
# #@title Install environment & dl MLP { form-width: "100%", display-mode: "form" }
# !pip install git+https://github.com/openai/CLIP.git
# !pip install gradio~=3.18.0
# #!pip install torch==1.13.1#+cu116
# !pip install pytorch-lightning~=2.0.1
# !wget -nc  https://huggingface.co/spaces/Seedmanc/batch-laion-aesthetic-predictor/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth

#@title CLIP dl & init { run: "auto", vertical-output: true, form-width: "25%", display-mode: "form" }
checkpoint = "ViT-L/14" #@param ["ViT-L/14", "ViT-L/14@336px"]
import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn as nn
import clip
import time
global prev_time
global isCpu

# if you changed the MLP architecture during training, change it also here:
class MLP(pl.LightningModule):
    def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
        super().__init__()
        self.input_size = input_size
        self.xcol = xcol
        self.ycol = ycol
        self.layers = nn.Sequential(
            nn.Linear(self.input_size, 1024),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            #nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            #nn.ReLU(),
            nn.Linear(16, 1)
        )

    def forward(self, x):
        return self.layers(x)

def normalized(a, axis=-1, order=2):
    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)


def load_models():
    model = MLP(768)
    global device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    global isCpu
    isCpu = device == "cpu"

    s = torch.load("sac+logos+ava1-l14-linearMSE.pth", map_location=device)

    model.load_state_dict(s)
    model.to(device)
    model.eval()

    model2, preprocess = clip.load(checkpoint, device=device, jit=True)

    model_dict = {}
    model_dict['classifier'] = model
    model_dict['clip_model'] = model2
    model_dict['clip_preprocess'] = preprocess
    model_dict['device'] = device

    return model_dict

if __name__ == '__main__':
    print('\tinit models')

    global model_dict
    prev_time = time.time()

    model_dict = load_models()
    print('model load', time.time() - prev_time)
    description = f"""
    ## Batch Image Aesthetic Predictor  
    0. Based on https://huggingface.co/spaces/Geonmo/laion-aesthetic-predictor, I just expanded the GUI & added stats.
    1. This model is designed by adding five MLP layers on top of (frozen) CLIP <u>**{checkpoint}**</u> checkpoint and only the MLP layers are fine-tuned with a lot of images by a regression loss term such as MSE and MAE.
    2. Output is bounded from 0 to 10. The higher the better.
    3. The MLP being used currently is: **sac+logos+ava1-l14-linearMSE.pth** trained on 224x224 images.
    4. Running on **{device}**{', be patient. Progressive output & immediate stats are available.' if isCpu else '. Batch mode enabled, results after completion.'}
    5. Please don't click 'Submit' again during the processing, it'll mess things up. To stop processing, clear the file input. If the results are missing from the stats or export areas at the end, sort the table by any header & wait.
    {'6. The MLP was not retrained for this CLIP checkpoint, correct results are not guaranteed. It is also 2x slower.' if checkpoint != "ViT-L/14" else ''}
    """

#@title 👁️⃤ { run: "auto", form-width: "15%" }
global predict#or
writeClip = False #param {type:"boolean"}
import os
from PIL import Image

if writeClip: #disabled in v1
  import torchvision
  os.makedirs('CLIPped', exist_ok=True)


def predict(image):
  img_input = model_dict['clip_preprocess'](Image.open(image))

  clipped = None
  if writeClip:
    clipped = img_input
      
  image_input = img_input.unsqueeze(0).to(model_dict['device']) #try batch

  with torch.no_grad():
      image_features = model_dict['clip_model'].encode_image(image_input) 

      if model_dict['device'] == 'cuda':    # add TPU support?
          im_emb_arr = normalized(image_features.detach().cpu().numpy())
          im_emb = torch.from_numpy(im_emb_arr).to(model_dict['device']).type(torch.cuda.FloatTensor)
      else:
          im_emb_arr = normalized(image_features.detach().numpy())
          im_emb = torch.from_numpy(im_emb_arr).to(model_dict['device']).type(torch.FloatTensor)

      prediction = model_dict['classifier'](im_emb)
  score = prediction.item() #optimize?

  return score, clipped

#@title Wrapper & stats { form-width: "10%" }
DEBUG = True #@param {type:"boolean"}
autoclearLog = True #@param {type:"boolean"}
import csv
import sys
import gradio as gr 
if DEBUG: print(gr.__version__) #

def defStats():
  return {'Max':{}, 'Max - min': {}}

global Ready 
global avgScore
global eta
global speed
global canPoll
canPoll=Ready =False
eta = avgScore = None
speed = 0
   
Stats = defStats()
global default_mode
default_mode = list(Stats.keys())[1]

def log(x = '', y = None): # debug only
  if not DEBUG:
    return x

  global prev_time
  print(f"<\033[97m{sys._getframe().f_back.f_code.co_name}\033[96m>:") 
  if x:
    print(time.strftime('%M:%S '), x, round(time.time() - prev_time, 3), '\033[0m')
  if y:
    print('   extra: ',   y, '\033[0m')
  prev_time = time.time()

  return x

    
def pollStatus(table=[]):  ### TODO idk what to do
  time.sleep(1)
  spd = speed and (f'{round(speed,1)} s/f' if speed >= 1 else f'{round(1/speed,1)} f/s')
  stext = f' Avg speed: {spd}.' if speed else ''
  etext = f' ETA: {eta} {"s." if type(eta) == int else ""}' if eta else ''
  atext = f'Running average: {avgScore}.' if avgScore else ''

  return f"[{time.strftime('%M:%S')}] {' Ready.' if not atext else ''}  {atext} {etext} {stext}" if canPoll else 'idle'


def switch_stats(mode):    
  global default_mode
  default_mode = mode if mode else 'Max'

  return Stats[default_mode]


def writeStats(labels):
  with open('stats.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    log('actual write stats', labels and labels.values()) #
    writer.writerow(gr.utils.sanitize_list_for_csv(labels.keys()))
    writer.writerow(gr.utils.sanitize_list_for_csv(labels.values()))

# MAIN ################################################################
def batch_predict(files=None, progress=gr.Progress()): #=> stats_toggle, stats_output, table_output, submit_btn 
  run_time = time.time()
  if files and len(files) > 1:
    global eta
    eta = 'calculating...'

  results = list()
  log('has file(s)?', files and files[0])
  global Stats
  global Ready 
  Stats = defStats()   

  if files is None: 
    log('empty load')  
    yield gr.update(), None, None, gr.update()
    log('ABORT')
    return    
  else:
    maxSteps = min(len(files), 3) if isCpu else len(files)
    log('good2go')
    yield gr.update(visible=False), gr.update(visible=False), None, gr.update(variant="secondary")

  progress((1, maxSteps), unit='', desc='Importing...')
  clearStats()

  log('start the main loop')  
   
  times=list()
  clips=list()
  
  for idx,file in enumerate(files, start=1):
    prev_time = time.time()
    score, clipped = predict(file)
    
    if not Ready: # the solution to the interruption bug, do not remove #
      return 

    results.append([file.orig_name, round(score, 5), None]) 

    if writeClip: #disabled in v1
      clips.append((clipped, file.orig_name))

    times.append(time.time() - prev_time) #simplify
    asyncThreshold = 1 if isCpu else len(files)-1

    if (idx <= asyncThreshold):
      progress((idx+1, maxSteps), unit='', desc='Starting...' if isCpu else 'Working...')

    if (idx > asyncThreshold) and (idx < len(files)): # === False if not isCpu
      global avgScore
      global speed
      speed = np.mean(times)
      avgScore = statistics(results, False)
      eta = round(speed*(len(files)-idx+1)) # +1 or [1::]?
      log(idx)
      yield gr.update(), None, results, gr.update()

  table_data = results  
  if DEBUG: print('RUN time', time.time() - run_time, 'avg',  np.mean(times)) #

  if len(results) > 1:
    eta = 'finishing...' 
    log('finishing')
    stats = statistics(results)
    
    for i, row in enumerate(table_data):
      table_data[i][2] = round((row[1] - stats['AVG'])**2, 4)  # pylint: disable=report-general-type-issues
    
    writeStats(stats) 
    log('|2|', table_data) #      
    yield gr.update(visible=True), gr.update(value=switch_stats(default_mode), visible=True), table_data, gr.update(variant="primary")
  else:
    log('I', table_data) #
    yield gr.update(visible=False), gr.update(value=None, visible=False), table_data,  gr.update(variant="primary") #

  avgScore = None
  if writeClip:     #supposedly runs async w/o delaying the results? disabled in v1 anyway
    log('beforeWrite')  
    for c,f in clips:
      torchvision.utils.save_image(c, 'CLIPped/'+f+'.png', normalize=True)
    log('afterWrite')

  log('Exit main loop')
  speed = (time.time() - run_time)/len(files)  
# /main ##################################################################### 

def statistics(results, full=True):
  array = np.array(results).T[1].astype(float)
  
  max = np.max(array)
  avg = round(array.mean(), 3)
  if (not full): return avg
  med = round(np.median(array), 3)
  min = array.min()
  std = round(array.std(), 4)
  cov = round(std/avg*100, 2)
  rng = round(max-min, 3)
  range = max-min

  Stats['Max'][f'MAX: {round(max, 3)}'] = 1 
  Stats['Max'][f'min: {round(min, 3)}'] = min/max
  Stats['Max'][f"CoV: {cov}%"] = std/max
  Stats['Max'][f'AVG: {avg}'] = avg/max
  Stats['Max'][f'Med: {med}'] = med/max
  Stats['Max'][f'M-m: {rng}'] = range/max
  # TODO can this be shortened?
  if (range == 0):
    range = 1
  Stats['Max - min'][f'MAX: {round(max, 3)}'] = 1 
  Stats['Max - min'][f'min: {round(min, 3)}'] = 0
  Stats['Max - min'][f"CoV: {cov}%"] = std/range
  Stats['Max - min'][f'AVG: {avg}'] = (avg-min)/range
  Stats['Max - min'][f'Med: {med}'] = (med-min)/range
  Stats['Max - min'][f'M-m: {rng}'] = rng/max

  return dict(zip(('AVG','CoV','M-m','min','Med','MAX'), (avg, cov, rng, round(min,3), med, round(max,3))))


def clearStats():  
  log('clst too many calls?') #
  for root, dirs, files in os.walk('.'):
    for file in files:
      if (file.startswith(('scores','stats'))): # TODO separate folder, names?
        os.remove(file)


def scan(): 
  r = ['scores.csv', 'stats.csv']  
  return [x for x in r if os.path.isfile(x)] 

# buggy as fuck
def writeScores(table, files): # => csv_output, stats_output, stats_toggle
  statsVisible = False 
  rows = table and table['data'] 
  log('Entering the scores writer', 'from table change' if files and table else None)
  showStats = (gr.update(visible=statsVisible) for x in range(0,2))  # add full return statement?

  if files is None:
    log('No files, exiting writer')#
    resetStatus('from table') # refactor
    return   [gr.update(value=scan()), *list(showStats)]
######
  def writes(tbl):
    with open('scores.csv', 'w', newline='') as f: #try tsv, json
      writer = csv.writer(f)
      log('Actual saving scores', len(tbl['data'])) #
      writer.writerow(gr.utils.sanitize_list_for_csv(tbl['headers']))
      writer.writerows(gr.utils.sanitize_list_for_csv(tbl['data'])) 
######  
  if table and any([x for x in rows[0]]):       
    if (len(rows) > 1): 
      statsVisible = len(rows) >= len(files)

      if statsVisible:
        writes(table)
        log('Updating two', 'finished') #
        global eta
        eta = 0
        return [gr.update(value=scan()), *list(showStats)]

    else:
      statsVisible = False
      if (len(files) == 1):
        writes(table)
        log('updating 1') #
        return [gr.update(value=scan()), *list(showStats)]

  log('Not ready for writing yet, exiting.', f'total files: {files and len(files)}, but ready rows: {rows and len(rows)}')
  return [gr.update(value=scan()), *list(showStats)]

#@title GUI { vertical-output: true, form-width: "50%", display-mode: "both" }
tableQueued_False = False #@param {type:"boolean"}
queueConcurrency_2 = 10 #@param {type:"integer", min:1}
queueUpdateInterval_0 = 0 #@param {type:"slider", min:0, max:10, step:0.2} 
#@markdown tableQueued == True + queueConcurrency == 1 guarantees stalling on CPU
#@markdown 
#@markdown tableQueued - unknown effect on speed or stability
#@markdown 
#@markdown queueConcurrency > 1 - technically should improve speed?
#@markdown 
#@markdown queueUpdateInterval - in (0, 1] slows down processing, otherwise seems useless. 

#@markdown prevent_thread_lock - keep the "busy cell" behavior of debug mode without it to avoid multiple instances running in parallel;
#@markdown effects on speed & stability unknown
if DEBUG:
  import shutil #i doshutilsya
  import subprocess

if writeClip: # disabled in v1
  for root, dirs, files in os.walk('CLIPped'):
    for file in files:
      os.remove('CLIPped/'+file)

if DEBUG:
  for root, dirs, files in os.walk('../tmp'): #debug only
    for dir in dirs:
      shutil.rmtree('../tmp/'+dir)
    for file in files:
      os.remove('../tmp/'+file)               #/debug

def resetStatus(msg = 'clear'): 
  global avgScore
  global eta   
  global speed
  avgScore = None
  eta = None
  speed = 0
  log(msg)
  if msg != 'clear':
    clearStats()
    print('\n')


Css = '''
  #lbl .output-class {
    background-color: transparent;
    max-height: 0;
    color: transparent;
    padding: var(--size-3);
  }
  #add_img .file-preview .file td:first-child {
    overflow-wrap: anywhere;
  } 
  #csv_out .file-preview {
    margin-bottom: var(--size-4);
    overflow-x: visible; 
  }
  #tbl_out tbody .cell-wrap:first-child {
    overflow-wrap: anywhere;
  }
  button#sbmt:focus:not(:active) {
    opacity: 0.75;
    pointer-events: none;
  }
  #mid_col :not(#csv_out) .wrap.default {
    opacity: 0!important;
  }
'''

def toggleRun(files): # => submit, dataframe, status
  global Ready 
  Ready = files is not None
  log('Toggle', Ready) 
  global canPoll
  canPoll = Ready

  if not Ready:
    if eta:
      log('INTERRUPTED at ss remaining (extra)', eta)
    resetStatus()
    if DEBUG and autoclearLog:
      subprocess.call('clear')
    print('\r')
      
  clearStats() 
  return gr.Button.update(variant='primary' if Ready else 'secondary'), None, pollStatus()
  # ''', interactive=True''')

log('GUI start')

blks = gr.Blocks(analytics_enabled=False, title="Batch Image Aesthetic Predictor", css=Css)
with blks as demo:
  with gr.Accordion('README', open=False):
    gr.Markdown(description) 
  with gr.Row().style(equal_height=False):
      with gr.Column(scale=2):
          imageinput = gr.Files(file_types=["image"], label="Add images", elem_id="addimg")
          submit_button = gr.Button('Submit', variant="secondary", elem_id='sbmt') #TODO interactive
      with gr.Column(variant="compact", min_width=256, elem_id="mid_col"):
          stats_toggle = gr.Radio(list(Stats.keys()), show_label=True, label='Stats relative to:', value=default_mode, visible=False)
          stats_output = gr.Label(label='Stats', visible=False, elem_id="lbl")
          csv_output = gr.File(  label="Export", elem_id='csv_out' )
      with gr.Column(scale=2):
        table_output = gr.Dataframe(headers=['Image', 'Score', 'MSE'], max_rows=15, overflow_row_behaviour="paginate", interactive=False, wrap=True, elem_id="tbl_out") 

  status = gr.Textbox(pollStatus(), max_lines=1, show_label=False, placeholder='Status bar').style(container=False)
  status.change(pollStatus, None, status, show_progress=  False, queue=False)

  tch = table_output.change(writeScores, [table_output, imageinput], [csv_output, stats_output, stats_toggle], preprocess=False, queue= tableQueued_False, show_progress=not isCpu)
  stats_toggle.change(switch_stats, [stats_toggle], [stats_output], queue=False, show_progress=False)  

  run = submit_button.click(batch_predict, imageinput, [stats_toggle, stats_output, table_output, submit_button], queue=True, scroll_to_output=True)
  #imageinput.clear(reset, [imageinput], [table_output], queue=False, show_progress=True, preprocess=False)
  imageinput.change(toggleRun, imageinput, [submit_button, table_output, status], queue=  False, cancels=[run], show_progress=False) #
  # try .then()
  if DEBUG:
    demo.load(lambda: log('load'), queue=not True, show_progress=False)
 
  demo.queue(api_open=   not DEBUG, status_update_rate='auto' if queueUpdateInterval_0 == 0 else queueUpdateInterval_0 , concurrency_count=max(queueConcurrency_2, 1))
  log('Prelaunch')

  demo.launch(debug=DEBUG, quiet=not DEBUG, show_error=True)

#demo.close()