Daniel Verdu commited on
Commit
d984001
2 Parent(s): 0cb9530 878ecf2

merged changes

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
__pycache__/app_utils.cpython-37.pyc ADDED
Binary file (3.35 kB). View file
 
app.py CHANGED
@@ -1,4 +1,4 @@
1
- #importing the libraries
2
  import os, sys, re
3
  import streamlit as st
4
  import PIL
@@ -6,21 +6,8 @@ from PIL import Image
6
  import cv2
7
  import numpy as np
8
  import uuid
9
-
10
- import ssl
11
- ssl._create_default_https_context = ssl._create_unverified_context
12
-
13
- # Import torch libraries
14
- import fastai
15
- import torch
16
-
17
- # Import util functions from app_utils
18
- from app_utils import download
19
- from app_utils import generate_random_filename
20
- from app_utils import clean_me
21
- from app_utils import clean_all
22
- from app_utils import get_model_bin
23
- from app_utils import convertToJPG
24
 
25
  # Import util functions from deoldify
26
  # NOTE: This must be the first call in order to work properly!
@@ -30,13 +17,17 @@ from deoldify.device_id import DeviceId
30
  device.set(device=DeviceId.CPU)
31
  from deoldify.visualize import *
32
 
 
 
 
 
33
 
34
  ####### INPUT PARAMS ###########
35
  model_folder = 'models/'
36
  max_img_size = 800
37
  ################################
38
 
39
- @st.cache(allow_output_mutation=True)
40
  def load_model(model_dir, option):
41
  if option.lower() == 'artistic':
42
  model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
@@ -68,42 +59,132 @@ def resize_img(input_img, max_size):
68
 
69
  return img
70
 
71
- def get_image_download_link(img,filename,text):
72
  button_uuid = str(uuid.uuid4()).replace('-', '')
73
  button_id = re.sub('\d+', '', button_uuid)
74
-
75
- custom_css = f"""
76
- <style>
77
- #{button_id} {{
78
- background-color: rgb(255, 255, 255);
79
- color: rgb(38, 39, 48);
80
- padding: 0.25em 0.38em;
81
- position: relative;
82
- text-decoration: none;
83
- border-radius: 4px;
84
- border-width: 1px;
85
- border-style: solid;
86
- border-color: rgb(230, 234, 241);
87
- border-image: initial;
88
-
89
- }}
90
- #{button_id}:hover {{
91
- border-color: rgb(246, 51, 102);
92
- color: rgb(246, 51, 102);
93
- }}
94
- #{button_id}:active {{
95
- box-shadow: none;
96
- background-color: rgb(246, 51, 102);
97
- color: white;
98
- }}
99
- </style> """
100
 
101
  buffered = BytesIO()
102
  img.save(buffered, format="JPEG")
103
  img_str = base64.b64encode(buffered.getvalue()).decode()
104
- href = custom_css + f'<a href="data:file/txt;base64,{img_str}" id="{button_id}" download="{filename}">{text}</a>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  return href
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # General configuration
109
  # st.set_page_config(layout="centered")
@@ -118,12 +199,16 @@ unsafe_allow_html=True)
118
  # Main window configuration
119
  st.title("Black and white colorizer")
120
  st.markdown("This app puts color into your black and white pictures")
121
- title_message = st.empty()
 
 
 
 
122
 
123
- title_message.markdown("**Model loading, please wait** ⌛")
124
 
125
  # # Sidebar
126
- color_option = st.sidebar.selectbox('Select colorizer mode',
127
  ('Artistic', 'Stable'))
128
 
129
  # st.sidebar.title('Model parameters')
