Thiago
Move application to root dir
05b0e9e
""" Download pre-trained models from Google drive. """
import os
import argparse
import zipfile
import logging
import requests
from tqdm import tqdm
import fire
import re
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(filename)s - %(message)s",
datefmt="%d/%m/%Y %H:%M:%S",
level=logging.INFO)
"", "", "", "","",""
MODEL_TO_URL = {
'PathologyEmoryPubMedBERT': 'https://drive.google.com/open?id=1l_el_mYXoTIQvGwKN2NZbp97E4svH4Fh',
'PathologyEmoryBERT': 'https://drive.google.com/open?id=11vzo6fJBw1RcdHVBAh6nnn8yua-4kj2IX',
'ClinicalBERT': 'https://drive.google.com/open?id=1UK9HqSspVneK8zGg7B93vIdTGKK9MI_v',
'BlueBERT': 'https://drive.google.com/open?id=1o-tcItErOiiwqZ-YRa3sMM3hGB4d3WkP',
'BioBERT': 'https://drive.google.com/open?id=1m7EkWkFBIBuGbfwg7j0R_WINNnYk3oS9',
'BERT': 'https://drive.google.com/open?id=1SB_AQAAsHkF79iSAaB3kumYT1rwcOJru',
'single_tfidf': 'https://drive.google.com/open?id=1-hxf7sKRtFGMOenlafdkeAr8_9pOz6Ym',
'branch_tfidf': 'https://drive.google.com/open?id=1pDSnwLFn3YzPRac9rKFV_FN9kdzj2Lb0'
}
"""
For large Files, Drive requires a Virus Check.
This function is reponsivle to extract the link from the button confirmation
"""
def get_url_from_gdrive_confirmation(contents):
url = ""
for line in contents.splitlines():
m = re.search(r'href="(\/uc\?export=download[^"]+)', line)
if m:
url = "https://docs.google.com" + m.groups()[0]
url = url.replace("&", "&")
break
m = re.search('id="downloadForm" action="(.+?)"', line)
if m:
url = m.groups()[0]
url = url.replace("&", "&")
break
m = re.search('"downloadUrl":"([^"]+)', line)
if m:
url = m.groups()[0]
url = url.replace("\\u003d", "=")
url = url.replace("\\u0026", "&")
break
m = re.search('<p class="uc-error-subcaption">(.*)</p>', line)
if m:
error = m.groups()[0]
raise RuntimeError(error)
if not url:
return None
return url
def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={ 'id' : id }, stream=True)
URL_new = get_url_from_gdrive_confirmation(response.text)
if URL_new != None:
URL = URL_new
response = session.get(URL, params={ 'id' : id }, stream=True)
token = get_confirm_token(response)
if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params=params, stream=True)
save_response_content(response, destination)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination):
CHUNK_SIZE = 32768
with open(destination, "wb") as f:
for chunk in tqdm(response.iter_content(CHUNK_SIZE)):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
def check_if_exist(model:str = "single_tfidf"):
if model =="single_vectorizer":
model = "single_tfidf"
if model =="branch_vectorizer":
model = "branch_tfidf"
project_dir = os.path.dirname(os.path.abspath(__file__))
if model != None:
if model in ['single_tfidf', 'branch_tfidf' ]:
path='models/all_labels_hierarchy/'
path_model = os.path.join(project_dir, path, model,'classifiers')
path_vectorizer = os.path.join(project_dir, path, model,'vectorizers')
if os.path.exists(path_model) and os.path.exists(path_vectorizer):
if len(os.listdir(path_model)) >0 and len(os.listdir(path_vectorizer)) >0:
return True
else:
path='models/higher_order_hierarchy/'
path_folder = os.path.join(project_dir, path, model)
if os.path.exists(path_folder):
if len(os.listdir(path_folder + "/" )) >1:
return True
return False
def download_model(all_labels='single_tfidf', higher_order='PathologyEmoryPubMedBERT'):
project_dir = os.path.dirname(os.path.abspath(__file__))
path_all_labels='models/all_labels_hierarchy/'
path_higher_order='models/higher_order_hierarchy/'
def extract_model(path_file, name):
os.makedirs(os.path.join(project_dir, path_file), exist_ok=True)
file_destination = os.path.join(project_dir, path_file, name + '.zip')
file_id = MODEL_TO_URL[name].split('id=')[-1]
logging.info(f'Downloading {name} model (~1000MB tar.xz archive)')
download_file_from_google_drive(file_id, file_destination)
logging.info('Extracting model from archive (~1300MB folder) and saving to ' + str(file_destination))
with zipfile.ZipFile(file_destination, 'r') as zip_ref:
zip_ref.extractall(path=os.path.dirname(file_destination))
logging.info('Removing archive')
os.remove(file_destination)
logging.info('Done.')
if higher_order != None:
if not check_if_exist(higher_order):
extract_model(path_higher_order, higher_order)
else:
logging.info('Model ' + str(higher_order) + ' already exist')
if all_labels!= None:
if not check_if_exist(all_labels):
extract_model(path_all_labels, all_labels)
else:
logging.info('Model ' + str(all_labels) + ' already exist')
def download(all_labels:str = "single_tfidf", higher_order:str = "PathologyEmoryPubMedBERT"):
"""
Input Options:
all_labels : single_tfidf, branch_tfidf
higher_order : clinicalBERT, blueBERT, patho_clinicalBERT, patho_blueBERT, charBERT
"""
all_labels_options = [ "single_tfidf", "branch_tfidf"]
higher_order_option = [ "PathologyEmoryPubMedBERT", "PathologyEmoryBERT", "ClinicalBERT", "BlueBERT","BioBERT","BERT" ]
if all_labels not in all_labels_options or higher_order not in higher_order_option:
print("\n\tPlease provide a valid model for downloading")
print("\n\t\tall_labels: " + " ".join(x for x in all_labels_options))
print("\n\t\thigher_order: " + " ".join(x for x in higher_order))
exit()
download_model(all_labels,higher_order)
if __name__ == "__main__":
fire.Fire(download)