|
from dataclasses import dataclass, field |
|
from typing import List, Literal |
|
|
|
import torch |
|
import os |
|
|
|
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "False") == "True" |
|
|
|
|
|
@dataclass |
|
class Config: |
|
""" |
|
The configuration for the API. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
host: str = "0.0.0.0" |
|
port: int = 9090 |
|
workers: int = 1 |
|
|
|
|
|
|
|
|
|
mode: Literal["txt2img", "img2img"] = "txt2img" |
|
|
|
model_id_or_path: str = "stabilityai/sd-turbo" |
|
|
|
lcm_lora_id: str = None |
|
|
|
vae_id: str = "madebyollin/taesd" |
|
|
|
device: torch.device = torch.device("cuda") |
|
|
|
dtype: torch.dtype = torch.float16 |
|
|
|
acceleration: Literal["none", "xformers", "tensorrt"] = "xformers" |
|
|
|
|
|
|
|
|
|
|
|
t_index_list: List[int] = field(default_factory=lambda: [0, 16, 32, 45]) |
|
|
|
warmup: int = 10 |
|
use_safety_checker: bool = SAFETY_CHECKER |
|
|