Yavar / app.py
code5ecure's picture
Update app.py
12a3d1e verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
from sentence_transformers import SentenceTransformer
import faiss
# Disable torch.compile to avoid meta device issues
torch._dynamo.config.suppress_errors = True
torch.set_default_dtype(torch.float32)
# Set device explicitly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load Persian GPT-2 model and tokenizer
model_name = "HooshvareLab/gpt2-fa"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set pad_token to eos_token to fix padding issue
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
# Differential Privacy parameters
epsilon = 1.0 # Privacy budget
delta = 1e-5 # Privacy parameter
sensitivity = 1.0 # Sensitivity of the query
apply_dp = False # Toggle differential privacy in inference (set to True to enable)
# Simple memory for conversation history
conversation_history = []
# RAG components
embedder = None
index = None
texts = []
# Load training data from training_data.txt in the root directory
def load_training_data():
global texts
try:
with open("training_data.txt", "r", encoding="utf-8") as file:
texts = [line.strip() for line in file if line.strip()]
print(f"Loaded {len(texts)} training examples from training_data.txt")
return texts
except FileNotFoundError:
print("Error: training_data.txt not found in the root directory.")
return []
except Exception as e:
print(f"Error reading training_data.txt: {e}")
return []
# Build RAG index
def build_rag_index(texts):
global embedder, index
try:
embedder = SentenceTransformer('sentence-transformers/paraphrase-xlm-r-multilingual-v1', device='cpu') # Better for conversational Persian
embeddings = embedder.encode(texts, convert_to_tensor=True, batch_size=8).cpu().numpy() # Smaller batch size
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
print("RAG index built successfully")
return embedder, index
except Exception as e:
print(f"Error building RAG index: {e}")
return None, None
# Initialize model and RAG (no fine-tuning)
def train_model():
global texts, embedder, index
texts = load_training_data()
if not texts:
print("No training data available. Skipping RAG index build.")
return
# Build RAG index
build_rag_index(texts)
print("Using pretrained Persian GPT-2 model without fine-tuning.")
def add_noise(tensor, sensitivity, epsilon, delta):
"""Add Laplace noise for differential privacy."""
scale = sensitivity / epsilon
noise = np.random.laplace(0, scale, tensor.shape)
return tensor + torch.tensor(noise, dtype=tensor.dtype, device=tensor.device)
def update_model(user_input, response):
"""Update conversation history."""
global conversation_history
conversation_history.append({"user": user_input, "bot": response})
if len(conversation_history) > 100: # Limit history size
conversation_history.pop(0)
return f"Learning from: {user_input} -> {response}"
def chat(message, history):
# Set model to evaluation mode for inference
model.eval()
# RAG retrieval
context = ""
if embedder and index:
try:
query_emb = embedder.encode(message, convert_to_tensor=True).cpu().numpy()
D, I = index.search(query_emb, k=10) # Increased k for better context
retrieved = [texts[i] for i in I[0] if i >= 0 and i < len(texts)]
context = "\n".join(retrieved)
except Exception as e:
print(f"Error in RAG retrieval: {e}")
# Include conversation history (last 3 exchanges)
history_context = "\n".join([f"User: {h['user']} -> Bot: {h['bot']}" for h in conversation_history[-3:]]) if conversation_history else ""
# Prepare prompt with context and history
prompt = f"شما یک چت‌بات فارسی مفید و دوستانه هستید. فقط به سؤال کاربر پاسخ کوتاه و مرتبط بدهید و از اطلاعات زمینه فقط برای کمک به پاسخ استفاده کنید:\nContext: {context}\nHistory: {history_context}\nUser: {message}\nBot:"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
# Generate response with model using beam search
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=150,
num_beams=10,
no_repeat_ngram_size=2,
temperature=0.8, # Slightly increased for better diversity
top_p=0.9, # Added for better response quality
early_stopping=True,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Apply differential privacy noise to logits (optional)
if apply_dp:
logits = model(**inputs).logits
noisy_logits = add_noise(logits, sensitivity, epsilon, delta)
response_ids = torch.argmax(noisy_logits, dim=-1)
response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
# Update conversation history
update_model(message, response)
return response
# Initialize model and RAG (no fine-tuning)
train_model()
# Gradio interface
iface = gr.ChatInterface(
fn=chat,
title="Persian GPT-2 Chatbot with RAG",
description="Chat with pretrained Persian GPT-2 model using training_data.txt as RAG knowledge base."
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)