import gradio as gr from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer import torch import nltk from nltk.tokenize import sent_tokenize from fin_readability_sustainability import BERTClass, do_predict import pandas as pd nltk.download('punkt') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base') model_sustain = BERTClass(2, "sustanability") model_sustain.to(device) model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict']) def get_sustainability(text): df = pd.DataFrame({'sentence':sent_tokenize(text)}) actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df) highlight = [] for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]): if prob>=4.384316: highlight.append((sent, 'non-sustainable')) elif prob<=1.423736: highlight.append((sent, 'sustainable')) else: highlight.append((sent, '-')) return highlight # b6 = gr.Button("Get Sustainability") #b6.click(get_sustainability, inputs = text, outputs = gr.HighlightedText()) iface = gr.Interface(fn=get_sustainability, inputs="textbox", title="CONBERT",description="SUSTAINABILITY TOOL", outputs=gr.HighlightedText(), allow_flagging="never") iface.launch()