sounar commited on
Commit
d16c5f3
·
verified ·
1 Parent(s): b37e8c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -28
app.py CHANGED
@@ -3,15 +3,7 @@ import torch
3
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  import gradio as gr
5
  from PIL import Image
6
-
7
- # First, let's check if flash-attn is installed
8
- try:
9
- import flash_attn
10
- FLASH_ATTN_AVAILABLE = True
11
- except ImportError:
12
- FLASH_ATTN_AVAILABLE = False
13
- print("Flash Attention is not installed. Using default attention mechanism.")
14
- print("To install Flash Attention, run: pip install flash-attn --no-build-isolation")
15
 
16
  # Get API token from environment variable
17
  api_token = os.getenv("HF_TOKEN").strip()
@@ -24,23 +16,15 @@ bnb_config = BitsAndBytesConfig(
24
  bnb_4bit_compute_dtype=torch.float16
25
  )
26
 
27
- # Initialize model with conditional Flash Attention
28
- model_args = {
29
- "quantization_config": bnb_config,
30
- "device_map": "auto",
31
- "torch_dtype": torch.float16,
32
- "trust_remote_code": True,
33
- "token": api_token
34
- }
35
-
36
- # Only add flash attention if available
37
- if FLASH_ATTN_AVAILABLE:
38
- model_args["attn_implementation"] = "flash_attention_2"
39
-
40
  # Initialize model and tokenizer
41
  model = AutoModel.from_pretrained(
42
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
43
- **model_args
 
 
 
 
 
44
  )
45
 
46
  tokenizer = AutoTokenizer.from_pretrained(
@@ -100,11 +84,6 @@ demo = gr.Interface(
100
 
101
  # Launch the Gradio app
102
  if __name__ == "__main__":
103
- # Print installation instructions if Flash Attention is not available
104
- if not FLASH_ATTN_AVAILABLE:
105
- print("\nTo enable Flash Attention 2 for better performance, please install it using:")
106
- print("pip install flash-attn --no-build-isolation")
107
-
108
  demo.launch(
109
  share=True,
110
  server_name="0.0.0.0",
 
3
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  import gradio as gr
5
  from PIL import Image
6
+ from torchvision.transforms import ToTensor
 
 
 
 
 
 
 
 
7
 
8
  # Get API token from environment variable
9
  api_token = os.getenv("HF_TOKEN").strip()
 
16
  bnb_4bit_compute_dtype=torch.float16
17
  )
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Initialize model and tokenizer
20
  model = AutoModel.from_pretrained(
21
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
22
+ quantization_config=bnb_config,
23
+ device_map="auto",
24
+ torch_dtype=torch.float16,
25
+ trust_remote_code=True,
26
+ attn_implementation="flash_attention_2",
27
+ token=api_token
28
  )
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(
 
84
 
85
  # Launch the Gradio app
86
  if __name__ == "__main__":
 
 
 
 
 
87
  demo.launch(
88
  share=True,
89
  server_name="0.0.0.0",