CatVsDog / app.py
PrakhAI's picture
Update app.py
a442541
raw
history blame contribute delete
No virus
3.62 kB
import streamlit as st
from PIL import Image
import jax
import numpy as np
import jax.numpy as jnp # JAX NumPy
from flax.training import train_state # Useful dataclass to keep train state
from flax import linen as nn # Linen API
from huggingface_hub import HfFileSystem
from flax.serialization import msgpack_restore, from_state_dict
import os
import tensorflow as tf
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=16)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
return x
cnn = CNN()
params = cnn.init(jax.random.PRNGKey(0), jnp.ones([2, 50, 50, 3]))['params']
fs = HfFileSystem()
with fs.open("PrakhAI/CatVsDog/checkpoint.msgpack", "rb") as f:
params = from_state_dict(params, msgpack_restore(f.read())["params"])
uploaded_files = st.file_uploader("Input images of cats or dogs (examples in files)", type=['jpg','png','tif'], accept_multiple_files=True)
if len(uploaded_files) == 0:
st.write("Please upload an image!")
else:
input = jnp.array([tf.cast(tf.image.resize(tf.convert_to_tensor(Image.open(uploaded_file)), [50, 50]), tf.float32) / 255. for uploaded_file in uploaded_files])
predictions = cnn.apply({"params": params}, input)
for (image, prediction) in zip(uploaded_files, predictions):
st.image(Image.open(image))
[cat_prob, dog_prob] = jax.nn.softmax(prediction)
if cat_prob > dog_prob:
st.write(f"Model Prediction - Cat ({100*cat_prob:.2f}%), Dog ({100*dog_prob:.2f}%)")
else:
st.write(f"Model Prediction - Dog ({100*dog_prob:.2f}%), Cat ({100*cat_prob:.2f}%)")
def gridify_rgb(kernel, grid, kernel_size, scaling=5, padding=1):
scaled_and_padded = np.pad(np.repeat(np.repeat(kernel, repeats=scaling, axis=0), repeats=scaling, axis=1), ((padding,),(padding,),(0,),(0,)), 'constant', constant_values=(-1,))
grid = np.pad(np.array(scaled_and_padded.reshape((kernel_size[0]*scaling+2*padding, kernel_size[1]*scaling+2*padding, 3, grid[0], grid[1])).transpose(3,0,4,1,2).reshape(grid[0]*(kernel_size[0]*scaling+2*padding), grid[1]*(kernel_size[1]*scaling+2*padding), 3)+1)*127., ((padding,),(padding,),(0,)), 'constant', constant_values=(0,))
st.image(Image.fromarray(grid.astype(np.uint8), mode="RGB"))
def gridify_grayscale(kernel, grid, kernel_size, scaling=5, padding=1):
scaled_and_padded = np.pad(np.repeat(np.repeat(kernel, repeats=scaling, axis=0), repeats=scaling, axis=1), ((padding,),(padding,),(0,),(0,)), 'constant', constant_values=(-1,))
grid = np.pad(np.array(scaled_and_padded.reshape((kernel_size[0]*scaling+2*padding, kernel_size[1]*scaling+2*padding, grid[0], grid[1])).transpose(2,0,3,1).reshape(grid[0]*(kernel_size[0]*scaling+2*padding), grid[1]*(kernel_size[1]*scaling+2*padding))+1)*127., (padding,), 'constant', constant_values=(0,))
st.image(Image.fromarray(np.repeat(np.expand_dims(grid, axis=0), repeats=3, axis=0).astype(np.uint8).transpose(1,2,0), mode="RGB"))
with st.expander("See first convolutional layer"):
gridify_rgb(params["Conv_0"]["kernel"], grid=(4,8), kernel_size=(3,3))
with st.expander("See second convolutional layer"):
gridify_grayscale(params["Conv_1"]["kernel"], grid=(32,64), kernel_size=(3,3))
st.write("The model and its details are at https://huggingface.co/PrakhAI/CatVsDog")