Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,8 +14,11 @@ def load_model_cache():
|
|
14 |
model_pl = T5ForConditionalGeneration.from_pretrained(
|
15 |
"Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
|
16 |
)
|
|
|
|
|
|
|
17 |
|
18 |
-
return tokenizer_pl, model_pl
|
19 |
|
20 |
|
21 |
img_full = Image.open("images/vl-logo-nlp-blue.png")
|
@@ -25,24 +28,35 @@ max_length: int = 5000
|
|
25 |
cache_size: int = 100
|
26 |
|
27 |
st.set_page_config(
|
28 |
-
page_title="DEMO - Reason for Contact
|
29 |
page_icon=img_favicon,
|
30 |
initial_sidebar_state="expanded",
|
31 |
)
|
32 |
|
33 |
-
tokenizer_pl, model_pl = load_model_cache()
|
34 |
|
35 |
|
36 |
-
def get_predictions(text):
|
37 |
input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
predicted_rfc = tokenizer_pl.decode(output[0], skip_special_tokens=True)
|
47 |
return predicted_rfc
|
48 |
|
@@ -55,7 +69,7 @@ def trim_length():
|
|
55 |
if __name__ == "__main__":
|
56 |
st.sidebar.image(img_short)
|
57 |
st.image(img_full)
|
58 |
-
st.title("VLT5 -
|
59 |
|
60 |
generated_keywords = ""
|
61 |
user_input = st.text_area(
|
@@ -66,16 +80,29 @@ if __name__ == "__main__":
|
|
66 |
key="input",
|
67 |
)
|
68 |
|
69 |
-
|
70 |
-
|
71 |
"Select model to test",
|
72 |
[
|
73 |
-
"Polish",
|
|
|
74 |
],
|
75 |
)
|
76 |
|
77 |
result = st.button("Find reason for contact")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
if result:
|
79 |
-
generated_rfc = get_predictions(text=user_input)
|
80 |
-
st.text_area(
|
81 |
print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}")
|
|
|
14 |
model_pl = T5ForConditionalGeneration.from_pretrained(
|
15 |
"Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
|
16 |
)
|
17 |
+
model_det_pl = T5ForConditionalGeneration.from_pretrained(
|
18 |
+
"Voicelab/vlt5-base-rfc-detector-1.0", use_auth_token=auth_token
|
19 |
+
)
|
20 |
|
21 |
+
return tokenizer_pl, model_pl, model_det_pl
|
22 |
|
23 |
|
24 |
img_full = Image.open("images/vl-logo-nlp-blue.png")
|
|
|
28 |
cache_size: int = 100
|
29 |
|
30 |
st.set_page_config(
|
31 |
+
page_title="DEMO - Reason for Contact generation",
|
32 |
page_icon=img_favicon,
|
33 |
initial_sidebar_state="expanded",
|
34 |
)
|
35 |
|
36 |
+
tokenizer_pl, model_pl, model_det_pl = load_model_cache()
|
37 |
|
38 |
|
39 |
+
def get_predictions(text, mode):
|
40 |
input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
|
41 |
+
if mode == "Polish - RfC Generation":
|
42 |
+
output = model_pl.generate(
|
43 |
+
input_ids,
|
44 |
+
no_repeat_ngram_size=1,
|
45 |
+
num_beams=3,
|
46 |
+
num_beam_groups=3,
|
47 |
+
min_length=10,
|
48 |
+
max_length=100,
|
49 |
+
)
|
50 |
+
elif mode == "Polish - RfC Detection":
|
51 |
+
output = model.generate(
|
52 |
+
input_ids,
|
53 |
+
no_repeat_ngram_size=2,
|
54 |
+
num_beams=3,
|
55 |
+
num_beam_groups=3,
|
56 |
+
repetition_penalty=1.5,
|
57 |
+
diversity_penalty=2.0,
|
58 |
+
length_penalty=2.0,
|
59 |
+
)
|
60 |
predicted_rfc = tokenizer_pl.decode(output[0], skip_special_tokens=True)
|
61 |
return predicted_rfc
|
62 |
|
|
|
69 |
if __name__ == "__main__":
|
70 |
st.sidebar.image(img_short)
|
71 |
st.image(img_full)
|
72 |
+
st.title("VLT5 - Reason for Contact generator")
|
73 |
|
74 |
generated_keywords = ""
|
75 |
user_input = st.text_area(
|
|
|
80 |
key="input",
|
81 |
)
|
82 |
|
83 |
+
mode = st.sidebar.title("Model settings")
|
84 |
+
mode = st.sidebar.radio(
|
85 |
"Select model to test",
|
86 |
[
|
87 |
+
"Polish - RfC Generation",
|
88 |
+
"Polish - RfC Detection",
|
89 |
],
|
90 |
)
|
91 |
|
92 |
result = st.button("Find reason for contact")
|
93 |
+
if mode == "Polish - RfC Generation (accepts whole conversation)":
|
94 |
+
print("You selected RfC Generation model.")
|
95 |
+
print("-- Input: Whole conversation. Should specify roles (e.g. AGENT: Hello, how can I help you? CLIENT: Hi, I would like to report a stolen card.")
|
96 |
+
print("-- Output: Reason for calling for the whole conversation.")
|
97 |
+
text_area = "Put a whole conversation or full e-mail here."
|
98 |
+
|
99 |
+
elif mode == "Polish - RfC Detection (accepts one turn)":
|
100 |
+
print("You selected RfC Detection model.")
|
101 |
+
print("-- Input: A single turn from the conversation e.g. 'Hello, how can I help you?' or 'Hi, I would like to report a stolen card.'")
|
102 |
+
print("-- Output: Model will return an empty string if a turn possibly does not includes Reason for Calling, or a sentence if the RfC is detected.")
|
103 |
+
text_area = "Put a single turn or a few sentences here."
|
104 |
+
|
105 |
if result:
|
106 |
+
generated_rfc = get_predictions(text=user_input, mode=mode)
|
107 |
+
st.text_area(text_area, generated_rfc)
|
108 |
print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}")
|