File size: 1,391 Bytes
f3c5717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from simpletransformers.classification import ClassificationModel, ClassificationArgs
import os
import torch
import pandas as pd
import streamlit as st

model_types = {
    "bert": "cahya/bert-base-indonesian-522M", 
    "roberta":"cahya/roberta-base-indonesian-522M"}
model_name = "bert"

class_names = ['none', 'E', 'S', 'G']

# Create a new instance of the model with the same architecture
model_args = ClassificationArgs()
model_args.use_cuda = False  # Use CPU

loaded_model = ClassificationModel(
    model_name, model_types[model_name], num_labels=len(class_names), args=model_args,use_cuda=False
)

# Load the state dictionary into the model
loaded_model.model.load_state_dict(torch.load('model_state_dict.pt', map_location=torch.device('cpu')))

def run():
    # create form 
    with st.form("form"):
        text_input = st.text_input("Enter some text")
        if text_input:
            st.write(f"Text input: {text_input}")
        st.markdown("---")

        submitted = st.form_submit_button("predict")

    data_inf = {
        "text" : text_input
    }

    data_inf = pd.DataFrame([data_inf])
    st.dataframe(data_inf)


    if submitted:
        predictions, raw_outputs = loaded_model.predict([data_inf["text"][0]])
        st.write("# ESG Category: ", class_names[int(predictions[0])])

if __name__=="__main__":
    run()