haor commited on
Commit
2996858
1 Parent(s): 5181467

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import clip
6
+ import pandas as pd
7
+ import hashlib
8
+ import numpy as np
9
+ import cv2
10
+ from PIL import Image
11
+
12
+ # if you changed the MLP architecture during training, change it also here:
13
+ class MLP(nn.Module):
14
+ def __init__(self, input_size, xcol="emb", ycol="avg_rating"):
15
+ super().__init__()
16
+ self.input_size = input_size
17
+ self.xcol = xcol
18
+ self.ycol = ycol
19
+ self.layers = nn.Sequential(
20
+ nn.Linear(self.input_size, 1024),
21
+ nn.Dropout(0.2),
22
+ nn.Linear(1024, 128),
23
+ nn.Dropout(0.2),
24
+ nn.Linear(128, 64),
25
+ nn.Dropout(0.1),
26
+ nn.Linear(64, 16),
27
+ nn.Linear(16, 1),
28
+ )
29
+
30
+ def forward(self, x):
31
+ return self.layers(x)
32
+
33
+ def _binary_array_to_hex(arr):
34
+ bit_string = ''.join(str(b) for b in 1 * arr.flatten())
35
+ width = int(np.ceil(len(bit_string) / 4))
36
+ return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
37
+
38
+ def phashstr(image, hash_size=8, highfreq_factor=4):
39
+ if hash_size < 2:
40
+ raise ValueError('Hash size must be greater than or equal to 2')
41
+
42
+ import scipy.fftpack
43
+ img_size = hash_size * highfreq_factor
44
+ image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
45
+ pixels = np.asarray(image)
46
+ dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
47
+ dctlowfreq = dct[:hash_size, :hash_size]
48
+ med = np.median(dctlowfreq)
49
+ diff = dctlowfreq > med
50
+ return _binary_array_to_hex(diff.flatten())
51
+
52
+ def normalized(a, axis=-1, order=2):
53
+ l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
54
+ l2[l2 == 0] = 1
55
+ return a / np.expand_dims(l2, axis)
56
+
57
+ def predict(image):
58
+ model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
59
+ pthpath = "https://huggingface.co/haor/aesthetics/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth"
60
+ device = "cuda" if torch.cuda.is_available() else "cpu"
61
+
62
+ model.load_state_dict(torch.hub.load_state_dict_from_url(pthpath, map_location=device))
63
+ model.to(device).eval()
64
+ model2, preprocess = clip.load("ViT-L/14", device=device)
65
+
66
+ image = Image.fromarray(image)
67
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
68
+ laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var()
69
+ phash = phashstr(image)
70
+ md5 = hashlib.md5(image.tobytes()).hexdigest()
71
+ sha1 = hashlib.sha1(image.tobytes()).hexdigest()
72
+
73
+ inputs = preprocess(image).unsqueeze(0).to(device)
74
+
75
+ with torch.no_grad():
76
+ img_emb = model2.encode_image(inputs)
77
+ img_emb = normalized(img_emb.cpu().numpy())
78
+ prediction = model(torch.from_numpy(img_emb).to(device).type(torch.cuda.FloatTensor)).item()
79
+
80
+ result = {
81
+ "clip_aesthetic": prediction,
82
+ "phash": phash,
83
+ "md5": md5,
84
+ "sha1": sha1,
85
+ "laplacian_variance": laplacian_variance
86
+ }
87
+ return result
88
+
89
+ title = "CLIP Aesthetic Score"
90
+ description = "Upload an image to predict its aesthetic score using the CLIP model and calculate other image metrics."
91
+
92
+ gr.Interface(
93
+ fn=predict,
94
+ inputs=gr.Image(type="numpy"),
95
+ outputs=gr.JSON(label="Result"),
96
+ title=title,
97
+ description=description,
98
+ examples=[["example1.jpg"], ["example2.jpg"]]
99
+ ).launch()