jaekookang commited on
Commit
c31a89f
β€’
1 Parent(s): b0e79f5

update gr.outputs.Label

Browse files
Files changed (1) hide show
  1. gradio_artist_classifier.py +11 -6
gradio_artist_classifier.py CHANGED
@@ -98,8 +98,8 @@ def predict(input_image):
98
  ax3.imshow(t_img)
99
 
100
  ax1.set_title(f'Input Image', ha='left', x=0, y=1.05)
101
- ax2.set_title(f'Artist Prediction:\n =>{a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
102
- ax3.set_title(f'Trend Prediction:\n =>{t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
103
  fig.tight_layout()
104
 
105
  buf = io.BytesIO()
@@ -107,18 +107,23 @@ def predict(input_image):
107
  buf.seek(0)
108
  pil_img = Image.open(buf)
109
  plt.close()
110
- logger.info('--- output generated')
111
- return pil_img
 
 
 
112
 
113
  iface = gr.Interface(
114
  predict,
115
  title='Predict Artist and Artistic Style of Drawings πŸŽ¨πŸ‘¨πŸ»β€πŸŽ¨ (prototype)',
116
- description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
117
  inputs=[
118
  gr.inputs.Image(label='Upload a drawing/image', type='file')
119
  ],
120
  outputs=[
121
- gr.outputs.Image(label='Prediction')
 
 
122
  ],
123
  examples=EXAMPLES,
124
  )
 
98
  ax3.imshow(t_img)
99
 
100
  ax1.set_title(f'Input Image', ha='left', x=0, y=1.05)
101
+ ax2.set_title(f'Artist Prediction:\n => {a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
102
+ ax3.set_title(f'Style Prediction:\n => {t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
103
  fig.tight_layout()
104
 
105
  buf = io.BytesIO()
 
107
  buf.seek(0)
108
  pil_img = Image.open(buf)
109
  plt.close()
110
+ logger.info('--- image generated')
111
+
112
+ a_labels = {id2artist[i]: float(pred) for i, pred in enumerate(a_pred_out)}
113
+ t_labels = {id2trend[i]: float(pred) for i, pred in enumerate(t_pred_out)}
114
+ return a_labels, t_labels, pil_img
115
 
116
  iface = gr.Interface(
117
  predict,
118
  title='Predict Artist and Artistic Style of Drawings πŸŽ¨πŸ‘¨πŸ»β€πŸŽ¨ (prototype)',
119
+ description='Upload a drawing/image and the model will predict how likely it seems given 10 artists and their trend/style',
120
  inputs=[
121
  gr.inputs.Image(label='Upload a drawing/image', type='file')
122
  ],
123
  outputs=[
124
+ gr.outputs.Label(label='Artists', num_top_classes=5, type='auto'),
125
+ gr.outputs.Label(label='Styles', num_top_classes=5, type='auto'),
126
+ gr.outputs.Image(label='Prediction with GradCAM')
127
  ],
128
  examples=EXAMPLES,
129
  )