tkharisov7 commited on
Commit
6406410
·
1 Parent(s): 26e3b7e
Files changed (1) hide show
  1. script.py +53 -0
script.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.markdown("# Automatic Essay Scoring for IELTS Writing Task 2")
4
+ st.markdown("## Please enter your question and essay below:")
5
+ st.markdown("**Disclaimer: This is a demo app and the results are not accurate. Model is trained on small dataset and is not robust enough to generalize well. Main application is to determine scores from 6 to 9. Scores below 6 are not accurate.**")
6
+
7
+ st.markdown("### Question:")
8
+ question = st.text_input("Enter your question here")
9
+
10
+ st.markdown("### Essay:")
11
+ essay = st.text_input("Enter your essay here")
12
+
13
+ @st.cache_resource
14
+ def get_pipeline():
15
+ from transformers import Pipeline
16
+
17
+ class AESIELTSPipeline(Pipeline):
18
+ def _sanitize_parameters(self, **kwargs):
19
+ return kwargs, {}, {}
20
+
21
+ def preprocess(self, inputs):
22
+ question, essay = inputs
23
+ encoding = self.tokenizer(question, essay, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
24
+ input_ids = encoding['input_ids']
25
+ attention_mask = encoding['attention_mask']
26
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
27
+
28
+ def _forward(self, input):
29
+ output = self.model(**input)
30
+ return output[0].item()
31
+
32
+ def postprocess(self, output):
33
+ return output
34
+
35
+ from transformers.pipelines import PIPELINE_REGISTRY
36
+ from transformers import DistilBertForSequenceClassification
37
+
38
+ PIPELINE_REGISTRY.register_pipeline(
39
+ "aes-ielts",
40
+ AESIELTSPipeline,
41
+ pt_model=DistilBertForSequenceClassification
42
+ )
43
+
44
+ from transformers import pipeline
45
+ pipe = pipeline("aes-ielts", model="tkharisov7/aes-ielts")
46
+ return pipe
47
+
48
+ pipe = get_pipeline()
49
+ predictions = pipe((question, essay))
50
+
51
+ st.markdown("### Estimated Score:")
52
+
53
+ st.markdown(f"**{predictions}**")