PoPd-PoPArT / app.py
Vijish's picture
Update app.py
69e6377
raw
history blame contribute delete
No virus
5.35 kB
import streamlit as st
import urllib.request
import PIL.Image
from PIL import Image
import requests
import fastai
from fastai.vision import *
from fastai.utils.mem import *
from fastai.vision import open_image, load_learner, image, torch
import numpy as np
from urllib.request import urlretrieve
from io import BytesIO
import numpy as np
import torchvision.transforms as T
from PIL import Image,ImageOps,ImageFilter
from io import BytesIO
import os
class FeatureLoss(nn.Module):
def __init__(self, m_feat, layer_ids, layer_wgts):
super().__init__()
self.m_feat = m_feat
self.loss_features = [self.m_feat[i] for i in layer_ids]
self.hooks = hook_outputs(self.loss_features, detach=False)
self.wgts = layer_wgts
self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
] + [f'gram_{i}' for i in range(len(layer_ids))]
def make_features(self, x, clone=False):
self.m_feat(x)
return [(o.clone() if clone else o) for o in self.hooks.stored]
def forward(self, input, target):
out_feat = self.make_features(target, clone=True)
in_feat = self.make_features(input)
self.feat_losses = [base_loss(input,target)]
self.feat_losses += [base_loss(f_in, f_out)*w
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.metrics = dict(zip(self.metric_names, self.feat_losses))
return sum(self.feat_losses)
def __del__(self): self.hooks.remove()
def getNeighbours(i, j, n, m) :
arr = []
if i-1 >= 0 and j-1 >= 0 :
arr.append((i-1, j-1))
if i-1 >= 0 :
arr.append((i-1, j))
if i-1 >= 0 and j+1 < m :
arr.append((i-1, j+1))
if j+1 < m :
arr.append((i, j+1))
if i+1 < n and j+1 < m :
arr.append((i+1, j+1))
if i+1 < n :
arr.append((i+1, j))
if i+1 < n and j-1 >= 0 :
arr.append((i+1, j-1))
if j-1 >= 0 :
arr.append((i, j-1))
return arr
MODEL_URL = "https://www.dropbox.com/s/05ong36r29h51ov/popd.pkl?dl=1"
urllib.request.urlretrieve(MODEL_URL, "popd.pkl")
path = Path(".")
learn=load_learner(path, 'popd.pkl')
def predict(image,colour):
img_fast = open_image(image)
a = PIL.Image.open(image).convert('RGB')
st.image(a, caption='Input')
p,img_hr,b = learn.predict(img_fast)
x = np.minimum(np.maximum(image2np(img_hr.data*255), 0), 255).astype(np.uint8)
img = PIL.Image.fromarray(x).convert('RGB')
size = a.size
im1 = img.resize(size)
membuf = BytesIO()
im1.save(membuf, format="png")
im = Image.open(membuf)
im = im.convert('RGBA')
data = np.array(im) # "data" is a height x width x 4 numpy array
red, green, blue, alpha = data.T # Temporarily unpack the bands for readability'
white_areas = (red == 0) & (blue == 0) & (green == 0)
data[..., :-1][white_areas.T] = (0,0,0) # Transpose back needed
im2 = Image.fromarray(data)
membuf = BytesIO()
im2.save(membuf, format="png")
img = Image.open(membuf)
bitmap = img.load()
n = img.size[0]
m = img.size[1]
stateMap = []
for i in range(n):
stateMap.append([False for j in range(m)])
queue = [(0, 0)]
while queue:
e = queue.pop(0)
i = e[0]
j = e[1]
if not stateMap[i][j]:
stateMap[i][j] = True
color = int((bitmap[i, j][0] + bitmap[i, j][1] + bitmap[i, j][2])/3)
if color > 100:
bitmap[i, j] =colour
neigh = getNeighbours(i, j, n, m)
for ne in neigh:
queue.append(ne)
return st.image(img, caption='PoP ArT')
SIDEBAR_OPTION_DEMO_IMAGE = "Select a Demo Image"
SIDEBAR_OPTION_UPLOAD_IMAGE = "Upload an Image"
#SIDEBAR_OPTION_COLOUR_IMAGE = "Choose a colour"
SIDEBAR_OPTIONS = [SIDEBAR_OPTION_DEMO_IMAGE, SIDEBAR_OPTION_UPLOAD_IMAGE]
st.sidebar.write("Check out GitHub [link](https://github.com/vijishmadhavan/PoPd)")
app_mode = st.sidebar.selectbox("Please select from the following", SIDEBAR_OPTIONS)
photos = ["fight.jpg","shaolin-kung-fu.jpg","unnamed.jpg","michael-jackson.png"]
colour = ['Red','Blue','Yellow']
if app_mode == SIDEBAR_OPTION_DEMO_IMAGE:
st.sidebar.write(" ------ ")
option = st.sidebar.selectbox('Please select a sample image,colour and then click PoP button', photos)
colour = st.sidebar.selectbox("Colour", colour)
if colour == 'Red':
colour = (185, 39, 40)
elif colour == 'Blue':
colour = (40, 96, 219)
else:
colour = (249, 223, 2)
pressed = st.sidebar.button('PoP')
if pressed:
st.empty()
st.sidebar.write('Please wait for the magic to happen! This may take up to a minute.')
predict(option,colour)
elif app_mode == SIDEBAR_OPTION_UPLOAD_IMAGE:
uploaded_file = st.file_uploader("Choose an image...")
if uploaded_file is not None:
colour = st.sidebar.selectbox("Colour", colour)
if colour == 'Red':
colour = (185, 39, 40)
elif colour == 'Blue':
colour = (40, 96, 219)
else:
colour = (249, 223, 2)
pressed = st.sidebar.button('PoP')
if pressed:
st.empty()
st.sidebar.write('Please wait for the magic to happen! This may take up to a minute.')
predict(uploaded_file,colour)