File size: 4,887 Bytes
af72b72
 
 
 
 
 
 
806e6c5
21c7de8
af72b72
 
 
 
21c7de8
39a6dd6
bcaf154
57064d0
39a6dd6
57064d0
39a6dd6
 
af72b72
 
 
 
 
 
39a6dd6
 
 
825e1d8
bcaf154
c9b69b7
 
af72b72
39a6dd6
 
 
 
 
 
 
 
 
 
21c7de8
 
 
 
 
39a6dd6
 
dbb7b85
39a6dd6
 
 
57064d0
 
 
 
 
39a6dd6
 
 
c9b69b7
21c7de8
39a6dd6
 
21c7de8
 
 
 
 
 
 
 
 
 
bcaf154
21c7de8
 
 
 
 
 
 
 
39a6dd6
21c7de8
bcaf154
21c7de8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcaf154
c31a89f
 
21c7de8
 
 
806e6c5
21c7de8
 
 
c31a89f
 
 
 
 
39a6dd6
 
 
21c7de8
c31a89f
39a6dd6
 
 
 
c31a89f
 
 
39a6dd6
 
 
806e6c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
'''Artist Classifier

prototype

---
- 2022-01-18 jkang first created
'''
from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

import io
import json
import numpy as np
import skimage
import skimage.io
from skimage.transform import resize
from loguru import logger
from huggingface_hub import from_pretrained_keras
import gradio as gr
import tensorflow as tf
tfk = tf.keras

from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap

# ---------- Settings ----------
ARTIST_META = 'artist.json'
TREND_META = 'trend.json'
EXAMPLES = ['monet2.jpg', 'surrelaism.png', 'graffitiart.png', 'lichtenstein_popart.jpg', 'pierre_augste_renoir.png']
ALPHA = 0.9
IMG_WIDTH = 299
IMG_HEIGHT = 299

# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')

# ---------- Model ----------
logger.info('loading models...')
artist_model = from_pretrained_keras("jkang/drawing-artist-classifier")
trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
logger.info('both models loaded')

def load_json_as_dict(json_file):
    with open(json_file, 'r') as f:
        out = json.load(f)
    return dict(out)

def load_image_as_array(image_file):
    img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
    if (img.shape[-1] > 3):  # if RGBA
        img = img[..., :-1]
    return img

def resize_image(img_array, width, height):
    img_resized = resize(img_array, (height, width), 
                         anti_aliasing=True, 
                         preserve_range=False)
    return skimage.img_as_ubyte(img_resized)

def predict(input_image):
    img_3d_array = load_image_as_array(input_image)
    img_3d_array = resize_image(img_3d_array, IMG_WIDTH, IMG_HEIGHT)
    img_4d_array = img_3d_array[np.newaxis,...]
    logger.info(f'--- {input_image} loaded')
    
    artist2id = load_json_as_dict(ARTIST_META)
    trend2id = load_json_as_dict(TREND_META)
    id2artist = {artist2id[artist]:artist for artist in artist2id}
    id2trend = {trend2id[trend]:trend for trend in trend2id}
    
    # Artist model
    a_heatmap, a_pred_id, a_pred_out = make_gradcam_heatmap(artist_model,
                                                            img_4d_array,
                                                            pred_idx=None)
    a_img_pil = align_image_with_heatmap(
        img_4d_array, a_heatmap, alpha=ALPHA, cmap='jet')
    a_img = np.asarray(a_img_pil).astype('float32')/255
    a_label = id2artist[a_pred_id]
    a_prob = a_pred_out[a_pred_id]

    # Trend model
    t_heatmap, t_pred_id, t_pred_out = make_gradcam_heatmap(trend_model,
                                                            img_4d_array,
                                                            pred_idx=None)
    
    t_img_pil = align_image_with_heatmap(
        img_4d_array, t_heatmap, alpha=ALPHA, cmap='jet')
    t_img = np.asarray(t_img_pil).astype('float32')/255
    t_label = id2trend[t_pred_id]
    t_prob = t_pred_out[t_pred_id]
    
    with sns.plotting_context('poster', font_scale=0.7):
        fig, (ax1, ax2, ax3) = plt.subplots(
            1, 3, figsize=(12, 6), facecolor='white')
        for ax in (ax1, ax2, ax3):
            ax.set_xticks([])
            ax.set_yticks([])

        ax1.imshow(img_3d_array)
        ax2.imshow(a_img)
        ax3.imshow(t_img)

        ax1.set_title(f'Input Image', ha='left', x=0, y=1.05)
        ax2.set_title(f'Artist Prediction:\n => {a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
        ax3.set_title(f'Style Prediction:\n => {t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
        fig.tight_layout()
        
    buf = io.BytesIO()
    fig.savefig(buf, bbox_inches='tight', format='jpg')
    buf.seek(0)
    pil_img = Image.open(buf)
    plt.close()
    logger.info('--- image generated')

    a_labels = {id2artist[i]: float(pred) for i, pred in enumerate(a_pred_out)}
    t_labels = {id2trend[i]: float(pred) for i, pred in enumerate(t_pred_out)}
    return a_labels, t_labels, pil_img

iface = gr.Interface(
    predict,
    title='Predict Artist and Artistic Style of Drawings πŸŽ¨πŸ‘¨πŸ»β€πŸŽ¨ (prototype)',
    description='Upload a drawing/image and the model will predict how likely it seems given 10 artists and their trend/style',
    inputs=[
        gr.inputs.Image(label='Upload a drawing/image', type='file')
    ],
    outputs=[
        gr.outputs.Label(label='Artists', num_top_classes=5, type='auto'),
        gr.outputs.Label(label='Styles', num_top_classes=5, type='auto'),
        gr.outputs.Image(label='Prediction with GradCAM')
    ],
    examples=EXAMPLES,
)
iface.launch(debug=True, enable_queue=True)