wldmr commited on
Commit
afe7c67
1 Parent(s): c2cc0a7
Files changed (3) hide show
  1. InferenceServer.py +17 -0
  2. app.py +78 -2
  3. requirements.txt +3 -0
InferenceServer.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rpunct import RestorePuncts
2
+ print("Loading Model...")
3
+ rpunct = RestorePuncts()
4
+
5
+ from fastapi import FastAPI
6
+
7
+ app = FastAPI()
8
+ print("Models loaded !")
9
+
10
+ @app.get("/")
11
+ def read_root():
12
+ return {"Homepage!"}
13
+
14
+ @app.get("/{restore}")
15
+ def get_correction(input_sentence):
16
+ '''Returns sentence with correct punctuations and case'''
17
+ return {"corrected_sentence": rpunct.punctuate(input_sentence, lang="en")}
app.py CHANGED
@@ -1,5 +1,81 @@
1
  import streamlit as st
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from multiprocessing import Process
3
+ import json
4
+ import requests
5
+ import time
6
+ import os
7
 
 
 
8
 
9
+ def start_server():
10
+ '''Helper to start to service through Unicorn '''
11
+ os.system("uvicorn InferenceServer:app --port 8080 --host 0.0.0.0 --workers 2")
12
+
13
+ def load_models():
14
+ '''One time loading/ Init of models and starting server as a seperate process'''
15
+ if not is_port_in_use(8080):
16
+ with st.spinner(text="Loading model, please wait..."):
17
+ proc = Process(target=start_server, args=(), daemon=True)
18
+ proc.start()
19
+ while not is_port_in_use(8080):
20
+ time.sleep(1)
21
+ st.success("Model server started.")
22
+ else:
23
+ st.success("Model server already running...")
24
+ st.session_state['models_loaded'] = True
25
+
26
+ def is_port_in_use(port):
27
+ '''Helper to check if service already running'''
28
+ import socket
29
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
30
+ return s.connect_ex(('0.0.0.0', port)) == 0
31
+
32
+ if 'models_loaded' not in st.session_state:
33
+ st.session_state['models_loaded'] = False
34
+
35
+ def get_correction(input_text):
36
+ '''Invokes the inference service'''
37
+ st.markdown(f'##### Corrected text:')
38
+ st.write('')
39
+ correct_request = "http://0.0.0.0:8080/restore?input_sentence="+input_text
40
+ with st.spinner('Wait for it...'):
41
+ correct_response = requests.get(correct_request)
42
+ correct_json = json.loads(correct_response.text)
43
+ corrected_sentence = correct_json["corrected_sentence"]
44
+ result = diff_strings(corrected_sentence,input_text)
45
+ st.markdown(result, unsafe_allow_html=True)
46
+
47
+ def diff_strings(output_text, input_text):
48
+ '''Highlights corrections'''
49
+ c_text = ""
50
+ for x in output_text.split(" "):
51
+ if x in input_text.split(" "):
52
+ c_text = c_text + x + " "
53
+ else:
54
+ c_text = c_text + '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + x + '</span>' + " "
55
+ return c_text
56
+
57
+ if __name__ == "__main__":
58
+
59
+ st.title('Rpunct Streamlit App')
60
+ st.subheader('For Punctuation and Upper Case restoration')
61
+ st.markdown("Model by [felflare](https://huggingface.co/felflare/bert-restore-punctuation) and app example by [anuragshas](https://huggingface.co/spaces/anuragshas/restore-punctuation-demo)", unsafe_allow_html=True)
62
+ st.markdown("Model restores the following punctuations -- [! ? . , - : ; ' ] and also the upper-casing of words.")
63
+ examples = [
64
+ "my name is clara and i live in berkeley california",
65
+ "in 2018 cornell researchers built a high-powered detector",
66
+ "lorem ipsum has been the industrys standard dummy text ever since the 1500s when an unknown printer took a galley of type and scrambled it to make a type specimen book"
67
+ ]
68
+ if not st.session_state['models_loaded']:
69
+ load_models()
70
+
71
+ input_text = st.selectbox(
72
+ label="Choose an example",
73
+ options=examples
74
+ )
75
+ st.write("(or)")
76
+ input_text = st.text_input(
77
+ label="Input sentence",
78
+ value=input_text
79
+ )
80
+ if input_text.strip():
81
+ get_correction(input_text)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git+https://github.com/braunschweig/rpunct
2
+ fastapi
3
+ uvicorn