ESG_Indonesia / app.py
didev007's picture
Upload 3 files
f3c5717 verified
raw
history blame contribute delete
No virus
1.39 kB
from simpletransformers.classification import ClassificationModel, ClassificationArgs
import os
import torch
import pandas as pd
import streamlit as st
model_types = {
"bert": "cahya/bert-base-indonesian-522M",
"roberta":"cahya/roberta-base-indonesian-522M"}
model_name = "bert"
class_names = ['none', 'E', 'S', 'G']
# Create a new instance of the model with the same architecture
model_args = ClassificationArgs()
model_args.use_cuda = False # Use CPU
loaded_model = ClassificationModel(
model_name, model_types[model_name], num_labels=len(class_names), args=model_args,use_cuda=False
)
# Load the state dictionary into the model
loaded_model.model.load_state_dict(torch.load('model_state_dict.pt', map_location=torch.device('cpu')))
def run():
# create form
with st.form("form"):
text_input = st.text_input("Enter some text")
if text_input:
st.write(f"Text input: {text_input}")
st.markdown("---")
submitted = st.form_submit_button("predict")
data_inf = {
"text" : text_input
}
data_inf = pd.DataFrame([data_inf])
st.dataframe(data_inf)
if submitted:
predictions, raw_outputs = loaded_model.predict([data_inf["text"][0]])
st.write("# ESG Category: ", class_names[int(predictions[0])])
if __name__=="__main__":
run()