Chris4K's picture
Update app.py
c458050 verified
raw
history blame
6.1 kB
import os
import time
import pdfplumber
import docx
import nltk
import gradio as gr
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_text_splitters import TokenTextSplitter
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from nltk import sent_tokenize
from typing import List, Tuple
from transformers import AutoModel, AutoTokenizer
#import spacy
#spacy.cli.download("en_core_web_sm") # Ensure the model is available
#nlp = spacy.load("en_core_web_sm") # Load the model
# Ensure nltk sentence tokenizer is downloaded
nltk.download('punkt')
FILES_DIR = './files'
# Supported embedding models
MODELS = {
'e5-base': "danielheinz/e5-base-sts-en-de",
'multilingual-e5-base': "multilingual-e5-base",
'paraphrase-miniLM': "paraphrase-multilingual-MiniLM-L12-v2",
'paraphrase-mpnet': "paraphrase-multilingual-mpnet-base-v2",
'gte-large': "gte-large",
'gbert-base': "gbert-base"
}
class FileHandler:
@staticmethod
def extract_text(file_path):
ext = os.path.splitext(file_path)[-1].lower()
if ext == '.pdf':
return FileHandler._extract_from_pdf(file_path)
elif ext == '.docx':
return FileHandler._extract_from_docx(file_path)
elif ext == '.txt':
return FileHandler._extract_from_txt(file_path)
else:
raise ValueError(f"Unsupported file type: {ext}")
@staticmethod
def _extract_from_pdf(file_path):
with pdfplumber.open(file_path) as pdf:
return ' '.join([page.extract_text() for page in pdf.pages])
@staticmethod
def _extract_from_docx(file_path):
doc = docx.Document(file_path)
return ' '.join([para.text for para in doc.paragraphs])
@staticmethod
def _extract_from_txt(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
class EmbeddingModel:
def __init__(self, model_name, max_tokens=None):
self.model = HuggingFaceEmbeddings(model_name=model_name)
self.max_tokens = max_tokens
def embed(self, chunks: List[str]):
# Embed the list of chunks
return self.model.embed_documents(chunks)
def process_files(model_name, split_strategy, chunk_size, overlap_size, max_tokens):
print('-----mmm--------')
print(model_name)
print(split_strategy)
print(overlap_size)
print(chunk_size)
print(max_tokens)
# File processing
text = ""
for file in os.listdir(FILES_DIR):
file_path = os.path.join(FILES_DIR, file)
text += FileHandler.extract_text(file_path)
# Split text into chunks
if split_strategy == 'token':
splitter = TokenTextSplitter(chunk_size=250, chunk_overlap=20)
else:
splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=20)
chunks = splitter.split_text(text)
# Embed chunks, not the full text
model = EmbeddingModel(MODELS[model_name], max_tokens=max_tokens)
embeddings = model.embed(chunks)
print(chunks)
return embeddings, chunks
def search_embeddings(query, model_name, top_k):
model = HuggingFaceEmbeddings(model_name=MODELS[model_name])
embeddings = model.embed_query(query)
# Perform FAISS or other similarity-based search over embeddings
# This part requires you to build and search a FAISS index with embeddings
return embeddings # You would likely return the top-k results here
def calculate_statistics(embeddings):
# Return time taken, token count, etc.
return {"tokens": len(embeddings), "time_taken": time.time()}
import shutil
import shutil
def upload_file(file, model_name, split_strategy, overlap_size,chunk_size, max_tokens, query, top_k):
# Ensure chunk_size and overlap_size are valid integers and provide defaults if needed
#try:
# if chunk_size is None or chunk_size == "":
# chunk_size = 100 # Default value if not provided
# else:
# chunk_size = int(chunk_size) # Convert to int if valid#
# if overlap_size is None or overlap_size == "":
# overlap_size = 0 # Default value if not provided
# else:
# overlap_size = int(overlap_size) # Convert to int if valid
#except ValueError:
# return {"error": "Chunk size and overlap size must be valid integers."}
print('-------------')
print(file.name)
print(model_name)
print(split_strategy)
print(overlap_size)
print(chunk_size)
print(max_tokens)
print(query)
print(top_k)
# Handle file upload using Gradio file object
file_path = file.name # Get the file path from Gradio file object
# Copy the uploaded file content to a local directory
destination_path = os.path.join(FILES_DIR, os.path.basename(file_path))
shutil.copyfile(file_path, destination_path) # Use shutil to copy the file
# Process files and get embeddings
embeddings, chunks = process_files(model_name, split_strategy, chunk_size, overlap_size, max_tokens)
# Perform search
results = search_embeddings(query, model_name, top_k)
# Calculate statistics
stats = calculate_statistics(embeddings)
return {"results": results, "stats": stats}
# Gradio interface
iface = gr.Interface(
fn=upload_file,
inputs=[
gr.File(label="Upload File"),
gr.Textbox(label="Search Query"),
gr.Dropdown(choices=list(MODELS.keys()), label="Embedding Model"),
gr.Radio(choices=["token", "recursive"], label="Split Strategy"),
gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"), # Ensure type is int
gr.Slider(0, 100, step=10, value=50, label="Overlap Size"), # Ensure type is int
gr.Slider(50, 500, step=50, value=200, label="Max Tokens"), # Ensure type is int
gr.Slider(1, 10, step=1, value=5, label="Top K") # Ensure type is int
],
outputs="json"
)
iface.launch()