FrameRateTech commited on
Commit
a3a27cd
·
verified ·
1 Parent(s): addeff3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -24
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import torch
3
- from peft import PeftModel, PeftConfig
4
  from transformers import (
5
  AutoTokenizer,
6
  AutoModelForCausalLM,
@@ -10,34 +9,32 @@ from transformers import (
10
  # ---------------------------------------------------------------------
11
  # 1. Model Configuration
12
  # ---------------------------------------------------------------------
13
- ADAPTER_ID = "FrameRateTech/DamageScan-llama-8b-instruct-merged"
14
-
15
- # Load adapter config to find base model name
16
- peft_config = PeftConfig.from_pretrained(ADAPTER_ID)
17
- BASE_MODEL_ID = peft_config.base_model_name_or_path
18
 
19
  # ---------------------------------------------------------------------
20
  # 2. Load Tokenizer
21
  # ---------------------------------------------------------------------
22
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=False)
 
 
 
 
 
 
 
23
  if tokenizer.pad_token_id is None:
24
  tokenizer.pad_token_id = tokenizer.eos_token_id
25
 
26
  # ---------------------------------------------------------------------
27
- # 3. Load Base Model + LoRA Weights
28
  # ---------------------------------------------------------------------
29
- # If you need 8-bit to save VRAM, add load_in_8bit=True and device_map="auto"
30
- base_model = AutoModelForCausalLM.from_pretrained(
31
- BASE_MODEL_ID,
32
- torch_dtype=torch.float16,
33
- device_map="auto",
34
- )
35
-
36
- model = PeftModel.from_pretrained(
37
- base_model,
38
- ADAPTER_ID,
39
  torch_dtype=torch.float16,
40
  device_map="auto",
 
41
  )
42
 
43
  model.eval()
@@ -111,11 +108,10 @@ def predict(messages, temperature, top_p, max_new_tokens):
111
  # 7. Build the Gradio Interface
112
  # ---------------------------------------------------------------------
113
  with gr.Blocks() as demo:
114
- gr.Markdown("<h1 align='center'>FrameRateTech/DamageScan-llama-8b-instruct-merged Chatbot</h1>")
115
 
116
  with gr.Row():
117
  with gr.Column():
118
- # type="messages" => each message is a dict with 'role' and 'content'
119
  chatbot = gr.Chatbot(label="Chat History", type="messages")
120
  with gr.Column():
121
  gr.Markdown("### Generation Settings")
@@ -129,17 +125,14 @@ with gr.Blocks() as demo:
129
  minimum=64, maximum=2048, value=256, step=64, label="Max New Tokens"
130
  )
131
 
132
- # Chat state is stored in 'chatbot' since type="messages"
133
  user_input = gr.Textbox(lines=1, label="Your Message", placeholder="Type here...")
134
  send_btn = gr.Button("Send")
135
 
136
- # Append user input to chat, generate model reply, then clear input
137
  def user_submit(message_history, user_text, temp, top_p, max_tokens):
138
  message_history.append({"role": "user", "content": user_text})
139
  updated_messages = predict(message_history, temp, top_p, max_tokens)
140
  return updated_messages, ""
141
 
142
- # Send button or pressing Enter triggers user_submit
143
  send_btn.click(
144
  user_submit,
145
  inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider],
@@ -151,5 +144,4 @@ with gr.Blocks() as demo:
151
  outputs=[chatbot, user_input],
152
  )
153
 
154
- # Launch the Gradio app
155
  demo.queue().launch()
 
1
  import gradio as gr
2
  import torch
 
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForCausalLM,
 
9
  # ---------------------------------------------------------------------
10
  # 1. Model Configuration
11
  # ---------------------------------------------------------------------
12
+ MODEL_ID = "FrameRateTech/DamageScan-llama-8b-instruct-merged"
 
 
 
 
13
 
14
  # ---------------------------------------------------------------------
15
  # 2. Load Tokenizer
16
  # ---------------------------------------------------------------------
17
+ # For many LLaMA-based models, you often need use_fast=False and sometimes trust_remote_code=True
18
+ tokenizer = AutoTokenizer.from_pretrained(
19
+ MODEL_ID,
20
+ use_fast=False,
21
+ # trust_remote_code=True, # Uncomment if needed for custom code
22
+ )
23
+
24
+ # Ensure we have a valid pad token
25
  if tokenizer.pad_token_id is None:
26
  tokenizer.pad_token_id = tokenizer.eos_token_id
27
 
28
  # ---------------------------------------------------------------------
29
+ # 3. Load the Model
30
  # ---------------------------------------------------------------------
31
+ # If you want to load 8-bit weights for VRAM savings, set load_in_8bit=True
32
+ # and device_map="auto". Otherwise, below loads in FP16.
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ MODEL_ID,
 
 
 
 
 
 
35
  torch_dtype=torch.float16,
36
  device_map="auto",
37
+ # trust_remote_code=True, # Uncomment if needed for custom code
38
  )
39
 
40
  model.eval()
 
108
  # 7. Build the Gradio Interface
109
  # ---------------------------------------------------------------------
110
  with gr.Blocks() as demo:
111
+ gr.Markdown("<h1 align='center'>DamageScan 8B Instruct Chatbot</h1>")
112
 
113
  with gr.Row():
114
  with gr.Column():
 
115
  chatbot = gr.Chatbot(label="Chat History", type="messages")
116
  with gr.Column():
117
  gr.Markdown("### Generation Settings")
 
125
  minimum=64, maximum=2048, value=256, step=64, label="Max New Tokens"
126
  )
127
 
 
128
  user_input = gr.Textbox(lines=1, label="Your Message", placeholder="Type here...")
129
  send_btn = gr.Button("Send")
130
 
 
131
  def user_submit(message_history, user_text, temp, top_p, max_tokens):
132
  message_history.append({"role": "user", "content": user_text})
133
  updated_messages = predict(message_history, temp, top_p, max_tokens)
134
  return updated_messages, ""
135
 
 
136
  send_btn.click(
137
  user_submit,
138
  inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider],
 
144
  outputs=[chatbot, user_input],
145
  )
146
 
 
147
  demo.queue().launch()