Daniel Verdu commited on
Commit
4403f05
1 Parent(s): 6f8465a

first commit in hf_spaces

Browse files
Files changed (1) hide show
  1. deoldify/visualize.py +10 -11
deoldify/visualize.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
  from matplotlib.axes import Axes
9
  from matplotlib.figure import Figure
10
  from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
 
11
 
12
  import torch
13
  from fastai.core import *
@@ -18,6 +19,7 @@ from .generators import gen_inference_deep, gen_inference_wide
18
 
19
 
20
 
 
21
  class ModelImageVisualizer:
22
  def __init__(self, filter: IFilter, results_dir: str = None):
23
  self.filter = filter
@@ -29,11 +31,11 @@ class ModelImageVisualizer:
29
  # gc.collect()
30
 
31
  def _open_pil_image(self, path: Path) -> Image:
32
- return PIL.Image.open(path).convert('RGB')
33
 
34
  def _get_image_from_url(self, url: str) -> Image:
35
  response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'})
36
- img = PIL.Image.open(BytesIO(response.content)).convert('RGB')
37
  return img
38
 
39
  def plot_transformed_image_from_url(
@@ -41,7 +43,7 @@ class ModelImageVisualizer:
41
  url: str,
42
  path: str = 'test_images/image.png',
43
  results_dir:Path = None,
44
- figsize: (int, int) = (20, 20),
45
  render_factor: int = None,
46
 
47
  display_render_factor: bool = False,
@@ -66,7 +68,7 @@ class ModelImageVisualizer:
66
  self,
67
  path: str,
68
  results_dir:Path = None,
69
- figsize: (int, int) = (20, 20),
70
  render_factor: int = None,
71
  display_render_factor: bool = False,
72
  compare: bool = False,
@@ -95,7 +97,7 @@ class ModelImageVisualizer:
95
  def plot_transformed_pil_image(
96
  self,
97
  input_image: Image,
98
- figsize: (int, int) = (20, 20),
99
  render_factor: int = None,
100
  display_render_factor: bool = False,
101
  compare: bool = False,
@@ -117,7 +119,7 @@ class ModelImageVisualizer:
117
 
118
  def _plot_comparison(
119
  self,
120
- figsize: (int, int),
121
  render_factor: int,
122
  display_render_factor: bool,
123
  orig: Image,
@@ -141,7 +143,7 @@ class ModelImageVisualizer:
141
 
142
  def _plot_solo(
143
  self,
144
- figsize: (int, int),
145
  render_factor: int,
146
  display_render_factor: bool,
147
  result: Image,
@@ -172,9 +174,6 @@ class ModelImageVisualizer:
172
  orig_image, orig_image, render_factor=render_factor,post_process=post_process
173
  )
174
 
175
- # if watermarked:
176
- # return get_watermarked(filtered_image)
177
-
178
  return filtered_image
179
 
180
  def get_transformed_pil_image(
@@ -208,7 +207,7 @@ class ModelImageVisualizer:
208
  backgroundcolor='black',
209
  )
210
 
211
- def _get_num_rows_columns(self, num_images: int, max_columns: int) -> (int, int):
212
  columns = min(num_images, max_columns)
213
  rows = num_images // columns
214
  rows = rows if rows * columns == num_images else rows + 1
 
8
  from matplotlib.axes import Axes
9
  from matplotlib.figure import Figure
10
  from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
11
+ from typing import Tuple
12
 
13
  import torch
14
  from fastai.core import *
 
19
 
20
 
21
 
22
+ # class LoadedModel
23
  class ModelImageVisualizer:
24
  def __init__(self, filter: IFilter, results_dir: str = None):
25
  self.filter = filter
 
31
  # gc.collect()
32
 
33
  def _open_pil_image(self, path: Path) -> Image:
34
+ return Image.open(path).convert('RGB')
35
 
36
  def _get_image_from_url(self, url: str) -> Image:
37
  response = requests.get(url, timeout=30, headers={'Accept': '*/*;q=0.8'})
38
+ img = Image.open(BytesIO(response.content)).convert('RGB')
39
  return img
40
 
41
  def plot_transformed_image_from_url(
 
43
  url: str,
44
  path: str = 'test_images/image.png',
45
  results_dir:Path = None,
46
+ figsize: Tuple[int, int] = (20, 20),
47
  render_factor: int = None,
48
 
49
  display_render_factor: bool = False,
 
68
  self,
69
  path: str,
70
  results_dir:Path = None,
71
+ figsize: Tuple[int, int] = (20, 20),
72
  render_factor: int = None,
73
  display_render_factor: bool = False,
74
  compare: bool = False,
 
97
  def plot_transformed_pil_image(
98
  self,
99
  input_image: Image,
100
+ figsize: Tuple[int, int] = (20, 20),
101
  render_factor: int = None,
102
  display_render_factor: bool = False,
103
  compare: bool = False,
 
119
 
120
  def _plot_comparison(
121
  self,
122
+ figsize: Tuple[int, int],
123
  render_factor: int,
124
  display_render_factor: bool,
125
  orig: Image,
 
143
 
144
  def _plot_solo(
145
  self,
146
+ figsize: Tuple[int, int],
147
  render_factor: int,
148
  display_render_factor: bool,
149
  result: Image,
 
174
  orig_image, orig_image, render_factor=render_factor,post_process=post_process
175
  )
176
 
 
 
 
177
  return filtered_image
178
 
179
  def get_transformed_pil_image(
 
207
  backgroundcolor='black',
208
  )
209
 
210
+ def _get_num_rows_columns(self, num_images: int, max_columns: int) -> Tuple[int, int]:
211
  columns = min(num_images, max_columns)
212
  rows = num_images // columns
213
  rows = rows if rows * columns == num_images else rows + 1