File size: 2,219 Bytes
2a7e546 054faf7 a150f0f 6fa0b52 42f8b67 2a7e546 d751051 6fa0b52 42f8b67 a150f0f 42f8b67 054faf7 a150f0f 054faf7 9bde8da aa6b13c 9bde8da 5961f34 d751051 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
import torch
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn as nn
torch_device = "cuda"
# test of text tokenizers
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
).to('cuda')
test1 = tokenizer(['aleppo pine tree, common oak tree'], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
#test3 = tokenizer([1.0, 0.0, .05], is_split_into_words=True, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
test1 = text_encoder(test1.input_ids.to('cuda'))[0]
test2 = tokenizer('dark green', padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
test2 = text_encoder(test2.input_ids.to('cuda'))[0]
unet = UNet2DConditionModel(sample_size=32, in_channels=4, out_channels=4, \
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32)
unet = unet.to('cuda', dtype=torch.float16)
# put float32 for the accumulation
for layer in unet.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
scheduler = DDPMScheduler(num_train_timesteps=20)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
vae = vae.to('cuda', dtype=torch.float16)
#pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
pipeline = RCTDiffusionPipeline.from_pretrained('rct_foliage_249')
output = pipeline(['(cabbage) pagoda tree'], ['(dark) green'], ['brown'])
output[0].save('out.png')
pipeline.save_pretrained('test')
print('test') |