File size: 5,230 Bytes
a4f4faa
 
 
 
 
 
 
 
 
 
 
 
 
 
3fc94e2
9801dba
bf5d203
3fc94e2
a4f4faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
898e49f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4f4faa
 
 
 
 
8792b08
bf5d203
 
1eb089e
a4f4faa
 
 
 
 
9801dba
 
 
3fc94e2
 
 
 
328206f
3fc94e2
898e49f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8792b08
898e49f
 
 
 
bf5d203
 
 
 
64f91ef
bf5d203
6db8171
eddc711
bf5d203
 
3c9073a
8792b08
1a8a3d2
1a9f3e6
 
8792b08
006d371
 
c35f067
 
 
 
477c58b
 
 
 
 
64f91ef
 
 
bf5d203
8792b08
006d371
c562de5
c35f067
 
 
 
c99486c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import streamlit as st
import urllib.request
import PIL.Image
from PIL import Image
import requests
import fastai
from fastai.vision import *
from fastai.utils.mem import *
from fastai.vision import open_image, load_learner, image, torch
import numpy as np
from urllib.request import urlretrieve
from io import BytesIO
import numpy as np
import torchvision.transforms as T
from PIL import Image,ImageOps,ImageFilter
from io import BytesIO
import os




class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()


def getNeighbours(i, j, n, m) :
    arr = []
    if i-1 >= 0 and j-1 >= 0 :
        arr.append((i-1, j-1))
    if i-1 >= 0 :
        arr.append((i-1, j))
    if i-1 >= 0 and j+1 < m :
        arr.append((i-1, j+1))
    if j+1 < m :
        arr.append((i, j+1))
    if i+1 < n and j+1 < m :
        arr.append((i+1, j+1))
    if i+1 < n :
        arr.append((i+1, j))
    if i+1 < n and j-1 >= 0 :
        arr.append((i+1, j-1))
    if j-1 >= 0 :
        arr.append((i, j-1))
    return arr

MODEL_URL = "https://www.dropbox.com/s/05ong36r29h51ov/popd.pkl?dl=1"
urllib.request.urlretrieve(MODEL_URL, "popd.pkl")
path = Path(".")
learn=load_learner(path, 'popd.pkl')

def predict(image,colour):
  img_fast = open_image(image)
  a = PIL.Image.open(image).convert('RGB')
  st.image(a, caption='Input')
  p,img_hr,b = learn.predict(img_fast)
  x = np.minimum(np.maximum(image2np(img_hr.data*255), 0), 255).astype(np.uint8)
  img = PIL.Image.fromarray(x).convert('RGB')
  size = a.size
  im1 = img.resize(size)
  membuf = BytesIO()
  im1.save(membuf, format="png") 
  im = Image.open(membuf)
  im = im.convert('RGBA')
  data = np.array(im)   # "data" is a height x width x 4 numpy array
  red, green, blue, alpha = data.T # Temporarily unpack the bands for readability'
  white_areas = (red == 0) & (blue == 0) & (green == 0)
  data[..., :-1][white_areas.T] = (0,0,0) # Transpose back needed
  im2 = Image.fromarray(data)
  membuf = BytesIO()
  im2.save(membuf, format="png") 
  img = Image.open(membuf)
  bitmap = img.load()
  n = img.size[0]
  m = img.size[1]
  stateMap = []
  for i in range(n):
    stateMap.append([False for j in range(m)])  
  queue = [(0, 0)]
  while queue:
    e = queue.pop(0)
    i = e[0]
    j = e[1]
    if not stateMap[i][j]:
      stateMap[i][j] = True
      color = int((bitmap[i, j][0] + bitmap[i, j][1] + bitmap[i, j][2])/3)
      if color > 100:
        bitmap[i, j] =colour
        neigh = getNeighbours(i, j, n, m)
        for ne in neigh:
          queue.append(ne)

  return st.image(img, caption='PoP ArT')
 
SIDEBAR_OPTION_DEMO_IMAGE = "Select a Demo Image"
SIDEBAR_OPTION_UPLOAD_IMAGE = "Upload an Image"
#SIDEBAR_OPTION_COLOUR_IMAGE = "Choose a colour"

SIDEBAR_OPTIONS = [SIDEBAR_OPTION_DEMO_IMAGE, SIDEBAR_OPTION_UPLOAD_IMAGE]


app_mode = st.sidebar.selectbox("Please select from the following", SIDEBAR_OPTIONS)
photos = ["fight.jpg","shaolin-kung-fu.jpg"]
colour = ['Red','Blue','Yellow']
if app_mode == SIDEBAR_OPTION_DEMO_IMAGE:
  st.sidebar.write(" ------ ")
  option = st.sidebar.selectbox('Please select a sample image, then click Magic Time button', photos)
  colour = st.sidebar.selectbox("Colour", colour)
  if colour == 'Red':
    colour = (185, 39, 40)
  elif colour == 'Blue':
    colour = (40, 96, 219)
  else:
    colour = (249, 223, 2)
  pressed = st.sidebar.button('PoP')
  if pressed:
    st.empty()
    st.sidebar.write('Please wait for the magic to happen! This may take up to a minute.')
    predict(option,colour)

elif app_mode == SIDEBAR_OPTION_UPLOAD_IMAGE:
  uploaded_file = st.file_uploader("Choose an image...")
  if uploaded_file is not None:
    colour = st.sidebar.selectbox("Colour", colour)
    if colour == 'Red':
      colour = (185, 39, 40)
    elif colour == 'Blue':
      colour = (40, 96, 219)
    else:
      colour = (249, 223, 2)
    pressed = st.sidebar.button('PoP')
    if pressed:
      st.empty()
      st.sidebar.write('Please wait for the magic to happen! This may take up to a minute.')
      predict(uploaded_file,colour)