|
from fastapi import FastAPI |
|
|
|
from config import args |
|
from device import device, torch_dtype |
|
from app_init import init_app |
|
from user_queue import user_data |
|
from util import get_pipeline_class |
|
|
|
|
|
print("DEVICE:", device) |
|
print("TORCH_DTYPE:", torch_dtype) |
|
print("PIPELINE:", args.pipeline) |
|
print("SAFETY_CHECKER:", args.safety_checker) |
|
print("TORCH_COMPILE:", args.torch_compile) |
|
|
|
|
|
app = FastAPI() |
|
|
|
pipeline_class = get_pipeline_class(args.pipeline) |
|
pipeline = pipeline_class(args, device, torch_dtype) |
|
init_app(app, user_data, args, pipeline) |
|
|