from dotenv import load_dotenv from img2table.document import Image from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI from pdf2image import convert_from_path from prompt import * from table_detector import detection_transform, device, model, ocr, outputs_to_objects import io import json import os import pandas as pd import re import torch load_dotenv() prompts = { 'alls': [prompt_entity_chunk, prompt_entity_combine], 'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine], 'summ': [prompt_entity_summ_chunk, prompt_entity_summ_combine], 'all': [prompt_entities_chunk, prompt_entities_combine] } class Process(): def __init__(self, llm): if llm.startswith('gpt'): self.llm = ChatOpenAI(temperature=0, model_name=llm) elif llm.startswith('gemini'): self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm) else: self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") def get_entity(self, data): chunks, types = data map_template = prompts[types][0] map_prompt = PromptTemplate.from_template(map_template) map_chain = LLMChain(llm=self.llm, prompt=map_prompt) reduce_template = prompts[types][1] reduce_prompt = PromptTemplate.from_template(reduce_template) reduce_chain = LLMChain(llm=self.llm, prompt=reduce_prompt) combine_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name="doc_summaries" ) reduce_documents_chain = ReduceDocumentsChain( combine_documents_chain=combine_chain, collapse_documents_chain=combine_chain, token_max=100000, ) map_reduce_chain = MapReduceDocumentsChain( llm_chain=map_chain, reduce_documents_chain=reduce_documents_chain, document_variable_name="docs", return_intermediate_steps=False, ) result = map_reduce_chain.invoke(chunks)['output_text'] print(types) print(result) if types != 'summ': result = eval(re.findall('(\{[^}]+\})', result)[0]) max_len = max([len(result[k]) for k in result]) for k in result: while len(result[k]) < max_len: result[k].append('') return pd.DataFrame(result) return result def get_entity_one(self, chunks): result = self.llm.invoke(prompt_entity_one_chunk.format(chunks)).content print('One') print(result) result = re.findall('(\{[^}]+\})', result)[0] return eval(result) def get_table(self, path): images = convert_from_path(path) tables = [] # Loop pages for image in images: pixel_values = detection_transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(pixel_values) id2label = model.config.id2label id2label[len(model.config.id2label)] = "no object" detected_tables = outputs_to_objects(outputs, image.size, id2label) # Loop table in page (if any) for idx in range(len(detected_tables)): cropped_table = image.crop(detected_tables[idx]["bbox"]) if detected_tables[idx]["label"] == 'table rotated': cropped_table = cropped_table.rotate(270, expand=True) # TODO: what is the perfect threshold? if detected_tables[idx]['score'] > 0.9: print(detected_tables[idx]) tables.append(cropped_table) df_result = pd.DataFrame() # Loop tables for table in tables: buffer = io.BytesIO() table.save(buffer, format='PNG') image = Image(buffer) # Extract to dataframe extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0) if len(extracted_tables) == 0: continue # Combine multiple dataframe df_table = extracted_tables[0].df for extracted_table in extracted_tables[1:]: df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True) df_table = df_table.fillna('') # Ask LLM with JSON data json_table = df_table.to_json(orient='records') str_json_table = json.dumps(json.loads(json_table), indent=2) result = self.llm.invoke(prompt_table.format(str_json_table)).content print('table') print(result) result = result[result.find('['):result.rfind(']')+1] try: result = eval(result) except SyntaxError: result = [] df_result = pd.concat([df_result, pd.DataFrame(result)], ignore_index=True) return df_result def get_rsid(self, df, text): rsids = re.findall('(rs[\d]{3,})', text) df_rsid = pd.DataFrame(rsids, columns=['rsID']) df = pd.concat([df, df_rsid]).fillna('').reset_index(drop=True) return df