doevent's picture
Upload 14 files
5004324 verified
raw
history blame
5.53 kB
import torch
from typing import Optional, Union
from huggingface_hub import hf_hub_download
from .sr_pipeline import KandiSuperResPipeline
from KandiSuperRes.model.unet import UNet
from KandiSuperRes.model.unet_sr import UNet as UNet_sr
from KandiSuperRes.movq import MoVQ
def get_sr_model(
device: Union[str, torch.device],
weights_path: Optional[str] = None,
dtype: Union[str, torch.dtype] = torch.float16
) -> (UNet_sr, Optional[dict], Optional[torch.Tensor]):
unet = UNet_sr(
init_channels=128,
model_channels=128,
num_channels=3,
time_embed_dim=512,
groups=32,
dim_mult=(1, 2, 4, 8),
num_resnet_blocks=(2,4,8,8),
add_cross_attention=(False, False, False, False),
add_self_attention=(False, False, False, False),
feature_pooling_type='attention',
lowres_cond =True
)
if weights_path:
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
try:
unet.load_state_dict(state_dict['unet'])
except:
unet.load_state_dict(state_dict)
unet.to(device=device, dtype=dtype).eval()
return unet
def get_T2I_unet(
device: Union[str, torch.device],
weights_path: Optional[str] = None,
dtype: Union[str, torch.dtype] = torch.float32,
) -> (UNet, Optional[torch.Tensor], Optional[dict]):
unet = UNet(
model_channels=384,
num_channels=4,
init_channels=192,
time_embed_dim=1536,
context_dim=4096,
groups=32,
head_dim=64,
expansion_ratio=4,
compression_ratio=2,
dim_mult=(1, 2, 4, 8),
num_blocks=(3, 3, 3, 3),
add_cross_attention=(False, True, True, True),
add_self_attention=(False, True, True, True),
)
null_embedding = None
if weights_path:
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
null_embedding = state_dict['null_embedding']
unet.load_state_dict(state_dict['unet'])
unet.to(device=device, dtype=dtype).eval()
return unet, null_embedding
def get_movq(
device: Union[str, torch.device],
weights_path: Optional[str] = None,
dtype: Union[str, torch.dtype] = torch.float32,
) -> MoVQ:
generator_config = {
'double_z': False,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 256,
'ch_mult': [1, 2, 2, 4],
'num_res_blocks': 2,
'attn_resolutions': [32],
'dropout': 0.0,
'tile_sample_min_size': 1024,
'tile_overlap_factor_enc': 0.0,
'tile_overlap_factor_dec': 0.25,
'use_tiling': True
}
movq = MoVQ(generator_config)
if weights_path:
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
movq.load_state_dict(state_dict)
movq.to(device=device, dtype=dtype).eval()
return movq
def get_SR_pipeline(
device: Union[str, torch.device],
fp16: bool = True,
flash: bool = True,
scale: int = 2,
cache_dir: str = '/tmp/KandiSuperRes/',
movq_path: str = None,
refiner_path: str = None,
unet_sr_path: str = None,
) -> KandiSuperResPipeline:
if flash:
if scale == 2:
device_map = {
'movq': device, 'refiner': device, 'sr_model': device
}
dtype = torch.float16 if fp16 else torch.float32
dtype_map = {
'movq': torch.float32, 'refiner': dtype, 'sr_model': dtype
}
if movq_path is None:
print('Download movq weights')
movq_path = hf_hub_download(
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
)
if refiner_path is None:
print('Download refiner weights')
refiner_path = hf_hub_download(
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_flash.pt', cache_dir=cache_dir
)
if unet_sr_path is None:
print('Download KandiSuperRes Flash weights')
unet_sr_path = hf_hub_download(
repo_id="ai-forever/KandiSuperRes", filename='KandiSuperRes_flash_x2.pt', cache_dir=cache_dir
)
sr_model = get_sr_model(device_map['sr_model'], unet_sr_path, dtype=dtype_map['sr_model'])
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
refiner, _ = get_T2I_unet(device_map['refiner'], refiner_path, dtype=dtype_map['refiner'])
return KandiSuperResPipeline(
scale, device_map, dtype_map, flash, sr_model, movq, refiner
)
else:
print('Flash model for x4 scale is not implemented.')
else:
if unet_sr_path is None:
if scale == 4:
unet_sr_path = hf_hub_download(
repo_id="ai-forever/KandiSuperRes", filename='KandiSuperRes.ckpt', cache_dir=cache_dir
)
elif scale == 2:
unet_sr_path = hf_hub_download(
repo_id="ai-forever/KandiSuperRes", filename='KandiSuperRes_x2.ckpt', cache_dir=cache_dir
)
dtype = torch.float16 if fp16 else torch.float32
sr_model = get_sr_model(device, unet_sr_path, dtype=dtype)
return KandiSuperResPipeline(scale, device, dtype, flash, sr_model)