NLP / zeroshot_clf_helper.py
ashishraics's picture
structure using config.yaml
d77ac81
raw
history blame
7.45 kB
import torch
from onnxruntime.quantization import quantize_dynamic,QuantType
import os
import subprocess
import numpy as np
import pandas as pd
import transformers
from pathlib import Path
import streamlit as st
import yaml
def read_yaml(file_path):
with open(file_path, "r") as f:
return yaml.safe_load(f)
config = read_yaml('config.yaml')
zs_chkpt=config['ZEROSHOT_CLF']['zs_chkpt']
zs_mdl_dir=config['ZEROSHOT_CLF']['zs_mdl_dir']
zs_onnx_mdl_dir=config['ZEROSHOT_CLF']['zs_onnx_mdl_dir']
zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']
zs_mlm_chkpt=config['ZEROSHOT_MLM']['zs_mlm_chkpt']
zs_mlm_mdl_dir=config['ZEROSHOT_MLM']['zs_mlm_mdl_dir']
zs_mlm_onnx_mdl_dir=config['ZEROSHOT_MLM']['zs_mlm_onnx_mdl_dir']
zs_mlm_onnx_mdl_name=config['ZEROSHOT_MLM']['zs_mlm_onnx_mdl_name']
##example
# zero_shot_classification(premise='Tiny worms and breath analyzers could screen for disease while it’s early and treatable',
# labels='science, sports, museum')
def zero_shot_classification(premise: str, labels: str, model, tokenizer):
"""
Args:
premise:
labels:
model:
tokenizer:
Returns:
"""
try:
labels=labels.split(',')
labels=[l.lower() for l in labels]
except:
raise Exception("please pass atleast 2 labels to classify")
premise=premise.lower()
labels_prob=[]
for l in labels:
hypothesis= f'this is an example of {l}'
input = tokenizer.encode(premise,hypothesis,
return_tensors='pt',
truncation_strategy='only_first')
output = model(input)
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
labels_prob.append(entail_contra_prob)
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
df=pd.DataFrame({'labels':labels,
'Probability':labels_prob_norm})
return df
def create_onnx_model_zs_nli(zs_chkpt,zs_onnx_mdl_dir):
"""
Args:
zs_onnx_mdl_dir:
Returns:
"""
# create onnx model using
if not os.path.exists(zs_onnx_mdl_dir):
try:
subprocess.run(['python3', '-m', 'transformers.onnx',
f'--model={zs_chkpt}',
'--feature=sequence-classification',
'--atol=1e-3',
zs_onnx_mdl_dir])
except Exception as e:
print(e)
# #create quanitzed model from vanila onnx
# quantize_dynamic(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}",
# f"{zs_onnx_mdl_dir}/{zs_onnx_quant_mdl_name}",
# weight_type=QuantType.QUInt8)
else:
pass
def zero_shot_classification_nli_onnx(premise,labels,_session,_tokenizer,hypothesis="This is an example of"):
"""
Args:
premise:
labels:
_session:
_tokenizer:
hypothesis:
Returns:
"""
try:
labels=labels.split(',')
labels=[l.lower() for l in labels]
except:
raise Exception("please pass atleast 2 labels to classify")
premise=premise.lower()
labels_prob=[]
for l in labels:
hypothesis= f"{hypothesis} {l}"
inputs = _tokenizer(premise,hypothesis,
return_tensors='pt',
truncation_strategy='only_first')
input_feed = {
"input_ids": np.array(inputs['input_ids']),
"attention_mask": np.array((inputs['attention_mask']))
}
output = _session.run(output_names=["logits"],input_feed=dict(input_feed))[0] #returns logits as array
output=torch.from_numpy(output)
entail_contra_prob = output[:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
labels_prob.append(entail_contra_prob)
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
df=pd.DataFrame({'labels':labels,
'Probability':labels_prob_norm})
return df
def create_onnx_model_zs_mlm(zs_mlm_chkpt,zs_mlm_onnx_mdl_dir):
"""
Args:
_model:
_tokenizer:
zs_mlm_onnx_mdl_dir:
Returns:
"""
if not os.path.exists(zs_mlm_onnx_mdl_dir):
try:
subprocess.run(['python3', '-m', 'transformers.onnx',
f'--model={zs_mlm_chkpt}',
'--feature=masked-lm',
zs_mlm_onnx_mdl_dir])
except:
pass
else:
pass
def zero_shot_classification_fillmask_onnx(premise,hypothesis,labels,_session,_tokenizer):
"""
Args:
premise:
hypothesis:
labels:
_session:
_tokenizer:
Returns:
"""
try:
labels=labels.split(',')
labels=[l.lower().rstrip().lstrip() for l in labels]
except:
raise Exception("please pass atleast 2 labels to classify")
premise=premise.lower()
hypothesis=hypothesis.lower()
final_input= f"{premise}.{hypothesis} [MASK]" #this can change depending on chkpt, this is for bert-base-uncased chkpt
_inputs=_tokenizer(final_input,padding=True, truncation=True,return_tensors="pt")
## lowers the performance
# premise_token_ids=_tokenizer.encode(premise,add_special_tokens=False)
# hypothesis_token_ids=_tokenizer.encode(hypothesis,add_special_tokens=False)
#
# #creating inputs ids
# input_ids=[_tokenizer.cls_token_id]+premise_token_ids+[_tokenizer.sep_token_id]+hypothesis_token_ids+[_tokenizer.sep_token_id]
# input_ids=np.array(input_ids)
#
# #creating token type ids
# premise_len=len(premise_token_ids)
# hypothesis_len=len(hypothesis_token_ids)
# token_type_ids=np.array([0]*(premise_len+2)+[1]*(hypothesis_len+1))
#
# #creating attention mask
# attention_mask=np.array([1]*(premise_len+hypothesis_len+3))
#
# input_feed={
# 'input_ids': np.expand_dims(input_ids,axis=0),
# 'token_type_ids': np.expand_dims(token_type_ids,0),
# 'attention_mask': np.expand_dims(attention_mask,0)
# }
input_feed={
'input_ids': np.array(_inputs['input_ids']),
'token_type_ids': np.array(_inputs['token_type_ids']),
'attention_mask': np.array(_inputs['attention_mask'])
}
output=_session.run(output_names=['logits'],input_feed=dict(input_feed))[0]
mask_token_index = np.argwhere(_inputs["input_ids"] == _tokenizer.mask_token_id)[1,0]
mask_token_logits=output[0,mask_token_index,:]
#seacrh for logits of input labels
#encode the labels and get the label id -
labels_logits=[]
for l in labels:
encoded_label=_tokenizer.encode(l)[1]
labels_logits.append(mask_token_logits[encoded_label])
#do a softmax on the logits
labels_logits=np.array(labels_logits)
labels_logits=torch.from_numpy(labels_logits)
labels_logits=labels_logits.softmax(dim=0)
output= {'Labels':labels,
'Probability':labels_logits}
df_output = pd.DataFrame(output)
df_output['Probability'] = df_output['Probability'].apply(lambda x: np.round(100*x, 1))
return df_output