timdettmers commited on
Commit
cfde7ef
1 Parent(s): c0f8c05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -4,6 +4,7 @@ import datetime
4
  import os
5
  from threading import Event, Thread
6
  from uuid import uuid4
 
7
 
8
  import gradio as gr
9
  import requests
@@ -19,7 +20,8 @@ from transformers import (
19
 
20
 
21
  # model_name = "lmsys/vicuna-7b-delta-v1.1"
22
- model_name = "timdettmers/guanaco-33b-merged"
 
23
  max_new_tokens = 1536
24
 
25
  auth_token = os.getenv("HF_TOKEN", None)
@@ -28,12 +30,18 @@ print(f"Starting to load the model {model_name} into memory")
28
 
29
  m = AutoModelForCausalLM.from_pretrained(
30
  model_name,
31
- load_in_8bit=True,
 
 
 
 
 
32
  torch_dtype=torch.bfloat16,
33
  device_map={"": 0}
34
  )
 
35
  m.eval()
36
- tok = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
37
  tok.bos_token_id = 1
38
 
39
  stop_token_ids = [0]
@@ -172,7 +180,7 @@ with gr.Blocks(
172
  ) as demo:
173
  conversation_id = gr.State(get_uuid)
174
  gr.Markdown(
175
- """<h1><center>Guanaco-33b playground</center></h1>
176
  """
177
  )
178
  chatbot = gr.Chatbot().style(height=500)
 
4
  import os
5
  from threading import Event, Thread
6
  from uuid import uuid4
7
+ from peft import PeftModel
8
 
9
  import gradio as gr
10
  import requests
 
20
 
21
 
22
  # model_name = "lmsys/vicuna-7b-delta-v1.1"
23
+ model_name = "decapoda-research/llama-65b-hf"
24
+
25
  max_new_tokens = 1536
26
 
27
  auth_token = os.getenv("HF_TOKEN", None)
 
30
 
31
  m = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
+ quantization_config=transformers.BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ bnb_4bit_compute_dtype=torch.bfloat16,
36
+ bnb_4bit_use_double_quant=True,
37
+ bnb_4bit_quant_type='nf4' # {'fp4', 'nf4'}
38
+ ),
39
  torch_dtype=torch.bfloat16,
40
  device_map={"": 0}
41
  )
42
+ m = PeftModel.from_pretrained(m, 'timdettmers/guanaco-65b')
43
  m.eval()
44
+ tok = LlamaTokenizer.from_pretrained("decapoda-research/llama-65b-hf")
45
  tok.bos_token_id = 1
46
 
47
  stop_token_ids = [0]
 
180
  ) as demo:
181
  conversation_id = gr.State(get_uuid)
182
  gr.Markdown(
183
+ """<h1><center>Guanaco-65b playground</center></h1>
184
  """
185
  )
186
  chatbot = gr.Chatbot().style(height=500)