| import streamlit as st |
| import torch |
| from transformers import AutoModelForSequenceClassification as ASC |
| from transformers import AutoTokenizer as AT |
|
|
| model = ASC.from_pretrained("rickxzo/albert-large-v2-s.a.m-nli") |
| tokenizer = AT.from_pretrained("rickxzo/albert-large-v2-s.a.m-nli") |
|
|
| def infer(sentence1, sentence2): |
| inputs = tokenizer(sentence1, sentence2, return_tensors="pt", truncation=True, padding=True) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| logits = outputs.logits |
| probs = torch.nn.functional.softmax(logits, dim=-1) |
| return torch.argmax(probs).item() |
|
|
| st.title("Contradiction Detector using AlBERT model") |
| premise = st.text_area("Enter the premise: ") |
| hypothesis = st.text_area("Enter the hypothesis: ") |
|
|
| if premise and hypothesis: |
| k = infer(premise, hypothesis) |
| if k == 2: |
| st.write("#### **Contradicting Statements Detected!**") |
| elif k == 1: |
| st.write("#### **Neutral Statements Detected.**") |
| elif k == 0: |
| st.write("#### **Entailing Statements Detected.**") |
|
|
|
|
|
|