File size: 2,043 Bytes
af72b72
 
 
 
 
 
 
 
 
 
 
 
39a6dd6
 
 
 
af72b72
 
 
 
 
 
39a6dd6
 
 
 
af72b72
39a6dd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
'''Artist Classifier

prototype

---
- 2022-01-18 jkang first created
'''

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

import json
import skimage.io
from loguru import logger
from huggingface_hub import from_pretrained_keras
import gradio as gr
import tensorflow as tf
tfk = tf.keras

from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap

# ---------- Settings ----------
ARTIST_META = 'artist.json'
TREND_META = 'trend.json'
EXAMPLES = ['monet.jpg']

# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')

# ---------- Model ----------
logger.info('loading models...')
artist_model = from_pretrained_keras("jkang/drawing-artist-classifier")
trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
logger.info('both models loaded')

def load_image_as_array(image_file):
    img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
    if (img.shape[-1] > 3) & (remove_alpha_channel):  # if RGBA
        img = img[..., :-1]
    return img

def load_image_as_tensor(image_file):
    img = tf.io.read_file(image_file)
    img = tf.io.decode_jpeg(img, channels=3)
    return img

def predict(input_image):
    img_3d_array = load_image_as_array(input_image)
    img_4d_tensor = load_image_as_tensor(input_image)
    logger.info(f'--- {input_image} loaded')
    
    artist_model(img_4d_tensor);
    trend_model(img_4d_tensor);
    
    return img_3d_array

iface = gr.Interface(
    predict,
    title='Predict Artist and Artistic Trend of Drawings πŸŽ¨πŸ‘¨πŸ»β€πŸŽ¨ (prototype)',
    description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
    inputs=[
        gr.inputs.Image(label='Upload a drawing/image', type='file')
    ],
    outputs=[
        gr.outputs.Image(label='Prediction')
    ],
    examples=EXAMPLES,
)
iface.launch(debug=True, enable_queue=True)