import streamlit as st from utils import get_roberta, get_gpt, get_distilbert import torch st.title('Sentence Entailment') col1, col2 = st.columns([1,1]) with col1: sentence1 = st.text_input('Premise') with col2: sentence2 = st.text_input('Hypothesis') btn = st.button("Submit") label_dict = { 0 : 'entailment', 1 : 'neutral', 2 : 'contradiction' } if btn: # Get Roberta Output roberta_tokenizer, roberta_model = get_roberta() roberta_input = roberta_tokenizer( sentence1, sentence2, return_tensors="pt", padding=True, truncation=True, max_length=512 ) roberta_logits = roberta_model(**roberta_input)['logits'] st.write('ROBERTA', label_dict[roberta_logits.argmax().item()]) distilbert_tokenizer, distilbert_model = get_distilbert() distilbert_input = distilbert_tokenizer( sentence1, sentence2, return_tensors="pt", padding=True, truncation=True, max_length=512 ) distilbert_logits = distilbert_model(**distilbert_input)['logits'] st.write('DistilBERT', label_dict[distilbert_logits.argmax().item()]) # gpt_tokenizer, gpt_model = get_gpt() gpt_input = gpt_tokenizer( sentence1 + ' [SEP] ' + sentence2, truncation=True, padding='max_length', max_length=512, return_tensors='pt' ) gpt_logits = gpt_model(**gpt_input)['logits'] st.write('GPT', label_dict[gpt_logits.argmax().item()])