NLP / zeroshot_clf_helper.py
ashishraics's picture
optimized app
a48f2db
raw history blame
No virus
3.07 kB
import torch
from onnxruntime.quantization import quantize_dynamic,QuantType
import os
import subprocess
import numpy as np
import pandas as pd
def zero_shot_classification(premise: str, labels: str, model, tokenizer):
try:
labels=labels.split(',')
labels=[l.lower() for l in labels]
except:
raise Exception("please pass atleast 2 labels to classify")
premise=premise.lower()
labels_prob=[]
for l in labels:
hypothesis= f'this is an example of {l}'
input = tokenizer.encode(premise,hypothesis,
return_tensors='pt',
truncation_strategy='only_first')
output = model(input)
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
labels_prob.append(entail_contra_prob)
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
df=pd.DataFrame({'labels':labels,
'Probability':labels_prob_norm})
return df
##example
# zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable',
# labels='science, sports, museum')
def create_onnx_model_zs(art_path='zeroshot_onnx_dir'):
# create onnx model using
if not os.path.exists(art_path):
try:
subprocess.run(['python3', '-m', 'transformers.onnx',
'--model=facebook/bart-large-mnli',
'--feature=sequence-classification',
art_path])
except:
pass
#create quanitzed model from vanila onnx
quantize_dynamic(f"{art_path}/model.onnx",f"{art_path}/model_quant.onnx",weight_type=QuantType.QUInt8)
else:
pass
def zero_shot_classification_onnx(premise,labels,_session,_tokenizer):
try:
labels=labels.split(',')
labels=[l.lower() for l in labels]
except:
raise Exception("please pass atleast 2 labels to classify")
premise=premise.lower()
labels_prob=[]
for l in labels:
hypothesis= f'this is an example of {l}'
inputs = _tokenizer(premise,hypothesis,
return_tensors='pt',
truncation_strategy='only_first')
input_feed = {
"input_ids": np.array(inputs['input_ids']),
"attention_mask": np.array((inputs['attention_mask']))
}
output = _session.run(output_names=["logits"],input_feed=dict(input_feed))[0] #returns logits as array
output=torch.from_numpy(output)
entail_contra_prob = output[:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
labels_prob.append(entail_contra_prob)
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
df=pd.DataFrame({'labels':labels,
'Probability':labels_prob_norm})
return df