Daniel Verdu
merged changes
d984001
raw
history blame
9.56 kB
import cv2
import gc
import requests
from io import BytesIO
import base64
from scipy import misc
from PIL import Image
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from typing import Tuple
import torch
from fastai.core import *
from fastai.vision import *
from .filters import IFilter, MasterFilter, ColorizerFilter
from .generators import gen_inference_deep, gen_inference_wide
<<<<<<< HEAD
=======
# class LoadedModel
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
class ModelImageVisualizer:
def __init__(self, filter: IFilter, results_dir: str = None):
self.filter = filter
self.results_dir = None if results_dir is None else Path(results_dir)
<<<<<<< HEAD
if self.results_dir is not None:
self.results_dir.mkdir(parents=True, exist_ok=True)
=======
self.results_dir.mkdir(parents=True, exist_ok=True)
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
def _clean_mem(self):
torch.cuda.empty_cache()
# gc.collect()
def _open_pil_image(self, path: Path) -> Image:
return Image.open(path).convert('RGB')
def _get_image_from_url(self, url: str) -> Image:
response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'})
img = Image.open(BytesIO(response.content)).convert('RGB')
return img
def plot_transformed_image_from_url(
self,
url: str,
path: str = 'test_images/image.png',
results_dir:Path = None,
figsize: Tuple[int, int] = (20, 20),
render_factor: int = None,
display_render_factor: bool = False,
compare: bool = False,
post_process: bool = True,
watermarked: bool = True,
) -> Path:
img = self._get_image_from_url(url)
img.save(path)
return self.plot_transformed_image(
path=path,
results_dir=results_dir,
figsize=figsize,
render_factor=render_factor,
display_render_factor=display_render_factor,
compare=compare,
post_process = post_process,
watermarked=watermarked,
)
def plot_transformed_image(
self,
path: str,
results_dir:Path = None,
figsize: Tuple[int, int] = (20, 20),
render_factor: int = None,
display_render_factor: bool = False,
compare: bool = False,
post_process: bool = True,
watermarked: bool = True,
) -> Path:
path = Path(path)
if results_dir is None:
results_dir = Path(self.results_dir)
result = self.get_transformed_image(
path, render_factor, post_process=post_process,watermarked=watermarked
)
orig = self._open_pil_image(path)
if compare:
self._plot_comparison(
figsize, render_factor, display_render_factor, orig, result
)
else:
self._plot_solo(figsize, render_factor, display_render_factor, result)
orig.close()
result_path = self._save_result_image(path, result, results_dir=results_dir)
result.close()
return result_path
def plot_transformed_pil_image(
self,
input_image: Image,
figsize: Tuple[int, int] = (20, 20),
render_factor: int = None,
display_render_factor: bool = False,
compare: bool = False,
post_process: bool = True,
) -> Image:
result = self.get_transformed_pil_image(
input_image, render_factor, post_process=post_process
)
if compare:
self._plot_comparison(
figsize, render_factor, display_render_factor, input_image, result
)
else:
self._plot_solo(figsize, render_factor, display_render_factor, result)
return result
def _plot_comparison(
self,
figsize: Tuple[int, int],
render_factor: int,
display_render_factor: bool,
orig: Image,
result: Image,
):
fig, axes = plt.subplots(1, 2, figsize=figsize)
self._plot_image(
orig,
axes=axes[0],
figsize=figsize,
render_factor=render_factor,
display_render_factor=False,
)
self._plot_image(
result,
axes=axes[1],
figsize=figsize,
render_factor=render_factor,
display_render_factor=display_render_factor,
)
def _plot_solo(
self,
figsize: Tuple[int, int],
render_factor: int,
display_render_factor: bool,
result: Image,
):
fig, axes = plt.subplots(1, 1, figsize=figsize)
self._plot_image(
result,
axes=axes,
figsize=figsize,
render_factor=render_factor,
display_render_factor=display_render_factor,
)
def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:
if results_dir is None:
results_dir = Path(self.results_dir)
result_path = results_dir / source_path.name
image.save(result_path)
return result_path
def get_transformed_image(
self, path: Path, render_factor: int = None, post_process: bool = True,
watermarked: bool = True,
) -> Image:
self._clean_mem()
orig_image = self._open_pil_image(path)
filtered_image = self.filter.filter(
orig_image, orig_image, render_factor=render_factor,post_process=post_process
)
return filtered_image
def get_transformed_pil_image(
self, input_image: Image, render_factor: int = None, post_process: bool = True,
) -> Image:
self._clean_mem()
filtered_image = self.filter.filter(
input_image, input_image, render_factor=render_factor,post_process=post_process
)
return filtered_image
def _plot_image(
self,
image: Image,
render_factor: int,
axes: Axes = None,
figsize=(20, 20),
display_render_factor = False,
):
if axes is None:
_, axes = plt.subplots(figsize=figsize)
axes.imshow(np.asarray(image) / 255)
axes.axis('off')
if render_factor is not None and display_render_factor:
plt.text(
10,
10,
'render_factor: ' + str(render_factor),
color='white',
backgroundcolor='black',
)
def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]:
columns = min(num_images, max_columns)
rows = num_images // columns
rows = rows if rows * columns == num_images else rows + 1
return rows, columns
<<<<<<< HEAD
def get_image_colorizer(root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
) -> ModelImageVisualizer:
=======
def get_image_colorizer(
root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
) -> ModelImageVisualizer:
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
if artistic:
return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
else:
return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
<<<<<<< HEAD
def get_stable_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeStable_gen',
results_dir='result_images', render_factor: int = 35
) -> ModelImageVisualizer:
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
# vis = ModelImageVisualizer(filtr, results_dir=results_dir)
vis = ModelImageVisualizer(filtr)
return vis
def get_artistic_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeArtistic_gen',
results_dir='result_images', render_factor: int = 35
) -> ModelImageVisualizer:
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
# vis = ModelImageVisualizer(filtr, results_dir=results_dir)
vis = ModelImageVisualizer(filtr)
=======
def get_stable_image_colorizer(
root_folder: Path = Path('./'),
weights_name: str = 'ColorizeStable_gen',
results_dir='result_images',
render_factor: int = 35
) -> ModelImageVisualizer:
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
return vis
def get_artistic_image_colorizer(
root_folder: Path = Path('./'),
weights_name: str = 'ColorizeArtistic_gen',
results_dir='result_images',
render_factor: int = 35
) -> ModelImageVisualizer:
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143
return vis