|
from dotenv import load_dotenv |
|
from img2table.document import Image |
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain |
|
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain |
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain |
|
from langchain.chains.llm import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_openai import ChatOpenAI |
|
from pdf2image import convert_from_path |
|
from prompt import * |
|
from table_detector import detection_transform, device, model, ocr, outputs_to_objects |
|
|
|
import io |
|
import json |
|
import os |
|
import pandas as pd |
|
import re |
|
import torch |
|
|
|
load_dotenv() |
|
|
|
prompts = { |
|
'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine], |
|
'summ': [prompt_entity_summ_chunk, prompt_entity_summ_combine], |
|
'all': [prompt_entities_chunk, prompt_entities_combine] |
|
} |
|
|
|
class Process(): |
|
|
|
def __init__(self, llm): |
|
|
|
if llm.startswith('gpt'): |
|
self.llm = ChatOpenAI(temperature=0, model_name=llm) |
|
elif llm.startswith('gemini'): |
|
self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm) |
|
else: |
|
self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") |
|
|
|
def get_entity(self, data): |
|
|
|
chunks, types = data |
|
|
|
map_template = prompts[types][0] |
|
map_prompt = PromptTemplate.from_template(map_template) |
|
map_chain = LLMChain(llm=self.llm, prompt=map_prompt) |
|
|
|
reduce_template = prompts[types][1] |
|
reduce_prompt = PromptTemplate.from_template(reduce_template) |
|
reduce_chain = LLMChain(llm=self.llm, prompt=reduce_prompt) |
|
|
|
combine_chain = StuffDocumentsChain( |
|
llm_chain=reduce_chain, document_variable_name="doc_summaries" |
|
) |
|
|
|
reduce_documents_chain = ReduceDocumentsChain( |
|
combine_documents_chain=combine_chain, |
|
collapse_documents_chain=combine_chain, |
|
token_max=100000, |
|
) |
|
|
|
map_reduce_chain = MapReduceDocumentsChain( |
|
llm_chain=map_chain, |
|
reduce_documents_chain=reduce_documents_chain, |
|
document_variable_name="docs", |
|
return_intermediate_steps=False, |
|
) |
|
|
|
result = map_reduce_chain.invoke(chunks)['output_text'] |
|
print(types) |
|
print(result) |
|
if types != 'summ': |
|
result = re.findall('(\{[^}]+\})', result)[0] |
|
return eval(result) |
|
|
|
return result |
|
|
|
def get_entity_one(self, chunks): |
|
|
|
result = self.llm.invoke(prompt_entity_one_chunk.format(chunks)).content |
|
|
|
print('One') |
|
print(result) |
|
result = re.findall('(\{[^}]+\})', result)[0] |
|
|
|
return eval(result) |
|
|
|
def get_table(self, path): |
|
|
|
images = convert_from_path(path) |
|
tables = [] |
|
|
|
|
|
for image in images: |
|
|
|
pixel_values = detection_transform(image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
outputs = model(pixel_values) |
|
|
|
id2label = model.config.id2label |
|
id2label[len(model.config.id2label)] = "no object" |
|
detected_tables = outputs_to_objects(outputs, image.size, id2label) |
|
|
|
|
|
for idx in range(len(detected_tables)): |
|
cropped_table = image.crop(detected_tables[idx]["bbox"]) |
|
if detected_tables[idx]["label"] == 'table rotated': |
|
cropped_table = cropped_table.rotate(270, expand=True) |
|
|
|
|
|
if detected_tables[idx]['score'] > 0.9: |
|
print(detected_tables[idx]) |
|
tables.append(cropped_table) |
|
|
|
genes = [] |
|
snps = [] |
|
diseases = [] |
|
|
|
|
|
for table in tables: |
|
|
|
buffer = io.BytesIO() |
|
table.save(buffer, format='PNG') |
|
image = Image(buffer) |
|
|
|
|
|
extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0) |
|
|
|
if len(extracted_tables) == 0: |
|
continue |
|
|
|
|
|
df_table = extracted_tables[0].df |
|
for extracted_table in extracted_tables[1:]: |
|
df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True) |
|
|
|
df_table = df_table.fillna('') |
|
|
|
|
|
json_table = df_table.to_json(orient='records') |
|
str_json_table = json.dumps(json.loads(json_table), indent=2) |
|
|
|
result = self.llm.invoke(prompt_table.format(str_json_table)).content |
|
print('table') |
|
print(result) |
|
result = result[result.find('['):result.rfind(']')+1] |
|
try: |
|
result = eval(result) |
|
except SyntaxError: |
|
result = [] |
|
|
|
for res in result: |
|
res_gene = res['Genes'] |
|
res_snp = res['SNPs'] |
|
res_disease = res['Diseases'] |
|
|
|
for snp in res_snp: |
|
genes.append(res_gene) |
|
snps.append(snp) |
|
diseases.append(res_disease) |
|
|
|
print(genes, snps, diseases) |
|
return genes, snps, diseases |
|
|