YANGYYYY commited on
Commit
ad243cd
1 Parent(s): 7e4e601

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +330 -0
inference.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import os
4
+ import numpy as np
5
+ import shutil
6
+ from models.anime_gan import GeneratorV1
7
+ from models.anime_gan_v2 import GeneratorV2
8
+ from models.anime_gan_v3 import GeneratorV3
9
+ from utils.common import load_checkpoint, RELEASED_WEIGHTS
10
+ from utils.image_processing import resize_image, normalize_input, denormalize_input
11
+ from utils import read_image, is_image_file
12
+ from tqdm import tqdm
13
+ # from torch.cuda.amp import autocast
14
+
15
+ try:
16
+ import matplotlib.pyplot as plt
17
+ except ImportError:
18
+ plt = None
19
+
20
+ try:
21
+ import moviepy.video.io.ffmpeg_writer as ffmpeg_writer
22
+ from moviepy.video.io.VideoFileClip import VideoFileClip
23
+ except ImportError:
24
+ ffmpeg_writer = None
25
+ VideoFileClip = None
26
+
27
+
28
+ VALID_FORMATS = {
29
+ 'jpeg', 'jpg', 'jpe',
30
+ 'png', 'bmp',
31
+ }
32
+
33
+ def auto_load_weight(weight, version=None, map_location=None):
34
+ """Auto load Generator version from weight."""
35
+ weight_name = os.path.basename(weight).lower()
36
+ if version is not None:
37
+ version = version.lower()
38
+ assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist"
39
+ # If version is provided, use it.
40
+ cls = {
41
+ "v1": GeneratorV1,
42
+ "v2": GeneratorV2,
43
+ "v3": GeneratorV3
44
+ }[version]
45
+ else:
46
+ # Try to get class by name of weight file
47
+ # For convenenice, weight should start with classname
48
+ # e.g: Generatorv2_{anything}.pt
49
+ if weight_name in RELEASED_WEIGHTS:
50
+ version = RELEASED_WEIGHTS[weight_name][0]
51
+ return auto_load_weight(weight, version=version, map_location=map_location)
52
+
53
+ elif weight_name.startswith("generatorv2"):
54
+ cls = GeneratorV2
55
+ elif weight_name.startswith("generatorv3"):
56
+ cls = GeneratorV3
57
+ elif weight_name.startswith("generator"):
58
+ cls = GeneratorV1
59
+ else:
60
+ raise ValueError((f"Can not get Model from {weight_name}, "
61
+ "you might need to explicitly specify version"))
62
+ model = cls()
63
+ load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location)
64
+ model.eval()
65
+ return model
66
+
67
+
68
+ class Predictor:
69
+ def __init__(self, weight='hayao', device='cpu', amp=True):
70
+ # if not torch.cuda.is_available():
71
+ # device = 'cpu'
72
+ # # Amp not working on cpu
73
+ # amp = False
74
+
75
+ self.amp = False # Automatic Mixed Precision
76
+ #self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
77
+ self.device_type = 'cpu'
78
+ self.device = torch.device(device)
79
+ self.G = auto_load_weight(weight, map_location=device)
80
+ self.G.to(self.device)
81
+
82
+ def transform_and_show(
83
+ self,
84
+ image_path,
85
+ figsize=(18, 10),
86
+ save_path=None
87
+ ):
88
+ image = resize_image(read_image(image_path))
89
+ anime_img = self.transform(image)
90
+ anime_img = anime_img.astype('uint8')
91
+
92
+ fig = plt.figure(figsize=figsize)
93
+ fig.add_subplot(1, 2, 1)
94
+ # plt.title("Input")
95
+ plt.imshow(image)
96
+ plt.axis('off')
97
+ fig.add_subplot(1, 2, 2)
98
+ # plt.title("Anime style")
99
+ plt.imshow(anime_img[0])
100
+ plt.axis('off')
101
+ plt.tight_layout()
102
+ plt.show()
103
+ if save_path is not None:
104
+ plt.savefig(save_path)
105
+
106
+ def transform(self, image, denorm=True):
107
+ '''
108
+ Transform a image to animation
109
+
110
+ @Arguments:
111
+ - image: np.array, shape = (Batch, width, height, channels)
112
+
113
+ @Returns:
114
+ - anime version of image: np.array
115
+ '''
116
+ with torch.no_grad():
117
+ image = self.preprocess_images(image)
118
+ # image = image.to(self.device)
119
+ # with autocast(self.device_type, enabled=self.amp):
120
+ # print(image.dtype, self.G)
121
+ fake = self.G(image)
122
+ fake = fake.detach().cpu().numpy()
123
+ # Channel last
124
+ fake = fake.transpose(0, 2, 3, 1)
125
+
126
+ if denorm:
127
+ fake = denormalize_input(fake, dtype=np.uint8)
128
+ return fake
129
+
130
+ def transform_image(self,image):
131
+ # if not is_image_file(save_path):
132
+ # raise ValueError(f"{save_path} is not valid")
133
+
134
+ # image = read_image(file_path)
135
+ #
136
+ # if image is None:
137
+ # raise ValueError(f"Could not get image from {file_path}")
138
+
139
+ anime_img = self.transform(resize_image(image))[0]
140
+ return anime_img
141
+ # cv2.imwrite(save_path, anime_img[..., ::-1])
142
+ # print(f"Anime image saved to {save_path}")
143
+
144
+ def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)):
145
+ '''
146
+ Read all images from img_dir, transform and write the result
147
+ to dest_dir
148
+
149
+ '''
150
+ os.makedirs(dest_dir, exist_ok=True)
151
+
152
+ files = os.listdir(img_dir)
153
+ files = [f for f in files if self.is_valid_file(f)]
154
+ print(f'Found {len(files)} images in {img_dir}')
155
+
156
+ if max_images:
157
+ files = files[:max_images]
158
+
159
+ for fname in tqdm(files):
160
+ image = cv2.imread(os.path.join(img_dir, fname))[:,:,::-1]
161
+ image = resize_image(image)
162
+ anime_img = self.transform(image)[0]
163
+ ext = fname.split('.')[-1]
164
+ fname = fname.replace(f'.{ext}', '')
165
+ cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1])
166
+
167
+ def transform_video_yuan(self, input_path, output_path, batch_size=4, start=0, end=0):
168
+
169
+ '''
170
+ Transform a video to animation version
171
+ https://github.com/lengstrom/fast-style-transfer/blob/master/evaluate.py#L21
172
+ '''
173
+ # Force to None
174
+ end = end or None
175
+
176
+ if not os.path.isfile(input_path):
177
+ raise FileNotFoundError(f'{input_path} does not exist')
178
+
179
+ output_dir = "/".join(output_path.split("/")[:-1])
180
+ os.makedirs(output_dir, exist_ok=True)
181
+ is_gg_drive = '/drive/' in output_path
182
+ temp_file = ''
183
+ #output_file = open(output_path, 'wb')
184
+
185
+ if is_gg_drive:
186
+ # Writing directly into google drive can be inefficient
187
+ temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
188
+
189
+ def transform_and_write(frames, count, writer):
190
+ anime_images = self.transform(frames)
191
+ for i in range(0, count):
192
+ img = np.clip(anime_images[i], 0, 255)
193
+ writer.write_frame(img)
194
+
195
+ video_clip = VideoFileClip(input_path, audio=False)
196
+ if start or end:
197
+ video_clip = video_clip.subclip(start, end)
198
+
199
+ video_writer = ffmpeg_writer.FFMPEG_VideoWriter(
200
+ output_path,
201
+ video_clip.size, video_clip.fps, codec="libx264",
202
+ # preset="medium", bitrate="2000k",
203
+ audiofile=input_path, threads=None,
204
+ ffmpeg_params=None)
205
+
206
+ total_frames = round(video_clip.fps * video_clip.duration)
207
+ print(f'Transfroming video {input_path}, {total_frames} frames, size: {video_clip.size}')
208
+
209
+ batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3)
210
+ frame_count = 0
211
+ frames = np.zeros(batch_shape, dtype=np.float32)
212
+ for frame in tqdm(video_clip.iter_frames()):
213
+ try:
214
+ frames[frame_count] = frame
215
+ frame_count += 1
216
+ if frame_count == batch_size:
217
+ transform_and_write(frames, frame_count, video_writer)
218
+ frame_count = 0
219
+ except Exception as e:
220
+ print(e)
221
+ break
222
+
223
+
224
+ # The last frames
225
+ if frame_count != 0:
226
+ transform_and_write(frames, frame_count, video_writer)
227
+
228
+ if temp_file:
229
+ # move to output path
230
+ shutil.move(temp_file, output_path)
231
+
232
+ print(f'Animation video saved to {output_path}')
233
+ video_writer.close()
234
+ #output_file.close()
235
+
236
+ def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0):
237
+ end = end or None
238
+
239
+ if not os.path.isfile(input_path):
240
+ raise FileNotFoundError(f'{input_path} does not exist')
241
+
242
+ output_dir = "/".join(output_path.split("/")[:-1])
243
+ os.makedirs(output_dir, exist_ok=True)
244
+ is_gg_drive = '/drive/' in output_path
245
+ temp_file = ''
246
+
247
+ if is_gg_drive:
248
+ temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
249
+
250
+ def transform_and_write(frames, count, writer):
251
+ anime_images = self.transform(frames)
252
+ for i in range(count):
253
+ img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
254
+ writer.write(img)
255
+
256
+ video_capture = cv2.VideoCapture(input_path)
257
+ frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
258
+ frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
259
+ fps = int(video_capture.get(cv2.CAP_PROP_FPS))
260
+ frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
261
+
262
+ if start or end:
263
+ start_frame = int(start * fps)
264
+ end_frame = int(end * fps) if end else frame_count
265
+ video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
266
+ frame_count = end_frame - start_frame
267
+
268
+ video_writer = cv2.VideoWriter(
269
+ output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
270
+
271
+ print(f'Transforming video {input_path}, {frame_count} frames, size: ({frame_width}, {frame_height})')
272
+
273
+ batch_shape = (batch_size, frame_height, frame_width, 3)
274
+ frames = np.zeros(batch_shape, dtype=np.uint8)
275
+ frame_idx = 0
276
+
277
+ try:
278
+ for _ in tqdm(range(frame_count)):
279
+ ret, frame = video_capture.read()
280
+ if not ret:
281
+ break
282
+ frames[frame_idx] = frame
283
+ frame_idx += 1
284
+ if frame_idx == batch_size:
285
+ transform_and_write(frames, frame_idx, video_writer)
286
+ frame_idx = 0
287
+ except Exception as e:
288
+ print(e)
289
+ finally:
290
+ video_capture.release()
291
+ video_writer.release()
292
+
293
+ if temp_file:
294
+ shutil.move(temp_file, output_path)
295
+
296
+ print(f'Animation video saved to {output_path}')
297
+ def preprocess_images(self, images):
298
+ '''
299
+ Preprocess image for inference
300
+
301
+ @Arguments:
302
+ - images: np.ndarray
303
+
304
+ @Returns
305
+ - images: torch.tensor
306
+ '''
307
+ images = images.astype(np.float32)
308
+
309
+ # Normalize to [-1, 1]
310
+ images = normalize_input(images)
311
+ images = torch.from_numpy(images)
312
+
313
+ images = images.to(self.device)
314
+
315
+ # Add batch dim
316
+ if len(images.shape) == 3:
317
+ images = images.unsqueeze(0)
318
+
319
+ # channel first
320
+ images = images.permute(0, 3, 1, 2)
321
+
322
+ return images
323
+
324
+
325
+ @staticmethod
326
+ def is_valid_file(fname):
327
+ ext = fname.split('.')[-1]
328
+ return ext in VALID_FORMATS
329
+
330
+