code5ecure commited on
Commit
12a3d1e
·
verified ·
1 Parent(s): 4dd6073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -55,7 +55,7 @@ def load_training_data():
55
  def build_rag_index(texts):
56
  global embedder, index
57
  try:
58
- embedder = SentenceTransformer('xmanii/maux-gte-persian', device='cpu') # Use CPU to save memory
59
  embeddings = embedder.encode(texts, convert_to_tensor=True, batch_size=8).cpu().numpy() # Smaller batch size
60
  dimension = embeddings.shape[1]
61
  index = faiss.IndexFlatL2(dimension)
@@ -101,14 +101,17 @@ def chat(message, history):
101
  if embedder and index:
102
  try:
103
  query_emb = embedder.encode(message, convert_to_tensor=True).cpu().numpy()
104
- D, I = index.search(query_emb, k=5) # Increased k for better context
105
  retrieved = [texts[i] for i in I[0] if i >= 0 and i < len(texts)]
106
  context = "\n".join(retrieved)
107
  except Exception as e:
108
  print(f"Error in RAG retrieval: {e}")
109
 
110
- # Prepare prompt with context
111
- prompt = f"Context: {context}\nUser: {message}\nBot:"
 
 
 
112
 
113
  # Tokenize input
114
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
@@ -118,9 +121,11 @@ def chat(message, history):
118
  outputs = model.generate(
119
  input_ids=inputs["input_ids"],
120
  attention_mask=inputs["attention_mask"],
121
- max_length=100, # Increased for better responses
122
- num_beams=7, # Increased for better quality
123
  no_repeat_ngram_size=2,
 
 
124
  early_stopping=True,
125
  )
126
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
55
  def build_rag_index(texts):
56
  global embedder, index
57
  try:
58
+ embedder = SentenceTransformer('sentence-transformers/paraphrase-xlm-r-multilingual-v1', device='cpu') # Better for conversational Persian
59
  embeddings = embedder.encode(texts, convert_to_tensor=True, batch_size=8).cpu().numpy() # Smaller batch size
60
  dimension = embeddings.shape[1]
61
  index = faiss.IndexFlatL2(dimension)
 
101
  if embedder and index:
102
  try:
103
  query_emb = embedder.encode(message, convert_to_tensor=True).cpu().numpy()
104
+ D, I = index.search(query_emb, k=10) # Increased k for better context
105
  retrieved = [texts[i] for i in I[0] if i >= 0 and i < len(texts)]
106
  context = "\n".join(retrieved)
107
  except Exception as e:
108
  print(f"Error in RAG retrieval: {e}")
109
 
110
+ # Include conversation history (last 3 exchanges)
111
+ history_context = "\n".join([f"User: {h['user']} -> Bot: {h['bot']}" for h in conversation_history[-3:]]) if conversation_history else ""
112
+
113
+ # Prepare prompt with context and history
114
+ prompt = f"شما یک چت‌بات فارسی مفید و دوستانه هستید. فقط به سؤال کاربر پاسخ کوتاه و مرتبط بدهید و از اطلاعات زمینه فقط برای کمک به پاسخ استفاده کنید:\nContext: {context}\nHistory: {history_context}\nUser: {message}\nBot:"
115
 
116
  # Tokenize input
117
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
 
121
  outputs = model.generate(
122
  input_ids=inputs["input_ids"],
123
  attention_mask=inputs["attention_mask"],
124
+ max_length=150,
125
+ num_beams=10,
126
  no_repeat_ngram_size=2,
127
+ temperature=0.8, # Slightly increased for better diversity
128
+ top_p=0.9, # Added for better response quality
129
  early_stopping=True,
130
  )
131
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)