Spaces:
Build error
Build error
| import torch | |
| import cv2 | |
| import os | |
| import numpy as np | |
| import shutil | |
| from models.anime_gan import GeneratorV1 | |
| from models.anime_gan_v2 import GeneratorV2 | |
| from models.anime_gan_v3 import GeneratorV3 | |
| from utils.common import load_checkpoint, RELEASED_WEIGHTS | |
| from utils.image_processing import resize_image, normalize_input, denormalize_input | |
| from utils import read_image, is_image_file | |
| from tqdm import tqdm | |
| # from torch.cuda.amp import autocast | |
| try: | |
| import matplotlib.pyplot as plt | |
| except ImportError: | |
| plt = None | |
| try: | |
| import moviepy.video.io.ffmpeg_writer as ffmpeg_writer | |
| from moviepy.video.io.VideoFileClip import VideoFileClip | |
| except ImportError: | |
| ffmpeg_writer = None | |
| VideoFileClip = None | |
| VALID_FORMATS = { | |
| 'jpeg', 'jpg', 'jpe', | |
| 'png', 'bmp', | |
| } | |
| def auto_load_weight(weight, version=None, map_location=None): | |
| """Auto load Generator version from weight.""" | |
| weight_name = os.path.basename(weight).lower() | |
| if version is not None: | |
| version = version.lower() | |
| assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist" | |
| # If version is provided, use it. | |
| cls = { | |
| "v1": GeneratorV1, | |
| "v2": GeneratorV2, | |
| "v3": GeneratorV3 | |
| }[version] | |
| else: | |
| # Try to get class by name of weight file | |
| # For convenenice, weight should start with classname | |
| # e.g: Generatorv2_{anything}.pt | |
| if weight_name in RELEASED_WEIGHTS: | |
| version = RELEASED_WEIGHTS[weight_name][0] | |
| return auto_load_weight(weight, version=version, map_location=map_location) | |
| elif weight_name.startswith("generatorv2"): | |
| cls = GeneratorV2 | |
| elif weight_name.startswith("generatorv3"): | |
| cls = GeneratorV3 | |
| elif weight_name.startswith("generator"): | |
| cls = GeneratorV1 | |
| else: | |
| raise ValueError((f"Can not get Model from {weight_name}, " | |
| "you might need to explicitly specify version")) | |
| model = cls() | |
| load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location) | |
| model.eval() | |
| return model | |
| class Predictor: | |
| def __init__(self, weight='hayao', device='cpu', amp=True): | |
| # if not torch.cuda.is_available(): | |
| # device = 'cpu' | |
| # # Amp not working on cpu | |
| # amp = False | |
| self.amp = False # Automatic Mixed Precision | |
| #self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' | |
| self.device_type = 'cpu' | |
| self.device = torch.device(device) | |
| self.G = auto_load_weight(weight, map_location=device) | |
| self.G.to(self.device) | |
| def transform_and_show( | |
| self, | |
| image_path, | |
| figsize=(18, 10), | |
| save_path=None | |
| ): | |
| image = resize_image(read_image(image_path)) | |
| anime_img = self.transform(image) | |
| anime_img = anime_img.astype('uint8') | |
| fig = plt.figure(figsize=figsize) | |
| fig.add_subplot(1, 2, 1) | |
| # plt.title("Input") | |
| plt.imshow(image) | |
| plt.axis('off') | |
| fig.add_subplot(1, 2, 2) | |
| # plt.title("Anime style") | |
| plt.imshow(anime_img[0]) | |
| plt.axis('off') | |
| plt.tight_layout() | |
| plt.show() | |
| if save_path is not None: | |
| plt.savefig(save_path) | |
| def transform(self, image, denorm=True): | |
| ''' | |
| Transform a image to animation | |
| @Arguments: | |
| - image: np.array, shape = (Batch, width, height, channels) | |
| @Returns: | |
| - anime version of image: np.array | |
| ''' | |
| with torch.no_grad(): | |
| image = self.preprocess_images(image) | |
| # image = image.to(self.device) | |
| # with autocast(self.device_type, enabled=self.amp): | |
| # print(image.dtype, self.G) | |
| fake = self.G(image) | |
| fake = fake.detach().cpu().numpy() | |
| # Channel last | |
| fake = fake.transpose(0, 2, 3, 1) | |
| if denorm: | |
| fake = denormalize_input(fake, dtype=np.uint8) | |
| return fake | |
| def transform_image(self,image): | |
| # if not is_image_file(save_path): | |
| # raise ValueError(f"{save_path} is not valid") | |
| # image = read_image(file_path) | |
| # | |
| # if image is None: | |
| # raise ValueError(f"Could not get image from {file_path}") | |
| anime_img = self.transform(resize_image(image))[0] | |
| return anime_img | |
| # cv2.imwrite(save_path, anime_img[..., ::-1]) | |
| # print(f"Anime image saved to {save_path}") | |
| def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)): | |
| ''' | |
| Read all images from img_dir, transform and write the result | |
| to dest_dir | |
| ''' | |
| os.makedirs(dest_dir, exist_ok=True) | |
| files = os.listdir(img_dir) | |
| files = [f for f in files if self.is_valid_file(f)] | |
| print(f'Found {len(files)} images in {img_dir}') | |
| if max_images: | |
| files = files[:max_images] | |
| for fname in tqdm(files): | |
| image = cv2.imread(os.path.join(img_dir, fname))[:,:,::-1] | |
| image = resize_image(image) | |
| anime_img = self.transform(image)[0] | |
| ext = fname.split('.')[-1] | |
| fname = fname.replace(f'.{ext}', '') | |
| cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1]) | |
| # def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0): | |
| # end = end or None | |
| # # if not os.path.isfile(input_path): | |
| # # raise FileNotFoundError(f'{input_path} does not exist') | |
| # # output_dir = "/".join(output_path.split("/")[:-1]) | |
| # # os.makedirs(output_dir, exist_ok=True) | |
| # # is_gg_drive = '/drive/' in output_path | |
| # # temp_file = '' | |
| # # if is_gg_drive: | |
| # # temp_file = f'tmp_anime.{output_path.split(".")[-1]}' | |
| # def transform_and_write(frames, count, writer): | |
| # anime_images = self.transform(frames) | |
| # for i in range(count): | |
| # img = np.clip(anime_images[i], 0, 255).astype(np.uint8) | |
| # writer.write(img) | |
| # video_capture = cv2.VideoCapture(input_path) | |
| # frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| # frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # fps = int(video_capture.get(cv2.CAP_PROP_FPS)) | |
| # frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # if start or end: | |
| # start_frame = int(start * fps) | |
| # end_frame = int(end * fps) if end else frame_count | |
| # video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| # frame_count = end_frame - start_frame | |
| # video_writer = cv2.VideoWriter( | |
| # output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height)) | |
| # print(f'Transforming video {input_path}, {frame_count} frames, size: ({frame_width}, {frame_height})') | |
| # batch_shape = (batch_size, frame_height, frame_width, 3) | |
| # frames = np.zeros(batch_shape, dtype=np.uint8) | |
| # frame_idx = 0 | |
| # try: | |
| # for _ in tqdm(range(frame_count)): | |
| # ret, frame = video_capture.read() | |
| # if not ret: | |
| # break | |
| # frames[frame_idx] = frame | |
| # frame_idx += 1 | |
| # if frame_idx == batch_size: | |
| # transform_and_write(frames, frame_idx, video_writer) | |
| # frame_idx = 0 | |
| # except Exception as e: | |
| # print(e) | |
| # finally: | |
| # video_capture.release() | |
| # video_writer.release() | |
| # # if temp_file: | |
| # # shutil.move(temp_file, output_path) | |
| # # print(f'Animation video saved to {output_path}') | |
| # def transform_video1(self, video, batch_size, start, end): | |
| # #end = end or None | |
| # # if not os.path.isfile(input_path): | |
| # # raise FileNotFoundError(f'{input_path} does not exist') | |
| # # output_dir = "/".join(output_path.split("/")[:-1]) | |
| # # os.makedirs(output_dir, exist_ok=True) | |
| # # is_gg_drive = '/drive/' in output_path | |
| # # temp_file = '' | |
| # # if is_gg_drive: | |
| # # temp_file = f'tmp_anime.{output_path.split(".")[-1]}' | |
| # # def transform_and_save(self, frames, count): | |
| # # transformed_frames = [] | |
| # # anime_images = self.transform(frames) | |
| # # for i in range(count): | |
| # # img = np.clip(anime_images[i], 0, 255).astype(np.uint8) | |
| # # transformed_frames.append(img) | |
| # # return transformed_frames | |
| # def transform_and_write(frames, count, video_buffer): | |
| # anime_images = self.transform(frames) | |
| # for i in range(count): | |
| # img = np.clip(anime_images[i], 0, 255).astype(np.uint8) | |
| # success, encoded_image = cv2.imencode('.jpg', img) | |
| # if success: | |
| # video_buffer.append(encoded_image.tobytes()) | |
| # video_capture = cv2.VideoCapture(video) | |
| # frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| # frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # fps = int(video_capture.get(cv2.CAP_PROP_FPS)) | |
| # frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # print(f'Transforming video {frame_count} frames, size: ({frame_width}, {frame_height})') | |
| # if start or end: | |
| # start_frame = int(start * fps) | |
| # end_frame = int(end * fps) if end else frame_count | |
| # video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| # frame_count = end_frame - start_frame | |
| # # frame_count = len(video_frames) | |
| # # transformed_video_frames = [] | |
| # video_buffer = [] | |
| # # batch_shape = (batch_size) + video_frames[0].shape | |
| # # frames = np.zeros(batch_shape, dtype=np.uint8) | |
| # # frame_idx = 0 | |
| # batch_shape = (batch_size, frame_height, frame_width, 3) | |
| # frames = np.zeros(batch_shape, dtype=np.uint8) | |
| # frame_idx = 0 | |
| # try: | |
| # for _ in range(frame_count): | |
| # ret, frame = video_capture.read() | |
| # if not ret: | |
| # break | |
| # frames[frame_idx] = frame | |
| # frame_idx += 1 | |
| # if frame_idx == batch_size: | |
| # transform_and_write(frames, frame_idx, video_buffer) | |
| # frame_idx = 0 | |
| # except Exception as e: | |
| # print(e) | |
| # finally: | |
| # video_capture.release() | |
| # return video_buffer | |
| def preprocess_images(self, images): | |
| ''' | |
| Preprocess image for inference | |
| @Arguments: | |
| - images: np.ndarray | |
| @Returns | |
| - images: torch.tensor | |
| ''' | |
| images = images.astype(np.float32) | |
| # Normalize to [-1, 1] | |
| images = normalize_input(images) | |
| images = torch.from_numpy(images) | |
| images = images.to(self.device) | |
| # Add batch dim | |
| if len(images.shape) == 3: | |
| images = images.unsqueeze(0) | |
| # channel first | |
| images = images.permute(0, 3, 1, 2) | |
| return images | |
| def is_valid_file(fname): | |
| ext = fname.split('.')[-1] | |
| return ext in VALID_FORMATS | |