jaekookang commited on
Commit
21c7de8
β€’
1 Parent(s): 39a6dd6

update callback

Browse files
.ipynb_checkpoints/gradio_artist_classifier-checkpoint.py CHANGED
@@ -5,11 +5,12 @@ prototype
5
  ---
6
  - 2022-01-18 jkang first created
7
  '''
8
-
9
  import matplotlib.pyplot as plt
10
  import matplotlib.image as mpimg
11
  import seaborn as sns
12
 
 
13
  import json
14
  import skimage.io
15
  from loguru import logger
@@ -35,6 +36,11 @@ artist_model = from_pretrained_keras("jkang/drawing-artist-classifier")
35
  trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
36
  logger.info('both models loaded')
37
 
 
 
 
 
 
38
  def load_image_as_array(image_file):
39
  img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
40
  if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA
@@ -48,17 +54,63 @@ def load_image_as_tensor(image_file):
48
 
49
  def predict(input_image):
50
  img_3d_array = load_image_as_array(input_image)
51
- img_4d_tensor = load_image_as_tensor(input_image)
 
52
  logger.info(f'--- {input_image} loaded')
53
 
54
- artist_model(img_4d_tensor);
55
- trend_model(img_4d_tensor);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- return img_3d_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  iface = gr.Interface(
60
  predict,
61
- title='Predict Artist and Artistic Trend of Drawings πŸŽ¨πŸ‘¨πŸ»β€πŸŽ¨ (prototype)',
62
  description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
63
  inputs=[
64
  gr.inputs.Image(label='Upload a drawing/image', type='file')
5
  ---
6
  - 2022-01-18 jkang first created
7
  '''
8
+ from PIL import Image
9
  import matplotlib.pyplot as plt
10
  import matplotlib.image as mpimg
11
  import seaborn as sns
12
 
13
+ import io
14
  import json
15
  import skimage.io
16
  from loguru import logger
36
  trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
37
  logger.info('both models loaded')
38
 
39
+ def load_json_as_dict(json_file):
40
+ with open(json_file, 'r') as f:
41
+ out = json.load(f)
42
+ return dict(out)
43
+
44
  def load_image_as_array(image_file):
45
  img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
46
  if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA
54
 
55
  def predict(input_image):
56
  img_3d_array = load_image_as_array(input_image)
57
+ # img_4d_tensor = load_image_as_tensor(input_image)
58
+ img_4d_array = img_3d_array[np.newaxis,...]
59
  logger.info(f'--- {input_image} loaded')
60
 
61
+ artist2id = load_json_as_dict(ARTIST_META)
62
+ trend2id = load_json_as_dict(TREND_META)
63
+ id2artist = {artist2id[artist]:artist for artist in artist2id}
64
+ id2trend = {trend2id[trend]:trend for trend in trend2id}
65
+
66
+ # Artist model
67
+ a_heatmap, a_pred_id, a_pred_out = make_gradcam_heatmap(artist_model,
68
+ img_4d_array,
69
+ pred_idx=None)
70
+ a_img_pil = align_image_with_heatmap(
71
+ img_4d_array, a_heatmap, alpha=alpha, cmap='jet')
72
+ a_img = np.asarray(a_img_pil).astype('float32')/255
73
+ a_label = id2artist[a_pred_id]
74
+ a_prob = a_pred_out[a_pred_id]
75
+
76
+ # Trend model
77
+ t_heatmap, t_pred_id, t_pred_out = make_gradcam_heatmap(trend_model,
78
+ img_4d_array,
79
+ pred_idx=None)
80
 
