vikhyatk commited on
Commit
60e7a28
1 Parent(s): 6c78975

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -6,6 +6,9 @@ import gradio as gr
6
  from threading import Thread
7
  from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM
8
 
 
 
 
9
  parser = argparse.ArgumentParser()
10
 
11
  if torch.cuda.is_available():
@@ -17,7 +20,8 @@ model_id = "vikhyatk/moondream2"
17
  revision = "2024-04-02"
18
  tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
19
  moondream = AutoModelForCausalLM.from_pretrained(
20
- model_id, trust_remote_code=True, revision=revision
 
21
  ).to(device=device, dtype=dtype)
22
  moondream.eval()
23
 
 
6
  from threading import Thread
7
  from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM
8
 
9
+ import subprocess
10
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
11
+
12
  parser = argparse.ArgumentParser()
13
 
14
  if torch.cuda.is_available():
 
20
  revision = "2024-04-02"
21
  tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
22
  moondream = AutoModelForCausalLM.from_pretrained(
23
+ model_id, trust_remote_code=True, revision=revision,
24
+ attn_implementation="flash_attention_2"
25
  ).to(device=device, dtype=dtype)
26
  moondream.eval()
27