Akis Giannoukos commited on
Commit
d044724
·
1 Parent(s): 628be0b

Disable global Torch compile/dynamo settings to prevent cudagraph assertion errors and remove the deprecated safe generation function for improved code clarity.

Browse files
Files changed (1) hide show
  1. app.py +7 -37
app.py CHANGED
@@ -1,4 +1,9 @@
1
  import os
 
 
 
 
 
2
  import json
3
  import re
4
  import time
@@ -91,39 +96,6 @@ def get_textgen_pipeline():
91
  )
92
  return _gen_pipe
93
 
94
- def _safe_hf_generate(pipe, prompt: str, **gen_kwargs):
95
- """Call HF generate pipeline with best-effort fallbacks to avoid TorchDynamo/Inductor issues."""
96
- try:
97
- return pipe(prompt, **gen_kwargs)
98
- except Exception:
99
- # Best-effort: disable dynamo via env and retry once
100
- try:
101
- os.environ["TORCHDYNAMO_DISABLE"] = "1"
102
- os.environ["TORCH_COMPILE_DISABLE"] = "1"
103
- os.environ["TORCHINDUCTOR_FREEZE"] = "1"
104
- except Exception:
105
- pass
106
- try:
107
- # Also disable cudagraphs if available
108
- try:
109
- import torch._inductor.config as _inductor_cfg # type: ignore
110
- _inductor_cfg.triton.cudagraphs = False
111
- except Exception:
112
- pass
113
- return pipe(prompt, **gen_kwargs)
114
- except Exception:
115
- # Final fallback: CPU pipeline generation
116
- try:
117
- from transformers import pipeline as hf_pipeline
118
- cpu_pipe = hf_pipeline(
119
- task="text-generation",
120
- model=pipe.model,
121
- tokenizer=pipe.tokenizer,
122
- device=-1,
123
- )
124
- return cpu_pipe(prompt, **gen_kwargs)
125
- except Exception:
126
- raise
127
 
128
 
129
  def set_current_model_id(new_model_id: str) -> str:
@@ -381,8 +353,7 @@ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
381
  import torch._dynamo as _dynamo # type: ignore
382
  except Exception:
383
  _dynamo = None
384
- gen = _safe_hf_generate(
385
- pipe,
386
  prompt,
387
  max_new_tokens=96,
388
  temperature=0.7,
@@ -428,8 +399,7 @@ def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str,
428
  import torch._dynamo as _dynamo # type: ignore
429
  except Exception:
430
  _dynamo = None
431
- gen = _safe_hf_generate(
432
- pipe,
433
  prompt,
434
  max_new_tokens=256,
435
  temperature=0.0,
 
1
  import os
2
+
3
+ # Disable torch compile/dynamo globally to avoid cudagraph assertion errors
4
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
5
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
6
+
7
  import json
8
  import re
9
  import time
 
96
  )
97
  return _gen_pipe
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
  def set_current_model_id(new_model_id: str) -> str:
 
353
  import torch._dynamo as _dynamo # type: ignore
354
  except Exception:
355
  _dynamo = None
356
+ gen = pipe(
 
357
  prompt,
358
  max_new_tokens=96,
359
  temperature=0.7,
 
399
  import torch._dynamo as _dynamo # type: ignore
400
  except Exception:
401
  _dynamo = None
402
+ gen = pipe(
 
403
  prompt,
404
  max_new_tokens=256,
405
  temperature=0.0,