File size: 2,294 Bytes
917d2f9
514343b
4fd42f1
420a042
917d2f9
 
 
4fd42f1
8f36cc4
843aeb0
917d2f9
44264ed
917d2f9
 
 
 
 
418bd7c
8f36cc4
44264ed
3367c7d
e838b9b
8f36cc4
1716434
8f36cc4
418bd7c
8f36cc4
20efea7
8f36cc4
20efea7
8f36cc4
ea052a5
8f36cc4
 
917d2f9
8f36cc4
917d2f9
8f36cc4
917d2f9
8f36cc4
917d2f9
8f36cc4
917d2f9
8f36cc4
917d2f9
8f36cc4
917d2f9
8f36cc4
 
917d2f9
8f36cc4
917d2f9
8f36cc4
 
 
917d2f9
8f36cc4
5daf8df
8f36cc4
 
 
5daf8df
2ec3b65
5daf8df
8f36cc4
 
 
 
5daf8df
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import streamlit as st



# Library for Entailment
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load model

tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")

text_classification_model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli")



### Streamlit interface ###
      
st.title("Text Classification")

st.subheader("Entailment, neutral or contradiction?")

with st.form("submission_form", clear_on_submit=False):

       threshold = st.slider("Threshold", min_value=0.0, max_value=1.0, step=0.1, value=0.7)

       sentence_1 = st.text_input("Sentence 1 input")
       
       sentence_2 = st.text_input("Sentence 2 input")
       
       submit_button_compare = st.form_submit_button("Compare Sentences")
       
# If submit_button_compare clicked
if submit_button_compare:

       print("Comparing sentences...")

       ### Text classification - entailment, neutral or contradiction ###

       raw_inputs = [f"{sentence_1}</s></s>{sentence_2}"]

       inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors="pt")

       # print(inputs)

       outputs = text_classification_model(**inputs)

       outputs = torch.nn.functional.softmax(outputs.logits, dim = -1)
       # print(outputs)

       # argmax_index = torch.argmax(outputs).item()

       print(text_classification_model.config.id2label[0], ":", round(outputs[0][0].item()*100,2),"%")
       print(text_classification_model.config.id2label[1], ":", round(outputs[0][1].item()*100,2),"%")
       print(text_classification_model.config.id2label[2], ":", round(outputs[0][2].item()*100,2),"%")

       st.subheader("Text classification for both sentences:")

       st.write(text_classification_model.config.id2label[1], ":", round(outputs[0][1].item()*100,2),"%")
       st.write(text_classification_model.config.id2label[0], ":", round(outputs[0][0].item()*100,2),"%")
       st.write(text_classification_model.config.id2label[2], ":", round(outputs[0][2].item()*100,2),"%")

       entailment_score = round(outputs[0][2].item(),2)

       if entailment_score >= threshold:
              st.subheader("The statements are very similar!")
       else:
              st.subheader("The statements are not close enough")