ashishraics commited on
Commit
e867b58
1 Parent(s): 5430ad2

updated architecture

Browse files
Files changed (1) hide show
  1. zeroshot_clf.py +44 -0
zeroshot_clf.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit
3
+ import torch
4
+ from transformers import AutoModelForSequenceClassification,AutoTokenizer
5
+ import numpy as np
6
+ import plotly.express as px
7
+
8
+ model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
9
+ tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
10
+
11
+ def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokenizer):
12
+ try:
13
+ labels=labels.split(',')
14
+ labels=[l.lower() for l in labels]
15
+ except:
16
+ raise Exception("please pass atleast 2 labels to classify")
17
+
18
+ premise=premise.lower()
19
+
20
+ labels_prob=[]
21
+
22
+ for l in labels:
23
+
24
+ hypothesis= f'this is an example of {l}'
25
+
26
+ input = tokenizer.encode(premise,hypothesis,
27
+ return_tensors='pt',
28
+ truncation_strategy='only_first')
29
+ output = model(input)
30
+ entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item()
31
+ labels_prob.append(entail_contra_prob)
32
+
33
+ labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
34
+
35
+ df=pd.DataFrame({'labels':labels,
36
+ 'Probability':labels_prob_norm})
37
+ fig=px.bar(x=df['Probability'],
38
+ y=df['labels'])
39
+ return streamlit.plotly_chart(fig)
40
+
41
+ # zero_shot_classification(premise='Tiny worms and breath analyzers could screen for disease while it’s early and treatable',
42
+ # labels='science, sports, museum')
43
+
44
+