File size: 3,317 Bytes
c4f4dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import fitz  # PyMuPDF
from docx import Document
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pickle
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from fastapi import FastAPI, UploadFile, File
from typing import List

app = FastAPI()

# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_path):
    text = ""
    doc = fitz.open(pdf_path)
    for page_num in range(len(doc)):
        page = doc.load_page(page_num)
        text += page.get_text()
    return text

# Function to extract text from a Word document
def extract_text_from_docx(docx_path):
    doc = Document(docx_path)
    text = "\n".join([para.text for para in doc.paragraphs])
    return text

# Initialize the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Hugging Face API token
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")

print(f"API Token: {api_token[:5]}...")

# Initialize the HuggingFace LLM
llm = HuggingFaceEndpoint(
    endpoint_url="https://api-inference.huggingface.co/models/gpt2",
    model_kwargs={"api_key": api_token}
)

# Initialize the HuggingFace embeddings
embedding = HuggingFaceEmbeddings()

# Load or create FAISS index
index_path = "faiss_index.pkl"
if os.path.exists(index_path):
    with open(index_path, "rb") as f:
        index = pickle.load(f)
else:
    # Create a new FAISS index if it doesn't exist
    index = faiss.IndexFlatL2(embedding_model.get_sentence_embedding_dimension())
    with open(index_path, "wb") as f:
        pickle.dump(index, f)

@app.post("/upload/")
async def upload_file(files: List[UploadFile] = File(...)):
    for file in files:
        content = await file.read()
        if file.filename.endswith('.pdf'):
            with open("temp.pdf", "wb") as f:
                f.write(content)
            text = extract_text_from_pdf("temp.pdf")
        elif file.filename.endswith('.docx'):
            with open("temp.docx", "wb") as f:
                f.write(content)
            text = extract_text_from_docx("temp.docx")
        else:
            return {"error": "Unsupported file format"}

        # Process the text and update FAISS index
        sentences = text.split("\n")
        embeddings = embedding_model.encode(sentences)
        index.add(np.array(embeddings))

    # Save the updated index
    with open(index_path, "wb") as f:
        pickle.dump(index, f)

    return {"message": "Files processed successfully"}

@app.post("/query/")
async def query(text: str):
    # Encode the query text
    query_embedding = embedding_model.encode([text])
    
    # Search the FAISS index
    D, I = index.search(np.array(query_embedding), k=5)
    
    top_documents = []
    for idx in I[0]:
        if idx != -1:  # Ensure that a valid index is found
            top_documents.append(f"Document {idx}")

    return {"top_documents": top_documents}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)