VN_law_chat / retrieval.py
manhteky123's picture
Update retrieval.py
a0f39d7 verified
raw
history blame
1.9 kB
import os
import pandas as pd
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm
import csv
import faiss
import requests
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
# URLs of files from Hugging Face
base_url = "https://huggingface.co/datasets/manhteky123/LawVietnamese/resolve/main/"
data_url = f"{base_url}data.csv"
faiss_index_url = f"{base_url}faiss_index.bin"
vectors_url = f"{base_url}vectors.npy"
# Function to download files to disk
def download_to_disk(url, filename):
response = requests.get(url)
if response.status_code == 200:
with open(filename, 'wb') as f:
f.write(response.content)
print(f"Downloaded {url} to {filename}.")
else:
raise Exception(f"Failed to download {url}: {response.status_code}")
# Download the necessary files to disk
data_file_path = 'data.csv'
faiss_index_file_path = 'faiss_index.bin'
vectors_file_path = 'vectors.npy'
download_to_disk(data_url, data_file_path)
download_to_disk(faiss_index_url, faiss_index_file_path)
download_to_disk(vectors_url, vectors_file_path)
# Read the CSV data from the downloaded file
df = pd.read_csv(data_file_path)
# Use the 'truncated_text' column
column_name = 'truncated_text'
# Load SentenceTransformer
model = SentenceTransformer('intfloat/multilingual-e5-small')
# Read FAISS index from file
index = faiss.read_index(faiss_index_file_path)
# Load vectors
vectors = np.load(vectors_file_path)
def retrieve_documents(query, k=5, threshold=0.7):
query_vector = model.encode([query], convert_to_tensor=True).cpu().numpy()
D, I = index.search(query_vector, k)
similarities = 1 / (1 + D[0])
filtered_documents = []
for i, similarity in enumerate(similarities):
if similarity >= threshold:
filtered_documents.append(df.iloc[I[0][i]][column_name])
return filtered_documents