Spaces:
Running
Running
from model import FluxParams, Flux | |
def build_model(version='base'): | |
if version == 'base': | |
params=FluxParams( | |
in_channels=32, | |
vec_in_dim=768, | |
context_in_dim=4096, | |
hidden_size=768, | |
mlp_ratio=4.0, | |
num_heads=16, | |
depth=12, | |
depth_single_blocks=24, | |
axes_dim=[16, 16, 16], | |
theta=10_000, | |
qkv_bias=True, | |
guidance_embed=True, | |
) | |
elif version == 'small': | |
params=FluxParams( | |
in_channels=32, | |
vec_in_dim=768, | |
context_in_dim=4096, | |
hidden_size=512, | |
mlp_ratio=4.0, | |
num_heads=16, | |
depth=8, | |
depth_single_blocks=16, | |
axes_dim=[8, 12, 12], | |
theta=10_000, | |
qkv_bias=True, | |
guidance_embed=True, | |
) | |
elif version == 'large': | |
params=FluxParams( | |
in_channels=32, | |
vec_in_dim=768, | |
context_in_dim=4096, | |
hidden_size=1024, | |
mlp_ratio=4.0, | |
num_heads=16, | |
depth=12, | |
depth_single_blocks=24, | |
axes_dim=[16, 24, 24], | |
theta=10_000, | |
qkv_bias=True, | |
guidance_embed=True, | |
) | |
else: | |
params=FluxParams( | |
in_channels=32, | |
vec_in_dim=768, | |
context_in_dim=4096, | |
hidden_size=1408, | |
mlp_ratio=4.0, | |
num_heads=16, | |
depth=16, | |
depth_single_blocks=32, | |
axes_dim=[16, 36, 36], | |
theta=10_000, | |
qkv_bias=True, | |
guidance_embed=True, | |
) | |
model = Flux(params) | |
return model |