from simpletransformers.classification import ClassificationModel, ClassificationArgs import os import torch import pandas as pd import streamlit as st model_types = { "bert": "cahya/bert-base-indonesian-522M", "roberta":"cahya/roberta-base-indonesian-522M"} model_name = "bert" class_names = ['none', 'E', 'S', 'G'] # Create a new instance of the model with the same architecture model_args = ClassificationArgs() model_args.use_cuda = False # Use CPU loaded_model = ClassificationModel( model_name, model_types[model_name], num_labels=len(class_names), args=model_args,use_cuda=False ) # Load the state dictionary into the model loaded_model.model.load_state_dict(torch.load('model_state_dict.pt', map_location=torch.device('cpu'))) def run(): # create form with st.form("form"): text_input = st.text_input("Enter some text") if text_input: st.write(f"Text input: {text_input}") st.markdown("---") submitted = st.form_submit_button("predict") data_inf = { "text" : text_input } data_inf = pd.DataFrame([data_inf]) st.dataframe(data_inf) if submitted: predictions, raw_outputs = loaded_model.predict([data_inf["text"][0]]) st.write("# ESG Category: ", class_names[int(predictions[0])]) if __name__=="__main__": run()