@@ -132,40 +217,27 @@ color_option = st.sidebar.selectbox('Select colorizer mode',
132
 
133
  # Load models
134
  try:
135
- colorizer = load_model(model_folder, color_option)
 
 
 
136
  except Exception as e:
137
- print(e)
138
  colorizer = None
139
  print('Error while loading the model. Please refresh the page')
 
 
140
 
141
  if colorizer is not None:
142
- print('Running colorizer')
143
- title_message.markdown("**To begin, please upload an image** 👇")
144
 
145
  #Choose your own image
146
- uploaded_file = st.file_uploader("Upload a black and white photo", type=['png', 'jpg', 'jpeg'])
147
-
148
- # show = st.image(use_column_width='auto')
149
- input_img_pos = st.empty()
150
- output_img_pos = st.empty()
151
 
152
- if uploaded_file is not None:
153
- img_name = uploaded_file.name
154
-
155
- pil_img = PIL.Image.open(uploaded_file)
156
- img_rgb = np.array(pil_img)
157
-
158
- resized_img_rgb = resize_img(img_rgb, max_img_size)
159
- resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
160
-
161
- title_message.markdown("**Processing your image, please wait** ⌛")
162
-
163
- output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
164
-
165
- title_message.markdown("**To begin, please upload an image** 👇")
166
 
167
- # Plot images
168
- input_img_pos.image(resized_pil_img, 'Input image', use_column_width=True)
169
- output_img_pos.image(output_pil_img, 'Output image', use_column_width=True)
170
 
171
- st.markdown(get_image_download_link(output_pil_img, img_name, 'Download '+img_name), unsafe_allow_html=True)
 
1
+ # Import general purpose libraries
2
  import os, sys, re
3
  import streamlit as st
4
  import PIL
 
6
  import cv2
7
  import numpy as np
8
  import uuid
9
+ from zipfile import ZipFile, ZIP_DEFLATED
10
+ from io import BytesIO
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Import util functions from deoldify
13
  # NOTE: This must be the first call in order to work properly!
 
17
  device.set(device=DeviceId.CPU)
18
  from deoldify.visualize import *
19
 
20
+ # Import util functions from app_utils
21
+ from app_utils import get_model_bin
22
+
23
+
24
 
25
  ####### INPUT PARAMS ###########
26
  model_folder = 'models/'
27
  max_img_size = 800
28
  ################################
29
 
30
+ @st.cache(allow_output_mutation=True, show_spinner=False)
31
  def load_model(model_dir, option):
32
  if option.lower() == 'artistic':
33
  model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
 
59
 
60
  return img
61
 
62
+ def get_image_download_link(img, filename, button_text):
63
  button_uuid = str(uuid.uuid4()).replace('-', '')
64
  button_id = re.sub('\d+', '', button_uuid)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  buffered = BytesIO()
67
  img.save(buffered, format="JPEG")
68
  img_str = base64.b64encode(buffered.getvalue()).decode()
69
+
70
+ return get_button_html_code(img_str, filename, 'txt', button_id, button_text)
71
+
72
+ def get_button_html_code(data_str, filename, filetype, button_id, button_txt='Download file'):
73
+ custom_css = f"""
74
+ <style>
75
+ #{button_id} {{
76
+ background-color: rgb(255, 255, 255);
77
+ color: rgb(38, 39, 48);
78
+ padding: 0.25em 0.38em;
79
+ position: relative;
80
+ text-decoration: none;
81
+ border-radius: 4px;
82
+ border-width: 1px;
83
+ border-style: solid;
84
+ border-color: rgb(230, 234, 241);
85
+ border-image: initial;
86
+
87
+ }}
88
+ #{button_id}:hover {{
89
+ border-color: rgb(246, 51, 102);
90
+ color: rgb(246, 51, 102);
91
+ }}
92
+ #{button_id}:active {{
93
+ box-shadow: none;
94
+ background-color: rgb(246, 51, 102);
95
+ color: white;
96
+ }}
97
+ </style> """
98
+
99
+ href = custom_css + f'<a href="data:file/{filetype};base64,{data_str}" id="{button_id}" download="{filename}">{button_txt}</a>'
100
  return href
101
 
102
+ def display_single_image(uploaded_file, img_size=800):
103
+ print('Type: ', type(uploaded_file))
104
+ st_title_message.markdown("**Processing your image, please wait** ⌛")
105
+ img_name = uploaded_file.name
106
+
107
+ # Open the image
108
+ pil_img = PIL.Image.open(uploaded_file)
109
+ img_rgb = np.array(pil_img)
110
+ resized_img_rgb = resize_img(img_rgb, img_size)
111
+ resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
112
+
113
+ # Send the image to the model
114
+ output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
115
+
116
+ # Plot images
117
+ st_input_img.image(resized_pil_img, 'Input image', use_column_width=True)
118
+ st_output_img.image(output_pil_img, 'Output image', use_column_width=True)
119
+
120
+ # Show download button
121
+ st_download_button.markdown(get_image_download_link(output_pil_img, img_name, 'Download Image'), unsafe_allow_html=True)
122
+
123
+ # Reset the message
124
+ st_title_message.markdown("**To begin, please upload an image** 👇")
125
+
126
+ def process_multiple_images(uploaded_files, img_size=800):
127
+ num_imgs = len(uploaded_files)
128
+
129
+ output_images_list = []
130
+ img_names_list = []
131
+ idx = 1
132
+ for idx, uploaded_file in enumerate(uploaded_files, start=1):
133
+ st_title_message.markdown("**Processing image {}/{}. Please wait** ⌛".format(idx,
134
+ num_imgs))
135
+
136
+ img_name = uploaded_file.name
137
+ img_type = uploaded_file.type
138
+
139
+ # Open the image
140
+ pil_img = PIL.Image.open(uploaded_file)
141
+ img_rgb = np.array(pil_img)
142
+ resized_img_rgb = resize_img(img_rgb, img_size)
143
+ resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
144
+
145
+ # Send the image to the model
146
+ output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
147
+
148
+ output_images_list.append(output_pil_img)
149
+ img_names_list.append(img_name.split('.')[0])
150
+
151
+ # Zip output files
152
+ zip_path = 'processed_images.zip'
153
+ zip_buf = zip_multiple_images(output_images_list, img_names_list, zip_path)
154
+
155
+ st_download_button.download_button(
156
+ label='Download ZIP file',
157
+ data=zip_buf.read(),
158
+ file_name=zip_path,
159
+ mime="application/zip"
160
+ )
161
+
162
+ # Show message
163
+ st_title_message.markdown("**Images are ready for download** 💾")
164
+
165
+ def zip_multiple_images(pil_images_list, img_names_list, dest_path):
166
+ # Create zip file on memory
167
+ zip_buf = BytesIO()
168
+
169
+ with ZipFile(zip_buf, 'w', ZIP_DEFLATED) as zipObj:
170
+ for pil_img, img_name in zip(pil_images_list, img_names_list):
171
+ with BytesIO() as output:
172
+ # Save image in memory
173
+ pil_img.save(output, format="PNG")
174
+
175
+ # Read data
176
+ contents = output.getvalue()
177
+
178
+ # Write it to zip file
179
+ zipObj.writestr(img_name+".png", contents)
180
+ zip_buf.seek(0)
181
+ return zip_buf
182
+
183
+
184
+
185
+ ###########################
186
+ ###### STREAMLIT CODE #####
187
+ ###########################
188
 
189
  # General configuration
190
  # st.set_page_config(layout="centered")
 
199
  # Main window configuration
200
  st.title("Black and white colorizer")
201
  st.markdown("This app puts color into your black and white pictures")
202
+ st_title_message = st.empty()
203
+ st_file_uploader = st.empty()
204
+ st_input_img = st.empty()
205
+ st_output_img = st.empty()
206
+ st_download_button = st.empty()
207
 
208
+ st_title_message.markdown("**Model loading, please wait** ⌛")
209
 
210
  # # Sidebar
211
+ st_color_option = st.sidebar.selectbox('Select colorizer mode',
212
  ('Artistic', 'Stable'))
213
 
214
  # st.sidebar.title('Model parameters')
 
217
 
218
  # Load models
219
  try:
220
+ print('before loading the model')
221
+ colorizer = load_model(model_folder, st_color_option)
222
+ print('after loading the model')
223
+
224
  except Exception as e:
 
225
  colorizer = None
226
  print('Error while loading the model. Please refresh the page')
227
+ print(e)
228
+ st_title_message.markdown("**Error while loading the model. Please refresh the page**")
229
 
230
  if colorizer is not None:
231
+ st_title_message.markdown("**To begin, please upload an image** 👇")
 
232
 
233
  #Choose your own image
234
+ uploaded_files = st_file_uploader.file_uploader("Upload a black and white photo",
235
+ type=['png', 'jpg', 'jpeg'],
236
+ accept_multiple_files=True)
 
 
237
 
238
+ if len(uploaded_files) == 1:
239
+ display_single_image(uploaded_files[0], max_img_size)
240
+ elif len(uploaded_files) > 1:
241
+ process_multiple_images(uploaded_files, max_img_size)
 
 
 
 
 
 
 
 
 
 
242
 
 
 
 
243
 
 
app_utils.py CHANGED
@@ -106,6 +106,7 @@ def clean_all(files):
106
  clean_me(me)
107
 
108
 
 
109
  def get_model_bin(url, output_path):
110
  # print('Getting model dir: ', output_path)
111
  if not os.path.exists(output_path):
@@ -115,14 +116,27 @@ def get_model_bin(url, output_path):
115
  output_folder = output_path.replace('\\','/').split('/')[0]
116
  if not os.path.exists(output_folder):
117
  os.makedirs(output_folder, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
118
 
119
  urllib.request.urlretrieve(url, output_path)
120
 
121
  # cmd = "wget -O %s %s" % (output_path, url)
122
  # print(cmd)
123
  # os.system(cmd)
 
124
  else:
125
  print('Model exists')
 
 
126
 
127
  return output_path
128
 
 
106
  clean_me(me)
107
 
108
 
109
+ <<<<<<< HEAD
110
  def get_model_bin(url, output_path):
111
  # print('Getting model dir: ', output_path)
112
  if not os.path.exists(output_path):
 
116
  output_folder = output_path.replace('\\','/').split('/')[0]
117
  if not os.path.exists(output_folder):
118
  os.makedirs(output_folder, exist_ok=True)
119
+ =======
120
+ def create_directory(path):
121
+ os.makedirs(os.path.dirname(path), exist_ok=True)
122
+
123
+
124
+ def get_model_bin(url, output_path):
125
+ # print('Getting model dir: ', output_path)
126
+ if not os.path.exists(output_path):
127
+ create_directory(output_path)
128
+ >>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
129
 
130
  urllib.request.urlretrieve(url, output_path)
131
 
132
  # cmd = "wget -O %s %s" % (output_path, url)
133
  # print(cmd)
134
  # os.system(cmd)
135
+ <<<<<<< HEAD
136
  else:
137
  print('Model exists')
138
+ =======
139
+ >>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
140
 
141
  return output_path
142
 
deoldify/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (218 Bytes). View file
 
deoldify/__pycache__/_device.cpython-37.pyc ADDED
Binary file (1.32 kB). View file
 
deoldify/__pycache__/augs.cpython-37.pyc ADDED
Binary file (867 Bytes). View file
 
deoldify/__pycache__/critics.cpython-37.pyc ADDED
Binary file (1.52 kB). View file
 
deoldify/__pycache__/dataset.cpython-37.pyc ADDED
Binary file (1.53 kB). View file
 
deoldify/__pycache__/device_id.cpython-37.pyc ADDED
Binary file (510 Bytes). View file
 
deoldify/__pycache__/filters.cpython-37.pyc ADDED
Binary file (4.9 kB). View file
 
deoldify/__pycache__/generators.cpython-37.pyc ADDED
Binary file (3.15 kB). View file
 
deoldify/__pycache__/layers.cpython-37.pyc ADDED
Binary file (1.43 kB). View file
 
deoldify/__pycache__/loss.cpython-37.pyc ADDED
Binary file (6.47 kB). View file
 
deoldify/__pycache__/unet.cpython-37.pyc ADDED
Binary file (8.21 kB). View file
 
deoldify/__pycache__/visualize.cpython-37.pyc ADDED
Binary file (6.62 kB). View file
 
deoldify/generators.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from fastai.basics import F, nn
2
  from fastai.basic_data import DataBunch
3
  from fastai.basic_train import Learner
@@ -5,12 +6,17 @@ from fastai.layers import NormType
5
  from fastai.torch_core import SplitFuncOrIdxList, to_device, apply_init
6
  from fastai.vision import *
7
  from fastai.vision.learner import cnn_config, create_body
 
 
 
 
8
  from .unet import DynamicUnetWide, DynamicUnetDeep
9
  from .loss import FeatureLoss
10
  from .dataset import *
11
 
12
  # Weights are implicitly read from ./models/ folder
13
  def gen_inference_wide(
 
14
  root_folder: Path, weights_name: str, nf_factor: int = 2,
15
  arch=models.resnet101
16
  ) -> Learner:
@@ -41,6 +47,16 @@ def get_inference(learn, root_folder, weights_name) -> Learner:
41
  print('Error while reading the model')
42
  learn.model.eval()
43
 
 
 
 
 
 
 
 
 
 
 
44
  return learn
45
 
46
 
@@ -104,10 +120,29 @@ def unet_learner_wide(
104
 
105
  # ----------------------------------------------------------------------
106
 
 
107
  def gen_learner_deep(data: ImageDataBunch, gen_loss, arch=models.resnet34,
108
  nf_factor: float = 1.5
109
  ) -> Learner:
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  return unet_learner_deep(
112
  data,
113
  arch,
@@ -123,6 +158,7 @@ def gen_learner_deep(data: ImageDataBunch, gen_loss, arch=models.resnet34,
123
 
124
  # The code below is meant to be merged into fastaiv1 ideally
125
  def unet_learner_deep(
 
126
  data: DataBunch,
127
  arch: Callable,
128
  pretrained: bool = True,
@@ -138,6 +174,22 @@ def unet_learner_deep(
138
  **kwargs: Any
139
  ) -> Learner:
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  "Build Unet learner from `data` and `arch`."
142
  meta = cnn_config(arch)
143
  body = create_body(arch, pretrained)
 
1
+ <<<<<<< HEAD
2
  from fastai.basics import F, nn
3
  from fastai.basic_data import DataBunch
4
  from fastai.basic_train import Learner
 
6
  from fastai.torch_core import SplitFuncOrIdxList, to_device, apply_init
7
  from fastai.vision import *
8
  from fastai.vision.learner import cnn_config, create_body
9
+ =======
10
+ from fastai.vision import *
11
+ from fastai.vision.learner import cnn_config
12
+ >>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
13
  from .unet import DynamicUnetWide, DynamicUnetDeep
14
  from .loss import FeatureLoss
15
  from .dataset import *
16
 
17
  # Weights are implicitly read from ./models/ folder
18
  def gen_inference_wide(
19
+ <<<<<<< HEAD
20
  root_folder: Path, weights_name: str, nf_factor: int = 2,
21
  arch=models.resnet101
22
  ) -> Learner:
 
47
  print('Error while reading the model')
48
  learn.model.eval()
49
 
50
+ =======
51
+ root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:
52
+ data = get_dummy_databunch()
53
+ learn = gen_learner_wide(
54
+ data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
55
+ )
56
+ learn.path = root_folder
57
+ learn.load(weights_name)
58
+ learn.model.eval()
59
+ >>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
60
  return learn
61
 
62
 
 
120
 
121
  # ----------------------------------------------------------------------
122
 
123
+ <<<<<<< HEAD
124
  def gen_learner_deep(data: ImageDataBunch, gen_loss, arch=models.resnet34,
125
  nf_factor: float = 1.5
126
  ) -> Learner:
127
 
128
+ =======
129
+ # Weights are implicitly read from ./models/ folder
130
+ def gen_inference_deep(
131
+ root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:
132
+ data = get_dummy_databunch()
133
+ learn = gen_learner_deep(
134
+ data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
135
+ )
136
+ learn.path = root_folder
137
+ learn.load(weights_name)
138
+ learn.model.eval()
139
+ return learn
140
+
141
+
142
+ def gen_learner_deep(
143
+ data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5
144
+ ) -> Learner:
145
+ >>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
146
  return unet_learner_deep(
147
  data,
148
  arch,
 
158
 
159
  # The code below is meant to be merged into fastaiv1 ideally
160
  def unet_learner_deep(
161
+ <<<<<<< HEAD
162
  data: DataBunch,
163
  arch: Callable,
164
  pretrained: bool = True,
 
174
  **kwargs: Any
175
  ) -> Learner:
176
 
177
+ =======
178
+ data: DataBunch,
179
+ arch: Callable,
180
+ pretrained: bool = True,
181
+ blur_final: bool = True,
182
+ norm_type: Optional[NormType] = NormType,
183
+ split_on: Optional[SplitFuncOrIdxList] = None,
184
+ blur: bool = False,
185
+ self_attention: bool = False,
186
+ y_range: Optional[Tuple[float, float]] = None,
187
+ last_cross: bool = True,
188
+ bottle: bool = False,
189
+ nf_factor: float = 1.5,
190
+ **kwargs: Any
191
+ ) -> Learner:
192
+ >>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
193
  "Build Unet learner from `data` and `arch`."
194
  meta = cnn_config(arch)
195
  body = create_body(arch, pretrained)
deoldify/visualize.py CHANGED
@@ -19,13 +19,21 @@ from .generators import gen_inference_deep, gen_inference_wide
19
 
20
 
21
 
 
 
 
 
22
  class ModelImageVisualizer:
23
  def __init__(self, filter: IFilter, results_dir: str = None):
24
  self.filter = filter
25
  self.results_dir = None if results_dir is None else Path(results_dir)
 
26
 
27
  if self.results_dir is not None:
28
  self.results_dir.mkdir(parents=True, exist_ok=True)
 
 
 
29
 
30
  def _clean_mem(self):
31
  torch.cuda.empty_cache()
@@ -215,15 +223,22 @@ class ModelImageVisualizer:
215
  return rows, columns
216
 
217
 
 
218
  def get_image_colorizer(root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
219
  ) -> ModelImageVisualizer:
220
 
 
 
 
 
 
221
  if artistic:
222
  return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
223
  else:
224
  return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
225
 
226
 
 
227
  def get_stable_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeStable_gen',
228
  results_dir='result_images', render_factor: int = 35
229
  ) -> ModelImageVisualizer:
@@ -243,4 +258,27 @@ def get_artistic_image_colorizer(root_folder: Path = Path('./'), weights_name: s
243
  filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
244
  # vis = ModelImageVisualizer(filtr, results_dir=results_dir)
245
  vis = ModelImageVisualizer(filtr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  return vis
 
19
 
20
 
21
 
22
+ <<<<<<< HEAD
23
+ =======
24
+ # class LoadedModel
25
+ >>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
26
  class ModelImageVisualizer:
27
  def __init__(self, filter: IFilter, results_dir: str = None):
28
  self.filter = filter
29
  self.results_dir = None if results_dir is None else Path(results_dir)
30
+ <<<<<<< HEAD
31
 
32
  if self.results_dir is not None:
33
  self.results_dir.mkdir(parents=True, exist_ok=True)
34
+ =======
35
+ self.results_dir.mkdir(parents=True, exist_ok=True)
36
+ >>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
37
 
38
  def _clean_mem(self):
39
  torch.cuda.empty_cache()
 
223
  return rows, columns
224
 
225
 
226
+ <<<<<<< HEAD
227
  def get_image_colorizer(root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
228
  ) -> ModelImageVisualizer:
229
 
230
+ =======
231
+ def get_image_colorizer(
232
+ root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
233
+ ) -> ModelImageVisualizer:
234
+ >>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
235
  if artistic:
236
  return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
237
  else:
238
  return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
239
 
240
 
241
+ <<<<<<< HEAD
242
  def get_stable_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeStable_gen',
243
  results_dir='result_images', render_factor: int = 35
244
  ) -> ModelImageVisualizer:
 
258
  filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
259
  # vis = ModelImageVisualizer(filtr, results_dir=results_dir)
260
  vis = ModelImageVisualizer(filtr)
261
+ =======
262
+ def get_stable_image_colorizer(
263
+ root_folder: Path = Path('./'),
264
+ weights_name: str = 'ColorizeStable_gen',
265
+ results_dir='result_images',
266
+ render_factor: int = 35
267
+ ) -> ModelImageVisualizer:
268
+ learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
269
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
270
+ vis = ModelImageVisualizer(filtr, results_dir=results_dir)
271
+ return vis
272
+
273
+
274
+ def get_artistic_image_colorizer(
275
+ root_folder: Path = Path('./'),
276
+ weights_name: str = 'ColorizeArtistic_gen',
277
+ results_dir='result_images',
278
+ render_factor: int = 35
279
+ ) -> ModelImageVisualizer:
280
+ learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
281
+ filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
282
+ vis = ModelImageVisualizer(filtr, results_dir=results_dir)
283
+ >>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
284
  return vis