Portx commited on
Commit
a89e2e5
·
verified ·
1 Parent(s): 8290e1d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -2
handler.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
  import os
6
  import base64
7
 
8
- #run("pip install flash-attn --no-build-isolation", shell=True, check=True)
9
  run("pip install --upgrade pip", shell=True, check=True)
10
  run("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124", shell=True, check=True)
11
 
@@ -14,12 +14,25 @@ run("pip install torch torchvision torchaudio --extra-index-url https://download
14
 
15
 
16
 
 
 
 
 
 
 
 
 
 
 
17
  from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
18
 
19
  model_id = "ibm-granite/granite-vision-3.2-2b"
20
 
21
  bnb_config = BitsAndBytesConfig(
22
  load_in_4bit=True,
 
 
 
23
  llm_int8_skip_modules=["vision_tower", "lm_head"],
24
  llm_int8_enable_fp32_cpu_offload=True
25
  )
@@ -69,7 +82,8 @@ class PromptSet:
69
  class EndpointHandler():
70
  def __init__(self, path=""):
71
  self.model=AutoModelForVision2Seq.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16,
72
- quantization_config=bnb_config)
 
73
  self.processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
74
 
75
  def __call__(self, data):
 
5
  import os
6
  import base64
7
 
8
+ run("pip install flash-attn --no-build-isolation", shell=True, check=True)
9
  run("pip install --upgrade pip", shell=True, check=True)
10
  run("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124", shell=True, check=True)
11
 
 
14
 
15
 
16
 
17
+
18
+ try:
19
+ import flash_attn
20
+ print("FlashAttention is installed")
21
+ USE_FLASH_ATTENTION = True
22
+ except ImportError:
23
+ print("FlashAttention is not installed")
24
+ USE_FLASH_ATTENTION = False
25
+
26
+
27
  from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
28
 
29
  model_id = "ibm-granite/granite-vision-3.2-2b"
30
 
31
  bnb_config = BitsAndBytesConfig(
32
  load_in_4bit=True,
33
+ bnb_4bit_use_double_quant=True,
34
+ bnb_4bit_quant_type="nf4",
35
+ bnb_4bit_compute_dtype=torch.bfloat16,
36
  llm_int8_skip_modules=["vision_tower", "lm_head"],
37
  llm_int8_enable_fp32_cpu_offload=True
38
  )
 
82
  class EndpointHandler():
83
  def __init__(self, path=""):
84
  self.model=AutoModelForVision2Seq.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16,
85
+ quantization_config=bnb_config,
86
+ _attn_implementation="flash_attention_2" if USE_FLASH_ATTENTION else None,)
87
  self.processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
88
 
89
  def __call__(self, data):