paper_qa / background.py
chansung's picture
.
e08d6ea
raw
history blame
3.94 kB
import datasets
import pandas as pd
from huggingface_hub import HfApi
from utils import push_to_hf_hub
from paper.download import download_pdf_from_arxiv
from paper.download import get_papers_from_arxiv_ids
from paper.parser import extract_text_and_figures
from gen.gemini import get_basic_qa, get_deep_qa
def _filter_function(example, ids):
ids_e = example['Requested arXiv IDs']
for iid in ids:
if iid in ids_e:
ids_e.remove(iid)
example['Requested arXiv IDs'] = ids_e
print(example)
return example
def _extract_arxiv_id(text):
print(text)
start = text.find("[") + 1
end = text.find("]")
# Extract the text between brackets
if start != -1 and end != -1:
return text[start:end]
else:
return text
def process_arxiv_ids(gemini_api, hf_repo_id, req_hf_repo_id, hf_token, restart_repo_id, how_many=10):
arxiv_ids = []
ds1 = datasets.load_dataset(req_hf_repo_id)
for d in ds1['train']:
req_arxiv_ids = d['Requested arXiv IDs']
if len(req_arxiv_ids) > 0 and req_arxiv_ids[0] != "top":
arxiv_ids = arxiv_ids + [_extract_arxiv_id(req_arxiv_ids[0])]
arxiv_ids = arxiv_ids[:how_many]
if arxiv_ids is not None and len(arxiv_ids) > 0:
print(f"1. Get metadata for the papers [{arxiv_ids}]")
papers = get_papers_from_arxiv_ids(arxiv_ids)
print("...DONE")
print("2. Generating QAs for the paper")
for paper in papers:
try:
title = paper['title']
target_date = paper['target_date']
abstract = paper['paper']['summary']
arxiv_id = paper['paper']['id']
authors = paper['paper']['authors']
print(f"...PROCESSING ON[{arxiv_id}, {title}]")
print(f"......Downloading the paper PDF")
filename = download_pdf_from_arxiv(arxiv_id)
print(f"......DONE")
print(f"......Extracting text and figures")
texts, figures = extract_text_and_figures(filename)
text =' '.join(texts)
print(f"......DONE")
print(f"......Generating the seed(basic) QAs")
qnas = get_basic_qa(text, gemini_api_key=gemini_api, trucate=30000)
qnas['title'] = title
qnas['abstract'] = abstract
qnas['authors'] = ','.join(authors)
qnas['arxiv_id'] = arxiv_id
qnas['target_date'] = target_date
qnas['full_text'] = text
print(f"......DONE")
print(f"......Generating the follow-up QAs")
qnas = get_deep_qa(text, qnas, gemini_api_key=gemini_api, trucate=30000)
del qnas["qna"]
print(f"......DONE")
print(f"......Exporting to HF Dataset repo at [{hf_repo_id}]")
df = pd.DataFrame([qnas])
ds = datasets.Dataset.from_pandas(df)
ds = ds.cast_column("target_date", datasets.features.Value("timestamp[s]"))
push_to_hf_hub(ds, hf_repo_id, hf_token)
print(f"......DONE")
print(f"......Updating request arXiv HF Dataset repo at [{req_hf_repo_id}]")
ds1 = ds1['train'].map(
lambda example: _filter_function(example, [f"[{arxiv_id}] {title}"])
).filter(
lambda example: len(example['Requested arXiv IDs']) > 0
)
ds1.push_to_hub(req_hf_repo_id, token=hf_token)
print(f"......DONE")
except Exception as e:
print(f".......failed due to exception {e}")
continue
HfApi(token=hf_token).restart_space(
repo_id=restart_repo_id, token=hf_token
)