import torch import torch.nn as nn from enum import Enum import base64 import json from io import BytesIO from PIL import Image import requests import re from copy import deepcopy class ImageType(Enum): REAL_UP_L = 0 REAL_UP_R = 1 REAL_DOWN_R = 2 REAL_DOWN_L = 3 FAKE = 4 def crop_image_part(image: torch.Tensor, part: ImageType) -> torch.Tensor: size = image.shape[2] // 2 if part == ImageType.REAL_UP_L: return image[:, :, :size, :size] elif part == ImageType.REAL_UP_R: return image[:, :, :size, size:] elif part == ImageType.REAL_DOWN_L: return image[:, :, size:, :size] elif part == ImageType.REAL_DOWN_R: return image[:, :, size:, size:] else: raise ValueError('invalid part') def init_weights(module: nn.Module): if isinstance(module, nn.Conv2d): torch.nn.init.normal_(module.weight, 0.0, 0.02) if isinstance(module, nn.BatchNorm2d): torch.nn.init.normal_(module.weight, 1.0, 0.02) module.bias.data.fill_(0) def load_image_from_local(image_path, image_resize=None): image = Image.open(image_path) if isinstance(image_resize, tuple): image = image.resize(image_resize) return image def load_image_from_url(image_url, rgba_mode=False, image_resize=None, default_image=None): try: image = Image.open(requests.get(image_url, stream=True).raw) if rgba_mode: image = image.convert("RGBA") if isinstance(image_resize, tuple): image = image.resize(image_resize) except Exception as e: image = None if default_image: image = load_image_from_local(default_image, image_resize=image_resize) return image def image_to_base64(image_array): buffered = BytesIO() image_array.save(buffered, format="PNG") image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return f"data:image/png;base64, {image_b64}" def copy_G_params(model): flatten = deepcopy(list(p.data for p in model.parameters())) return flatten def load_params(model, new_param): for p, new_p in zip(model.parameters(), new_param): p.data.copy_(new_p)