Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -55,7 +55,7 @@ def load_training_data():
|
|
| 55 |
def build_rag_index(texts):
|
| 56 |
global embedder, index
|
| 57 |
try:
|
| 58 |
-
embedder = SentenceTransformer('
|
| 59 |
embeddings = embedder.encode(texts, convert_to_tensor=True, batch_size=8).cpu().numpy() # Smaller batch size
|
| 60 |
dimension = embeddings.shape[1]
|
| 61 |
index = faiss.IndexFlatL2(dimension)
|
|
@@ -101,14 +101,17 @@ def chat(message, history):
|
|
| 101 |
if embedder and index:
|
| 102 |
try:
|
| 103 |
query_emb = embedder.encode(message, convert_to_tensor=True).cpu().numpy()
|
| 104 |
-
D, I = index.search(query_emb, k=
|
| 105 |
retrieved = [texts[i] for i in I[0] if i >= 0 and i < len(texts)]
|
| 106 |
context = "\n".join(retrieved)
|
| 107 |
except Exception as e:
|
| 108 |
print(f"Error in RAG retrieval: {e}")
|
| 109 |
|
| 110 |
-
#
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Tokenize input
|
| 114 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
|
|
@@ -118,9 +121,11 @@ def chat(message, history):
|
|
| 118 |
outputs = model.generate(
|
| 119 |
input_ids=inputs["input_ids"],
|
| 120 |
attention_mask=inputs["attention_mask"],
|
| 121 |
-
max_length=
|
| 122 |
-
num_beams=
|
| 123 |
no_repeat_ngram_size=2,
|
|
|
|
|
|
|
| 124 |
early_stopping=True,
|
| 125 |
)
|
| 126 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
| 55 |
def build_rag_index(texts):
|
| 56 |
global embedder, index
|
| 57 |
try:
|
| 58 |
+
embedder = SentenceTransformer('sentence-transformers/paraphrase-xlm-r-multilingual-v1', device='cpu') # Better for conversational Persian
|
| 59 |
embeddings = embedder.encode(texts, convert_to_tensor=True, batch_size=8).cpu().numpy() # Smaller batch size
|
| 60 |
dimension = embeddings.shape[1]
|
| 61 |
index = faiss.IndexFlatL2(dimension)
|
|
|
|
| 101 |
if embedder and index:
|
| 102 |
try:
|
| 103 |
query_emb = embedder.encode(message, convert_to_tensor=True).cpu().numpy()
|
| 104 |
+
D, I = index.search(query_emb, k=10) # Increased k for better context
|
| 105 |
retrieved = [texts[i] for i in I[0] if i >= 0 and i < len(texts)]
|
| 106 |
context = "\n".join(retrieved)
|
| 107 |
except Exception as e:
|
| 108 |
print(f"Error in RAG retrieval: {e}")
|
| 109 |
|
| 110 |
+
# Include conversation history (last 3 exchanges)
|
| 111 |
+
history_context = "\n".join([f"User: {h['user']} -> Bot: {h['bot']}" for h in conversation_history[-3:]]) if conversation_history else ""
|
| 112 |
+
|
| 113 |
+
# Prepare prompt with context and history
|
| 114 |
+
prompt = f"شما یک چتبات فارسی مفید و دوستانه هستید. فقط به سؤال کاربر پاسخ کوتاه و مرتبط بدهید و از اطلاعات زمینه فقط برای کمک به پاسخ استفاده کنید:\nContext: {context}\nHistory: {history_context}\nUser: {message}\nBot:"
|
| 115 |
|
| 116 |
# Tokenize input
|
| 117 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
|
|
|
|
| 121 |
outputs = model.generate(
|
| 122 |
input_ids=inputs["input_ids"],
|
| 123 |
attention_mask=inputs["attention_mask"],
|
| 124 |
+
max_length=150,
|
| 125 |
+
num_beams=10,
|
| 126 |
no_repeat_ngram_size=2,
|
| 127 |
+
temperature=0.8, # Slightly increased for better diversity
|
| 128 |
+
top_p=0.9, # Added for better response quality
|
| 129 |
early_stopping=True,
|
| 130 |
)
|
| 131 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|