didev007 commited on
Commit
f3c5717
1 Parent(s): 741cc58

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +48 -0
  2. model_state_dict.pt +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from simpletransformers.classification import ClassificationModel, ClassificationArgs
2
+ import os
3
+ import torch
4
+ import pandas as pd
5
+ import streamlit as st
6
+
7
+ model_types = {
8
+ "bert": "cahya/bert-base-indonesian-522M",
9
+ "roberta":"cahya/roberta-base-indonesian-522M"}
10
+ model_name = "bert"
11
+
12
+ class_names = ['none', 'E', 'S', 'G']
13
+
14
+ # Create a new instance of the model with the same architecture
15
+ model_args = ClassificationArgs()
16
+ model_args.use_cuda = False # Use CPU
17
+
18
+ loaded_model = ClassificationModel(
19
+ model_name, model_types[model_name], num_labels=len(class_names), args=model_args,use_cuda=False
20
+ )
21
+
22
+ # Load the state dictionary into the model
23
+ loaded_model.model.load_state_dict(torch.load('model_state_dict.pt', map_location=torch.device('cpu')))
24
+
25
+ def run():
26
+ # create form
27
+ with st.form("form"):
28
+ text_input = st.text_input("Enter some text")
29
+ if text_input:
30
+ st.write(f"Text input: {text_input}")
31
+ st.markdown("---")
32
+
33
+ submitted = st.form_submit_button("predict")
34
+
35
+ data_inf = {
36
+ "text" : text_input
37
+ }
38
+
39
+ data_inf = pd.DataFrame([data_inf])
40
+ st.dataframe(data_inf)
41
+
42
+
43
+ if submitted:
44
+ predictions, raw_outputs = loaded_model.predict([data_inf["text"][0]])
45
+ st.write("# ESG Category: ", class_names[int(predictions[0])])
46
+
47
+ if __name__=="__main__":
48
+ run()
model_state_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baa445d294b8e64c7066819ed04e5574df66f2e7f1e5d57b1e0b1e51b26a0abb
3
+ size 442562325
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pandas==1.5.3
2
+ simpletransformers==0.70.1
3
+ torch==2.2.1
4
+ streamlit