Drake / utilis.py
kausthubkannan17's picture
feat: OCR support
8163d1a
import os
import json
from typing import List
import assemblyai as aai
from youtube_transcript_api import YouTubeTranscriptApi
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents.base import Document
from langchain_community.vectorstores import DeepLake
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from typing import Dict
import uuid
from pdf2image import convert_from_bytes
import pytesseract
from pytesseract import Output
class Processing:
def __init__(self, dataset_path: str, embedding_model_name: str, chunk_size=500, chunk_overlap=5):
"""
Parameters:
dataset_path (str): Path to the dataset in the Vector-DB
embedding_model_name (str): Name of the HuggingFace model to be used for embeddings
chunk_size (int): Size of each chunk to be processed
chunk_overlap (int): Overlap between each chunk
Initialize embedding model, text splitter, transcriber and Vector-DB
"""
self.dataset_path = dataset_path
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self.transcriber = aai.Transcriber()
self.embedding_model = HuggingFaceEmbeddings(
model_name=embedding_model_name,
encode_kwargs={'normalize_embeddings': False}
)
self.db = DeepLake(dataset_path=f"hub://{self.dataset_path}",
embedding=self.embedding_model,
exec_option="compute_engine"
)
def _add_metadata(self, documents: List[Document], url: str, id: str, source: str, file_type: str,
course_tag="") -> (List[Document], Dict[str, str]):
"""
Parameters:
documents (List[Document]): List of documents to add metadata to
id (str): ID of the documents
source (str): Source of the documents
file_type (str): Type of the documents
course_tag (str): Tag to identify the course the documents belongs to
Returns:
documents (List[Document], Dict[str, str): List of documents with metadata added along with the metadata
Add metadata to the documents
"""
metadata = {
"id": id,
"source": source,
"url": url,
"file_type": file_type,
"course_tag": course_tag
}
for doc in documents:
doc.metadata = metadata
return documents, metadata
def load_pdf(self, text) -> (List[Document], Dict[str, str]):
"""
Returns:
documents (List[Document], Dict[str, str): List of documents with metadata added along with the metadata
Load PDF file, split into chunks and add metadata
"""
pdf_chunk = self.text_splitter.create_documents([text])
print("Created document chunks")
return self._add_metadata(pdf_chunk, url="NaN", id=str(uuid.uuid4()), source="document", file_type="pdf")
def load_scanned_pdf(self, file) -> str:
"""
Parameters:
file (File): Scanned PDF file to be processed
Returns:
str: Text extracted from the scanned PDF file
Extract text from scanned PDF file
"""
images = convert_from_bytes(file)
all_text = ""
for image in images:
# Perform OCR on the image
text = pytesseract.image_to_data(image, lang='eng', output_type=Output.DICT)
# Extract text from the dictionary
page_text = " ".join(text['text'])
all_text += page_text + "\n"
return all_text
def load_transcript(self, url) -> (List[Document], Dict[str, str]):
"""
Returns:
documents (List[Document], Dict[str, str): List of documents with metadata added along with the metadata
Load transcript, split into chunks and add metadata
"""
transcript = self.transcriber.transcribe(url)
print("Transcribed")
transcript_chunk = self.text_splitter.create_documents([transcript.text])
print("Created transcript chunks")
return self._add_metadata(transcript_chunk, url="NaN", id=str(uuid.uuid4()), source="custom_video",
file_type="transcript")
def load_yt_transcript(self, url) -> (List[Document], Dict[str, str]):
"""
Returns:
documents (List[Document], Dict[str, str): List of documents with metadata added along with the metadata
Load YouTube transcript, split into chunks and add metadata
"""
if url.startswith("https://www.youtube.com/watch?v="):
video_id = url.replace("https://www.youtube.com/watch?v=", "")
else:
video_id = url.replace("https://youtu.be/", "")
transcript = YouTubeTranscriptApi.get_transcript(video_id)
print("Downloaded transcript")
transcript = [line['text'] for line in transcript]
transcript_text = ' '.join(transcript)
yt_transcript_chunk = self.text_splitter.create_documents([transcript_text])
print("Created YouTube transcript chunks")
return self._add_metadata(yt_transcript_chunk, url=url, id=video_id, source="youtube", file_type="transcript")
def upload_to_db(self, documents: List[Document]):
"""
Parameters:
documents (List[Document]): List of documents to upload to Vector-DB
Upload documents to Vector-DB
"""
print("Embedding and uploading to Vector-DB...")
self.db.add_documents(documents)
print("Uploaded to Vector-DB")
class PromptCreate:
def __init__(self, example_path, save_path):
self.examples = []
self.example_prompt = None
self.few_shot_prompt = None
self.example_path = example_path
self.file_name = "example_{i}.json"
self.save_path = save_path
def load_examples(self):
for i in range(1, len(os.listdir(self.example_path)) + 1):
filename = os.path.join(self.example_path, self.file_name.format(i=i))
try:
with open(filename, "r") as json_file:
self.examples.append(json.load(json_file))
except FileNotFoundError:
print(f"File {filename} not found.")
except json.JSONDecodeError:
print(f"Error decoding JSON from file {filename}.")
def create_prompt_template(self, input_variables, template_string):
self.example_prompt = PromptTemplate(input_variables=input_variables, template=template_string)
def create_few_shot_prompt(self, prefix, suffix):
self.few_shot_prompt = FewShotPromptTemplate(
examples=self.examples, example_prompt=self.example_prompt, prefix=prefix, suffix=suffix
)
def save_prompt(self):
self.few_shot_prompt.save(self.save_path)