tangzhy commited on
Commit
a9764a0
·
verified ·
1 Parent(s): 54becda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -12,6 +12,9 @@ from transformers import (
12
  TextIteratorStreamer,
13
  )
14
 
 
 
 
15
  DESCRIPTION = """\
16
  # ORLM LLaMA-3-8B
17
 
@@ -24,18 +27,26 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
24
 
25
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
 
 
 
 
 
 
 
 
27
  model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
28
  tokenizer = AutoTokenizer.from_pretrained(model_id)
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_id,
31
  device_map="auto",
32
- quantization_config=BitsAndBytesConfig(load_in_8bit=True),
 
33
  )
34
  model.config.sliding_window = 4096
35
  model.eval()
36
 
37
 
38
- @spaces.GPU(duration=100)
39
  def generate(
40
  message: str,
41
  chat_history: list[tuple[str, str]],
@@ -63,6 +74,7 @@ def generate(
63
  temperature=temperature,
64
  num_beams=1,
65
  repetition_penalty=repetition_penalty,
 
66
  )
67
  t = Thread(target=model.generate, kwargs=generate_kwargs)
68
  t.start()
 
12
  TextIteratorStreamer,
13
  )
14
 
15
+ import subprocess
16
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
17
+
18
  DESCRIPTION = """\
19
  # ORLM LLaMA-3-8B
20
 
 
27
 
28
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
 
30
+ quantization_config = BitsAndBytesConfig(
31
+ load_in_4bit=True,
32
+ bnb_4bit_compute_dtype=torch.bfloat16,
33
+ bnb_4bit_use_double_quant=True,
34
+ bnb_4bit_quant_type= "nf4")
35
+ # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
36
+
37
  model_id = "CardinalOperations/ORLM-LLaMA-3-8B"
38
  tokenizer = AutoTokenizer.from_pretrained(model_id)
39
  model = AutoModelForCausalLM.from_pretrained(
40
  model_id,
41
  device_map="auto",
42
+ attn_implementation="flash_attention_2",
43
+ # quantization_config=quantization_config,
44
  )
45
  model.config.sliding_window = 4096
46
  model.eval()
47
 
48
 
49
+ @spaces.GPU(duration=120)
50
  def generate(
51
  message: str,
52
  chat_history: list[tuple[str, str]],
 
74
  temperature=temperature,
75
  num_beams=1,
76
  repetition_penalty=repetition_penalty,
77
+ eos_token_id=[tok.eos_token_id],
78
  )
79
  t = Thread(target=model.generate, kwargs=generate_kwargs)
80
  t.start()