Update app.py
Browse files
app.py
CHANGED
@@ -2,11 +2,11 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
|
5 |
-
# Load the
|
6 |
-
model_name = "
|
7 |
model = AutoModelForCausalLM.from_pretrained(
|
8 |
model_name,
|
9 |
-
torch_dtype=
|
10 |
device_map="auto"
|
11 |
)
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
@@ -37,34 +37,29 @@ def chat(user_input):
|
|
37 |
# Append user message to the conversation history
|
38 |
messages.append({"role": "user", "content": user_input})
|
39 |
|
40 |
-
# Prepare input for the model
|
41 |
-
|
42 |
-
messages,
|
43 |
-
tokenize=False,
|
44 |
-
add_generation_prompt=True
|
45 |
-
)
|
46 |
-
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
47 |
|
48 |
# Generate a response using the model
|
49 |
try:
|
|
|
50 |
generated_ids = model.generate(
|
51 |
**model_inputs,
|
52 |
-
max_new_tokens=
|
|
|
|
|
|
|
|
|
53 |
)
|
54 |
-
generated_ids =
|
55 |
-
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
56 |
-
]
|
57 |
-
|
58 |
-
# Decode the response
|
59 |
-
response_content = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
60 |
|
61 |
except Exception as e:
|
62 |
-
|
63 |
|
64 |
# Store assistant response in the chat history
|
65 |
-
messages.append({"role": "assistant", "content":
|
66 |
|
67 |
-
return messages,
|
68 |
return messages, ""
|
69 |
|
70 |
# Gradio Interface
|
@@ -96,12 +91,12 @@ with gr.Blocks() as demo:
|
|
96 |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
|
97 |
|
98 |
# Connect the buttons to their respective functions
|
99 |
-
output_message = gr.Textbox(label="Output Message"
|
100 |
submit_btn.click(submit_questionnaire, inputs=[name, age, location, gender, ethnicity, height, weight,
|
101 |
style_preference, color_palette, everyday_style], outputs=output_message)
|
102 |
|
103 |
-
reset_btn.click(reset_chat, outputs=[chatbox
|
104 |
-
user_input.submit(chat, inputs=user_input, outputs=[chatbox, user_input]) #
|
105 |
|
106 |
# Run the app
|
107 |
demo.launch()
|
|
|
2 |
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
|
5 |
+
# Load the Zephyr-7B model
|
6 |
+
model_name = "HuggingFaceH4/zephyr-7b-beta"
|
7 |
model = AutoModelForCausalLM.from_pretrained(
|
8 |
model_name,
|
9 |
+
torch_dtype=torch.bfloat16,
|
10 |
device_map="auto"
|
11 |
)
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
37 |
# Append user message to the conversation history
|
38 |
messages.append({"role": "user", "content": user_input})
|
39 |
|
40 |
+
# Prepare input for the model using chat template
|
41 |
+
chat_input = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
# Generate a response using the model
|
44 |
try:
|
45 |
+
model_inputs = tokenizer(chat_input, return_tensors="pt").to(model.device)
|
46 |
generated_ids = model.generate(
|
47 |
**model_inputs,
|
48 |
+
max_new_tokens=256,
|
49 |
+
do_sample=True,
|
50 |
+
temperature=0.7,
|
51 |
+
top_k=50,
|
52 |
+
top_p=0.95
|
53 |
)
|
54 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
except Exception as e:
|
57 |
+
response = f"Error: {str(e)}"
|
58 |
|
59 |
# Store assistant response in the chat history
|
60 |
+
messages.append({"role": "assistant", "content": response})
|
61 |
|
62 |
+
return messages, response
|
63 |
return messages, ""
|
64 |
|
65 |
# Gradio Interface
|
|
|
91 |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
|
92 |
|
93 |
# Connect the buttons to their respective functions
|
94 |
+
output_message = gr.Textbox(label="Output Message")
|
95 |
submit_btn.click(submit_questionnaire, inputs=[name, age, location, gender, ethnicity, height, weight,
|
96 |
style_preference, color_palette, everyday_style], outputs=output_message)
|
97 |
|
98 |
+
reset_btn.click(reset_chat, outputs=[chatbox]) # Reset chat
|
99 |
+
user_input.submit(chat, inputs=user_input, outputs=[chatbox, user_input]) # Connect chat input
|
100 |
|
101 |
# Run the app
|
102 |
demo.launch()
|