Chakshu123 commited on
Commit
77a6843
1 Parent(s): fe4350f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -1
app.py CHANGED
@@ -27,6 +27,8 @@ print('Use device:', device)
27
 
28
  net = torch.jit.load(f'weights/pkp-v1.{device}.jit.pt')
29
 
 
 
30
 
31
  def resize_original(img: Image.Image):
32
  if img is None:
@@ -53,7 +55,7 @@ def resize_original(img: Image.Image):
53
  return gr.Image.update(value=guide_img.convert('RGBA')), guide_img.convert('RGBA')
54
 
55
 
56
- def colorize(img: Dict[str, Image.Image], guide_img: Image.Image, seed: int, hint_mode: str):
57
  if not isinstance(img, dict):
58
  return gr.update(visible=True)
59
 
@@ -84,6 +86,23 @@ def colorize(img: Dict[str, Image.Image], guide_img: Image.Image, seed: int, hin
84
  return Image.fromarray(out).convert('RGB')
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  with gr.Blocks() as demo:
88
  gr.Markdown('''<center><h1>Image Colorization With Hint</h1></center>
89
  <h2>Colorize your images/sketches with hint points.</h2>
@@ -110,6 +129,25 @@ with gr.Blocks() as demo:
110
  btn = gr.Button("Run")
111
  with gr.Column():
112
  output = gr.Image(type="pil", label="Output", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  gr.Markdown('''
114
  Upon uploading an image, kindly give color hints at specific points, and then run the model. Average inference time is about 52 seconds.<br />
115
  ''')
@@ -124,6 +162,11 @@ Upon uploading an image, kindly give color hints at specific points, and then ru
124
  [inp, inp_store, seed, hint_mode],
125
  output
126
  )
 
 
 
 
 
127
 
128
  if __name__ == "__main__":
129
  demo.launch()
 
27
 
28
  net = torch.jit.load(f'weights/pkp-v1.{device}.jit.pt')
29
 
30
+ model_net = torch.load(f'weights/colorizer.pt')
31
+
32
 
33
  def resize_original(img: Image.Image):
34
  if img is None:
 
55
  return gr.Image.update(value=guide_img.convert('RGBA')), guide_img.convert('RGBA')
56
 
57
 
58
+ def colorize(img: Dict[str Image.Image], guide_img: Image.Image, seed: int, hint_mode: str):
59
  if not isinstance(img, dict):
60
  return gr.update(visible=True)
61
 
 
86
  return Image.fromarray(out).convert('RGB')
87
 
88
 
89
+ def colorize2(img: Image.Image, model_option: str):
90
+ if not isinstance(img, dict):
91
+ return gr.update(visible=True)
92
+
93
+ if hint_mode == "Model 1":
94
+ model_int = 0
95
+ elif hint_mode == "Model 2":
96
+ model_int = 0
97
+
98
+ with torch.inference_mode():
99
+ out2 = model(input)
100
+ out = sample[0].cpu().numpy().transpose([1,2,0])
101
+ out = np.uint8(((out + 1) / 2 * 255).clip(0,255))
102
+
103
+ return Image.fromarray(out).convert('RGB')
104
+
105
+
106
  with gr.Blocks() as demo:
107
  gr.Markdown('''<center><h1>Image Colorization With Hint</h1></center>
108
  <h2>Colorize your images/sketches with hint points.</h2>
 
129
  btn = gr.Button("Run")
130
  with gr.Column():
131
  output = gr.Image(type="pil", label="Output", interactive=False)
132
+ with gr.Row():
133
+ with gr.Column():
134
+ inp2 = gr.Image(
135
+ source="upload",
136
+ type="pil",
137
+ label="Sketch",
138
+ interactive=True
139
+ )
140
+ inp_store2 = gr.Image(
141
+ type="pil",
142
+ interactive=False
143
+ )
144
+ inp_store2.visible = False
145
+ with gr.Column():
146
+ # seed = gr.Slider(1, 2**32, step=1, label="Seed", interactive=True, randomize=True)
147
+ model_option = gr.Radio(["Model 1", "Model 2"], value="Model 1", label="Model 2")
148
+ btn2 = gr.Button("Run Colorization")
149
+ with gr.Column():
150
+ output2 = gr.Image(type="pil", label="Output2", interactive=False)
151
  gr.Markdown('''
152
  Upon uploading an image, kindly give color hints at specific points, and then run the model. Average inference time is about 52 seconds.<br />
153
  ''')
 
162
  [inp, inp_store, seed, hint_mode],
163
  output
164
  )
165
+ btn2.click(
166
+ colorize2,
167
+ [inp, model_option],
168
+ output2
169
+ )
170
 
171
  if __name__ == "__main__":
172
  demo.launch()