Spaces:
Running
Running
File size: 1,838 Bytes
afe1a07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from model import FluxParams, Flux
def build_model(version='base'):
if version == 'base':
params=FluxParams(
in_channels=32,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=768,
mlp_ratio=4.0,
num_heads=16,
depth=12,
depth_single_blocks=24,
axes_dim=[16, 16, 16],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
)
elif version == 'small':
params=FluxParams(
in_channels=32,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=512,
mlp_ratio=4.0,
num_heads=16,
depth=8,
depth_single_blocks=16,
axes_dim=[8, 12, 12],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
)
elif version == 'large':
params=FluxParams(
in_channels=32,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=1024,
mlp_ratio=4.0,
num_heads=16,
depth=12,
depth_single_blocks=24,
axes_dim=[16, 24, 24],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
)
else:
params=FluxParams(
in_channels=32,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=1408,
mlp_ratio=4.0,
num_heads=16,
depth=16,
depth_single_blocks=32,
axes_dim=[16, 36, 36],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
)
model = Flux(params)
return model |