Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# app.py
|
2 |
|
|
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
from transformers import (
|
@@ -9,26 +10,44 @@ from transformers import (
|
|
9 |
)
|
10 |
|
11 |
###############################################################################
|
12 |
-
#
|
13 |
###############################################################################
|
14 |
-
MODEL_ID = "FrameRateTech/DamageScan-llama-8b-instruct-merged"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
###############################################################################
|
17 |
-
#
|
18 |
###############################################################################
|
|
|
19 |
tokenizer = AutoTokenizer.from_pretrained(
|
20 |
MODEL_ID,
|
21 |
use_fast=False,
|
22 |
trust_remote_code=True
|
23 |
)
|
24 |
|
25 |
-
#
|
26 |
if getattr(tokenizer, "pad_token_id", None) is None:
|
27 |
-
# If no pad token is defined, fall back to eos_token_id
|
28 |
tokenizer.pad_token_id = getattr(tokenizer, "eos_token_id", None)
|
29 |
|
30 |
###############################################################################
|
31 |
-
#
|
32 |
###############################################################################
|
33 |
model = AutoModelForCausalLM.from_pretrained(
|
34 |
MODEL_ID,
|
@@ -39,7 +58,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
39 |
model.eval()
|
40 |
|
41 |
###############################################################################
|
42 |
-
#
|
43 |
###############################################################################
|
44 |
default_gen_config = GenerationConfig(
|
45 |
temperature=0.7,
|
@@ -50,17 +69,9 @@ default_gen_config = GenerationConfig(
|
|
50 |
)
|
51 |
|
52 |
###############################################################################
|
53 |
-
#
|
54 |
###############################################################################
|
55 |
def messages_to_prompt(messages):
|
56 |
-
"""
|
57 |
-
Convert a list of chat messages (role/content) into a text prompt.
|
58 |
-
Example of messages:
|
59 |
-
[
|
60 |
-
{"role": "user", "content": "..."},
|
61 |
-
{"role": "assistant", "content": "..."}
|
62 |
-
]
|
63 |
-
"""
|
64 |
conversation = ""
|
65 |
for msg in messages:
|
66 |
if msg["role"] == "user":
|
@@ -70,16 +81,10 @@ def messages_to_prompt(messages):
|
|
70 |
return conversation
|
71 |
|
72 |
###############################################################################
|
73 |
-
#
|
74 |
###############################################################################
|
75 |
def predict(messages, temperature, top_p, max_new_tokens):
|
76 |
-
"""
|
77 |
-
Takes the current conversation (messages) and returns an updated list
|
78 |
-
of messages with the model's response appended.
|
79 |
-
"""
|
80 |
prompt_text = messages_to_prompt(messages) + "Assistant:"
|
81 |
-
|
82 |
-
# Create a GenerationConfig on the fly with user settings
|
83 |
gen_config = GenerationConfig(
|
84 |
temperature=temperature,
|
85 |
top_p=top_p,
|
@@ -87,31 +92,22 @@ def predict(messages, temperature, top_p, max_new_tokens):
|
|
87 |
repetition_penalty=1.1,
|
88 |
max_new_tokens=max_new_tokens,
|
89 |
)
|
90 |
-
|
91 |
with torch.no_grad():
|
92 |
-
# Tokenize and move to GPU
|
93 |
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
|
94 |
-
# Generate
|
95 |
outputs = model.generate(**inputs, generation_config=gen_config)
|
96 |
-
# Decode the output
|
97 |
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
98 |
-
|
99 |
-
# The newly generated text is the difference between the prompt and the full output
|
100 |
generated_reply = full_text[len(prompt_text):].strip()
|
101 |
-
|
102 |
-
# Append the model's reply
|
103 |
messages.append({"role": "assistant", "content": generated_reply})
|
104 |
return messages
|
105 |
|
106 |
###############################################################################
|
107 |
-
#
|
108 |
###############################################################################
|
109 |
with gr.Blocks() as demo:
|
110 |
gr.Markdown("<h1 align='center'>DamageScan 8B Instruct Chatbot</h1>")
|
111 |
|
112 |
with gr.Row():
|
113 |
with gr.Column():
|
114 |
-
# "type='messages'" => each item is a dict {"role": ..., "content": ...}
|
115 |
chatbot = gr.Chatbot(label="Chat History", type="messages")
|
116 |
with gr.Column():
|
117 |
gr.Markdown("### Generation Settings")
|
@@ -125,17 +121,14 @@ with gr.Blocks() as demo:
|
|
125 |
minimum=64, maximum=2048, value=256, step=64, label="Max New Tokens"
|
126 |
)
|
127 |
|
128 |
-
# User input box
|
129 |
user_input = gr.Textbox(lines=1, label="Your Message", placeholder="Type here...")
|
130 |
send_btn = gr.Button("Send")
|
131 |
|
132 |
-
# Function that appends the user's input to the chat, calls the model, and returns the updated chat
|
133 |
def user_submit(message_history, user_text, temp, top_p, max_tokens):
|
134 |
message_history.append({"role": "user", "content": user_text})
|
135 |
updated_messages = predict(message_history, temp, top_p, max_tokens)
|
136 |
return updated_messages, ""
|
137 |
|
138 |
-
# Link the button and the textbox "Enter" key to user_submit
|
139 |
send_btn.click(
|
140 |
user_submit,
|
141 |
inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider],
|
@@ -147,5 +140,4 @@ with gr.Blocks() as demo:
|
|
147 |
outputs=[chatbot, user_input],
|
148 |
)
|
149 |
|
150 |
-
# Launch the Gradio interface with a queue for concurrency
|
151 |
demo.queue().launch()
|
|
|
1 |
# app.py
|
2 |
|
3 |
+
import transformers
|
4 |
import torch
|
5 |
import gradio as gr
|
6 |
from transformers import (
|
|
|
10 |
)
|
11 |
|
12 |
###############################################################################
|
13 |
+
# Debug Print Section
|
14 |
###############################################################################
|
15 |
+
MODEL_ID = "FrameRateTech/DamageScan-llama-8b-instruct-merged"
|
16 |
+
print("Transformers version:", transformers.__version__)
|
17 |
+
|
18 |
+
# Attempt to load the tokenizer once just to see what happens
|
19 |
+
try:
|
20 |
+
tokenizer_test = AutoTokenizer.from_pretrained(
|
21 |
+
MODEL_ID,
|
22 |
+
use_fast=False,
|
23 |
+
trust_remote_code=True
|
24 |
+
)
|
25 |
+
print("tokenizer_test =", tokenizer_test)
|
26 |
+
print("type(tokenizer_test) =", type(tokenizer_test))
|
27 |
+
except Exception as e:
|
28 |
+
print("AutoTokenizer failed with exception:", e)
|
29 |
+
raise e
|
30 |
+
|
31 |
+
# If it's returning False, bail out early so we don't crash below
|
32 |
+
if tokenizer_test is False:
|
33 |
+
raise ValueError("AutoTokenizer returned False, meaning it failed to load properly.")
|
34 |
|
35 |
###############################################################################
|
36 |
+
# 1. Load Tokenizer
|
37 |
###############################################################################
|
38 |
+
# Now load the real tokenizer for your app
|
39 |
tokenizer = AutoTokenizer.from_pretrained(
|
40 |
MODEL_ID,
|
41 |
use_fast=False,
|
42 |
trust_remote_code=True
|
43 |
)
|
44 |
|
45 |
+
# If `tokenizer` is not False, set pad_token_id if needed
|
46 |
if getattr(tokenizer, "pad_token_id", None) is None:
|
|
|
47 |
tokenizer.pad_token_id = getattr(tokenizer, "eos_token_id", None)
|
48 |
|
49 |
###############################################################################
|
50 |
+
# 2. Load Model
|
51 |
###############################################################################
|
52 |
model = AutoModelForCausalLM.from_pretrained(
|
53 |
MODEL_ID,
|
|
|
58 |
model.eval()
|
59 |
|
60 |
###############################################################################
|
61 |
+
# 3. Default Generation Settings
|
62 |
###############################################################################
|
63 |
default_gen_config = GenerationConfig(
|
64 |
temperature=0.7,
|
|
|
69 |
)
|
70 |
|
71 |
###############################################################################
|
72 |
+
# 4. Helper: Convert Chatbot Messages to Prompt
|
73 |
###############################################################################
|
74 |
def messages_to_prompt(messages):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
conversation = ""
|
76 |
for msg in messages:
|
77 |
if msg["role"] == "user":
|
|
|
81 |
return conversation
|
82 |
|
83 |
###############################################################################
|
84 |
+
# 5. Generation Function
|
85 |
###############################################################################
|
86 |
def predict(messages, temperature, top_p, max_new_tokens):
|
|
|
|
|
|
|
|
|
87 |
prompt_text = messages_to_prompt(messages) + "Assistant:"
|
|
|
|
|
88 |
gen_config = GenerationConfig(
|
89 |
temperature=temperature,
|
90 |
top_p=top_p,
|
|
|
92 |
repetition_penalty=1.1,
|
93 |
max_new_tokens=max_new_tokens,
|
94 |
)
|
|
|
95 |
with torch.no_grad():
|
|
|
96 |
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
|
|
|
97 |
outputs = model.generate(**inputs, generation_config=gen_config)
|
|
|
98 |
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
99 |
generated_reply = full_text[len(prompt_text):].strip()
|
|
|
|
|
100 |
messages.append({"role": "assistant", "content": generated_reply})
|
101 |
return messages
|
102 |
|
103 |
###############################################################################
|
104 |
+
# 6. Build the Gradio Interface
|
105 |
###############################################################################
|
106 |
with gr.Blocks() as demo:
|
107 |
gr.Markdown("<h1 align='center'>DamageScan 8B Instruct Chatbot</h1>")
|
108 |
|
109 |
with gr.Row():
|
110 |
with gr.Column():
|
|
|
111 |
chatbot = gr.Chatbot(label="Chat History", type="messages")
|
112 |
with gr.Column():
|
113 |
gr.Markdown("### Generation Settings")
|
|
|
121 |
minimum=64, maximum=2048, value=256, step=64, label="Max New Tokens"
|
122 |
)
|
123 |
|
|
|
124 |
user_input = gr.Textbox(lines=1, label="Your Message", placeholder="Type here...")
|
125 |
send_btn = gr.Button("Send")
|
126 |
|
|
|
127 |
def user_submit(message_history, user_text, temp, top_p, max_tokens):
|
128 |
message_history.append({"role": "user", "content": user_text})
|
129 |
updated_messages = predict(message_history, temp, top_p, max_tokens)
|
130 |
return updated_messages, ""
|
131 |
|
|
|
132 |
send_btn.click(
|
133 |
user_submit,
|
134 |
inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider],
|
|
|
140 |
outputs=[chatbot, user_input],
|
141 |
)
|
142 |
|
|
|
143 |
demo.queue().launch()
|