Spaces:
Running
on
Zero
Running
on
Zero
Add musubi-tuner integration: clone repository at startup and set default directory and repo URL
Browse files- .gitignore +1 -0
- app.py +54 -3
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/musubi-tuner
|
app.py
CHANGED
|
@@ -24,6 +24,12 @@ DEFAULT_DATASET_CONFIG = "/workspace/auto/dataset_QIE.toml"
|
|
| 24 |
DEFAULT_MODELS_ROOT = DEFAULT_MODELS_DIR # "/workspace/Qwen-Image_models"
|
| 25 |
WORKSPACE_AUTO_DIR = "/workspace/auto"
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
TRAINING_DIR = Path(__file__).resolve().parent
|
| 29 |
|
|
@@ -150,6 +156,49 @@ def _pick_shell() -> str:
|
|
| 150 |
raise RuntimeError("No POSIX shell found. Please install bash or sh.")
|
| 151 |
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
def run_training(
|
| 154 |
dataset_name: str,
|
| 155 |
caption: str,
|
|
@@ -274,10 +323,12 @@ def _startup_download_models() -> None:
|
|
| 274 |
|
| 275 |
|
| 276 |
if __name__ == "__main__":
|
| 277 |
-
# 1)
|
|
|
|
|
|
|
|
|
|
| 278 |
_startup_download_models()
|
| 279 |
|
| 280 |
-
#
|
| 281 |
ui = build_ui()
|
| 282 |
ui.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
|
| 283 |
-
|
|
|
|
| 24 |
DEFAULT_MODELS_ROOT = DEFAULT_MODELS_DIR # "/workspace/Qwen-Image_models"
|
| 25 |
WORKSPACE_AUTO_DIR = "/workspace/auto"
|
| 26 |
|
| 27 |
+
# musubi-tuner settings
|
| 28 |
+
DEFAULT_MUSUBI_TUNER_DIR = os.environ.get("MUSUBI_TUNER_DIR", "/musubi-tuner")
|
| 29 |
+
DEFAULT_MUSUBI_TUNER_REPO = os.environ.get(
|
| 30 |
+
"MUSUBI_TUNER_REPO", "https://github.com/kohya-ss/musubi-tuner.git"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
|
| 34 |
TRAINING_DIR = Path(__file__).resolve().parent
|
| 35 |
|
|
|
|
| 156 |
raise RuntimeError("No POSIX shell found. Please install bash or sh.")
|
| 157 |
|
| 158 |
|
| 159 |
+
def _is_git_repo(path: str) -> bool:
|
| 160 |
+
try:
|
| 161 |
+
out = subprocess.run(
|
| 162 |
+
["git", "-C", path, "rev-parse", "--is-inside-work-tree"],
|
| 163 |
+
capture_output=True,
|
| 164 |
+
text=True,
|
| 165 |
+
check=False,
|
| 166 |
+
)
|
| 167 |
+
return out.returncode == 0 and out.stdout.strip() == "true"
|
| 168 |
+
except Exception:
|
| 169 |
+
return False
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _startup_clone_musubi_tuner() -> None:
|
| 173 |
+
target = DEFAULT_MUSUBI_TUNER_DIR
|
| 174 |
+
repo = DEFAULT_MUSUBI_TUNER_REPO
|
| 175 |
+
parent = os.path.dirname(target.rstrip("/\\")) or "/"
|
| 176 |
+
try:
|
| 177 |
+
os.makedirs(parent, exist_ok=True)
|
| 178 |
+
except Exception:
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
if os.path.isdir(target) and _is_git_repo(target):
|
| 182 |
+
print(f"[QIE] musubi-tuner exists at {target}; pulling latest...")
|
| 183 |
+
try:
|
| 184 |
+
subprocess.run(["git", "-C", target, "fetch", "--all", "--prune"], check=False)
|
| 185 |
+
subprocess.run(["git", "-C", target, "pull", "--ff-only"], check=False)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print(f"[QIE] git pull failed: {e}")
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
if os.path.exists(target) and not _is_git_repo(target):
|
| 191 |
+
print(f"[QIE] Warning: {target} exists and is not a git repo. Skipping clone.")
|
| 192 |
+
return
|
| 193 |
+
|
| 194 |
+
print(f"[QIE] Cloning musubi-tuner into {target} from {repo} ...")
|
| 195 |
+
try:
|
| 196 |
+
subprocess.run(["git", "clone", "--depth", "1", repo, target], check=True)
|
| 197 |
+
print("[QIE] Clone completed.")
|
| 198 |
+
except subprocess.CalledProcessError as e:
|
| 199 |
+
print(f"[QIE] Clone failed: {e}")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
def run_training(
|
| 203 |
dataset_name: str,
|
| 204 |
caption: str,
|
|
|
|
| 323 |
|
| 324 |
|
| 325 |
if __name__ == "__main__":
|
| 326 |
+
# 1) Ensure musubi-tuner is cloned before anything else
|
| 327 |
+
_startup_clone_musubi_tuner()
|
| 328 |
+
|
| 329 |
+
# 2) Download models at startup (blocking by design)
|
| 330 |
_startup_download_models()
|
| 331 |
|
| 332 |
+
# 3) Launch Gradio app
|
| 333 |
ui = build_ui()
|
| 334 |
ui.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
|
|
|