|
import argparse |
|
import json |
|
import os |
|
import subprocess |
|
import sys |
|
import webbrowser |
|
from datetime import datetime |
|
from threading import Lock |
|
|
|
import uvicorn |
|
from fastapi import BackgroundTasks, FastAPI, Request |
|
from fastapi.responses import FileResponse |
|
from fastapi.staticfiles import StaticFiles |
|
|
|
import toml |
|
|
|
app = FastAPI() |
|
|
|
lock = Lock() |
|
|
|
|
|
sf = StaticFiles(directory="frontend/dist") |
|
_o_fr = sf.file_response |
|
def _hooked_file_response(*args, **kwargs): |
|
full_path = args[0] |
|
r = _o_fr(*args, **kwargs) |
|
if full_path.endswith(".js"): |
|
r.media_type = "application/javascript" |
|
elif full_path.endswith(".css"): |
|
r.media_type = "text/css" |
|
return r |
|
sf.file_response = _hooked_file_response |
|
|
|
parser = argparse.ArgumentParser(description="GUI for training network") |
|
parser.add_argument("--port", type=int, default=28000, help="Port to run the server on") |
|
|
|
def run_train(toml_path: str): |
|
print(f"Training started with config file / 训练开始,使用配置文件: {toml_path}") |
|
args = [ |
|
"accelerate", "launch", "--num_cpu_threads_per_process", "8", |
|
"./sd-scripts/train_network.py", |
|
"--config_file", toml_path, |
|
] |
|
try: |
|
result = subprocess.run(args, shell=True, env=os.environ) |
|
if result.returncode != 0: |
|
print(f"Training failed / 训练失败") |
|
else: |
|
print(f"Training finished / 训练完成") |
|
except Exception as e: |
|
print(f"An error occurred when training / 创建训练进程时出现致命错误: {e}") |
|
finally: |
|
lock.release() |
|
|
|
|
|
@app.post("/api/run") |
|
async def create_toml_file(request: Request, background_tasks: BackgroundTasks): |
|
acquired = lock.acquire(blocking=False) |
|
|
|
if not acquired: |
|
print("Training is already running / 已有正在进行的训练") |
|
return {"status": "fail", "detail": "Training is already running"} |
|
|
|
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") |
|
toml_file = f"toml/{timestamp}.toml" |
|
toml_data = await request.body() |
|
j = json.loads(toml_data.decode("utf-8")) |
|
with open(toml_file, "w") as f: |
|
f.write(toml.dumps(j)) |
|
background_tasks.add_task(run_train, toml_file) |
|
return {"status": "success"} |
|
|
|
@app.middleware("http") |
|
async def add_cache_control_header(request, call_next): |
|
response = await call_next(request) |
|
response.headers["Cache-Control"] = "max-age=0" |
|
return response |
|
|
|
@app.get("/") |
|
async def index(): |
|
return FileResponse("./frontend/dist/index.html") |
|
|
|
|
|
app.mount("/", sf, name="static") |
|
|
|
if __name__ == "__main__": |
|
args, _ = parser.parse_known_args() |
|
print(f"Server started at http://127.0.0.1:{args.port}") |
|
if sys.platform == "win32": |
|
|
|
os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1" |
|
|
|
webbrowser.open(f"http://127.0.0.1:{args.port}") |
|
uvicorn.run(app, host="127.0.0.1", port=28000, log_level="error") |
|
|