involution / app.py
merve's picture
merve HF staff
Update app.py
ec329bf
from huggingface_hub import from_pretrained_keras
import tensorflow as tf
import gradio as gr
# download the model in the global context
vis_model = from_pretrained_keras("keras-io/involution")
def infer(test_image):
# convert the image to a tensorflow tensor and resize the image
# to a constant 32x32
image = tf.constant(test_image)
image = tf.image.resize(image, (32, 32))
# Use the model and get the activation maps
(inv1_out, inv2_out, inv3_out) = vis_model.predict(image[None, ...])
_, inv1_kernel = inv1_out
_, inv2_kernel = inv2_out
_, inv3_kernel = inv3_out
inv1_kernel = tf.reduce_sum(inv1_kernel, axis=[-1, -2, -3])
inv2_kernel = tf.reduce_sum(inv2_kernel, axis=[-1, -2, -3])
inv3_kernel = tf.reduce_sum(inv3_kernel, axis=[-1, -2, -3])
return (
tf.keras.utils.array_to_img(inv1_kernel[0, ..., None]),
tf.keras.utils.array_to_img(inv2_kernel[0, ..., None]),
tf.keras.utils.array_to_img(inv3_kernel[0, ..., None]),
)
# define the article
article = """<center>
Authors: <a href='https://twitter.com/ariG23498' target='_blank'>Aritra Roy Gosthipaty</a> |
<a href='https://twitter.com/ritwik_raha' target='_blank'>Ritwik Raha</a>
<br>
<a href='https://arxiv.org/abs/2103.06255' target='_blank'>Involution: Inverting the Inherence of Convolution for Visual Recognition</a>
<br>
Convolution Kernel
<img src='https://i.imgur.com/Y7xVrwb.png' alt='Convolution'>
<br>
Involution Kernel
<img src='https://i.imgur.com/jHIW26g.png' alt='Involution'>
</center>"""
# define the description
description="""
Visualize the activation maps from the Involution Kernel.πŸ•΅πŸ»β€β™‚οΈ
"""
iface = gr.Interface(
fn=infer,
title="Involutional Neural Networks",
article=article,
description=description,
inputs=gr.inputs.Image(label="Input Image"),
outputs=[
gr.outputs.Image(label="Activation from Kernel 1"),
gr.outputs.Image(label="Activation from Kernel 2"),
gr.outputs.Image(label="Activation from Kernel 3"),
],
examples=[["examples/lama.jpeg"], ["examples/dalai_lama.jpeg"]],
layout="horizontal",
).launch()