ArchitSharma's picture
Upload 16 files
c716076
raw
history blame
7.94 kB
import cv2
import gc
import requests
from io import BytesIO
import base64
from scipy import misc
from PIL import Image
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from typing import Tuple
import torch
from fastai.core import *
from fastai.vision import *
from .filters import IFilter, MasterFilter, ColorizerFilter
from .generators import gen_inference_deep, gen_inference_wide
# class LoadedModel
class ModelImageVisualizer:
def __init__(self, filter: IFilter, results_dir: str = None):
self.filter = filter
self.results_dir = None if results_dir is None else Path(results_dir)
self.results_dir.mkdir(parents=True, exist_ok=True)
def _clean_mem(self):
torch.cuda.empty_cache()
# gc.collect()
def _open_pil_image(self, path: Path) -> Image:
return Image.open(path).convert('RGB')
def _get_image_from_url(self, url: str) -> Image:
response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'})
img = Image.open(BytesIO(response.content)).convert('RGB')
return img
def plot_transformed_image_from_url(
self,
url: str,
path: str = 'test_images/image.png',
results_dir:Path = None,
figsize: Tuple[int, int] = (20, 20),
render_factor: int = None,
display_render_factor: bool = False,
compare: bool = False,
post_process: bool = True,
watermarked: bool = True,
) -> Path:
img = self._get_image_from_url(url)
img.save(path)
return self.plot_transformed_image(
path=path,
results_dir=results_dir,
figsize=figsize,
render_factor=render_factor,
display_render_factor=display_render_factor,
compare=compare,
post_process = post_process,
watermarked=watermarked,
)
def plot_transformed_image(
self,
path: str,
results_dir:Path = None,
figsize: Tuple[int, int] = (20, 20),
render_factor: int = None,
display_render_factor: bool = False,
compare: bool = False,
post_process: bool = True,
watermarked: bool = True,
) -> Path:
path = Path(path)
if results_dir is None:
results_dir = Path(self.results_dir)
result = self.get_transformed_image(
path, render_factor, post_process=post_process,watermarked=watermarked
)
orig = self._open_pil_image(path)
if compare:
self._plot_comparison(
figsize, render_factor, display_render_factor, orig, result
)
else:
self._plot_solo(figsize, render_factor, display_render_factor, result)
orig.close()
result_path = self._save_result_image(path, result, results_dir=results_dir)
result.close()
return result_path
def plot_transformed_pil_image(
self,
input_image: Image,
figsize: Tuple[int, int] = (20, 20),
render_factor: int = None,
display_render_factor: bool = False,
compare: bool = False,
post_process: bool = True,
) -> Image:
result = self.get_transformed_pil_image(
input_image, render_factor, post_process=post_process
)
if compare:
self._plot_comparison(
figsize, render_factor, display_render_factor, input_image, result
)
else:
self._plot_solo(figsize, render_factor, display_render_factor, result)
return result
def _plot_comparison(
self,
figsize: Tuple[int, int],
render_factor: int,
display_render_factor: bool,
orig: Image,
result: Image,
):
fig, axes = plt.subplots(1, 2, figsize=figsize)
self._plot_image(
orig,
axes=axes[0],
figsize=figsize,
render_factor=render_factor,
display_render_factor=False,
)
self._plot_image(
result,
axes=axes[1],
figsize=figsize,
render_factor=render_factor,
display_render_factor=display_render_factor,
)
def _plot_solo(
self,
figsize: Tuple[int, int],
render_factor: int,
display_render_factor: bool,
result: Image,
):
fig, axes = plt.subplots(1, 1, figsize=figsize)
self._plot_image(
result,
axes=axes,
figsize=figsize,
render_factor=render_factor,
display_render_factor=display_render_factor,
)
def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path:
if results_dir is None:
results_dir = Path(self.results_dir)
result_path = results_dir / source_path.name
image.save(result_path)
return result_path
def get_transformed_image(
self, path: Path, render_factor: int = None, post_process: bool = True,
watermarked: bool = True,
) -> Image:
self._clean_mem()
orig_image = self._open_pil_image(path)
filtered_image = self.filter.filter(
orig_image, orig_image, render_factor=render_factor,post_process=post_process
)
return filtered_image
def get_transformed_pil_image(
self, input_image: Image, render_factor: int = None, post_process: bool = True,
) -> Image:
self._clean_mem()
filtered_image = self.filter.filter(
input_image, input_image, render_factor=render_factor,post_process=post_process
)
return filtered_image
def _plot_image(
self,
image: Image,
render_factor: int,
axes: Axes = None,
figsize=(20, 20),
display_render_factor = False,
):
if axes is None:
_, axes = plt.subplots(figsize=figsize)
axes.imshow(np.asarray(image) / 255)
axes.axis('off')
if render_factor is not None and display_render_factor:
plt.text(
10,
10,
'render_factor: ' + str(render_factor),
color='white',
backgroundcolor='black',
)
def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]:
columns = min(num_images, max_columns)
rows = num_images // columns
rows = rows if rows * columns == num_images else rows + 1
return rows, columns
def get_image_colorizer(
root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True
) -> ModelImageVisualizer:
if artistic:
return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor)
else:
return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor)
def get_stable_image_colorizer(
root_folder: Path = Path('./'),
weights_name: str = 'ColorizeStable_gen',
results_dir='output',
render_factor: int = 35
) -> ModelImageVisualizer:
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
return vis
def get_artistic_image_colorizer(
root_folder: Path = Path('./'),
weights_name: str = 'ColorizeArtistic_gen',
results_dir='output',
render_factor: int = 35
) -> ModelImageVisualizer:
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
vis = ModelImageVisualizer(filtr, results_dir=results_dir)
return vis