Spaces:
Running
Running
Lodor
commited on
Commit
•
47c60f9
1
Parent(s):
bb85d6b
Initial commit
Browse files- .gitignore +124 -0
- .streamlit/config.toml +6 -0
- Dockerfile +9 -0
- README.md +1 -0
- app.py +158 -0
- app_.py +57 -0
- assets/demo.jpg +0 -0
- docker-compose.yml +14 -0
- models/.gitkeep +0 -0
- requirements.txt +12 -0
- src/__init__.py +0 -0
- src/app_utils.py +131 -0
- src/deoldify/__init__.py +3 -0
- src/deoldify/_device.py +30 -0
- src/deoldify/augs.py +29 -0
- src/deoldify/critics.py +44 -0
- src/deoldify/dataset.py +48 -0
- src/deoldify/device_id.py +12 -0
- src/deoldify/filters.py +120 -0
- src/deoldify/generators.py +151 -0
- src/deoldify/layers.py +48 -0
- src/deoldify/loss.py +136 -0
- src/deoldify/save.py +29 -0
- src/deoldify/unet.py +285 -0
- src/deoldify/visualize.py +247 -0
- src/st_style.py +42 -0
.gitignore
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
|
53 |
+
# Translations
|
54 |
+
*.mo
|
55 |
+
*.pot
|
56 |
+
|
57 |
+
# Django stuff:
|
58 |
+
*.log
|
59 |
+
local_settings.py
|
60 |
+
db.sqlite3
|
61 |
+
db.sqlite3-journal
|
62 |
+
|
63 |
+
# Flask stuff:
|
64 |
+
instance/
|
65 |
+
.webassets-cache
|
66 |
+
|
67 |
+
# Scrapy stuff:
|
68 |
+
.scrapy
|
69 |
+
|
70 |
+
# Sphinx documentation
|
71 |
+
docs/_build/
|
72 |
+
|
73 |
+
# PyBuilder
|
74 |
+
target/
|
75 |
+
|
76 |
+
# Jupyter Notebook
|
77 |
+
.ipynb_checkpoints
|
78 |
+
|
79 |
+
# IPython
|
80 |
+
profile_default/
|
81 |
+
ipython_config.py
|
82 |
+
|
83 |
+
# pyenv
|
84 |
+
.python-version
|
85 |
+
|
86 |
+
# pipenv
|
87 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
88 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
89 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
90 |
+
# install all needed dependencies.
|
91 |
+
#Pipfile.lock
|
92 |
+
|
93 |
+
# celery beat schedule file
|
94 |
+
celerybeat-schedule
|
95 |
+
|
96 |
+
# SageMath parsed files
|
97 |
+
*.sage.py
|
98 |
+
|
99 |
+
# Environments
|
100 |
+
.env
|
101 |
+
.venv
|
102 |
+
env/
|
103 |
+
venv/
|
104 |
+
ENV/
|
105 |
+
env.bak/
|
106 |
+
venv.bak/
|
107 |
+
|
108 |
+
# Spyder project settings
|
109 |
+
.spyderproject
|
110 |
+
.spyproject
|
111 |
+
|
112 |
+
# Rope project settings
|
113 |
+
.ropeproject
|
114 |
+
|
115 |
+
# mkdocs documentation
|
116 |
+
/site
|
117 |
+
|
118 |
+
# mypy
|
119 |
+
.mypy_cache/
|
120 |
+
.dmypy.json
|
121 |
+
dmypy.json
|
122 |
+
|
123 |
+
# Pyre type checker
|
124 |
+
.pyre/
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
maxUploadSize = 10
|
3 |
+
|
4 |
+
[theme]
|
5 |
+
base="light"
|
6 |
+
primaryColor="#0074ff"
|
Dockerfile
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:latest
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY . .
|
6 |
+
|
7 |
+
RUN pip install -r requirements.txt
|
8 |
+
|
9 |
+
CMD [ "streamlit", "run", "app.py" ]
|
README.md
CHANGED
@@ -5,6 +5,7 @@ colorFrom: gray
|
|
5 |
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.2.0
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
5 |
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.2.0
|
8 |
+
python_version: 3.9.5
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
---
|
app.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on: https://github.com/jantic/DeOldify
|
2 |
+
import os, re, time
|
3 |
+
|
4 |
+
os.environ["TORCH_HOME"] = os.path.join(os.getcwd(), ".cache")
|
5 |
+
os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), ".cache")
|
6 |
+
|
7 |
+
import streamlit as st
|
8 |
+
import PIL
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import uuid
|
12 |
+
from zipfile import ZipFile, ZIP_DEFLATED
|
13 |
+
from io import BytesIO
|
14 |
+
from random import randint
|
15 |
+
from datetime import datetime
|
16 |
+
|
17 |
+
from src.deoldify import device
|
18 |
+
from src.deoldify.device_id import DeviceId
|
19 |
+
from src.deoldify.visualize import *
|
20 |
+
from src.app_utils import get_model_bin
|
21 |
+
|
22 |
+
|
23 |
+
device.set(device=DeviceId.CPU)
|
24 |
+
|
25 |
+
|
26 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
27 |
+
def load_model(model_dir, option):
|
28 |
+
if option.lower() == 'artistic':
|
29 |
+
model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
|
30 |
+
get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
|
31 |
+
colorizer = get_image_colorizer(artistic=True)
|
32 |
+
elif option.lower() == 'stable':
|
33 |
+
model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
|
34 |
+
get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
|
35 |
+
colorizer = get_image_colorizer(artistic=False)
|
36 |
+
|
37 |
+
return colorizer
|
38 |
+
|
39 |
+
|
40 |
+
def resize_img(input_img, max_size):
|
41 |
+
img = input_img.copy()
|
42 |
+
img_height, img_width = img.shape[0],img.shape[1]
|
43 |
+
|
44 |
+
if max(img_height, img_width) > max_size:
|
45 |
+
if img_height > img_width:
|
46 |
+
new_width = img_width*(max_size/img_height)
|
47 |
+
new_height = max_size
|
48 |
+
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
|
49 |
+
return resized_img
|
50 |
+
|
51 |
+
elif img_height <= img_width:
|
52 |
+
new_width = img_height*(max_size/img_width)
|
53 |
+
new_height = max_size
|
54 |
+
resized_img = cv2.resize(img,(int(new_width), int(new_height)))
|
55 |
+
return resized_img
|
56 |
+
|
57 |
+
return img
|
58 |
+
|
59 |
+
|
60 |
+
def colorize_image(pil_image, img_size=800) -> "PIL.Image":
|
61 |
+
# Open the image
|
62 |
+
pil_img = pil_image.convert("RGB")
|
63 |
+
img_rgb = np.array(pil_img)
|
64 |
+
resized_img_rgb = resize_img(img_rgb, img_size)
|
65 |
+
resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
|
66 |
+
|
67 |
+
# Send the image to the model
|
68 |
+
output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
|
69 |
+
|
70 |
+
return output_pil_img
|
71 |
+
|
72 |
+
|
73 |
+
def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
|
74 |
+
if fmt not in ["jpg", "png"]:
|
75 |
+
raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
|
76 |
+
|
77 |
+
pil_format = "JPEG" if fmt == "jpg" else "PNG"
|
78 |
+
file_format = "jpg" if fmt == "jpg" else "png"
|
79 |
+
mime = "image/jpeg" if fmt == "jpg" else "image/png"
|
80 |
+
|
81 |
+
buf = BytesIO()
|
82 |
+
pil_image.save(buf, format=pil_format)
|
83 |
+
|
84 |
+
return st.download_button(
|
85 |
+
label=label,
|
86 |
+
data=buf.getvalue(),
|
87 |
+
file_name=f'{filename}.{file_format}',
|
88 |
+
mime=mime,
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
###########################
|
93 |
+
###### STREAMLIT CODE #####
|
94 |
+
###########################
|
95 |
+
|
96 |
+
|
97 |
+
st_color_option = "Artistic"
|
98 |
+
|
99 |
+
# Load models
|
100 |
+
try:
|
101 |
+
with st.spinner("Loading..."):
|
102 |
+
print('before loading the model')
|
103 |
+
colorizer = load_model('models/', st_color_option)
|
104 |
+
print('after loading the model')
|
105 |
+
|
106 |
+
except Exception as e:
|
107 |
+
colorizer = None
|
108 |
+
print('Error while loading the model. Please refresh the page')
|
109 |
+
print(e)
|
110 |
+
st.write("**App loading error. Please try again later.**")
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
if colorizer is not None:
|
115 |
+
st.title("AI Photo Colorization")
|
116 |
+
|
117 |
+
st.image(open("assets/demo.jpg", "rb").read())
|
118 |
+
|
119 |
+
st.markdown(
|
120 |
+
"""
|
121 |
+
Colorizing black & white photo can be expensive and time consuming. We introduce AI that can colorize
|
122 |
+
grayscale photo in seconds. **Just upload your grayscale image, then click colorize.**
|
123 |
+
"""
|
124 |
+
)
|
125 |
+
|
126 |
+
uploaded_file = st.file_uploader("Upload photo", accept_multiple_files=False, type=["png", "jpg", "jpeg"])
|
127 |
+
|
128 |
+
if uploaded_file is not None:
|
129 |
+
bytes_data = uploaded_file.getvalue()
|
130 |
+
img_input = PIL.Image.open(BytesIO(bytes_data)).convert("RGB")
|
131 |
+
|
132 |
+
with st.expander("Original photo", True):
|
133 |
+
st.image(img_input)
|
134 |
+
|
135 |
+
if st.button("Colorize!") and uploaded_file is not None:
|
136 |
+
|
137 |
+
with st.spinner("AI is doing the magic!"):
|
138 |
+
img_output = colorize_image(img_input)
|
139 |
+
img_output = img_output.resize(img_input.size)
|
140 |
+
|
141 |
+
# NOTE: Calm! I'm not logging the input and outputs.
|
142 |
+
# It is impossible to access the filesystem in spaces environment.
|
143 |
+
now = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
|
144 |
+
img_input.convert("RGB").save(f"./output/{now}-input.jpg")
|
145 |
+
img_output.convert("RGB").save(f"./output/{now}-output.jpg")
|
146 |
+
|
147 |
+
st.write("AI has finished the job!")
|
148 |
+
st.image(img_output)
|
149 |
+
# reuse = st.button('Edit again (Re-use this image)', on_click=set_image, args=(inpainted_img, ))
|
150 |
+
|
151 |
+
uploaded_name = os.path.splitext(uploaded_file.name)[0]
|
152 |
+
image_download_button(
|
153 |
+
pil_image=img_output,
|
154 |
+
filename=uploaded_name,
|
155 |
+
fmt="jpg",
|
156 |
+
label="Download Image"
|
157 |
+
)
|
158 |
+
|
app_.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# App code based on:
|
2 |
+
# Model based on:
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import streamlit as st
|
7 |
+
import os
|
8 |
+
from datetime import datetime
|
9 |
+
from PIL import Image
|
10 |
+
from streamlit_drawable_canvas import st_canvas
|
11 |
+
from io import BytesIO
|
12 |
+
from copy import deepcopy
|
13 |
+
|
14 |
+
from src.core import process_inpaint
|
15 |
+
|
16 |
+
|
17 |
+
st.title("AI Photo Colorization")
|
18 |
+
|
19 |
+
st.image(open("assets/demo.png", "rb").read())
|
20 |
+
|
21 |
+
st.markdown(
|
22 |
+
"""
|
23 |
+
Colorizing black & white photo can be expensive and time consuming. We introduce AI that can colorize
|
24 |
+
grayscale photo in seconds. **Just upload your grayscale image, then click colorize.**
|
25 |
+
"""
|
26 |
+
)
|
27 |
+
uploaded_file = st.file_uploader("Choose image", accept_multiple_files=False, type=["png", "jpg", "jpeg"])
|
28 |
+
|
29 |
+
if uploaded_file is not None:
|
30 |
+
bytes_data = uploaded_file.getvalue()
|
31 |
+
img_input = Image.open(BytesIO(bytes_data)).convert("RGBA")
|
32 |
+
|
33 |
+
if uploaded_file is not None and st.button("Colorize!"):
|
34 |
+
|
35 |
+
with st.spinner("AI is doing the magic!"):
|
36 |
+
img_output = """TODO"""
|
37 |
+
|
38 |
+
# NOTE: Calm! I'm not logging the input and outputs.
|
39 |
+
# It is impossible to access the filesystem in spaces environment.
|
40 |
+
now = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
|
41 |
+
img_input.convert("RGB").save(f"./output/{now}.jpg")
|
42 |
+
Image.fromarray(img_output).convert("RGB").save(f"./output/{now}-edited.jpg")
|
43 |
+
|
44 |
+
st.write("AI has finished the job!")
|
45 |
+
st.image(img_output)
|
46 |
+
# reuse = st.button('Edit again (Re-use this image)', on_click=set_image, args=(inpainted_img, ))
|
47 |
+
|
48 |
+
with open(f"./output/{now}-edited.jpg", "rb") as fs:
|
49 |
+
uploaded_name = os.path.splitext(uploaded_file.name)[0]
|
50 |
+
st.download_button(
|
51 |
+
label="Download",
|
52 |
+
data=fs,
|
53 |
+
file_name=f'edited_{uploaded_name}.jpg',
|
54 |
+
)
|
55 |
+
|
56 |
+
st.info("**TIP**: If the result is not perfect, you can download then "
|
57 |
+
"re-upload the result then remove the artifacts.")
|
assets/demo.jpg
ADDED
docker-compose.yml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
version: '3'
|
3 |
+
services:
|
4 |
+
st-photo-colorization:
|
5 |
+
build: .
|
6 |
+
container_name: st-photo-colorization
|
7 |
+
restart: unless-stopped
|
8 |
+
ports:
|
9 |
+
- 51004:8501
|
10 |
+
volumes:
|
11 |
+
- .:/app
|
12 |
+
environment:
|
13 |
+
- TZ=Asia/Jakarta
|
14 |
+
# command: streamlit run sdc.py
|
models/.gitkeep
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
numpy
|
4 |
+
opencv-python-headless
|
5 |
+
matplotlib
|
6 |
+
streamlit
|
7 |
+
scipy==1.7.1
|
8 |
+
scikit_image==0.18.3
|
9 |
+
requests
|
10 |
+
fastai==1.0.51
|
11 |
+
Pillow
|
12 |
+
stqdm
|
src/__init__.py
ADDED
File without changes
|
src/app_utils.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
import random
|
4 |
+
import _thread as thread
|
5 |
+
from uuid import uuid4
|
6 |
+
import urllib
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import skimage
|
10 |
+
from skimage.filters import gaussian
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
def compress_image(image, path_original):
|
14 |
+
size = 1920, 1080
|
15 |
+
width = 1920
|
16 |
+
height = 1080
|
17 |
+
|
18 |
+
name = os.path.basename(path_original).split('.')
|
19 |
+
first_name = os.path.join(os.path.dirname(path_original), name[0] + '.jpg')
|
20 |
+
|
21 |
+
if image.size[0] > width and image.size[1] > height:
|
22 |
+
image.thumbnail(size, Image.ANTIALIAS)
|
23 |
+
image.save(first_name, quality=85)
|
24 |
+
elif image.size[0] > width:
|
25 |
+
wpercent = (width/float(image.size[0]))
|
26 |
+
height = int((float(image.size[1])*float(wpercent)))
|
27 |
+
image = image.resize((width,height), Image.ANTIALIAS)
|
28 |
+
image.save(first_name,quality=85)
|
29 |
+
elif image.size[1] > height:
|
30 |
+
wpercent = (height/float(image.size[1]))
|
31 |
+
width = int((float(image.size[0])*float(wpercent)))
|
32 |
+
image = image.resize((width,height), Image.ANTIALIAS)
|
33 |
+
image.save(first_name, quality=85)
|
34 |
+
else:
|
35 |
+
image.save(first_name, quality=85)
|
36 |
+
|
37 |
+
|
38 |
+
def convertToJPG(path_original):
|
39 |
+
img = Image.open(path_original)
|
40 |
+
name = os.path.basename(path_original).split('.')
|
41 |
+
first_name = os.path.join(os.path.dirname(path_original), name[0] + '.jpg')
|
42 |
+
|
43 |
+
if img.format == "JPEG":
|
44 |
+
image = img.convert('RGB')
|
45 |
+
compress_image(image, path_original)
|
46 |
+
img.close()
|
47 |
+
|
48 |
+
elif img.format == "GIF":
|
49 |
+
i = img.convert("RGBA")
|
50 |
+
bg = Image.new("RGBA", i.size)
|
51 |
+
image = Image.composite(i, bg, i)
|
52 |
+
compress_image(image, path_original)
|
53 |
+
img.close()
|
54 |
+
|
55 |
+
elif img.format == "PNG":
|
56 |
+
try:
|
57 |
+
image = Image.new("RGB", img.size, (255,255,255))
|
58 |
+
image.paste(img,img)
|
59 |
+
compress_image(image, path_original)
|
60 |
+
except ValueError:
|
61 |
+
image = img.convert('RGB')
|
62 |
+
compress_image(image, path_original)
|
63 |
+
|
64 |
+
img.close()
|
65 |
+
|
66 |
+
elif img.format == "BMP":
|
67 |
+
image = img.convert('RGB')
|
68 |
+
compress_image(image, path_original)
|
69 |
+
img.close()
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def blur(image, x0, x1, y0, y1, sigma=1, multichannel=True):
|
74 |
+
y0, y1 = min(y0, y1), max(y0, y1)
|
75 |
+
x0, x1 = min(x0, x1), max(x0, x1)
|
76 |
+
im = image.copy()
|
77 |
+
sub_im = im[y0:y1,x0:x1].copy()
|
78 |
+
blur_sub_im = gaussian(sub_im, sigma=sigma, multichannel=multichannel)
|
79 |
+
blur_sub_im = np.round(255 * blur_sub_im)
|
80 |
+
im[y0:y1,x0:x1] = blur_sub_im
|
81 |
+
return im
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
def download(url, filename):
|
86 |
+
data = requests.get(url).content
|
87 |
+
with open(filename, 'wb') as handler:
|
88 |
+
handler.write(data)
|
89 |
+
|
90 |
+
return filename
|
91 |
+
|
92 |
+
|
93 |
+
def generate_random_filename(upload_directory, extension):
|
94 |
+
filename = str(uuid4())
|
95 |
+
filename = os.path.join(upload_directory, filename + "." + extension)
|
96 |
+
return filename
|
97 |
+
|
98 |
+
|
99 |
+
def clean_me(filename):
|
100 |
+
if os.path.exists(filename):
|
101 |
+
os.remove(filename)
|
102 |
+
|
103 |
+
|
104 |
+
def clean_all(files):
|
105 |
+
for me in files:
|
106 |
+
clean_me(me)
|
107 |
+
|
108 |
+
|
109 |
+
def create_directory(path):
|
110 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
111 |
+
|
112 |
+
|
113 |
+
def get_model_bin(url, output_path):
|
114 |
+
# print('Getting model dir: ', output_path)
|
115 |
+
if not os.path.exists(output_path):
|
116 |
+
create_directory(output_path)
|
117 |
+
|
118 |
+
urllib.request.urlretrieve(url, output_path)
|
119 |
+
|
120 |
+
# cmd = "wget -O %s %s" % (output_path, url)
|
121 |
+
# print(cmd)
|
122 |
+
# os.system(cmd)
|
123 |
+
|
124 |
+
return output_path
|
125 |
+
|
126 |
+
|
127 |
+
#model_list = [(url, output_path), (url, output_path)]
|
128 |
+
def get_multi_model_bin(model_list):
|
129 |
+
for m in model_list:
|
130 |
+
thread.start_new_thread(get_model_bin, m)
|
131 |
+
|
src/deoldify/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from src.deoldify._device import _Device
|
2 |
+
|
3 |
+
device = _Device()
|
src/deoldify/_device.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from enum import Enum
|
3 |
+
from .device_id import DeviceId
|
4 |
+
|
5 |
+
#NOTE: This must be called first before any torch imports in order to work properly!
|
6 |
+
|
7 |
+
class DeviceException(Exception):
|
8 |
+
pass
|
9 |
+
|
10 |
+
class _Device:
|
11 |
+
def __init__(self):
|
12 |
+
self.set(DeviceId.CPU)
|
13 |
+
|
14 |
+
def is_gpu(self):
|
15 |
+
''' Returns `True` if the current device is GPU, `False` otherwise. '''
|
16 |
+
return self.current() is not DeviceId.CPU
|
17 |
+
|
18 |
+
def current(self):
|
19 |
+
return self._current_device
|
20 |
+
|
21 |
+
def set(self, device:DeviceId):
|
22 |
+
if device == DeviceId.CPU:
|
23 |
+
os.environ['CUDA_VISIBLE_DEVICES']=''
|
24 |
+
else:
|
25 |
+
os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
|
26 |
+
import torch
|
27 |
+
torch.backends.cudnn.benchmark=False
|
28 |
+
|
29 |
+
self._current_device = device
|
30 |
+
return device
|
src/deoldify/augs.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
from fastai.vision.image import TfmPixel
|
4 |
+
|
5 |
+
# Contributed by Rani Horev. Thank you!
|
6 |
+
def _noisify(
|
7 |
+
x, pct_pixels_min: float = 0.001, pct_pixels_max: float = 0.4, noise_range: int = 30
|
8 |
+
):
|
9 |
+
if noise_range > 255 or noise_range < 0:
|
10 |
+
raise Exception("noise_range must be between 0 and 255, inclusively.")
|
11 |
+
|
12 |
+
h, w = x.shape[1:]
|
13 |
+
img_size = h * w
|
14 |
+
mult = 10000.0
|
15 |
+
pct_pixels = (
|
16 |
+
random.randrange(int(pct_pixels_min * mult), int(pct_pixels_max * mult)) / mult
|
17 |
+
)
|
18 |
+
noise_count = int(img_size * pct_pixels)
|
19 |
+
|
20 |
+
for ii in range(noise_count):
|
21 |
+
yy = random.randrange(h)
|
22 |
+
xx = random.randrange(w)
|
23 |
+
noise = random.randrange(-noise_range, noise_range) / 255.0
|
24 |
+
x[:, yy, xx].add_(noise)
|
25 |
+
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
noisify = TfmPixel(_noisify)
|
src/deoldify/critics.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.core import *
|
2 |
+
from fastai.torch_core import *
|
3 |
+
from fastai.vision import *
|
4 |
+
from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
|
5 |
+
|
6 |
+
_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
|
7 |
+
|
8 |
+
|
9 |
+
def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
|
10 |
+
return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
|
11 |
+
|
12 |
+
|
13 |
+
def custom_gan_critic(
|
14 |
+
n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
|
15 |
+
):
|
16 |
+
"Critic to train a `GAN`."
|
17 |
+
layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
|
18 |
+
for i in range(n_blocks):
|
19 |
+
layers += [
|
20 |
+
_conv(nf, nf, ks=3, stride=1),
|
21 |
+
nn.Dropout2d(p),
|
22 |
+
_conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
|
23 |
+
]
|
24 |
+
nf *= 2
|
25 |
+
layers += [
|
26 |
+
_conv(nf, nf, ks=3, stride=1),
|
27 |
+
_conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
|
28 |
+
Flatten(),
|
29 |
+
]
|
30 |
+
return nn.Sequential(*layers)
|
31 |
+
|
32 |
+
|
33 |
+
def colorize_crit_learner(
|
34 |
+
data: ImageDataBunch,
|
35 |
+
loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
|
36 |
+
nf: int = 256,
|
37 |
+
) -> Learner:
|
38 |
+
return Learner(
|
39 |
+
data,
|
40 |
+
custom_gan_critic(nf=nf),
|
41 |
+
metrics=accuracy_thresh_expand,
|
42 |
+
loss_func=loss_critic,
|
43 |
+
wd=1e-3,
|
44 |
+
)
|
src/deoldify/dataset.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fastai
|
2 |
+
from fastai import *
|
3 |
+
from fastai.core import *
|
4 |
+
from fastai.vision.transform import get_transforms
|
5 |
+
from fastai.vision.data import ImageImageList, ImageDataBunch, imagenet_stats
|
6 |
+
from .augs import noisify
|
7 |
+
|
8 |
+
|
9 |
+
def get_colorize_data(
|
10 |
+
sz: int,
|
11 |
+
bs: int,
|
12 |
+
crappy_path: Path,
|
13 |
+
good_path: Path,
|
14 |
+
random_seed: int = None,
|
15 |
+
keep_pct: float = 1.0,
|
16 |
+
num_workers: int = 8,
|
17 |
+
stats: tuple = imagenet_stats,
|
18 |
+
xtra_tfms=[],
|
19 |
+
) -> ImageDataBunch:
|
20 |
+
|
21 |
+
src = (
|
22 |
+
ImageImageList.from_folder(crappy_path, convert_mode='RGB')
|
23 |
+
.use_partial_data(sample_pct=keep_pct, seed=random_seed)
|
24 |
+
.split_by_rand_pct(0.1, seed=random_seed)
|
25 |
+
)
|
26 |
+
|
27 |
+
data = (
|
28 |
+
src.label_from_func(lambda x: good_path / x.relative_to(crappy_path))
|
29 |
+
.transform(
|
30 |
+
get_transforms(
|
31 |
+
max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms
|
32 |
+
),
|
33 |
+
size=sz,
|
34 |
+
tfm_y=True,
|
35 |
+
)
|
36 |
+
.databunch(bs=bs, num_workers=num_workers, no_check=True)
|
37 |
+
.normalize(stats, do_y=True)
|
38 |
+
)
|
39 |
+
|
40 |
+
data.c = 3
|
41 |
+
return data
|
42 |
+
|
43 |
+
|
44 |
+
def get_dummy_databunch() -> ImageDataBunch:
|
45 |
+
path = Path('./assets/dummy/')
|
46 |
+
return get_colorize_data(
|
47 |
+
sz=1, bs=1, crappy_path=path, good_path=path, keep_pct=0.001
|
48 |
+
)
|
src/deoldify/device_id.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import IntEnum
|
2 |
+
|
3 |
+
class DeviceId(IntEnum):
|
4 |
+
GPU0 = 0,
|
5 |
+
GPU1 = 1,
|
6 |
+
GPU2 = 2,
|
7 |
+
GPU3 = 3,
|
8 |
+
GPU4 = 4,
|
9 |
+
GPU5 = 5,
|
10 |
+
GPU6 = 6,
|
11 |
+
GPU7 = 7,
|
12 |
+
CPU = 99
|
src/deoldify/filters.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy import ndarray
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from .critics import colorize_crit_learner
|
4 |
+
from fastai.core import *
|
5 |
+
from fastai.vision import *
|
6 |
+
from fastai.vision.image import *
|
7 |
+
from fastai.vision.data import *
|
8 |
+
from fastai import *
|
9 |
+
import math
|
10 |
+
from scipy import misc
|
11 |
+
import cv2
|
12 |
+
from PIL import Image as PilImage
|
13 |
+
|
14 |
+
|
15 |
+
class IFilter(ABC):
|
16 |
+
@abstractmethod
|
17 |
+
def filter(
|
18 |
+
self, orig_image: PilImage, filtered_image: PilImage, render_factor: int
|
19 |
+
) -> PilImage:
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
class BaseFilter(IFilter):
|
24 |
+
def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
|
25 |
+
super().__init__()
|
26 |
+
self.learn = learn
|
27 |
+
self.device = next(self.learn.model.parameters()).device
|
28 |
+
self.norm, self.denorm = normalize_funcs(*stats)
|
29 |
+
|
30 |
+
def _transform(self, image: PilImage) -> PilImage:
|
31 |
+
return image
|
32 |
+
|
33 |
+
def _scale_to_square(self, orig: PilImage, targ: int) -> PilImage:
|
34 |
+
# a simple stretch to fit a square really makes a big difference in rendering quality/consistency.
|
35 |
+
# I've tried padding to the square as well (reflect, symetric, constant, etc). Not as good!
|
36 |
+
targ_sz = (targ, targ)
|
37 |
+
return orig.resize(targ_sz, resample=PIL.Image.BILINEAR)
|
38 |
+
|
39 |
+
def _get_model_ready_image(self, orig: PilImage, sz: int) -> PilImage:
|
40 |
+
result = self._scale_to_square(orig, sz)
|
41 |
+
result = self._transform(result)
|
42 |
+
return result
|
43 |
+
|
44 |
+
def _model_process(self, orig: PilImage, sz: int) -> PilImage:
|
45 |
+
model_image = self._get_model_ready_image(orig, sz)
|
46 |
+
x = pil2tensor(model_image, np.float32)
|
47 |
+
x = x.to(self.device)
|
48 |
+
x.div_(255)
|
49 |
+
x, y = self.norm((x, x), do_x=True)
|
50 |
+
|
51 |
+
try:
|
52 |
+
result = self.learn.pred_batch(
|
53 |
+
ds_type=DatasetType.Valid, batch=(x[None], y[None]), reconstruct=True
|
54 |
+
)
|
55 |
+
except RuntimeError as rerr:
|
56 |
+
if 'memory' not in str(rerr):
|
57 |
+
raise rerr
|
58 |
+
print('Warning: render_factor was set too high, and out of memory error resulted. Returning original image.')
|
59 |
+
return model_image
|
60 |
+
|
61 |
+
out = result[0]
|
62 |
+
out = self.denorm(out.px, do_x=False)
|
63 |
+
out = image2np(out * 255).astype(np.uint8)
|
64 |
+
return PilImage.fromarray(out)
|
65 |
+
|
66 |
+
def _unsquare(self, image: PilImage, orig: PilImage) -> PilImage:
|
67 |
+
targ_sz = orig.size
|
68 |
+
image = image.resize(targ_sz, resample=PIL.Image.BILINEAR)
|
69 |
+
return image
|
70 |
+
|
71 |
+
|
72 |
+
class ColorizerFilter(BaseFilter):
|
73 |
+
def __init__(self, learn: Learner, stats: tuple = imagenet_stats):
|
74 |
+
super().__init__(learn=learn, stats=stats)
|
75 |
+
self.render_base = 16
|
76 |
+
|
77 |
+
def filter(
|
78 |
+
self, orig_image: PilImage, filtered_image: PilImage, render_factor: int, post_process: bool = True) -> PilImage:
|
79 |
+
render_sz = render_factor * self.render_base
|
80 |
+
model_image = self._model_process(orig=filtered_image, sz=render_sz)
|
81 |
+
raw_color = self._unsquare(model_image, orig_image)
|
82 |
+
|
83 |
+
if post_process:
|
84 |
+
return self._post_process(raw_color, orig_image)
|
85 |
+
else:
|
86 |
+
return raw_color
|
87 |
+
|
88 |
+
def _transform(self, image: PilImage) -> PilImage:
|
89 |
+
return image.convert('LA').convert('RGB')
|
90 |
+
|
91 |
+
# This takes advantage of the fact that human eyes are much less sensitive to
|
92 |
+
# imperfections in chrominance compared to luminance. This means we can
|
93 |
+
# save a lot on memory and processing in the model, yet get a great high
|
94 |
+
# resolution result at the end. This is primarily intended just for
|
95 |
+
# inference
|
96 |
+
def _post_process(self, raw_color: PilImage, orig: PilImage) -> PilImage:
|
97 |
+
color_np = np.asarray(raw_color)
|
98 |
+
orig_np = np.asarray(orig)
|
99 |
+
color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
|
100 |
+
# do a black and white transform first to get better luminance values
|
101 |
+
orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
|
102 |
+
hires = np.copy(orig_yuv)
|
103 |
+
hires[:, :, 1:3] = color_yuv[:, :, 1:3]
|
104 |
+
final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
|
105 |
+
final = PilImage.fromarray(final)
|
106 |
+
return final
|
107 |
+
|
108 |
+
|
109 |
+
class MasterFilter(BaseFilter):
|
110 |
+
def __init__(self, filters: [IFilter], render_factor: int):
|
111 |
+
self.filters = filters
|
112 |
+
self.render_factor = render_factor
|
113 |
+
|
114 |
+
def filter(
|
115 |
+
self, orig_image: PilImage, filtered_image: PilImage, render_factor: int = None, post_process: bool = True) -> PilImage:
|
116 |
+
render_factor = self.render_factor if render_factor is None else render_factor
|
117 |
+
for filter in self.filters:
|
118 |
+
filtered_image = filter.filter(orig_image, filtered_image, render_factor, post_process)
|
119 |
+
|
120 |
+
return filtered_image
|
src/deoldify/generators.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.vision import *
|
2 |
+
from fastai.vision.learner import cnn_config
|
3 |
+
from .unet import DynamicUnetWide, DynamicUnetDeep
|
4 |
+
from .loss import FeatureLoss
|
5 |
+
from .dataset import *
|
6 |
+
|
7 |
+
# Weights are implicitly read from ./models/ folder
|
8 |
+
def gen_inference_wide(
|
9 |
+
root_folder: Path, weights_name: str, nf_factor: int = 2, arch=models.resnet101) -> Learner:
|
10 |
+
data = get_dummy_databunch()
|
11 |
+
learn = gen_learner_wide(
|
12 |
+
data=data, gen_loss=F.l1_loss, nf_factor=nf_factor, arch=arch
|
13 |
+
)
|
14 |
+
learn.path = root_folder
|
15 |
+
learn.load(weights_name)
|
16 |
+
learn.model.eval()
|
17 |
+
return learn
|
18 |
+
|
19 |
+
|
20 |
+
def gen_learner_wide(
|
21 |
+
data: ImageDataBunch, gen_loss, arch=models.resnet101, nf_factor: int = 2
|
22 |
+
) -> Learner:
|
23 |
+
return unet_learner_wide(
|
24 |
+
data,
|
25 |
+
arch=arch,
|
26 |
+
wd=1e-3,
|
27 |
+
blur=True,
|
28 |
+
norm_type=NormType.Spectral,
|
29 |
+
self_attention=True,
|
30 |
+
y_range=(-3.0, 3.0),
|
31 |
+
loss_func=gen_loss,
|
32 |
+
nf_factor=nf_factor,
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
37 |
+
def unet_learner_wide(
|
38 |
+
data: DataBunch,
|
39 |
+
arch: Callable,
|
40 |
+
pretrained: bool = True,
|
41 |
+
blur_final: bool = True,
|
42 |
+
norm_type: Optional[NormType] = NormType,
|
43 |
+
split_on: Optional[SplitFuncOrIdxList] = None,
|
44 |
+
blur: bool = False,
|
45 |
+
self_attention: bool = False,
|
46 |
+
y_range: Optional[Tuple[float, float]] = None,
|
47 |
+
last_cross: bool = True,
|
48 |
+
bottle: bool = False,
|
49 |
+
nf_factor: int = 1,
|
50 |
+
**kwargs: Any
|
51 |
+
) -> Learner:
|
52 |
+
"Build Unet learner from `data` and `arch`."
|
53 |
+
meta = cnn_config(arch)
|
54 |
+
body = create_body(arch, pretrained)
|
55 |
+
model = to_device(
|
56 |
+
DynamicUnetWide(
|
57 |
+
body,
|
58 |
+
n_classes=data.c,
|
59 |
+
blur=blur,
|
60 |
+
blur_final=blur_final,
|
61 |
+
self_attention=self_attention,
|
62 |
+
y_range=y_range,
|
63 |
+
norm_type=norm_type,
|
64 |
+
last_cross=last_cross,
|
65 |
+
bottle=bottle,
|
66 |
+
nf_factor=nf_factor,
|
67 |
+
),
|
68 |
+
data.device,
|
69 |
+
)
|
70 |
+
learn = Learner(data, model, **kwargs)
|
71 |
+
learn.split(ifnone(split_on, meta['split']))
|
72 |
+
if pretrained:
|
73 |
+
learn.freeze()
|
74 |
+
apply_init(model[2], nn.init.kaiming_normal_)
|
75 |
+
return learn
|
76 |
+
|
77 |
+
|
78 |
+
# ----------------------------------------------------------------------
|
79 |
+
|
80 |
+
# Weights are implicitly read from ./models/ folder
|
81 |
+
def gen_inference_deep(
|
82 |
+
root_folder: Path, weights_name: str, arch=models.resnet34, nf_factor: float = 1.5) -> Learner:
|
83 |
+
data = get_dummy_databunch()
|
84 |
+
learn = gen_learner_deep(
|
85 |
+
data=data, gen_loss=F.l1_loss, arch=arch, nf_factor=nf_factor
|
86 |
+
)
|
87 |
+
learn.path = root_folder
|
88 |
+
learn.load(weights_name)
|
89 |
+
learn.model.eval()
|
90 |
+
return learn
|
91 |
+
|
92 |
+
|
93 |
+
def gen_learner_deep(
|
94 |
+
data: ImageDataBunch, gen_loss, arch=models.resnet34, nf_factor: float = 1.5
|
95 |
+
) -> Learner:
|
96 |
+
return unet_learner_deep(
|
97 |
+
data,
|
98 |
+
arch,
|
99 |
+
wd=1e-3,
|
100 |
+
blur=True,
|
101 |
+
norm_type=NormType.Spectral,
|
102 |
+
self_attention=True,
|
103 |
+
y_range=(-3.0, 3.0),
|
104 |
+
loss_func=gen_loss,
|
105 |
+
nf_factor=nf_factor,
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
110 |
+
def unet_learner_deep(
|
111 |
+
data: DataBunch,
|
112 |
+
arch: Callable,
|
113 |
+
pretrained: bool = True,
|
114 |
+
blur_final: bool = True,
|
115 |
+
norm_type: Optional[NormType] = NormType,
|
116 |
+
split_on: Optional[SplitFuncOrIdxList] = None,
|
117 |
+
blur: bool = False,
|
118 |
+
self_attention: bool = False,
|
119 |
+
y_range: Optional[Tuple[float, float]] = None,
|
120 |
+
last_cross: bool = True,
|
121 |
+
bottle: bool = False,
|
122 |
+
nf_factor: float = 1.5,
|
123 |
+
**kwargs: Any
|
124 |
+
) -> Learner:
|
125 |
+
"Build Unet learner from `data` and `arch`."
|
126 |
+
meta = cnn_config(arch)
|
127 |
+
body = create_body(arch, pretrained)
|
128 |
+
model = to_device(
|
129 |
+
DynamicUnetDeep(
|
130 |
+
body,
|
131 |
+
n_classes=data.c,
|
132 |
+
blur=blur,
|
133 |
+
blur_final=blur_final,
|
134 |
+
self_attention=self_attention,
|
135 |
+
y_range=y_range,
|
136 |
+
norm_type=norm_type,
|
137 |
+
last_cross=last_cross,
|
138 |
+
bottle=bottle,
|
139 |
+
nf_factor=nf_factor,
|
140 |
+
),
|
141 |
+
data.device,
|
142 |
+
)
|
143 |
+
learn = Learner(data, model, **kwargs)
|
144 |
+
learn.split(ifnone(split_on, meta['split']))
|
145 |
+
if pretrained:
|
146 |
+
learn.freeze()
|
147 |
+
apply_init(model[2], nn.init.kaiming_normal_)
|
148 |
+
return learn
|
149 |
+
|
150 |
+
|
151 |
+
# -----------------------------
|
src/deoldify/layers.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.layers import *
|
2 |
+
from fastai.torch_core import *
|
3 |
+
from torch.nn.parameter import Parameter
|
4 |
+
from torch.autograd import Variable
|
5 |
+
|
6 |
+
|
7 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
8 |
+
|
9 |
+
|
10 |
+
def custom_conv_layer(
|
11 |
+
ni: int,
|
12 |
+
nf: int,
|
13 |
+
ks: int = 3,
|
14 |
+
stride: int = 1,
|
15 |
+
padding: int = None,
|
16 |
+
bias: bool = None,
|
17 |
+
is_1d: bool = False,
|
18 |
+
norm_type: Optional[NormType] = NormType.Batch,
|
19 |
+
use_activ: bool = True,
|
20 |
+
leaky: float = None,
|
21 |
+
transpose: bool = False,
|
22 |
+
init: Callable = nn.init.kaiming_normal_,
|
23 |
+
self_attention: bool = False,
|
24 |
+
extra_bn: bool = False,
|
25 |
+
):
|
26 |
+
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
|
27 |
+
if padding is None:
|
28 |
+
padding = (ks - 1) // 2 if not transpose else 0
|
29 |
+
bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
|
30 |
+
if bias is None:
|
31 |
+
bias = not bn
|
32 |
+
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
|
33 |
+
conv = init_default(
|
34 |
+
conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
|
35 |
+
init,
|
36 |
+
)
|
37 |
+
if norm_type == NormType.Weight:
|
38 |
+
conv = weight_norm(conv)
|
39 |
+
elif norm_type == NormType.Spectral:
|
40 |
+
conv = spectral_norm(conv)
|
41 |
+
layers = [conv]
|
42 |
+
if use_activ:
|
43 |
+
layers.append(relu(True, leaky=leaky))
|
44 |
+
if bn:
|
45 |
+
layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
|
46 |
+
if self_attention:
|
47 |
+
layers.append(SelfAttention(nf))
|
48 |
+
return nn.Sequential(*layers)
|
src/deoldify/loss.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai import *
|
2 |
+
from fastai.core import *
|
3 |
+
from fastai.torch_core import *
|
4 |
+
from fastai.callbacks import hook_outputs
|
5 |
+
import torchvision.models as models
|
6 |
+
|
7 |
+
|
8 |
+
class FeatureLoss(nn.Module):
|
9 |
+
def __init__(self, layer_wgts=[20, 70, 10]):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.m_feat = models.vgg16_bn(True).features.cuda().eval()
|
13 |
+
requires_grad(self.m_feat, False)
|
14 |
+
blocks = [
|
15 |
+
i - 1
|
16 |
+
for i, o in enumerate(children(self.m_feat))
|
17 |
+
if isinstance(o, nn.MaxPool2d)
|
18 |
+
]
|
19 |
+
layer_ids = blocks[2:5]
|
20 |
+
self.loss_features = [self.m_feat[i] for i in layer_ids]
|
21 |
+
self.hooks = hook_outputs(self.loss_features, detach=False)
|
22 |
+
self.wgts = layer_wgts
|
23 |
+
self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]
|
24 |
+
self.base_loss = F.l1_loss
|
25 |
+
|
26 |
+
def _make_features(self, x, clone=False):
|
27 |
+
self.m_feat(x)
|
28 |
+
return [(o.clone() if clone else o) for o in self.hooks.stored]
|
29 |
+
|
30 |
+
def forward(self, input, target):
|
31 |
+
out_feat = self._make_features(target, clone=True)
|
32 |
+
in_feat = self._make_features(input)
|
33 |
+
self.feat_losses = [self.base_loss(input, target)]
|
34 |
+
self.feat_losses += [
|
35 |
+
self.base_loss(f_in, f_out) * w
|
36 |
+
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
|
37 |
+
]
|
38 |
+
|
39 |
+
self.metrics = dict(zip(self.metric_names, self.feat_losses))
|
40 |
+
return sum(self.feat_losses)
|
41 |
+
|
42 |
+
def __del__(self):
|
43 |
+
self.hooks.remove()
|
44 |
+
|
45 |
+
|
46 |
+
# Refactored code, originally from https://github.com/VinceMarron/style_transfer
|
47 |
+
class WassFeatureLoss(nn.Module):
|
48 |
+
def __init__(self, layer_wgts=[5, 15, 2], wass_wgts=[3.0, 0.7, 0.01]):
|
49 |
+
super().__init__()
|
50 |
+
self.m_feat = models.vgg16_bn(True).features.cuda().eval()
|
51 |
+
requires_grad(self.m_feat, False)
|
52 |
+
blocks = [
|
53 |
+
i - 1
|
54 |
+
for i, o in enumerate(children(self.m_feat))
|
55 |
+
if isinstance(o, nn.MaxPool2d)
|
56 |
+
]
|
57 |
+
layer_ids = blocks[2:5]
|
58 |
+
self.loss_features = [self.m_feat[i] for i in layer_ids]
|
59 |
+
self.hooks = hook_outputs(self.loss_features, detach=False)
|
60 |
+
self.wgts = layer_wgts
|
61 |
+
self.wass_wgts = wass_wgts
|
62 |
+
self.metric_names = (
|
63 |
+
['pixel']
|
64 |
+
+ [f'feat_{i}' for i in range(len(layer_ids))]
|
65 |
+
+ [f'wass_{i}' for i in range(len(layer_ids))]
|
66 |
+
)
|
67 |
+
self.base_loss = F.l1_loss
|
68 |
+
|
69 |
+
def _make_features(self, x, clone=False):
|
70 |
+
self.m_feat(x)
|
71 |
+
return [(o.clone() if clone else o) for o in self.hooks.stored]
|
72 |
+
|
73 |
+
def _calc_2_moments(self, tensor):
|
74 |
+
chans = tensor.shape[1]
|
75 |
+
tensor = tensor.view(1, chans, -1)
|
76 |
+
n = tensor.shape[2]
|
77 |
+
mu = tensor.mean(2)
|
78 |
+
tensor = (tensor - mu[:, :, None]).squeeze(0)
|
79 |
+
# Prevents nasty bug that happens very occassionally- divide by zero. Why such things happen?
|
80 |
+
if n == 0:
|
81 |
+
return None, None
|
82 |
+
cov = torch.mm(tensor, tensor.t()) / float(n)
|
83 |
+
return mu, cov
|
84 |
+
|
85 |
+
def _get_style_vals(self, tensor):
|
86 |
+
mean, cov = self._calc_2_moments(tensor)
|
87 |
+
if mean is None:
|
88 |
+
return None, None, None
|
89 |
+
eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
|
90 |
+
eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))
|
91 |
+
root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())
|
92 |
+
tr_cov = eigvals.clamp(min=0).sum()
|
93 |
+
return mean, tr_cov, root_cov
|
94 |
+
|
95 |
+
def _calc_l2wass_dist(
|
96 |
+
self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth
|
97 |
+
):
|
98 |
+
tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
|
99 |
+
mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
|
100 |
+
cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
|
101 |
+
var_overlap = torch.sqrt(
|
102 |
+
torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0) + 1e-8
|
103 |
+
).sum()
|
104 |
+
dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap
|
105 |
+
return dist
|
106 |
+
|
107 |
+
def _single_wass_loss(self, pred, targ):
|
108 |
+
mean_test, tr_cov_test, root_cov_test = targ
|
109 |
+
mean_synth, cov_synth = self._calc_2_moments(pred)
|
110 |
+
loss = self._calc_l2wass_dist(
|
111 |
+
mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth
|
112 |
+
)
|
113 |
+
return loss
|
114 |
+
|
115 |
+
def forward(self, input, target):
|
116 |
+
out_feat = self._make_features(target, clone=True)
|
117 |
+
in_feat = self._make_features(input)
|
118 |
+
self.feat_losses = [self.base_loss(input, target)]
|
119 |
+
self.feat_losses += [
|
120 |
+
self.base_loss(f_in, f_out) * w
|
121 |
+
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
|
122 |
+
]
|
123 |
+
|
124 |
+
styles = [self._get_style_vals(i) for i in out_feat]
|
125 |
+
|
126 |
+
if styles[0][0] is not None:
|
127 |
+
self.feat_losses += [
|
128 |
+
self._single_wass_loss(f_pred, f_targ) * w
|
129 |
+
for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)
|
130 |
+
]
|
131 |
+
|
132 |
+
self.metrics = dict(zip(self.metric_names, self.feat_losses))
|
133 |
+
return sum(self.feat_losses)
|
134 |
+
|
135 |
+
def __del__(self):
|
136 |
+
self.hooks.remove()
|
src/deoldify/save.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basic_train import Learner, LearnerCallback
|
2 |
+
from fastai.vision.gan import GANLearner
|
3 |
+
|
4 |
+
|
5 |
+
class GANSaveCallback(LearnerCallback):
|
6 |
+
"""A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
learn: GANLearner,
|
11 |
+
learn_gen: Learner,
|
12 |
+
filename: str,
|
13 |
+
save_iters: int = 1000,
|
14 |
+
):
|
15 |
+
super().__init__(learn)
|
16 |
+
self.learn_gen = learn_gen
|
17 |
+
self.filename = filename
|
18 |
+
self.save_iters = save_iters
|
19 |
+
|
20 |
+
def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
|
21 |
+
if iteration == 0:
|
22 |
+
return
|
23 |
+
|
24 |
+
if iteration % self.save_iters == 0:
|
25 |
+
self._save_gen_learner(iteration=iteration, epoch=epoch)
|
26 |
+
|
27 |
+
def _save_gen_learner(self, iteration: int, epoch: int):
|
28 |
+
filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
|
29 |
+
self.learn_gen.save(filename)
|
src/deoldify/unet.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.layers import *
|
2 |
+
from .layers import *
|
3 |
+
from fastai.torch_core import *
|
4 |
+
from fastai.callbacks.hooks import *
|
5 |
+
from fastai.vision import *
|
6 |
+
|
7 |
+
|
8 |
+
# The code below is meant to be merged into fastaiv1 ideally
|
9 |
+
|
10 |
+
__all__ = ['DynamicUnetDeep', 'DynamicUnetWide']
|
11 |
+
|
12 |
+
|
13 |
+
def _get_sfs_idxs(sizes: Sizes) -> List[int]:
|
14 |
+
"Get the indexes of the layers where the size of the activation changes."
|
15 |
+
feature_szs = [size[-1] for size in sizes]
|
16 |
+
sfs_idxs = list(
|
17 |
+
np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]
|
18 |
+
)
|
19 |
+
if feature_szs[0] != feature_szs[1]:
|
20 |
+
sfs_idxs = [0] + sfs_idxs
|
21 |
+
return sfs_idxs
|
22 |
+
|
23 |
+
|
24 |
+
class CustomPixelShuffle_ICNR(nn.Module):
|
25 |
+
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
ni: int,
|
30 |
+
nf: int = None,
|
31 |
+
scale: int = 2,
|
32 |
+
blur: bool = False,
|
33 |
+
leaky: float = None,
|
34 |
+
**kwargs
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
nf = ifnone(nf, ni)
|
38 |
+
self.conv = custom_conv_layer(
|
39 |
+
ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs
|
40 |
+
)
|
41 |
+
icnr(self.conv[0].weight)
|
42 |
+
self.shuf = nn.PixelShuffle(scale)
|
43 |
+
# Blurring over (h*w) kernel
|
44 |
+
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
|
45 |
+
# - https://arxiv.org/abs/1806.02658
|
46 |
+
self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
|
47 |
+
self.blur = nn.AvgPool2d(2, stride=1)
|
48 |
+
self.relu = relu(True, leaky=leaky)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = self.shuf(self.relu(self.conv(x)))
|
52 |
+
return self.blur(self.pad(x)) if self.blur else x
|
53 |
+
|
54 |
+
|
55 |
+
class UnetBlockDeep(nn.Module):
|
56 |
+
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
up_in_c: int,
|
61 |
+
x_in_c: int,
|
62 |
+
hook: Hook,
|
63 |
+
final_div: bool = True,
|
64 |
+
blur: bool = False,
|
65 |
+
leaky: float = None,
|
66 |
+
self_attention: bool = False,
|
67 |
+
nf_factor: float = 1.0,
|
68 |
+
**kwargs
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
self.hook = hook
|
72 |
+
self.shuf = CustomPixelShuffle_ICNR(
|
73 |
+
up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs
|
74 |
+
)
|
75 |
+
self.bn = batchnorm_2d(x_in_c)
|
76 |
+
ni = up_in_c // 2 + x_in_c
|
77 |
+
nf = int((ni if final_div else ni // 2) * nf_factor)
|
78 |
+
self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
|
79 |
+
self.conv2 = custom_conv_layer(
|
80 |
+
nf, nf, leaky=leaky, self_attention=self_attention, **kwargs
|
81 |
+
)
|
82 |
+
self.relu = relu(leaky=leaky)
|
83 |
+
|
84 |
+
def forward(self, up_in: Tensor) -> Tensor:
|
85 |
+
s = self.hook.stored
|
86 |
+
up_out = self.shuf(up_in)
|
87 |
+
ssh = s.shape[-2:]
|
88 |
+
if ssh != up_out.shape[-2:]:
|
89 |
+
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
|
90 |
+
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
|
91 |
+
return self.conv2(self.conv1(cat_x))
|
92 |
+
|
93 |
+
|
94 |
+
class DynamicUnetDeep(SequentialEx):
|
95 |
+
"Create a U-Net from a given architecture."
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
encoder: nn.Module,
|
100 |
+
n_classes: int,
|
101 |
+
blur: bool = False,
|
102 |
+
blur_final=True,
|
103 |
+
self_attention: bool = False,
|
104 |
+
y_range: Optional[Tuple[float, float]] = None,
|
105 |
+
last_cross: bool = True,
|
106 |
+
bottle: bool = False,
|
107 |
+
norm_type: Optional[NormType] = NormType.Batch,
|
108 |
+
nf_factor: float = 1.0,
|
109 |
+
**kwargs
|
110 |
+
):
|
111 |
+
extra_bn = norm_type == NormType.Spectral
|
112 |
+
imsize = (256, 256)
|
113 |
+
sfs_szs = model_sizes(encoder, size=imsize)
|
114 |
+
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
|
115 |
+
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
|
116 |
+
x = dummy_eval(encoder, imsize).detach()
|
117 |
+
|
118 |
+
ni = sfs_szs[-1][1]
|
119 |
+
middle_conv = nn.Sequential(
|
120 |
+
custom_conv_layer(
|
121 |
+
ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
122 |
+
),
|
123 |
+
custom_conv_layer(
|
124 |
+
ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
125 |
+
),
|
126 |
+
).eval()
|
127 |
+
x = middle_conv(x)
|
128 |
+
layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
|
129 |
+
|
130 |
+
for i, idx in enumerate(sfs_idxs):
|
131 |
+
not_final = i != len(sfs_idxs) - 1
|
132 |
+
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
|
133 |
+
do_blur = blur and (not_final or blur_final)
|
134 |
+
sa = self_attention and (i == len(sfs_idxs) - 3)
|
135 |
+
unet_block = UnetBlockDeep(
|
136 |
+
up_in_c,
|
137 |
+
x_in_c,
|
138 |
+
self.sfs[i],
|
139 |
+
final_div=not_final,
|
140 |
+
blur=blur,
|
141 |
+
self_attention=sa,
|
142 |
+
norm_type=norm_type,
|
143 |
+
extra_bn=extra_bn,
|
144 |
+
nf_factor=nf_factor,
|
145 |
+
**kwargs
|
146 |
+
).eval()
|
147 |
+
layers.append(unet_block)
|
148 |
+
x = unet_block(x)
|
149 |
+
|
150 |
+
ni = x.shape[1]
|
151 |
+
if imsize != sfs_szs[0][-2:]:
|
152 |
+
layers.append(PixelShuffle_ICNR(ni, **kwargs))
|
153 |
+
if last_cross:
|
154 |
+
layers.append(MergeLayer(dense=True))
|
155 |
+
ni += in_channels(encoder)
|
156 |
+
layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
|
157 |
+
layers += [
|
158 |
+
custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
|
159 |
+
]
|
160 |
+
if y_range is not None:
|
161 |
+
layers.append(SigmoidRange(*y_range))
|
162 |
+
super().__init__(*layers)
|
163 |
+
|
164 |
+
def __del__(self):
|
165 |
+
if hasattr(self, "sfs"):
|
166 |
+
self.sfs.remove()
|
167 |
+
|
168 |
+
|
169 |
+
# ------------------------------------------------------
|
170 |
+
class UnetBlockWide(nn.Module):
|
171 |
+
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
|
172 |
+
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
up_in_c: int,
|
176 |
+
x_in_c: int,
|
177 |
+
n_out: int,
|
178 |
+
hook: Hook,
|
179 |
+
final_div: bool = True,
|
180 |
+
blur: bool = False,
|
181 |
+
leaky: float = None,
|
182 |
+
self_attention: bool = False,
|
183 |
+
**kwargs
|
184 |
+
):
|
185 |
+
super().__init__()
|
186 |
+
self.hook = hook
|
187 |
+
up_out = x_out = n_out // 2
|
188 |
+
self.shuf = CustomPixelShuffle_ICNR(
|
189 |
+
up_in_c, up_out, blur=blur, leaky=leaky, **kwargs
|
190 |
+
)
|
191 |
+
self.bn = batchnorm_2d(x_in_c)
|
192 |
+
ni = up_out + x_in_c
|
193 |
+
self.conv = custom_conv_layer(
|
194 |
+
ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs
|
195 |
+
)
|
196 |
+
self.relu = relu(leaky=leaky)
|
197 |
+
|
198 |
+
def forward(self, up_in: Tensor) -> Tensor:
|
199 |
+
s = self.hook.stored
|
200 |
+
up_out = self.shuf(up_in)
|
201 |
+
ssh = s.shape[-2:]
|
202 |
+
if ssh != up_out.shape[-2:]:
|
203 |
+
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
|
204 |
+
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
|
205 |
+
return self.conv(cat_x)
|
206 |
+
|
207 |
+
|
208 |
+
class DynamicUnetWide(SequentialEx):
|
209 |
+
"Create a U-Net from a given architecture."
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
encoder: nn.Module,
|
214 |
+
n_classes: int,
|
215 |
+
blur: bool = False,
|
216 |
+
blur_final=True,
|
217 |
+
self_attention: bool = False,
|
218 |
+
y_range: Optional[Tuple[float, float]] = None,
|
219 |
+
last_cross: bool = True,
|
220 |
+
bottle: bool = False,
|
221 |
+
norm_type: Optional[NormType] = NormType.Batch,
|
222 |
+
nf_factor: int = 1,
|
223 |
+
**kwargs
|
224 |
+
):
|
225 |
+
|
226 |
+
nf = 512 * nf_factor
|
227 |
+
extra_bn = norm_type == NormType.Spectral
|
228 |
+
imsize = (256, 256)
|
229 |
+
sfs_szs = model_sizes(encoder, size=imsize)
|
230 |
+
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
|
231 |
+
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
|
232 |
+
x = dummy_eval(encoder, imsize).detach()
|
233 |
+
|
234 |
+
ni = sfs_szs[-1][1]
|
235 |
+
middle_conv = nn.Sequential(
|
236 |
+
custom_conv_layer(
|
237 |
+
ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
238 |
+
),
|
239 |
+
custom_conv_layer(
|
240 |
+
ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs
|
241 |
+
),
|
242 |
+
).eval()
|
243 |
+
x = middle_conv(x)
|
244 |
+
layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
|
245 |
+
|
246 |
+
for i, idx in enumerate(sfs_idxs):
|
247 |
+
not_final = i != len(sfs_idxs) - 1
|
248 |
+
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
|
249 |
+
do_blur = blur and (not_final or blur_final)
|
250 |
+
sa = self_attention and (i == len(sfs_idxs) - 3)
|
251 |
+
|
252 |
+
n_out = nf if not_final else nf // 2
|
253 |
+
|
254 |
+
unet_block = UnetBlockWide(
|
255 |
+
up_in_c,
|
256 |
+
x_in_c,
|
257 |
+
n_out,
|
258 |
+
self.sfs[i],
|
259 |
+
final_div=not_final,
|
260 |
+
blur=blur,
|
261 |
+
self_attention=sa,
|
262 |
+
norm_type=norm_type,
|
263 |
+
extra_bn=extra_bn,
|
264 |
+
**kwargs
|
265 |
+
).eval()
|
266 |
+
layers.append(unet_block)
|
267 |
+
x = unet_block(x)
|
268 |
+
|
269 |
+
ni = x.shape[1]
|
270 |
+
if imsize != sfs_szs[0][-2:]:
|
271 |
+
layers.append(PixelShuffle_ICNR(ni, **kwargs))
|
272 |
+
if last_cross:
|
273 |
+
layers.append(MergeLayer(dense=True))
|
274 |
+
ni += in_channels(encoder)
|
275 |
+
layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
|
276 |
+
layers += [
|
277 |
+
custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
|
278 |
+
]
|
279 |
+
if y_range is not None:
|
280 |
+
layers.append(SigmoidRange(*y_range))
|
281 |
+
super().__init__(*layers)
|
282 |
+
|
283 |
+
def __del__(self):
|
284 |
+
if hasattr(self, "sfs"):
|
285 |
+
self.sfs.remove()
|
src/deoldify/visualize.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import gc
|
3 |
+
import requests
|
4 |
+
from io import BytesIO
|
5 |
+
import base64
|
6 |
+
from scipy import misc
|
7 |
+
from PIL import Image
|
8 |
+
from matplotlib.axes import Axes
|
9 |
+
from matplotlib.figure import Figure
|
10 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
11 |
+
from typing import Tuple
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from fastai.core import *
|
15 |
+
from fastai.vision import *
|
16 |
+
|
17 |
+
from .filters import IFilter, MasterFilter, ColorizerFilter
|
18 |
+
from .generators import gen_inference_deep, gen_inference_wide
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
# class LoadedModel
|
23 |
+
class ModelImageVisualizer:
|
24 |
+
def __init__(self, filter: IFilter, results_dir: str = None):
|
25 |
+
self.filter = filter
|
26 |
+
self.results_dir = None if results_dir is None else Path(results_dir)
|
27 |
+
self.results_dir.mkdir(parents=True, exist_ok=True)
|
28 |
+
|
29 |
+
def _clean_mem(self):
|
30 |
+
torch.cuda.empty_cache()
|
31 |
+
# gc.collect()
|
32 |
+
|
33 |
+
def _open_pil_image(self, path: Path) -> Image:
|
34 |
+
return Image.open(path).convert('RGB')
|
35 |
+
|
36 |
+
def _get_image_from_url(self, url: str) -> Image:
|
37 |
+
response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'})
|
38 |
+
img = Image.open(BytesIO(response.content)).convert('RGB')
|
39 |
+
return img
|
40 |
+
|
41 |
+
def plot_transformed_image_from_url(
|
42 |
+
self,
|
43 |
+
url: str,
|
44 |
+
path: str = 'test_images/image.png',
|
45 |
+
results_dir:Path = None,
|
46 |
+
figsize: Tuple[int, int] = (20, 20),
|
47 |
+
render_factor: int = None,
|
48 |
+
|
49 |
+
display_render_factor: bool = False,
|
50 |
+
compare: bool = False,
|
51 |
+
post_process: bool = True,
|
52 |
+
watermarked: bool = True,
|
53 |
+
) -> Path:
|
54 |
+
img = self._get_image_from_url(url)
|
55 |
+
img.save(path)
|
56 |
+
return self.plot_transformed_image(
|
57 |
+
path=path,
|
58 |
+
results_dir=results_dir,
|
59 |
+
figsize=figsize,
|
60 |
+
render_factor=render_factor,
|
61 |
+
display_render_factor=display_render_factor,
|
62 |
+
compare=compare,
|
63 |
+
post_process = post_process,
|
64 |
+
watermarked=watermarked,
|
65 |
+
)
|
66 |
+
|
67 |
+
def plot_transformed_image(
|
68 |
+
self,
|
69 |
+
path: str,
|
70 |
+
results_dir:Path = None,
|
71 |
+
figsize: Tuple[int, int] = (20, 20),
|
72 |
+
render_factor: int = None,
|
73 |
+
display_render_factor: bool = False,
|
74 |
+
compare: bool = False,
|
75 |
+
post_process: bool = True,
|
76 |
+
watermarked: bool = True,
|
77 |
+
) -> Path:
|
78 |
+
path = Path(path)
|
79 |
+
if results_dir is None:
|
80 |
+
results_dir = Path(self.results_dir)
|
81 |
+
result = self.get_transformed_image(
|
82 |
+
path, render_factor, post_process=post_process,watermarked=watermarked
|
83 |
+
)
|
84 |
+
orig = self._open_pil_image(path)
|
85 |
+
if compare:
|
86 |
+
self._plot_comparison(
|
87 |
+
figsize, render_factor, display_render_factor, orig, result
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
self._plot_solo(figsize, render_factor, display_render_factor, result)
|
91 |
+
|
92 |
+
orig.close()
|
93 |
+
result_path = self._save_result_image(path, result, results_dir=results_dir)
|
94 |
+
result.close()
|
95 |
+
return result_path
|
96 |
+
|
97 |
+
def plot_transformed_pil_image(
|
98 |
+
self,
|
99 |
+
input_image: Image,
|
100 |
+
figsize: Tuple[int, int] = (20, 20),
|
101 |
+
render_factor: int = None,
|
102 |
+
display_render_factor: bool = False,
|
103 |
+
compare: bool = False,
|
104 |
+
post_process: bool = True,
|
105 |
+
) -> Image:
|
106 |
+
|
107 |
+
result = self.get_transformed_pil_image(
|
108 |
+
input_image, render_factor, post_process=post_process
|
109 |
+
)
|
110 |
+
|
111 |
+
if compare:
|
112 |
+
self._plot_comparison(
|
113 |
+
figsize, render_factor, display_render_factor, input_image, result
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
self._plot_solo(figsize, render_factor, display_render_factor, result)
|
117 |
+
|
118 |
+
return result
|
119 |
+
|
120 |
+
def _plot_comparison(
|
121 |
+
self,
|
122 |
+
figsize: Tuple[int, int],
|
123 |
+
render_factor: int,
|
124 |
+
display_render_factor: bool,
|
125 |
+
orig: Image,
|
126 |
+
result: Image,
|
127 |
+
):
|
128 |
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
129 |
+
self._plot_image(
|
130 |
+
orig,
|
131 |
+
axes=axes[0],
|
132 |
+
figsize=figsize,
|
133 |
+
render_factor=render_factor,
|
134 |
+
display_render_factor=False,
|
135 |
+
)
|
136 |
+
self._plot_image(
|
137 |
+
result,
|
138 |
+
axes=axes[1],
|
139 |
+
figsize=figsize,
|
140 |
+
render_factor=render_factor,
|
141 |
+
display_render_factor=display_render_factor,
|
142 |
+
)
|
143 |
+
|
144 |
+
def _plot_solo(
|
145 |
+
self,
|
146 |
+
figsize: Tuple[int, int],
|
147 |
+
render_factor: int,
|
148 |
+
display_render_factor: bool,
|
149 |
+
result: Image,
|
150 |
+
):
|
151 |
+
fig, axes = plt.subplots(1, 1, figsize=figsize)
|
152 |
+
self._plot_image(
|
153 |
+
result,
|
154 |
+
axes=axes,
|
155 |
+
figsize=figsize,
|
156 |
+
render_factor=render_factor,
|
157 |
+
display_render_factor=display_render_factor,
|
158 |
+
)
|
159 |
+
|
160 |
+
def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:
|
161 |
+
if results_dir is None:
|
162 |
+
results_dir = Path(self.results_dir)
|
163 |
+
result_path = results_dir / source_path.name
|
164 |
+
image.save(result_path)
|
165 |
+
return result_path
|
166 |
+
|
167 |
+
def get_transformed_image(
|
168 |
+
self, path: Path, render_factor: int = None, post_process: bool = True,
|
169 |
+
watermarked: bool = True,
|
170 |
+
) -> Image:
|
171 |
+
self._clean_mem()
|
172 |
+
orig_image = self._open_pil_image(path)
|
173 |
+
filtered_image = self.filter.filter(
|
174 |
+
orig_image, orig_image, render_factor=render_factor,post_process=post_process
|
175 |
+
)
|
176 |
+
|
177 |
+
return filtered_image
|
178 |
+
|
179 |
+
def get_transformed_pil_image(
|
180 |
+
self, input_image: Image, render_factor: int = None, post_process: bool = True,
|
181 |
+
) -> Image:
|
182 |
+
self._clean_mem()
|
183 |
+
filtered_image = self.filter.filter(
|
184 |
+
input_image, input_image, render_factor=render_factor,post_process=post_process
|
185 |
+
)
|
186 |
+
|
187 |
+
return filtered_image
|
188 |
+
|
189 |
+
def _plot_image(
|
190 |
+
self,
|
191 |
+
image: Image,
|
192 |
+
render_factor: int,
|
193 |
+
axes: Axes = None,
|
194 |
+
figsize=(20, 20),
|
195 |
+
display_render_factor = False,
|
196 |
+
):
|
197 |
+
if axes is None:
|
198 |
+
_, axes = plt.subplots(figsize=figsize)
|
199 |
+
axes.imshow(np.asarray(image) / 255)
|
200 |
+
axes.axis('off')
|
201 |
+
if render_factor is not None and display_render_factor:
|
202 |
+
plt.text(
|
203 |
+
10,
|
204 |
+
10,
|
205 |
+
'render_factor: ' + str(render_factor),
|
206 |
+
color='white',
|
207 |
+
backgroundcolor='black',
|
208 |
+
)
|
209 |
+
|
210 |
+
def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]:
|
211 |
+
columns = min(num_images, max_columns)
|
212 |
+
rows = num_images // columns
|
213 |
+
rows = rows if rows * columns == num_images else rows + 1
|
214 |
+
return rows, columns
|
215 |
+
|
216 |
+
|
217 |
+
def get_image_colorizer(
|
218 |
+
root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
|
219 |
+
) -> ModelImageVisualizer:
|
220 |
+
if artistic:
|
221 |
+
return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
|
222 |
+
else:
|
223 |
+
return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
|
224 |
+
|
225 |
+
|
226 |
+
def get_stable_image_colorizer(
|
227 |
+
root_folder: Path = Path('./'),
|
228 |
+
weights_name: str = 'ColorizeStable_gen',
|
229 |
+
results_dir='output',
|
230 |
+
render_factor: int = 35
|
231 |
+
) -> ModelImageVisualizer:
|
232 |
+
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
|
233 |
+
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
234 |
+
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
235 |
+
return vis
|
236 |
+
|
237 |
+
|
238 |
+
def get_artistic_image_colorizer(
|
239 |
+
root_folder: Path = Path('./'),
|
240 |
+
weights_name: str = 'ColorizeArtistic_gen',
|
241 |
+
results_dir='output',
|
242 |
+
render_factor: int = 35
|
243 |
+
) -> ModelImageVisualizer:
|
244 |
+
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
|
245 |
+
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
|
246 |
+
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
|
247 |
+
return vis
|
src/st_style.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
button_style = """
|
2 |
+
<style>
|
3 |
+
div.stButton > button:first-child {
|
4 |
+
background-color: rgb(255, 75, 75);
|
5 |
+
color: rgb(255, 255, 255);
|
6 |
+
}
|
7 |
+
div.stButton > button:hover {
|
8 |
+
background-color: rgb(255, 75, 75);
|
9 |
+
color: rgb(255, 255, 255);
|
10 |
+
}
|
11 |
+
div.stButton > button:active {
|
12 |
+
background-color: rgb(255, 75, 75);
|
13 |
+
color: rgb(255, 255, 255);
|
14 |
+
}
|
15 |
+
div.stButton > button:focus {
|
16 |
+
background-color: rgb(255, 75, 75);
|
17 |
+
color: rgb(255, 255, 255);
|
18 |
+
}
|
19 |
+
.css-1cpxqw2:focus:not(:active) {
|
20 |
+
background-color: rgb(255, 75, 75);
|
21 |
+
border-color: rgb(255, 75, 75);
|
22 |
+
color: rgb(255, 255, 255);
|
23 |
+
}
|
24 |
+
"""
|
25 |
+
|
26 |
+
style = """
|
27 |
+
<style>
|
28 |
+
#MainMenu {
|
29 |
+
visibility: hidden;
|
30 |
+
}
|
31 |
+
footer {
|
32 |
+
visibility: hidden;
|
33 |
+
}
|
34 |
+
header {
|
35 |
+
visibility: hidden;
|
36 |
+
}
|
37 |
+
</style>
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
def apply_prod_style(st):
|
42 |
+
return st.markdown(style, unsafe_allow_html=True)
|