import torch from onnxruntime.quantization import quantize_dynamic,QuantType import os import subprocess import numpy as np import pandas as pd import yaml def read_yaml(file_path): with open(file_path, "r") as f: return yaml.safe_load(f) config = read_yaml('config.yaml') zs_chkpt=config['ZEROSHOT_CLF']['zs_chkpt'] zs_mdl_dir=config['ZEROSHOT_CLF']['zs_mdl_dir'] zs_onnx_mdl_dir=config['ZEROSHOT_CLF']['zs_onnx_mdl_dir'] zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name'] zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name'] def zero_shot_classification(premise: str, labels: str, model, tokenizer): try: labels=labels.split(',') labels=[l.lower() for l in labels] except: raise Exception("please pass atleast 2 labels to classify") premise=premise.lower() labels_prob=[] for l in labels: hypothesis= f'this is an example of {l}' input = tokenizer.encode(premise,hypothesis, return_tensors='pt', truncation_strategy='only_first') output = model(input) entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties labels_prob.append(entail_contra_prob) labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob] df=pd.DataFrame({'labels':labels, 'Probability':labels_prob_norm}) return df ##example # zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable', # labels='science, sports, museum') def create_onnx_model_zs(zs_onnx_mdl_dir=zs_onnx_mdl_dir): # create onnx model using if not os.path.exists(zs_onnx_mdl_dir): try: subprocess.run(['python3', '-m', 'transformers.onnx', '--model=valhalla/distilbart-mnli-12-1', '--feature=sequence-classification', zs_onnx_mdl_dir]) except Exception as e: print(e) # #create quanitzed model from vanila onnx # quantize_dynamic(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}", # f"{zs_onnx_mdl_dir}/{zs_onnx_quant_mdl_name}", # weight_type=QuantType.QUInt8) else: pass def zero_shot_classification_onnx(premise,labels,_session,_tokenizer): try: labels=labels.split(',') labels=[l.lower() for l in labels] except: raise Exception("please pass atleast 2 labels to classify") premise=premise.lower() labels_prob=[] for l in labels: hypothesis= f'this is an example of {l}' inputs = _tokenizer(premise,hypothesis, return_tensors='pt', truncation_strategy='only_first') input_feed = { "input_ids": np.array(inputs['input_ids']), "attention_mask": np.array((inputs['attention_mask'])) } output = _session.run(output_names=["logits"],input_feed=dict(input_feed))[0] #returns logits as array output=torch.from_numpy(output) entail_contra_prob = output[:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties labels_prob.append(entail_contra_prob) labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob] df=pd.DataFrame({'labels':labels, 'Probability':labels_prob_norm}) return df