Spaces:
Runtime error
Runtime error
import cv2 | |
import gc | |
import requests | |
from io import BytesIO | |
import base64 | |
from scipy import misc | |
from PIL import Image | |
from matplotlib.axes import Axes | |
from matplotlib.figure import Figure | |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
from typing import Tuple | |
import torch | |
from fastai.core import * | |
from fastai.vision import * | |
from .filters import IFilter, MasterFilter, ColorizerFilter | |
from .generators import gen_inference_deep, gen_inference_wide | |
<<<<<<< HEAD | |
======= | |
# class LoadedModel | |
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143 | |
class ModelImageVisualizer: | |
def __init__(self, filter: IFilter, results_dir: str = None): | |
self.filter = filter | |
self.results_dir = None if results_dir is None else Path(results_dir) | |
<<<<<<< HEAD | |
if self.results_dir is not None: | |
self.results_dir.mkdir(parents=True, exist_ok=True) | |
======= | |
self.results_dir.mkdir(parents=True, exist_ok=True) | |
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143 | |
def _clean_mem(self): | |
torch.cuda.empty_cache() | |
# gc.collect() | |
def _open_pil_image(self, path: Path) -> Image: | |
return Image.open(path).convert('RGB') | |
def _get_image_from_url(self, url: str) -> Image: | |
response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'}) | |
img = Image.open(BytesIO(response.content)).convert('RGB') | |
return img | |
def plot_transformed_image_from_url( | |
self, | |
url: str, | |
path: str = 'test_images/image.png', | |
results_dir:Path = None, | |
figsize: Tuple[int, int] = (20, 20), | |
render_factor: int = None, | |
display_render_factor: bool = False, | |
compare: bool = False, | |
post_process: bool = True, | |
watermarked: bool = True, | |
) -> Path: | |
img = self._get_image_from_url(url) | |
img.save(path) | |
return self.plot_transformed_image( | |
path=path, | |
results_dir=results_dir, | |
figsize=figsize, | |
render_factor=render_factor, | |
display_render_factor=display_render_factor, | |
compare=compare, | |
post_process = post_process, | |
watermarked=watermarked, | |
) | |
def plot_transformed_image( | |
self, | |
path: str, | |
results_dir:Path = None, | |
figsize: Tuple[int, int] = (20, 20), | |
render_factor: int = None, | |
display_render_factor: bool = False, | |
compare: bool = False, | |
post_process: bool = True, | |
watermarked: bool = True, | |
) -> Path: | |
path = Path(path) | |
if results_dir is None: | |
results_dir = Path(self.results_dir) | |
result = self.get_transformed_image( | |
path, render_factor, post_process=post_process,watermarked=watermarked | |
) | |
orig = self._open_pil_image(path) | |
if compare: | |
self._plot_comparison( | |
figsize, render_factor, display_render_factor, orig, result | |
) | |
else: | |
self._plot_solo(figsize, render_factor, display_render_factor, result) | |
orig.close() | |
result_path = self._save_result_image(path, result, results_dir=results_dir) | |
result.close() | |
return result_path | |
def plot_transformed_pil_image( | |
self, | |
input_image: Image, | |
figsize: Tuple[int, int] = (20, 20), | |
render_factor: int = None, | |
display_render_factor: bool = False, | |
compare: bool = False, | |
post_process: bool = True, | |
) -> Image: | |
result = self.get_transformed_pil_image( | |
input_image, render_factor, post_process=post_process | |
) | |
if compare: | |
self._plot_comparison( | |
figsize, render_factor, display_render_factor, input_image, result | |
) | |
else: | |
self._plot_solo(figsize, render_factor, display_render_factor, result) | |
return result | |
def _plot_comparison( | |
self, | |
figsize: Tuple[int, int], | |
render_factor: int, | |
display_render_factor: bool, | |
orig: Image, | |
result: Image, | |
): | |
fig, axes = plt.subplots(1, 2, figsize=figsize) | |
self._plot_image( | |
orig, | |
axes=axes[0], | |
figsize=figsize, | |
render_factor=render_factor, | |
display_render_factor=False, | |
) | |
self._plot_image( | |
result, | |
axes=axes[1], | |
figsize=figsize, | |
render_factor=render_factor, | |
display_render_factor=display_render_factor, | |
) | |
def _plot_solo( | |
self, | |
figsize: Tuple[int, int], | |
render_factor: int, | |
display_render_factor: bool, | |
result: Image, | |
): | |
fig, axes = plt.subplots(1, 1, figsize=figsize) | |
self._plot_image( | |
result, | |
axes=axes, | |
figsize=figsize, | |
render_factor=render_factor, | |
display_render_factor=display_render_factor, | |
) | |
def _save_result_image(self, source_path: Path, image: Image, results_dir = None) -> Path: | |
if results_dir is None: | |
results_dir = Path(self.results_dir) | |
result_path = results_dir / source_path.name | |
image.save(result_path) | |
return result_path | |
def get_transformed_image( | |
self, path: Path, render_factor: int = None, post_process: bool = True, | |
watermarked: bool = True, | |
) -> Image: | |
self._clean_mem() | |
orig_image = self._open_pil_image(path) | |
filtered_image = self.filter.filter( | |
orig_image, orig_image, render_factor=render_factor,post_process=post_process | |
) | |
return filtered_image | |
def get_transformed_pil_image( | |
self, input_image: Image, render_factor: int = None, post_process: bool = True, | |
) -> Image: | |
self._clean_mem() | |
filtered_image = self.filter.filter( | |
input_image, input_image, render_factor=render_factor,post_process=post_process | |
) | |
return filtered_image | |
def _plot_image( | |
self, | |
image: Image, | |
render_factor: int, | |
axes: Axes = None, | |
figsize=(20, 20), | |
display_render_factor = False, | |
): | |
if axes is None: | |
_, axes = plt.subplots(figsize=figsize) | |
axes.imshow(np.asarray(image) / 255) | |
axes.axis('off') | |
if render_factor is not None and display_render_factor: | |
plt.text( | |
10, | |
10, | |
'render_factor: ' + str(render_factor), | |
color='white', | |
backgroundcolor='black', | |
) | |
def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]: | |
columns = min(num_images, max_columns) | |
rows = num_images // columns | |
rows = rows if rows * columns == num_images else rows + 1 | |
return rows, columns | |
<<<<<<< HEAD | |
def get_image_colorizer(root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True | |
) -> ModelImageVisualizer: | |
======= | |
def get_image_colorizer( | |
root_folder: Path = Path('./'), render_factor: int = 35, artistic: bool = True | |
) -> ModelImageVisualizer: | |
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143 | |
if artistic: | |
return get_artistic_image_colorizer(root_folder=root_folder, render_factor=render_factor) | |
else: | |
return get_stable_image_colorizer(root_folder=root_folder, render_factor=render_factor) | |
<<<<<<< HEAD | |
def get_stable_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeStable_gen', | |
results_dir='result_images', render_factor: int = 35 | |
) -> ModelImageVisualizer: | |
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name) | |
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor) | |
# vis = ModelImageVisualizer(filtr, results_dir=results_dir) | |
vis = ModelImageVisualizer(filtr) | |
return vis | |
def get_artistic_image_colorizer(root_folder: Path = Path('./'), weights_name: str = 'ColorizeArtistic_gen', | |
results_dir='result_images', render_factor: int = 35 | |
) -> ModelImageVisualizer: | |
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name) | |
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor) | |
# vis = ModelImageVisualizer(filtr, results_dir=results_dir) | |
vis = ModelImageVisualizer(filtr) | |
======= | |
def get_stable_image_colorizer( | |
root_folder: Path = Path('./'), | |
weights_name: str = 'ColorizeStable_gen', | |
results_dir='result_images', | |
render_factor: int = 35 | |
) -> ModelImageVisualizer: | |
learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name) | |
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor) | |
vis = ModelImageVisualizer(filtr, results_dir=results_dir) | |
return vis | |
def get_artistic_image_colorizer( | |
root_folder: Path = Path('./'), | |
weights_name: str = 'ColorizeArtistic_gen', | |
results_dir='result_images', | |
render_factor: int = 35 | |
) -> ModelImageVisualizer: | |
learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name) | |
filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor) | |
vis = ModelImageVisualizer(filtr, results_dir=results_dir) | |
>>>>>>> 878ecf212e9f3f2f6e923e3bfff6ec899dc40143 | |
return vis |