|
from datetime import datetime |
|
from dotenv import load_dotenv |
|
from img2table.document import Image |
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain |
|
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain |
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain |
|
from langchain.chains.llm import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain_openai import ChatOpenAI |
|
from pdf2image import convert_from_path |
|
from prompt import prompt_entity_gsd_chunk, prompt_entity_gsd_combine, prompt_entity_summ_chunk, prompt_entity_summ_combine, prompt_entities_chunk, prompt_entities_combine, prompt_entity_one_chunk, prompt_table, prompt_validation |
|
from table_detector import detection_transform, device, model, ocr, outputs_to_objects |
|
|
|
import google.generativeai as genai |
|
import io |
|
import json |
|
import os |
|
import pandas as pd |
|
import re |
|
import torch |
|
|
|
load_dotenv() |
|
genai.configure(api_key=os.environ['GOOGLE_API_KEY']) |
|
|
|
llm = ChatOpenAI(temperature=0, model_name="gpt-4-turbo") |
|
llm_p = ChatOpenAI(temperature=0, model_name="llama-3-sonar-large-32k-chat", api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") |
|
llm_g = genai.GenerativeModel(model_name='gemini-1.5-pro-latest') |
|
|
|
prompts = { |
|
'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine], |
|
'summ': [prompt_entity_summ_chunk, prompt_entity_summ_combine], |
|
'all': [prompt_entities_chunk, prompt_entities_combine] |
|
} |
|
|
|
def get_entity(data): |
|
|
|
chunks, types = data |
|
|
|
map_template = prompts[types][0] |
|
map_prompt = PromptTemplate.from_template(map_template) |
|
map_chain = LLMChain(llm=llm, prompt=map_prompt) |
|
|
|
reduce_template = prompts[types][1] |
|
reduce_prompt = PromptTemplate.from_template(reduce_template) |
|
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) |
|
|
|
combine_chain = StuffDocumentsChain( |
|
llm_chain=reduce_chain, document_variable_name="doc_summaries" |
|
) |
|
|
|
reduce_documents_chain = ReduceDocumentsChain( |
|
combine_documents_chain=combine_chain, |
|
collapse_documents_chain=combine_chain, |
|
token_max=100000, |
|
) |
|
|
|
map_reduce_chain = MapReduceDocumentsChain( |
|
llm_chain=map_chain, |
|
reduce_documents_chain=reduce_documents_chain, |
|
document_variable_name="docs", |
|
return_intermediate_steps=False, |
|
) |
|
|
|
result = map_reduce_chain.invoke(chunks)['output_text'] |
|
print(types) |
|
print(result) |
|
if types != 'summ': |
|
result = re.findall('(\{[^}]+\})', result)[0] |
|
return eval(result) |
|
|
|
return result |
|
|
|
def get_entity_one(chunks): |
|
|
|
result = llm.invoke(prompt_entity_one_chunk.format(chunks)).content |
|
|
|
print('One') |
|
print(result) |
|
result = re.findall('(\{[^}]+\})', result)[0] |
|
|
|
return eval(result) |
|
|
|
def get_table(path): |
|
|
|
start_time = datetime.now() |
|
images = convert_from_path(path) |
|
print('PDF to Image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes") |
|
tables = [] |
|
|
|
|
|
for image in images: |
|
|
|
pixel_values = detection_transform(image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
outputs = model(pixel_values) |
|
|
|
id2label = model.config.id2label |
|
id2label[len(model.config.id2label)] = "no object" |
|
detected_tables = outputs_to_objects(outputs, image.size, id2label) |
|
|
|
|
|
for idx in range(len(detected_tables)): |
|
cropped_table = image.crop(detected_tables[idx]["bbox"]) |
|
if detected_tables[idx]["label"] == 'table rotated': |
|
cropped_table = cropped_table.rotate(270, expand=True) |
|
|
|
|
|
if detected_tables[idx]['score'] > 0.9: |
|
print(detected_tables[idx]) |
|
tables.append(cropped_table) |
|
|
|
print('Detect table from image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes") |
|
genes = [] |
|
snps = [] |
|
diseases = [] |
|
|
|
|
|
for table in tables: |
|
|
|
buffer = io.BytesIO() |
|
table.save(buffer, format='PNG') |
|
image = Image(buffer) |
|
|
|
|
|
extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0) |
|
|
|
if len(extracted_tables) == 0: |
|
continue |
|
|
|
|
|
df_table = extracted_tables[0].df |
|
for extracted_table in extracted_tables[1:]: |
|
df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True) |
|
|
|
df_table.loc[0] = df_table.loc[0].fillna('') |
|
|
|
|
|
rows = [] |
|
indexes = [] |
|
for i in df_table.index: |
|
if not df_table.loc[i].isna().any(): |
|
if len(indexes) > 0: |
|
rows.append(indexes) |
|
indexes = [] |
|
indexes.append(i) |
|
rows.append(indexes) |
|
|
|
df_table_cleaned = pd.DataFrame(columns=df_table.columns) |
|
for row in rows: |
|
row_str = df_table.loc[row[0]] |
|
for idx in row[1:]: |
|
row_str += ' ' + df_table.loc[idx].fillna('') |
|
row_str = row_str.str.strip() |
|
df_table_cleaned.loc[len(df_table_cleaned)] = row_str |
|
|
|
|
|
json_table = df_table_cleaned.to_json(orient='records') |
|
str_json_table = json.dumps(json.loads(json_table), indent=2) |
|
|
|
result = llm.invoke(prompt_table.format(str_json_table)).content |
|
print('table') |
|
print(result) |
|
result = result[result.find('['):result.rfind(']')+1] |
|
try: |
|
result = eval(result) |
|
except SyntaxError: |
|
result = [] |
|
|
|
for res in result: |
|
res_gene = res['Genes'] |
|
res_snp = res['SNPs'] |
|
res_disease = res['Diseases'] |
|
|
|
for snp in res_snp: |
|
genes.append(res_gene) |
|
snps.append(snp) |
|
diseases.append(res_disease) |
|
|
|
print('OCR table to extract', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes") |
|
print(genes, snps, diseases) |
|
|
|
return genes, snps, diseases |
|
|
|
def validate(df): |
|
|
|
df = df[df['Genes'].notna()].reset_index(drop=True) |
|
df = df.fillna('') |
|
df['Genes'] = df['Genes'].str.upper() |
|
df['SNPs'] = df['SNPs'].str.lower() |
|
|
|
|
|
sym = ['-', '/', '|'] |
|
for i in df.index: |
|
gene = df.loc[i, 'Genes'] |
|
for s in sym: |
|
if s in gene: |
|
genes = gene.split(s) |
|
df.loc[i + 0.5] = df.loc[i] |
|
df = df.sort_index().reset_index(drop=True) |
|
df.loc[i, 'Genes'], df.loc[i + 1, 'Genes'] = genes[0], genes[1] |
|
|
|
|
|
for i in df.index: |
|
safe = True |
|
snp = df.loc[i, 'SNPs'] |
|
if re.fullmatch('rs(\d)+|', snp): |
|
pass |
|
elif re.fullmatch('ts(\d)+', snp): |
|
snp = 'r' + snp[1:] |
|
elif re.fullmatch('s(\d)+', snp): |
|
snp = 'r' + snp |
|
elif re.fullmatch('(\d)+', snp): |
|
snp = 'rs' + snp |
|
else: |
|
safe = False |
|
df = df.drop(i) |
|
|
|
if safe: |
|
df.loc[i, 'SNPs'] = snp |
|
|
|
df.reset_index(drop=True, inplace=True) |
|
|
|
|
|
json_table = df[['Genes', 'SNPs', 'Diseases']].to_json(orient='records') |
|
str_json_table = json.dumps(json.loads(json_table), indent=2) |
|
|
|
result = llm_p.invoke(input=prompt_validation.format(str_json_table)).content |
|
print('val') |
|
print(result) |
|
|
|
result = result[result.find('['):result.rfind(']')+1] |
|
try: |
|
result = eval(result) |
|
except SyntaxError: |
|
result = [] |
|
|
|
df_val = pd.DataFrame(result) |
|
df_val = df_val.merge(df.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross') |
|
|
|
|
|
|
|
return df, df_val |