robinhad commited on
Commit
e0ce993
·
verified ·
1 Parent(s): 6c60cba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import os
2
  import subprocess
 
 
 
3
  import threading
4
 
5
  # subprocess.check_call([os.sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
@@ -22,6 +25,7 @@ def load_model():
22
  MODEL_ID,
23
  torch_dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
24
  device_map="auto", # if device == "cuda" else None,
 
25
  ) # .cuda()
26
  print(f"Selected device:", device)
27
  return model, tokenizer, device
 
1
  import os
2
  import subprocess
3
+
4
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
5
+
6
  import threading
7
 
8
  # subprocess.check_call([os.sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
 
25
  MODEL_ID,
26
  torch_dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
27
  device_map="auto", # if device == "cuda" else None,
28
+ attn_implementation="flash_attention_2",
29
  ) # .cuda()
30
  print(f"Selected device:", device)
31
  return model, tokenizer, device