from datetime import datetime from dotenv import load_dotenv from img2table.document import Image from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain_openai import ChatOpenAI from pdf2image import convert_from_path from prompt import prompt_entity_gsd_chunk, prompt_entity_gsd_combine, prompt_entity_summ_chunk, prompt_entity_summ_combine, prompt_entities_chunk, prompt_entities_combine, prompt_entity_one_chunk, prompt_table, prompt_validation from table_detector import detection_transform, device, model, ocr, outputs_to_objects import google.generativeai as genai import io import json import os import pandas as pd import re import torch load_dotenv() genai.configure(api_key=os.environ['GOOGLE_API_KEY']) llm = ChatOpenAI(temperature=0, model_name="gpt-4-turbo") llm_p = ChatOpenAI(temperature=0, model_name="llama-3-sonar-large-32k-chat", api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") llm_g = genai.GenerativeModel(model_name='gemini-1.5-pro-latest') prompts = { 'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine], 'summ': [prompt_entity_summ_chunk, prompt_entity_summ_combine], 'all': [prompt_entities_chunk, prompt_entities_combine] } def get_entity(data): chunks, types = data map_template = prompts[types][0] map_prompt = PromptTemplate.from_template(map_template) map_chain = LLMChain(llm=llm, prompt=map_prompt) reduce_template = prompts[types][1] reduce_prompt = PromptTemplate.from_template(reduce_template) reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) combine_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name="doc_summaries" ) reduce_documents_chain = ReduceDocumentsChain( combine_documents_chain=combine_chain, collapse_documents_chain=combine_chain, token_max=100000, ) map_reduce_chain = MapReduceDocumentsChain( llm_chain=map_chain, reduce_documents_chain=reduce_documents_chain, document_variable_name="docs", return_intermediate_steps=False, ) result = map_reduce_chain.invoke(chunks)['output_text'] print(types) print(result) if types != 'summ': result = re.findall('(\{[^}]+\})', result)[0] return eval(result) return result def get_entity_one(chunks): result = llm.invoke(prompt_entity_one_chunk.format(chunks)).content print('One') print(result) result = re.findall('(\{[^}]+\})', result)[0] return eval(result) def get_table(path): start_time = datetime.now() images = convert_from_path(path) print('PDF to Image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes") tables = [] # Loop pages for image in images: pixel_values = detection_transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(pixel_values) id2label = model.config.id2label id2label[len(model.config.id2label)] = "no object" detected_tables = outputs_to_objects(outputs, image.size, id2label) # Loop table in page (if any) for idx in range(len(detected_tables)): cropped_table = image.crop(detected_tables[idx]["bbox"]) if detected_tables[idx]["label"] == 'table rotated': cropped_table = cropped_table.rotate(270, expand=True) # TODO: what is the perfect threshold? if detected_tables[idx]['score'] > 0.9: print(detected_tables[idx]) tables.append(cropped_table) print('Detect table from image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes") genes = [] snps = [] diseases = [] # Loop tables for table in tables: buffer = io.BytesIO() table.save(buffer, format='PNG') image = Image(buffer) # Extract to dataframe extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0) if len(extracted_tables) == 0: continue # Combine multiple dataframe df_table = extracted_tables[0].df for extracted_table in extracted_tables[1:]: df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True) df_table.loc[0] = df_table.loc[0].fillna('') # Identify multiple rows (in dataframe) as one row (in image) rows = [] indexes = [] for i in df_table.index: if not df_table.loc[i].isna().any(): if len(indexes) > 0: rows.append(indexes) indexes = [] indexes.append(i) rows.append(indexes) df_table_cleaned = pd.DataFrame(columns=df_table.columns) for row in rows: row_str = df_table.loc[row[0]] for idx in row[1:]: row_str += ' ' + df_table.loc[idx].fillna('') row_str = row_str.str.strip() df_table_cleaned.loc[len(df_table_cleaned)] = row_str # Ask LLM with JSON data json_table = df_table_cleaned.to_json(orient='records') str_json_table = json.dumps(json.loads(json_table), indent=2) result = llm.invoke(prompt_table.format(str_json_table)).content print('table') print(result) result = result[result.find('['):result.rfind(']')+1] try: result = eval(result) except SyntaxError: result = [] for res in result: res_gene = res['Genes'] res_snp = res['SNPs'] res_disease = res['Diseases'] for snp in res_snp: genes.append(res_gene) snps.append(snp) diseases.append(res_disease) print('OCR table to extract', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes") print(genes, snps, diseases) return genes, snps, diseases def validate(df): df = df[df['Genes'].notna()].reset_index(drop=True) df = df.fillna('') df['Genes'] = df['Genes'].str.upper() df['SNPs'] = df['SNPs'].str.lower() # Check if there is two gene names sym = ['-', '/', '|'] for i in df.index: gene = df.loc[i, 'Genes'] for s in sym: if s in gene: genes = gene.split(s) df.loc[i + 0.5] = df.loc[i] df = df.sort_index().reset_index(drop=True) df.loc[i, 'Genes'], df.loc[i + 1, 'Genes'] = genes[0], genes[1] # Check if there is SNPs without 'rs' for i in df.index: safe = True snp = df.loc[i, 'SNPs'] if re.fullmatch('rs(\d)+|', snp): pass elif re.fullmatch('ts(\d)+', snp): snp = 't' + snp[1:] elif re.fullmatch('s(\d)+', snp): snp = 'r' + snp elif re.fullmatch('(\d)+', snp): snp = 'rs' + snp else: safe = False df = df.drop(i) if safe: df.loc[i, 'SNPs'] = snp df.reset_index(drop=True, inplace=True) # Validate genes and diseases with LLM json_table = df[['Genes', 'SNPs', 'Diseases']].to_json(orient='records') str_json_table = json.dumps(json.loads(json_table), indent=2) result = llm_p.invoke(input=prompt_validation.format(str_json_table)).content print('val') print(result) result = result[result.find('['):result.rfind(']')+1] try: result = eval(result) except SyntaxError: result = [] df_val = pd.DataFrame(result) df_val = df_val.merge(df.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross') # TODO: How to validate genes and SNPs with ground truth? return df, df_val