import torch import torch.nn as nn from enum import Enum import base64 import json from io import BytesIO from PIL import Image import requests import re class ImageType(Enum): REAL_UP_L = 0 REAL_UP_R = 1 REAL_DOWN_R = 2 REAL_DOWN_L = 3 FAKE = 4 def crop_image_part(image: torch.Tensor, part: ImageType) -> torch.Tensor: size = image.shape[2] // 2 if part == ImageType.REAL_UP_L: return image[:, :, :size, :size] elif part == ImageType.REAL_UP_R: return image[:, :, :size, size:] elif part == ImageType.REAL_DOWN_L: return image[:, :, size:, :size] elif part == ImageType.REAL_DOWN_R: return image[:, :, size:, size:] else: raise ValueError('invalid part') def init_weights(module: nn.Module): if isinstance(module, nn.Conv2d): torch.nn.init.normal_(module.weight, 0.0, 0.02) if isinstance(module, nn.BatchNorm2d): torch.nn.init.normal_(module.weight, 1.0, 0.02) module.bias.data.fill_(0) def load_image_from_local(image_path, image_resize=None): image = Image.open(image_path) if isinstance(image_resize, tuple): image = image.resize(image_resize) return image def load_image_from_url(image_url, rgba_mode=False, image_resize=None, default_image=None): try: image = Image.open(requests.get(image_url, stream=True).raw) if rgba_mode: image = image.convert("RGBA") if isinstance(image_resize, tuple): image = image.resize(image_resize) except Exception as e: image = None if default_image: image = load_image_from_local(default_image, image_resize=image_resize) return image def image_to_base64(image_array): buffered = BytesIO() image_array.save(buffered, format="PNG") image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return f"data:image/png;base64, {image_b64}"