'''Artist Classifier prototype --- - 2022-01-18 jkang first created ''' import matplotlib.pyplot as plt import matplotlib.image as mpimg import seaborn as sns import json import skimage.io from loguru import logger from huggingface_hub import from_pretrained_keras import gradio as gr import tensorflow as tf tfk = tf.keras from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap # ---------- Settings ---------- ARTIST_META = 'artist.json' TREND_META = 'trend.json' EXAMPLES = ['monet.jpg'] # ---------- Logging ---------- logger.add('app.log', mode='a') logger.info('============================= App restarted =============================') # ---------- Model ---------- logger.info('loading models...') artist_model = from_pretrained_keras("jkang/drawing-artist-classifier") trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier") logger.info('both models loaded') def load_image_as_array(image_file): img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib') if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA img = img[..., :-1] return img def load_image_as_tensor(image_file): img = tf.io.read_file(image_file) img = tf.io.decode_jpeg(img, channels=3) return img def predict(input_image): img_3d_array = load_image_as_array(input_image) img_4d_tensor = load_image_as_tensor(input_image) logger.info(f'--- {input_image} loaded') artist_model(img_4d_tensor); trend_model(img_4d_tensor); return img_3d_array iface = gr.Interface( predict, title='Predict Artist and Artistic Trend of Drawings 🎨👨🏻‍🎨 (prototype)', description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style', inputs=[ gr.inputs.Image(label='Upload a drawing/image', type='file') ], outputs=[ gr.outputs.Image(label='Prediction') ], examples=EXAMPLES, ) iface.launch(debug=True, enable_queue=True)