chatbot-api / app.py
eabybabu's picture
Optimized chatbot for speed with CUDA & quantization
5cfc235
import os
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# βœ… Load API Token Securely from Hugging Face Secrets
HF_TOKEN = os.getenv("HF_TOKEN")
# βœ… Load model and tokenizer (Optimized for Speed)
MODEL_NAME = "eabybabu/chatbot_model" # Replace with your actual model name
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
# βœ… Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# βœ… Load model and apply quantization (if available)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=HF_TOKEN).to(device)
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) # Apply quantization
# βœ… Function to generate chatbot responses with chat history
def chatbot_response(user_input, chat_history):
try:
chat_context = " ".join([f"User: {msg}\nChatbot: {resp}" for msg, resp in chat_history])
prompt = f"{chat_context}\nUser: {user_input}\nChatbot:"
# Encode input
inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generate response (Faster with CUDA & Optimized Settings)
outputs = model.generate(
inputs,
max_length=200,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=1.5,
num_return_sequences=1
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = ". ".join(set(response.split(". "))) # Prevent repetition
chat_history.append((user_input, response))
return chat_history, ""
except Exception as e:
return chat_history, f"Error: {str(e)}"
# βœ… Create Gradio UI with Chat History
with gr.Blocks() as demo:
gr.Markdown("# πŸ€– Cybersecurity Chatbot")
gr.Markdown("Ask me anything about ISO 27001, ISO 27005, MITRE ATT&CK, and more!")
chatbot = gr.Chatbot(label="Chat History")
user_input = gr.Textbox(label="Type your question:")
submit_btn = gr.Button("Ask Chatbot")
chat_history = gr.State([])
submit_btn.click(chatbot_response, inputs=[user_input, chat_history], outputs=[chatbot, user_input])
# βœ… Launch the Gradio app
demo.launch()