basic_chatbot / app.py
arcsu1's picture
Fix response cleaning to remove special tokens
b02a875
from flask import Flask, jsonify, request, render_template
from flask_cors import CORS
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
app = Flask(__name__)
CORS(app)
# Global variables for model and tokenizer
MODEL_PATH = "./models/fine-tuned-gpt2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = None
model = None
def load_chatbot_model():
"""Load the chatbot model and tokenizer"""
global tokenizer, model
if model is None:
print(f"Loading chatbot model from {MODEL_PATH}...")
print(f"Using device: {device}")
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH)
model = GPT2LMHeadModel.from_pretrained(MODEL_PATH)
model.to(device)
print("Model loaded successfully!")
# Load model on startup
load_chatbot_model()
@app.route("/")
def index():
"""Serve the chat interface"""
return render_template('index.html')
@app.route("/api")
def root():
return jsonify({
"message": "Chatbot API",
"status": "running",
"model": "fine-tuned-gpt2",
"device": str(device)
})
@app.route("/health")
def health():
return jsonify({
"status": "healthy",
"model_loaded": model is not None,
"device": str(device)
})
@app.route("/chat", methods=["POST"])
def chat():
"""
Generate a chatbot response based on conversation history
"""
if model is None or tokenizer is None:
return jsonify({"error": "Model not loaded"}), 500
try:
data = request.get_json()
user_messages = data.get("user", [])
ai_messages = data.get("ai", [])
# Build conversation history
combined_prompt = ""
# Limit history to last 7 exchanges
user_msgs = user_messages[-7:] if len(user_messages) > 7 else user_messages
ai_msgs = ai_messages[-6:] if len(ai_messages) > 6 else ai_messages
# Add conversation history
for user_message, ai_message in zip(user_msgs[:-1], ai_msgs):
combined_prompt += f"<user> {user_message}{tokenizer.eos_token}<AI> {ai_message}{tokenizer.eos_token}"
# Add current message
if user_msgs:
combined_prompt += f"<user> {user_msgs[-1]}{tokenizer.eos_token}<AI>"
# Tokenize and generate
inputs = tokenizer.encode(combined_prompt, return_tensors="pt").to(device)
attention_mask = torch.ones(inputs.shape, device=device)
outputs = model.generate(
inputs,
max_new_tokens=50,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=2,
temperature=0.7,
top_k=50,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
attention_mask=attention_mask,
repetition_penalty=1.2
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the new AI response
# Split by <AI> and get the last response
if "<AI>" in response:
response = response.split("<AI>")[-1].strip()
# Remove any <user> tags if they appear (model might generate them)
if "<user>" in response:
response = response.split("<user>")[0].strip()
# Clean up any remaining special tokens
response = response.replace("<AI>", "").replace("<user>", "").strip()
# If empty response, provide a default
if not response:
response = "I'm not sure how to respond to that."
return jsonify({
"response": response,
"device": str(device)
})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=False)