demo-artist-classifier / .ipynb_checkpoints /gradio_artist_classifier-checkpoint.py
jaekookang
fix minor
bcaf154
raw history blame
No virus
4.16 kB
'''Artist Classifier
prototype
---
- 2022-01-18 jkang first created
'''
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
import io
import json
import numpy as np
import skimage.io
from loguru import logger
from huggingface_hub import from_pretrained_keras
import gradio as gr
import tensorflow as tf
tfk = tf.keras
from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap
# ---------- Settings ----------
ARTIST_META = 'artist.json'
TREND_META = 'trend.json'
EXAMPLES = ['monet.jpg']
# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')
# ---------- Model ----------
logger.info('loading models...')
artist_model = from_pretrained_keras("jkang/drawing-artist-classifier")
trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
logger.info('both models loaded')
def load_json_as_dict(json_file):
with open(json_file, 'r') as f:
out = json.load(f)
return dict(out)
def load_image_as_array(image_file):
img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
if (img.shape[-1] > 3): # if RGBA
img = img[..., :-1]
return img
def load_image_as_tensor(image_file):
img = tf.io.read_file(image_file)
img = tf.io.decode_jpeg(img, channels=3)
return img
def predict(input_image):
img_3d_array = load_image_as_array(input_image)
# img_4d_tensor = load_image_as_tensor(input_image)
img_4d_array = img_3d_array[np.newaxis,...]
logger.info(f'--- {input_image} loaded')
artist2id = load_json_as_dict(ARTIST_META)
trend2id = load_json_as_dict(TREND_META)
id2artist = {artist2id[artist]:artist for artist in artist2id}
id2trend = {trend2id[trend]:trend for trend in trend2id}
# Artist model
a_heatmap, a_pred_id, a_pred_out = make_gradcam_heatmap(artist_model,
img_4d_array,
pred_idx=None)
a_img_pil = align_image_with_heatmap(
img_4d_array, a_heatmap, alpha=alpha, cmap='jet')
a_img = np.asarray(a_img_pil).astype('float32')/255
a_label = id2artist[a_pred_id]
a_prob = a_pred_out[a_pred_id]
# Trend model
t_heatmap, t_pred_id, t_pred_out = make_gradcam_heatmap(trend_model,
img_4d_array,
pred_idx=None)
t_img_pil = align_image_with_heatmap(
img_4d_array, t_heatmap, alpha=alpha, cmap='jet')
t_img = np.asarray(t_img_pil).astype('float32')/255
t_label = id2trend[t_pred_id]
t_prob = t_pred_out[t_pred_id]
with sns.plotting_context('poster', font_scale=0.7):
fig, (ax1, ax2, ax3) = plt.subplots(
1, 3, figsize=(12, 6), facecolor='white')
for ax in (ax1, ax2, ax3):
ax.set_xticks([])
ax.set_yticks([])
ax1.imshow(img_3d_array)
ax2.imshow(a_img)
ax3.imshow(t_img)
ax1.set_title(f'Artist: {artist}\nTrend: {trend}', ha='left', x=0, y=1.05)
ax2.set_title(f'Artist Prediction:\n =>{a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
ax3.set_title(f'Trend Prediction:\n =>{t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
fig.tight_layout()
buf = io.BytesIO()
fig.save(buf, bbox_inces='tight', fotmat='jpg')
buf.seek(0)
pil_img = Image.open(buf)
plt.close()
logger.info('--- output generated')
return pil_img
iface = gr.Interface(
predict,
title='Predict Artist and Artistic Style of Drawings πŸŽ¨πŸ‘¨πŸ»β€πŸŽ¨ (prototype)',
description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
inputs=[
gr.inputs.Image(label='Upload a drawing/image', type='file')
],
outputs=[
gr.outputs.Image(label='Prediction')
],
examples=EXAMPLES,
)
iface.launch(debug=True, enable_queue=True)