Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import numpy as np | |
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| # Disable torch.compile to avoid meta device issues | |
| torch._dynamo.config.suppress_errors = True | |
| torch.set_default_dtype(torch.float32) | |
| # Set device explicitly | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load Persian GPT-2 model and tokenizer | |
| model_name = "HooshvareLab/gpt2-fa" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Set pad_token to eos_token to fix padding issue | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ).to(device) | |
| # Differential Privacy parameters | |
| epsilon = 1.0 # Privacy budget | |
| delta = 1e-5 # Privacy parameter | |
| sensitivity = 1.0 # Sensitivity of the query | |
| apply_dp = False # Toggle differential privacy in inference (set to True to enable) | |
| # Simple memory for conversation history | |
| conversation_history = [] | |
| # RAG components | |
| embedder = None | |
| index = None | |
| texts = [] | |
| # Load training data from training_data.txt in the root directory | |
| def load_training_data(): | |
| global texts | |
| try: | |
| with open("training_data.txt", "r", encoding="utf-8") as file: | |
| texts = [line.strip() for line in file if line.strip()] | |
| print(f"Loaded {len(texts)} training examples from training_data.txt") | |
| return texts | |
| except FileNotFoundError: | |
| print("Error: training_data.txt not found in the root directory.") | |
| return [] | |
| except Exception as e: | |
| print(f"Error reading training_data.txt: {e}") | |
| return [] | |
| # Build RAG index | |
| def build_rag_index(texts): | |
| global embedder, index | |
| try: | |
| embedder = SentenceTransformer('sentence-transformers/paraphrase-xlm-r-multilingual-v1', device='cpu') # Better for conversational Persian | |
| embeddings = embedder.encode(texts, convert_to_tensor=True, batch_size=8).cpu().numpy() # Smaller batch size | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings) | |
| print("RAG index built successfully") | |
| return embedder, index | |
| except Exception as e: | |
| print(f"Error building RAG index: {e}") | |
| return None, None | |
| # Initialize model and RAG (no fine-tuning) | |
| def train_model(): | |
| global texts, embedder, index | |
| texts = load_training_data() | |
| if not texts: | |
| print("No training data available. Skipping RAG index build.") | |
| return | |
| # Build RAG index | |
| build_rag_index(texts) | |
| print("Using pretrained Persian GPT-2 model without fine-tuning.") | |
| def add_noise(tensor, sensitivity, epsilon, delta): | |
| """Add Laplace noise for differential privacy.""" | |
| scale = sensitivity / epsilon | |
| noise = np.random.laplace(0, scale, tensor.shape) | |
| return tensor + torch.tensor(noise, dtype=tensor.dtype, device=tensor.device) | |
| def update_model(user_input, response): | |
| """Update conversation history.""" | |
| global conversation_history | |
| conversation_history.append({"user": user_input, "bot": response}) | |
| if len(conversation_history) > 100: # Limit history size | |
| conversation_history.pop(0) | |
| return f"Learning from: {user_input} -> {response}" | |
| def chat(message, history): | |
| # Set model to evaluation mode for inference | |
| model.eval() | |
| # RAG retrieval | |
| context = "" | |
| if embedder and index: | |
| try: | |
| query_emb = embedder.encode(message, convert_to_tensor=True).cpu().numpy() | |
| D, I = index.search(query_emb, k=10) # Increased k for better context | |
| retrieved = [texts[i] for i in I[0] if i >= 0 and i < len(texts)] | |
| context = "\n".join(retrieved) | |
| except Exception as e: | |
| print(f"Error in RAG retrieval: {e}") | |
| # Include conversation history (last 3 exchanges) | |
| history_context = "\n".join([f"User: {h['user']} -> Bot: {h['bot']}" for h in conversation_history[-3:]]) if conversation_history else "" | |
| # Prepare prompt with context and history | |
| prompt = f"شما یک چتبات فارسی مفید و دوستانه هستید. فقط به سؤال کاربر پاسخ کوتاه و مرتبط بدهید و از اطلاعات زمینه فقط برای کمک به پاسخ استفاده کنید:\nContext: {context}\nHistory: {history_context}\nUser: {message}\nBot:" | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device) | |
| # Generate response with model using beam search | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_length=150, | |
| num_beams=10, | |
| no_repeat_ngram_size=2, | |
| temperature=0.8, # Slightly increased for better diversity | |
| top_p=0.9, # Added for better response quality | |
| early_stopping=True, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Apply differential privacy noise to logits (optional) | |
| if apply_dp: | |
| logits = model(**inputs).logits | |
| noisy_logits = add_noise(logits, sensitivity, epsilon, delta) | |
| response_ids = torch.argmax(noisy_logits, dim=-1) | |
| response = tokenizer.decode(response_ids[0], skip_special_tokens=True) | |
| # Update conversation history | |
| update_model(message, response) | |
| return response | |
| # Initialize model and RAG (no fine-tuning) | |
| train_model() | |
| # Gradio interface | |
| iface = gr.ChatInterface( | |
| fn=chat, | |
| title="Persian GPT-2 Chatbot with RAG", | |
| description="Chat with pretrained Persian GPT-2 model using training_data.txt as RAG knowledge base." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |