nyanko7 commited on
Commit
49ba457
1 Parent(s): 9d68af7

fix: add timeout to prevent unexcepted tasl

Browse files
Files changed (2) hide show
  1. app.py +3 -0
  2. modules/model.py +11 -0
app.py CHANGED
@@ -60,6 +60,7 @@ samplers_k_diffusion = [
60
  # ]
61
 
62
  start_time = time.time()
 
63
 
64
  scheduler = DDIMScheduler.from_pretrained(
65
  base_model,
@@ -257,6 +258,8 @@ def inference(
257
  "sampler_opt": sampler_opt,
258
  "pww_state": state,
259
  "pww_attn_weight": g_strength,
 
 
260
  }
261
 
262
  if img_input is not None:
 
60
  # ]
61
 
62
  start_time = time.time()
63
+ timeout = 120
64
 
65
  scheduler = DDIMScheduler.from_pretrained(
66
  base_model,
 
258
  "sampler_opt": sampler_opt,
259
  "pww_state": state,
260
  "pww_attn_weight": g_strength,
261
+ "start_time": start_time,
262
+ "timeout": timeout,
263
  }
264
 
265
  if img_input is not None:
modules/model.py CHANGED
@@ -6,6 +6,7 @@ import re
6
  from collections import defaultdict
7
  from typing import List, Optional, Union
8
 
 
9
  import k_diffusion
10
  import numpy as np
11
  import PIL
@@ -446,6 +447,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
446
  pww_attn_weight=1.0,
447
  sampler_name="",
448
  sampler_opt={},
 
 
449
  scale_ratio=8.0,
450
  ):
451
  sampler = self.get_scheduler(sampler_name)
@@ -504,6 +507,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
504
 
505
  def model_fn(x, sigma):
506
 
 
 
 
507
  latent_model_input = torch.cat([x] * 2)
508
  weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
509
  encoder_state = {
@@ -617,6 +623,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
617
  pww_attn_weight=1.0,
618
  sampler_name="",
619
  sampler_opt={},
 
 
620
  ):
621
  sampler = self.get_scheduler(sampler_name)
622
  # 1. Check inputs. Raise error if not correct
@@ -667,6 +675,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
667
 
668
  def model_fn(x, sigma):
669
 
 
 
 
670
  latent_model_input = torch.cat([x] * 2)
671
  weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
672
  encoder_state = {
 
6
  from collections import defaultdict
7
  from typing import List, Optional, Union
8
 
9
+ import time
10
  import k_diffusion
11
  import numpy as np
12
  import PIL
 
447
  pww_attn_weight=1.0,
448
  sampler_name="",
449
  sampler_opt={},
450
+ start_time=-1,
451
+ timeout=180,
452
  scale_ratio=8.0,
453
  ):
454
  sampler = self.get_scheduler(sampler_name)
 
507
 
508
  def model_fn(x, sigma):
509
 
510
+ if start_time > 0 and timeout > 0:
511
+ assert (time.time() - start_time) < timeout, "inference process timed out"
512
+
513
  latent_model_input = torch.cat([x] * 2)
514
  weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
515
  encoder_state = {
 
623
  pww_attn_weight=1.0,
624
  sampler_name="",
625
  sampler_opt={},
626
+ start_time=-1,
627
+ timeout=180,
628
  ):
629
  sampler = self.get_scheduler(sampler_name)
630
  # 1. Check inputs. Raise error if not correct
 
675
 
676
  def model_fn(x, sigma):
677
 
678
+ if start_time > 0 and timeout > 0:
679
+ assert (time.time() - start_time) < timeout, "inference process timed out"
680
+
681
  latent_model_input = torch.cat([x] * 2)
682
  weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
683
  encoder_state = {