Spaces:
Running
on
T4
Running
on
T4
fix: add timeout to prevent unexcepted tasl
Browse files- app.py +3 -0
- modules/model.py +11 -0
app.py
CHANGED
@@ -60,6 +60,7 @@ samplers_k_diffusion = [
|
|
60 |
# ]
|
61 |
|
62 |
start_time = time.time()
|
|
|
63 |
|
64 |
scheduler = DDIMScheduler.from_pretrained(
|
65 |
base_model,
|
@@ -257,6 +258,8 @@ def inference(
|
|
257 |
"sampler_opt": sampler_opt,
|
258 |
"pww_state": state,
|
259 |
"pww_attn_weight": g_strength,
|
|
|
|
|
260 |
}
|
261 |
|
262 |
if img_input is not None:
|
|
|
60 |
# ]
|
61 |
|
62 |
start_time = time.time()
|
63 |
+
timeout = 120
|
64 |
|
65 |
scheduler = DDIMScheduler.from_pretrained(
|
66 |
base_model,
|
|
|
258 |
"sampler_opt": sampler_opt,
|
259 |
"pww_state": state,
|
260 |
"pww_attn_weight": g_strength,
|
261 |
+
"start_time": start_time,
|
262 |
+
"timeout": timeout,
|
263 |
}
|
264 |
|
265 |
if img_input is not None:
|
modules/model.py
CHANGED
@@ -6,6 +6,7 @@ import re
|
|
6 |
from collections import defaultdict
|
7 |
from typing import List, Optional, Union
|
8 |
|
|
|
9 |
import k_diffusion
|
10 |
import numpy as np
|
11 |
import PIL
|
@@ -446,6 +447,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
446 |
pww_attn_weight=1.0,
|
447 |
sampler_name="",
|
448 |
sampler_opt={},
|
|
|
|
|
449 |
scale_ratio=8.0,
|
450 |
):
|
451 |
sampler = self.get_scheduler(sampler_name)
|
@@ -504,6 +507,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
504 |
|
505 |
def model_fn(x, sigma):
|
506 |
|
|
|
|
|
|
|
507 |
latent_model_input = torch.cat([x] * 2)
|
508 |
weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
|
509 |
encoder_state = {
|
@@ -617,6 +623,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
617 |
pww_attn_weight=1.0,
|
618 |
sampler_name="",
|
619 |
sampler_opt={},
|
|
|
|
|
620 |
):
|
621 |
sampler = self.get_scheduler(sampler_name)
|
622 |
# 1. Check inputs. Raise error if not correct
|
@@ -667,6 +675,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
667 |
|
668 |
def model_fn(x, sigma):
|
669 |
|
|
|
|
|
|
|
670 |
latent_model_input = torch.cat([x] * 2)
|
671 |
weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
|
672 |
encoder_state = {
|
|
|
6 |
from collections import defaultdict
|
7 |
from typing import List, Optional, Union
|
8 |
|
9 |
+
import time
|
10 |
import k_diffusion
|
11 |
import numpy as np
|
12 |
import PIL
|
|
|
447 |
pww_attn_weight=1.0,
|
448 |
sampler_name="",
|
449 |
sampler_opt={},
|
450 |
+
start_time=-1,
|
451 |
+
timeout=180,
|
452 |
scale_ratio=8.0,
|
453 |
):
|
454 |
sampler = self.get_scheduler(sampler_name)
|
|
|
507 |
|
508 |
def model_fn(x, sigma):
|
509 |
|
510 |
+
if start_time > 0 and timeout > 0:
|
511 |
+
assert (time.time() - start_time) < timeout, "inference process timed out"
|
512 |
+
|
513 |
latent_model_input = torch.cat([x] * 2)
|
514 |
weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
|
515 |
encoder_state = {
|
|
|
623 |
pww_attn_weight=1.0,
|
624 |
sampler_name="",
|
625 |
sampler_opt={},
|
626 |
+
start_time=-1,
|
627 |
+
timeout=180,
|
628 |
):
|
629 |
sampler = self.get_scheduler(sampler_name)
|
630 |
# 1. Check inputs. Raise error if not correct
|
|
|
675 |
|
676 |
def model_fn(x, sigma):
|
677 |
|
678 |
+
if start_time > 0 and timeout > 0:
|
679 |
+
assert (time.time() - start_time) < timeout, "inference process timed out"
|
680 |
+
|
681 |
latent_model_input = torch.cat([x] * 2)
|
682 |
weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
|
683 |
encoder_state = {
|