abhirajeshbhai's picture
implement unet and deplot
9205986
raw
history blame
648 Bytes
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from model import model, image_transforms
def col_select(value):
print(value)
st.title("Banan Image Colorizer")
upload_file = st.file_uploader("Upload Image")
if upload_file:
image = upload_file
image = Image.open(image)
image_gs = image_transforms(image)
image_gs_prev = image_gs.permute(1, 2, 0).detach().cpu().numpy()
image_color = model(image_gs.unsqueeze(0)).squeeze().permute(1, 2, 0).detach().cpu().numpy()
col1, col2 = st.columns(2)
col1.image(image_gs_prev)
col2.image(image_color, clamp=True, channels='RGB')