demo-artist-classifier / gradio_artist_classifier.py
jaekookang
test app
39a6dd6
raw history blame
No virus
2.04 kB
'''Artist Classifier
prototype
---
- 2022-01-18 jkang first created
'''
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
import json
import skimage.io
from loguru import logger
from huggingface_hub import from_pretrained_keras
import gradio as gr
import tensorflow as tf
tfk = tf.keras
from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap
# ---------- Settings ----------
ARTIST_META = 'artist.json'
TREND_META = 'trend.json'
EXAMPLES = ['monet.jpg']
# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')
# ---------- Model ----------
logger.info('loading models...')
artist_model = from_pretrained_keras("jkang/drawing-artist-classifier")
trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
logger.info('both models loaded')
def load_image_as_array(image_file):
img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA
img = img[..., :-1]
return img
def load_image_as_tensor(image_file):
img = tf.io.read_file(image_file)
img = tf.io.decode_jpeg(img, channels=3)
return img
def predict(input_image):
img_3d_array = load_image_as_array(input_image)
img_4d_tensor = load_image_as_tensor(input_image)
logger.info(f'--- {input_image} loaded')
artist_model(img_4d_tensor);
trend_model(img_4d_tensor);
return img_3d_array
iface = gr.Interface(
predict,
title='Predict Artist and Artistic Trend of Drawings πŸŽ¨πŸ‘¨πŸ»β€πŸŽ¨ (prototype)',
description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
inputs=[
gr.inputs.Image(label='Upload a drawing/image', type='file')
],
outputs=[
gr.outputs.Image(label='Prediction')
],
examples=EXAMPLES,
)
iface.launch(debug=True, enable_queue=True)