Spaces:
Runtime error
Runtime error
'''Artist Classifier | |
prototype | |
--- | |
- 2022-01-18 jkang first created | |
''' | |
import matplotlib.pyplot as plt | |
import matplotlib.image as mpimg | |
import seaborn as sns | |
import json | |
import skimage.io | |
from loguru import logger | |
from huggingface_hub import from_pretrained_keras | |
import gradio as gr | |
import tensorflow as tf | |
tfk = tf.keras | |
from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap | |
# ---------- Settings ---------- | |
ARTIST_META = 'artist.json' | |
TREND_META = 'trend.json' | |
EXAMPLES = ['monet.jpg'] | |
# ---------- Logging ---------- | |
logger.add('app.log', mode='a') | |
logger.info('============================= App restarted =============================') | |
# ---------- Model ---------- | |
logger.info('loading models...') | |
artist_model = from_pretrained_keras("jkang/drawing-artist-classifier") | |
trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier") | |
logger.info('both models loaded') | |
def load_image_as_array(image_file): | |
img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib') | |
if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA | |
img = img[..., :-1] | |
return img | |
def load_image_as_tensor(image_file): | |
img = tf.io.read_file(image_file) | |
img = tf.io.decode_jpeg(img, channels=3) | |
return img | |
def predict(input_image): | |
img_3d_array = load_image_as_array(input_image) | |
img_4d_tensor = load_image_as_tensor(input_image) | |
logger.info(f'--- {input_image} loaded') | |
artist_model(img_4d_tensor); | |
trend_model(img_4d_tensor); | |
return img_3d_array | |
iface = gr.Interface( | |
predict, | |
title='Predict Artist and Artistic Trend of Drawings π¨π¨π»βπ¨ (prototype)', | |
description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style', | |
inputs=[ | |
gr.inputs.Image(label='Upload a drawing/image', type='file') | |
], | |
outputs=[ | |
gr.outputs.Image(label='Prediction') | |
], | |
examples=EXAMPLES, | |
) | |
iface.launch(debug=True, enable_queue=True) |