aeb56 commited on
Commit
2f60fd7
·
1 Parent(s): 74fe23d

Fix flash attention error by patching model config to use eager attention

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -40,6 +40,13 @@ class ChatBot:
40
  )
41
 
42
  self.model.eval()
 
 
 
 
 
 
 
43
  self.loaded = True
44
 
45
  # Get GPU distribution info
@@ -85,7 +92,7 @@ class ChatBot:
85
  inputs = self.tokenizer(prompt, return_tensors="pt")
86
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
87
 
88
- # Generate
89
  with torch.no_grad():
90
  outputs = self.model.generate(
91
  **inputs,
@@ -94,6 +101,7 @@ class ChatBot:
94
  top_p=top_p,
95
  do_sample=temperature > 0,
96
  pad_token_id=self.tokenizer.eos_token_id,
 
97
  )
98
 
99
  # Decode
 
40
  )
41
 
42
  self.model.eval()
43
+
44
+ # Patch model config to avoid flash attention issues
45
+ if hasattr(self.model.config, '_attn_implementation'):
46
+ self.model.config._attn_implementation = "eager"
47
+ if hasattr(self.model.config, 'attn_implementation'):
48
+ self.model.config.attn_implementation = "eager"
49
+
50
  self.loaded = True
51
 
52
  # Get GPU distribution info
 
92
  inputs = self.tokenizer(prompt, return_tensors="pt")
93
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
94
 
95
+ # Generate with explicit attention settings
96
  with torch.no_grad():
97
  outputs = self.model.generate(
98
  **inputs,
 
101
  top_p=top_p,
102
  do_sample=temperature > 0,
103
  pad_token_id=self.tokenizer.eos_token_id,
104
+ use_cache=True, # Enable KV caching
105
  )
106
 
107
  # Decode