contract-rag / match.py
poemsforaphrodite's picture
Upload folder using huggingface_hub
b459d4c verified
import argparse
import os
import json
from openai import OpenAI
from PyPDF2 import PdfReader
from pinecone import Pinecone, ServerlessSpec
from dotenv import load_dotenv
import tiktoken
load_dotenv()
print("Starting the script...")
# Set up argument parser
parser = argparse.ArgumentParser(description="Process PDFs and match letters.")
parser.add_argument("--test", action="store_true", help="Run in test mode")
args = parser.parse_args()
# Set up OpenAI client
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
print("OpenAI client set up.")
# Set up Pinecone
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
index_name = "match"
print(f"Pinecone client set up. Using index name: {index_name}")
# Check if the index exists, if not, create it
if index_name not in pc.list_indexes().names():
print(f"Index '{index_name}' not found. Creating new index...")
pc.create_index(
name=index_name,
dimension=3072, # dimension for text-embedding-3-large
metric="cosine",
spec=ServerlessSpec(
cloud="aws",
region="us-west-2"
)
)
print(f"Created new index: {index_name}")
else:
print(f"Index '{index_name}' already exists.")
index = pc.Index(index_name)
print("Pinecone index initialized.")
def get_embedding(text):
print("Getting embedding for text...")
response = client.embeddings.create(input=text, model="text-embedding-3-large")
print("Embedding obtained.")
return response.data[0].embedding
def split_text(text, max_tokens=8000):
encoding = tiktoken.encoding_for_model("text-embedding-3-large")
tokens = encoding.encode(text)
chunks = []
current_chunk = []
current_length = 0
for token in tokens:
if current_length + 1 > max_tokens:
chunks.append(encoding.decode(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(token)
current_length += 1
if current_chunk:
chunks.append(encoding.decode(current_chunk))
return chunks
def save_pdf_to_pinecone(file_path, file_name):
print(f"Processing file: {file_path}")
pdf_reader = PdfReader(file_path)
content = ""
for page in pdf_reader.pages:
content += page.extract_text()
print(f"Extracted {len(content)} characters from the PDF.")
chunks = split_text(content)
print(f"Split content into {len(chunks)} chunks.")
for i, chunk in enumerate(chunks):
embedding = get_embedding(chunk)
chunk_id = f"{file_name}_chunk_{i}"
print(f"Upserting vector for {chunk_id} to Pinecone...")
index.upsert(vectors=[
{
"id": chunk_id,
"values": embedding,
"metadata": {"content": chunk, "file_name": file_name, "chunk_index": i}
}
])
print(f"Vector for {chunk_id} upserted successfully.")
def match_letter(letter_content):
print("Matching letter content...")
query_embedding = get_embedding(letter_content)
search_results = index.query(vector=query_embedding, top_k=1, include_metadata=True)
if search_results['matches']:
best_match = search_results['matches'][0]
print(f"Best match found: {best_match['metadata']['file_name']}")
return best_match['metadata']['file_name']
else:
print("No match found.")
return None
def process_matches_file(file_path, test_mode=False):
print(f"Processing matches file: {file_path}")
with open(file_path, 'r') as f:
matches = json.load(f)
print(f"Found {len(matches)} matches in the file.")
for input_file, _ in matches.items():
input_path = os.path.join('docs', 'in', input_file)
if test_mode:
print(f"Test mode: Would process input file: {input_path}")
else:
if os.path.exists(input_path):
print(f"Processing input file: {input_path}")
save_pdf_to_pinecone(input_path, input_file)
else:
print(f"Input file not found: {input_path}")
if test_mode:
print("Test mode: Stopping after first match.")
break
def clear_index():
print(f"Clearing all vectors from index '{index_name}'...")
try:
# Delete all vectors in the index
index.delete(delete_all=True)
print(f"All vectors deleted from index '{index_name}'.")
except Exception as e:
print(f"Error clearing index: {e}")
def main():
print("Starting main function...")
if not args.test:
# Add an option to clear the database
clear_db = input("Do you want to clear the database before processing? (y/n): ").lower()
if clear_db == 'y':
clear_index()
# Process the matches.json file
process_matches_file('matches.json', args.test)
# Test the matching functionality
print("Enter the content of your letter (or 'q' to quit):")
while True:
user_input = input()
if user_input.lower() == 'q':
break
matched_file = match_letter(user_input)
if matched_file:
print(f"The best match for your letter is: {matched_file}")
else:
print("No match found for your letter.")
if __name__ == "__main__":
main()