File size: 2,605 Bytes
b69b754
 
 
4a1e2da
b69b754
 
 
 
 
 
 
 
 
 
 
 
 
81a15fb
158d54a
 
 
 
81a15fb
 
 
b69b754
 
 
 
 
 
 
 
 
 
4a1e2da
b69b754
 
4a1e2da
b69b754
 
 
 
 
 
 
 
 
 
4a1e2da
 
 
 
 
 
81a15fb
4a1e2da
 
 
 
 
 
70fa8a8
4a1e2da
 
 
 
 
70fa8a8
4a1e2da
 
 
70fa8a8
 
 
 
 
 
 
4a1e2da
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import time

import errant
import spacy
import streamlit as st
from happytransformer import HappyTextToText, TTSettings

from highlighter import show_edits, show_highlights

checkpoints = [
    "aseifert/t5-base-jfleg-wi",
    "aseifert/byt5-base-jfleg-wi",
    "prithivida/grammar_error_correcter_v2",
    "Modfiededition/t5-base-fine-tuned-on-jfleg",
]


def download_spacy_model(model="en"):
    try:
        spacy.load(model)
    except ImportError:
        spacy.cli.download(model)  # type: ignore
    return True


@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_model(model_name):
    return HappyTextToText("T5", model_name)


@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_annotator(lang: str):
    return errant.load(lang)


def output(model, args, annotator, input_text):
    with st.spinner("Checking for errors πŸ”"):
        prefixed_input_text = "Grammar: " + input_text
        result = model.generate_text(prefixed_input_text, args=args).text

        try:
            st.success(result)
            show_highlights(annotator, input_text, result)
            # st.table(show_edits(annotator, input_text, result))
        except Exception as e:
            st.error("Some error occured!" + str(e))
            st.stop()


def main():
    st.title("πŸ€— Writing Assistant")
    st.markdown(
        """This writing assistant will proofread any text for you! See my [GitHub repo](https://github.com/aseifert/hf-writing-assistant) for implementation details."""
    )

    download_spacy_model()
    annotator = get_annotator("en")
    checkpoint = st.selectbox("Choose model", checkpoints)
    model = get_model(checkpoint)
    args = TTSettings(num_beams=5, min_length=1, max_length=1024)

    default_text = "A dog is bigger then mouse."
    default_text = "it gives him many apprtunites in the life, and i think that being knowledge person is a very wouderful thing to have so we can spend our lives in a successful way and full of happenis."
    input_text = st.text_area(
        label="Original text",
        value=default_text,
    )

    start = None
    if st.button("✍️ Check"):
        start = time.time()
        output(model, args, annotator, input_text)

    st.write("---")
    st.markdown(
        f"Built by [@aseifert](https://twitter.com/therealaseifert) during the HF community event – [GitHub repo](https://github.com/aseifert/hf-writing-assistant) – Team Writing Assistant"
    )
    if start is not None:
        st.text(f"prediction took {time.time() - start:.2f}s")


if __name__ == "__main__":
    main()