Spaces:
Runtime error
Runtime error
File size: 4,679 Bytes
bf15361 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import argparse
import os
import torch
from data_loader.loader import ContentData
from models.unet import UNetModel
from diffusers import AutoencoderKL
from models.diffusion import Diffusion
import torchvision
from parse_config import cfg, cfg_from_file, assert_and_infer_cfg
from utils.util import fix_seed
from PIL import Image
import torchvision.transforms as transforms
class OneDMInference:
def __init__(self, model_path, cfg_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.device = device
# Load config
cfg_from_file(cfg_path)
assert_and_infer_cfg()
fix_seed(cfg.TRAIN.SEED)
# Initialize models
self.unet = self._initialize_unet(model_path)
self.vae = self._initialize_vae()
self.diffusion = Diffusion(device=self.device)
self.content_loader = ContentData()
# Define transform
self.transform = transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor()
])
def _initialize_unet(self, model_path):
unet = UNetModel(
in_channels=cfg.MODEL.IN_CHANNELS,
model_channels=cfg.MODEL.EMB_DIM,
out_channels=cfg.MODEL.OUT_CHANNELS,
num_res_blocks=cfg.MODEL.NUM_RES_BLOCKS,
attention_resolutions=(1,1),
channel_mult=(1, 1),
num_heads=cfg.MODEL.NUM_HEADS,
context_dim=cfg.MODEL.EMB_DIM
).to(self.device)
# Load model with weights_only=True
unet.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
unet.eval()
return unet
def _initialize_vae(self):
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
vae = vae.to(self.device)
vae.requires_grad_(False)
return vae
def _load_image(self, image_path):
image = Image.open(image_path)
image_tensor = self.transform(image)
return image_tensor
def generate(self, text, style_path, laplace_path, output_dir,
sample_method='ddim', sampling_timesteps=50, eta=0.0):
"""
Generate handwritten text with the specified style
"""
# Load style and laplace images
style_input = self._load_image(style_path).unsqueeze(0).to(self.device)
laplace = self._load_image(laplace_path).unsqueeze(0).to(self.device)
# Prepare text reference
text_ref = self.content_loader.get_content(text)
text_ref = text_ref.to(self.device).repeat(1, 1, 1, 1)
# Initialize noise
x = torch.randn((text_ref.shape[0], 4, style_input.shape[2]//8,
(text_ref.shape[1]*32)//8)).to(self.device)
# Generate image
if sample_method == 'ddim':
sampled_images = self.diffusion.ddim_sample(
self.unet, self.vae, style_input.shape[0],
x, style_input, laplace, text_ref,
sampling_timesteps, eta
)
elif sample_method == 'ddpm':
sampled_images = self.diffusion.ddpm_sample(
self.unet, self.vae, style_input.shape[0],
x, style_input, laplace, text_ref
)
# Save generated image
os.makedirs(output_dir, exist_ok=True)
output_paths = []
for idx, image in enumerate(sampled_images):
im = torchvision.transforms.ToPILImage()(image)
image = im.convert("L")
output_path = os.path.join(output_dir, f"{text}_{idx}.png")
image.save(output_path)
output_paths.append(output_path)
return output_paths
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', required=True, help='Path to the One-DM model checkpoint')
parser.add_argument('--cfg_path', required=True, help='Path to the config file')
parser.add_argument('--text', required=True, help='Text to generate')
parser.add_argument('--style_path', required=True, help='Path to style image')
parser.add_argument('--laplace_path', required=True, help='Path to laplace image')
parser.add_argument('--output_dir', required=True, help='Output directory')
args = parser.parse_args()
model = OneDMInference(args.model_path, args.cfg_path)
output_paths = model.generate(
args.text,
args.style_path,
args.laplace_path,
args.output_dir
)
print(f"Generated images saved at: {output_paths}")
if __name__ == "__main__":
main() |