File size: 2,237 Bytes
bea6893
 
 
 
 
b9afdfb
bea6893
 
 
b9afdfb
bea6893
ca0f425
bea6893
 
ca0f425
bea6893
 
ca0f425
bea6893
 
 
 
 
ca0f425
bea6893
 
 
ca0f425
bea6893
 
 
 
 
 
b9afdfb
ca0f425
b9afdfb
 
 
 
 
 
 
 
 
 
 
bea6893
 
 
 
 
 
 
 
 
 
ca0f425
bea6893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca0f425
bea6893
d229ef6
ca0f425
52f8d54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image
import os


@st.cache(allow_output_mutation=True)
def load_model_cache():
    auth_token = os.environ.get("TOKEN_FROM_SECRET") or True

    tokenizer_pl = T5Tokenizer.from_pretrained(
        "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
    )
    model_pl = T5ForConditionalGeneration.from_pretrained(
        "Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
    )

    return tokenizer_pl, model_pl


img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/sVL-NLP-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 5000
cache_size: int = 100

st.set_page_config(
    page_title="DEMO - Reason for Contact detection",
    page_icon=img_favicon,
    initial_sidebar_state="expanded",
)

tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()


def get_predictions(text):
    input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
    output = model_pl.generate(
        input_ids,
        no_repeat_ngram_size=1,
        num_beams=3,
        num_beam_groups=3,
        min_length=10,
        max_length=100,
    )
    predicted_rfc = tokenizer_pl.decode(output[0], skip_special_tokens=True)
    return predicted_rfc


def trim_length():
    if len(st.session_state["input"]) > max_length:
        st.session_state["input"] = st.session_state["input"][:max_length]


if __name__ == "__main__":
    st.sidebar.image(img_short)
    st.image(img_full)
    st.title("VLT5 - RfC generation")

    generated_keywords = ""
    user_input = st.text_area(
        label=f"Input text (max {max_length} characters)",
        value="",
        height=300,
        on_change=trim_length,
        key="input",
    )

    language = st.sidebar.title("Model settings")
    language = st.sidebar.radio(
        "Select model to test",
        [
            "Polish",
        ],
    )

    result = st.button("Find reason for contact")
    if result:
        generated_rfc = get_predictions(text=user_input)
        st.text_area("Reason", generated_rfc)
        print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}")