'''Artist Classifier prototype --- - 2022-01-18 jkang first created ''' from PIL import Image import matplotlib.pyplot as plt import matplotlib.image as mpimg import seaborn as sns import io import json import skimage.io from loguru import logger from huggingface_hub import from_pretrained_keras import gradio as gr import tensorflow as tf tfk = tf.keras from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap # ---------- Settings ---------- ARTIST_META = 'artist.json' TREND_META = 'trend.json' EXAMPLES = ['monet.jpg'] # ---------- Logging ---------- logger.add('app.log', mode='a') logger.info('============================= App restarted =============================') # ---------- Model ---------- logger.info('loading models...') artist_model = from_pretrained_keras("jkang/drawing-artist-classifier") trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier") logger.info('both models loaded') def load_json_as_dict(json_file): with open(json_file, 'r') as f: out = json.load(f) return dict(out) def load_image_as_array(image_file): img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib') if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA img = img[..., :-1] return img def load_image_as_tensor(image_file): img = tf.io.read_file(image_file) img = tf.io.decode_jpeg(img, channels=3) return img def predict(input_image): img_3d_array = load_image_as_array(input_image) # img_4d_tensor = load_image_as_tensor(input_image) img_4d_array = img_3d_array[np.newaxis,...] logger.info(f'--- {input_image} loaded') artist2id = load_json_as_dict(ARTIST_META) trend2id = load_json_as_dict(TREND_META) id2artist = {artist2id[artist]:artist for artist in artist2id} id2trend = {trend2id[trend]:trend for trend in trend2id} # Artist model a_heatmap, a_pred_id, a_pred_out = make_gradcam_heatmap(artist_model, img_4d_array, pred_idx=None) a_img_pil = align_image_with_heatmap( img_4d_array, a_heatmap, alpha=alpha, cmap='jet') a_img = np.asarray(a_img_pil).astype('float32')/255 a_label = id2artist[a_pred_id] a_prob = a_pred_out[a_pred_id] # Trend model t_heatmap, t_pred_id, t_pred_out = make_gradcam_heatmap(trend_model, img_4d_array, pred_idx=None) t_img_pil = align_image_with_heatmap( img_4d_array, t_heatmap, alpha=alpha, cmap='jet') t_img = np.asarray(t_img_pil).astype('float32')/255 t_label = id2trend[t_pred_id] t_prob = t_pred_out[t_pred_id] with sns.plotting_context('poster', font_scale=0.7): fig, (ax1, ax2, ax3) = plt.subplots( 1, 3, figsize=(12, 6), facecolor='white') for ax in (ax1, ax2, ax3): ax.set_xticks([]) ax.set_yticks([]) ax1.imshow(img_3d_array) ax2.imshow(a_img) ax3.imshow(t_img) ax1.set_title(f'Artist: {artist}\nTrend: {trend}', ha='left', x=0, y=1.05) ax2.set_title(f'Artist Prediction:\n =>{a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05) ax3.set_title(f'Trend Prediction:\n =>{t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05) fig.tight_layout() buf = io.BytesIO() fig.save(buf, bbox_inces='tight', fotmat='jpg') buf.seek(0) pil_img = Image.open(buf) plt.close() logger.info('--- output generated') return pil_img iface = gr.Interface( predict, title='Predict Artist and Artistic Style of Drawings 🎨👨🏻‍🎨 (prototype)', description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style', inputs=[ gr.inputs.Image(label='Upload a drawing/image', type='file') ], outputs=[ gr.outputs.Image(label='Prediction') ], examples=EXAMPLES, ) iface.launch(debug=True, enable_queue=True)