vilarin commited on
Commit
7cb9567
1 Parent(s): a936635

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -8
app.py CHANGED
@@ -6,15 +6,16 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
6
  import gradio as gr
7
  from threading import Thread
8
 
9
- MODEL_LIST = ["meta-llama/Meta-Llama-3.1-8B-Instruct"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = os.environ.get("MODEL_ID")
12
 
13
- TITLE = "<h1><center>Meta-Llama3.1-8B</center></h1>"
14
 
15
  PLACEHOLDER = """
16
  <center>
17
- <p>Hi! How can I help you today?</p>
 
18
  </center>
19
  """
20
 
@@ -33,16 +34,26 @@ h3 {
33
 
34
  device = "cuda" # for GPU usage or "cpu" for CPU usage
35
 
36
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
 
 
 
 
 
37
 
38
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
39
- model = AutoModelForCausalLM.from_pretrained(
40
- MODEL,
 
 
 
 
 
41
  torch_dtype=torch.bfloat16,
42
  device_map="auto",
43
  quantization_config=quantization_config)
44
 
45
- @spaces.GPU()
46
  def stream_chat(
47
  message: str,
48
  history: list,
@@ -52,6 +63,7 @@ def stream_chat(
52
  top_p: float = 1.0,
53
  top_k: int = 20,
54
  penalty: float = 1.2,
 
55
  ):
56
  print(f'message: {message}')
57
  print(f'history: {history}')
@@ -67,6 +79,11 @@ def stream_chat(
67
 
68
  conversation.append({"role": "user", "content": message})
69
 
 
 
 
 
 
70
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
71
 
72
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
@@ -101,7 +118,7 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
101
  fn=stream_chat,
102
  chatbot=chatbot,
103
  fill_height=True,
104
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
105
  additional_inputs=[
106
  gr.Textbox(
107
  value="You are a helpful assistant",
@@ -148,6 +165,12 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
148
  label="Repetition penalty",
149
  render=False,
150
  ),
 
 
 
 
 
 
151
  ],
152
  examples=[
153
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
 
6
  import gradio as gr
7
  from threading import Thread
8
 
9
+ MODEL_LIST = ["meta-llama/Meta-Llama-3.1-8B-Instruct", "meta-llama/Meta-Llama-3.1-70B-Instruct"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = os.environ.get("MODEL_ID")
12
 
13
+ TITLE = "<h1><center>Meta-Llama3.1-Chat</center></h1>"
14
 
15
  PLACEHOLDER = """
16
  <center>
17
+ <p>😊Hi! How can I help you today?</p><br>
18
+ <p>✨Select Meta-Llama3.1-8B/70B in Advanced Options</p>
19
  </center>
20
  """
21
 
 
34
 
35
  device = "cuda" # for GPU usage or "cpu" for CPU usage
36
 
37
+ quantization_config = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_compute_dtype=torch.bfloat16,
40
+ bnb_4bit_use_double_quant=True,
41
+ bnb_4bit_quant_type= "nf4")
42
+
43
 
44
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
45
+ model_8b = AutoModelForCausalLM.from_pretrained(
46
+ MODEL_LIST[0],
47
+ torch_dtype=torch.bfloat16,
48
+ device_map="auto",
49
+ quantization_config=quantization_config)
50
+ model_70b = AutoModelForCausalLM.from_pretrained(
51
+ MODEL_LIST[1],
52
  torch_dtype=torch.bfloat16,
53
  device_map="auto",
54
  quantization_config=quantization_config)
55
 
56
+ @spaces.GPU(duration=120)
57
  def stream_chat(
58
  message: str,
59
  history: list,
 
63
  top_p: float = 1.0,
64
  top_k: int = 20,
65
  penalty: float = 1.2,
66
+ choice: str = "Meta-Llama-3.1-8B"
67
  ):
68
  print(f'message: {message}')
69
  print(f'history: {history}')
 
79
 
80
  conversation.append({"role": "user", "content": message})
81
 
82
+ if choice == "Meta-Llama-3.1-8B":
83
+ model = model_8b
84
+ else:
85
+ model = model_70b
86
+
87
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
88
 
89
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
118
  fn=stream_chat,
119
  chatbot=chatbot,
120
  fill_height=True,
121
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Advanced Options", open=False, render=False),
122
  additional_inputs=[
123
  gr.Textbox(
124
  value="You are a helpful assistant",
 
165
  label="Repetition penalty",
166
  render=False,
167
  ),
168
+ gr.Radio(
169
+ ["Meta-Llama-3.1-8B", "Meta-Llama-3.1-70B"],
170
+ value="Meta-Llama-3.1-8B",
171
+ label="Load Model",
172
+ render=False,
173
+ ),
174
  ],
175
  examples=[
176
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],