81
+ t_img_pil = align_image_with_heatmap(
82
+ img_4d_array, t_heatmap, alpha=alpha, cmap='jet')
83
+ t_img = np.asarray(t_img_pil).astype('float32')/255
84
+ t_label = id2trend[t_pred_id]
85
+ t_prob = t_pred_out[t_pred_id]
86
+
87
+ with sns.plotting_context('poster', font_scale=0.7):
88
+ fig, (ax1, ax2, ax3) = plt.subplots(
89
+ 1, 3, figsize=(12, 6), facecolor='white')
90
+ for ax in (ax1, ax2, ax3):
91
+ ax.set_xticks([])
92
+ ax.set_yticks([])
93
+
94
+ ax1.imshow(img_3d_array)
95
+ ax2.imshow(a_img)
96
+ ax3.imshow(t_img)
97
+
98
+ ax1.set_title(f'Artist: {artist}\nTrend: {trend}', ha='left', x=0, y=1.05)
99
+ ax2.set_title(f'Artist Prediction:\n =>{a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
100
+ ax3.set_title(f'Trend Prediction:\n =>{t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
101
+ fig.tight_layout()
102
+
103
+ buf = io.BytesIO()
104
+ fig.save(buf, bbox_inces='tight', fotmat='jpg')
105
+ buf.seek(0)
106
+ pil_img = Image.open(buf)
107
+ plt.close()
108
+ logger.info('--- output generated')
109
+ return pil_img
110
 
111
  iface = gr.Interface(
112
  predict,
113
+ title='Predict Artist and Artistic Style of Drawings πŸŽ¨πŸ‘¨πŸ»β€πŸŽ¨ (prototype)',
114
  description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
115
  inputs=[
116
  gr.inputs.Image(label='Upload a drawing/image', type='file')
gradio_artist_classifier.py CHANGED
@@ -5,11 +5,12 @@ prototype
5
  ---
6
  - 2022-01-18 jkang first created
7
  '''
8
-
9
  import matplotlib.pyplot as plt
10
  import matplotlib.image as mpimg
11
  import seaborn as sns
12
 
 
13
  import json
14
  import skimage.io
15
  from loguru import logger
@@ -35,6 +36,11 @@ artist_model = from_pretrained_keras("jkang/drawing-artist-classifier")
35
  trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
36
  logger.info('both models loaded')
37
 
 
 
 
 
 
38
  def load_image_as_array(image_file):
39
  img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
40
  if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA
@@ -48,17 +54,63 @@ def load_image_as_tensor(image_file):
48
 
49
  def predict(input_image):
50
  img_3d_array = load_image_as_array(input_image)
51
- img_4d_tensor = load_image_as_tensor(input_image)
 
52
  logger.info(f'--- {input_image} loaded')
53
 
54
- artist_model(img_4d_tensor);
55
- trend_model(img_4d_tensor);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- return img_3d_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  iface = gr.Interface(
60
  predict,
61
- title='Predict Artist and Artistic Trend of Drawings πŸŽ¨πŸ‘¨πŸ»β€πŸŽ¨ (prototype)',
62
  description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
63
  inputs=[
64
  gr.inputs.Image(label='Upload a drawing/image', type='file')
5
  ---
6
  - 2022-01-18 jkang first created
7
  '''
8
+ from PIL import Image
9
  import matplotlib.pyplot as plt
10
  import matplotlib.image as mpimg
11
  import seaborn as sns
12
 
13
+ import io
14
  import json
15
  import skimage.io
16
  from loguru import logger
36
  trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
37
  logger.info('both models loaded')
38
 
39
+ def load_json_as_dict(json_file):
40
+ with open(json_file, 'r') as f:
41
+ out = json.load(f)
42
+ return dict(out)
43
+
44
  def load_image_as_array(image_file):
45
  img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
46
  if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA
54
 
55
  def predict(input_image):
56
  img_3d_array = load_image_as_array(input_image)
57
+ # img_4d_tensor = load_image_as_tensor(input_image)
58
+ img_4d_array = img_3d_array[np.newaxis,...]
59
  logger.info(f'--- {input_image} loaded')
60
 
61
+ artist2id = load_json_as_dict(ARTIST_META)
62
+ trend2id = load_json_as_dict(TREND_META)
63
+ id2artist = {artist2id[artist]:artist for artist in artist2id}
64
+ id2trend = {trend2id[trend]:trend for trend in trend2id}
65
+
66
+ # Artist model
67
+ a_heatmap, a_pred_id, a_pred_out = make_gradcam_heatmap(artist_model,
68
+ img_4d_array,
69
+ pred_idx=None)
70
+ a_img_pil = align_image_with_heatmap(
71
+ img_4d_array, a_heatmap, alpha=alpha, cmap='jet')
72
+ a_img = np.asarray(a_img_pil).astype('float32')/255
73
+ a_label = id2artist[a_pred_id]
74
+ a_prob = a_pred_out[a_pred_id]
75
+
76
+ # Trend model
77
+ t_heatmap, t_pred_id, t_pred_out = make_gradcam_heatmap(trend_model,
78
+ img_4d_array,
79
+ pred_idx=None)
80
 
81
+ t_img_pil = align_image_with_heatmap(
82
+ img_4d_array, t_heatmap, alpha=alpha, cmap='jet')
83
+ t_img = np.asarray(t_img_pil).astype('float32')/255
84
+ t_label = id2trend[t_pred_id]
85
+ t_prob = t_pred_out[t_pred_id]
86
+
87
+ with sns.plotting_context('poster', font_scale=0.7):
88
+ fig, (ax1, ax2, ax3) = plt.subplots(
89
+ 1, 3, figsize=(12, 6), facecolor='white')
90
+ for ax in (ax1, ax2, ax3):
91
+ ax.set_xticks([])
92
+ ax.set_yticks([])
93
+
94
+ ax1.imshow(img_3d_array)
95
+ ax2.imshow(a_img)
96
+ ax3.imshow(t_img)
97
+
98
+ ax1.set_title(f'Artist: {artist}\nTrend: {trend}', ha='left', x=0, y=1.05)
99
+ ax2.set_title(f'Artist Prediction:\n =>{a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
100
+ ax3.set_title(f'Trend Prediction:\n =>{t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
101
+ fig.tight_layout()
102
+
103
+ buf = io.BytesIO()
104
+ fig.save(buf, bbox_inces='tight', fotmat='jpg')
105
+ buf.seek(0)
106
+ pil_img = Image.open(buf)
107
+ plt.close()
108
+ logger.info('--- output generated')
109
+ return pil_img
110
 
111
  iface = gr.Interface(
112
  predict,
113
+ title='Predict Artist and Artistic Style of Drawings πŸŽ¨πŸ‘¨πŸ»β€πŸŽ¨ (prototype)',
114
  description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
115
  inputs=[
116
  gr.inputs.Image(label='Upload a drawing/image', type='file')
requirements-dev.txt CHANGED
@@ -3,7 +3,8 @@ huggingface_hub==0.4.0
3
  loguru==0.5.3
4
  matplotlib==3.5.1
5
  numpy==1.22.0
 
6
  scikit_image==0.19.1
7
  seaborn==0.11.2
8
- scikit-image==0.19.1
9
  tensorflow==2.7.0
3
  loguru==0.5.3
4
  matplotlib==3.5.1
5
  numpy==1.22.0
6
+ Pillow==9.0.0
7
  scikit_image==0.19.1
8
  seaborn==0.11.2
9
+ skimage==0.0
10
  tensorflow==2.7.0