Spaces:
Runtime error
Runtime error
Upload dashboard
Browse files- Data_Loading.py +119 -0
- README.md +1 -1
- codebook_demo.json +108 -0
- data_demo.csv +35 -0
- helpers.py +630 -0
- pages/1_Codebook_Design.py +731 -0
- pages/2_Codebook_Advanced_Edit.py +251 -0
- pages/3_Apply_Codebook.py +222 -0
- requirements.txt +10 -0
Data_Loading.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
from st_aggrid import AgGrid, DataReturnMode
|
6 |
+
|
7 |
+
current = os.path.dirname(os.path.realpath(__file__))
|
8 |
+
parent = os.path.dirname(current)
|
9 |
+
sys.path.append(parent)
|
10 |
+
from helpers import apply_style, get_idx_column, read_csv_from_web, read_json_from_web
|
11 |
+
|
12 |
+
apply_style()
|
13 |
+
|
14 |
+
codebook = {}
|
15 |
+
|
16 |
+
st.markdown(
|
17 |
+
"""
|
18 |
+
# Codebook Creation/Edition Tool based on the PR-ENT Approach.
|
19 |
+
### *Rethinking the Event Coding Pipeline with Prompt Entailment*
|
20 |
+
### Author: Anonymized for submission"
|
21 |
+
##### Version: 1.0
|
22 |
+
"""
|
23 |
+
)
|
24 |
+
st.markdown("***********")
|
25 |
+
|
26 |
+
st.markdown(
|
27 |
+
"""
|
28 |
+
## Data Loading
|
29 |
+
"""
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
st.markdown(
|
34 |
+
"""
|
35 |
+
### Upload a CSV of event descriptions.
|
36 |
+
"""
|
37 |
+
)
|
38 |
+
uploaded_file = st.file_uploader("Upload a csv file containing event descriptions")
|
39 |
+
if uploaded_file is not None:
|
40 |
+
st.session_state.data = read_csv_from_web(uploaded_file)
|
41 |
+
|
42 |
+
|
43 |
+
if "data" in st.session_state:
|
44 |
+
# Filter will be reset if the page is left and then used again
|
45 |
+
loading_df = st.text("Loading data display...")
|
46 |
+
st.write(
|
47 |
+
"""
|
48 |
+
The below display of the data can be used to filter the data. Click on the *3 bars logo* when hovering over a column name and the filtering
|
49 |
+
tool will appear. Filters are kept in memory on the whole dashboard as long as the `Reset Filters` button is not clicked.
|
50 |
+
|
51 |
+
Current limitation: If a filter is set and the user change page. Then it can not be modified anymore and needs to be reset.
|
52 |
+
"""
|
53 |
+
)
|
54 |
+
if "filtered_df" not in st.session_state:
|
55 |
+
st.session_state.filtered_df = st.session_state.data
|
56 |
+
if st.button("Reset Filters"):
|
57 |
+
st.session_state.filtered_df = st.session_state.data
|
58 |
+
|
59 |
+
st.session_state.filtered_df = AgGrid(
|
60 |
+
st.session_state.filtered_df,
|
61 |
+
height=400,
|
62 |
+
data_return_mode=DataReturnMode.FILTERED,
|
63 |
+
update_mode="MANUAL",
|
64 |
+
)["data"]
|
65 |
+
|
66 |
+
if "text_column_design_perm" not in st.session_state:
|
67 |
+
st.session_state[
|
68 |
+
"text_column_design_perm"
|
69 |
+
] = st.session_state.filtered_df.columns[0]
|
70 |
+
|
71 |
+
def callback_function(mod, key):
|
72 |
+
st.session_state[mod] = st.session_state[key]
|
73 |
+
|
74 |
+
st.write("Select the column which contains the event descriptions.")
|
75 |
+
st.selectbox(
|
76 |
+
"Select the event description column:",
|
77 |
+
st.session_state.filtered_df.columns,
|
78 |
+
key="text_column_design",
|
79 |
+
on_change=callback_function,
|
80 |
+
args=("text_column_design_perm", "text_column_design"),
|
81 |
+
index=get_idx_column(
|
82 |
+
st.session_state["text_column_design_perm"],
|
83 |
+
list(st.session_state.filtered_df.columns),
|
84 |
+
),
|
85 |
+
)
|
86 |
+
loading_df.text("")
|
87 |
+
|
88 |
+
# Remove NaN Texts
|
89 |
+
if st.button("Remove Empty Event Descriptions"):
|
90 |
+
st.session_state.filtered_df = st.session_state.filtered_df.dropna(
|
91 |
+
subset=[st.session_state["text_column_design_perm"]]
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
st.write("********")
|
96 |
+
st.markdown("## Optional Upload")
|
97 |
+
|
98 |
+
|
99 |
+
st.markdown(
|
100 |
+
"""
|
101 |
+
### Upload a codebook if available. It needs to be in the format used in this dashboard.
|
102 |
+
"""
|
103 |
+
)
|
104 |
+
uploaded_codebook = st.file_uploader("Upload a codebook if available (OPTIONAL)")
|
105 |
+
if uploaded_codebook is not None:
|
106 |
+
codebook = read_json_from_web(uploaded_codebook)
|
107 |
+
st.session_state.codebook = codebook
|
108 |
+
|
109 |
+
st.markdown(
|
110 |
+
"""
|
111 |
+
### Upload a validated dataset (accept, reject, ignored) in the format of this dashboard.
|
112 |
+
"""
|
113 |
+
)
|
114 |
+
|
115 |
+
uploaded_validated_data = st.file_uploader(
|
116 |
+
"Upload a json file containing validated data (OPTIONAL)"
|
117 |
+
)
|
118 |
+
if uploaded_validated_data is not None:
|
119 |
+
st.session_state.validated_data = read_json_from_web(uploaded_validated_data)
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: gray
|
|
5 |
colorTo: gray
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
5 |
colorTo: gray
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
+
app_file: Data_Loading.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
codebook_demo.json
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"events": {
|
3 |
+
"Arrest": {
|
4 |
+
"all": [],
|
5 |
+
"any": [
|
6 |
+
"People were arrested."
|
7 |
+
],
|
8 |
+
"not": {},
|
9 |
+
"not_any": [],
|
10 |
+
"not_all": [],
|
11 |
+
"all_any_rel": "OR",
|
12 |
+
"not_all_any_rel": "OR"
|
13 |
+
},
|
14 |
+
"Destruction": {
|
15 |
+
"all": [],
|
16 |
+
"any": [
|
17 |
+
"This event involves arson."
|
18 |
+
],
|
19 |
+
"not": {},
|
20 |
+
"not_any": [],
|
21 |
+
"not_all": [],
|
22 |
+
"all_any_rel": "OR",
|
23 |
+
"not_all_any_rel": "OR"
|
24 |
+
},
|
25 |
+
"Killing": {
|
26 |
+
"all": [],
|
27 |
+
"any": [
|
28 |
+
"This event involves killing."
|
29 |
+
],
|
30 |
+
"not": {},
|
31 |
+
"not_any": [],
|
32 |
+
"not_all": [],
|
33 |
+
"all_any_rel": "OR",
|
34 |
+
"not_all_any_rel": "OR"
|
35 |
+
},
|
36 |
+
"Looting": {
|
37 |
+
"all": [],
|
38 |
+
"any": [],
|
39 |
+
"not": {},
|
40 |
+
"not_any": [],
|
41 |
+
"not_all": [],
|
42 |
+
"all_any_rel": "OR",
|
43 |
+
"not_all_any_rel": "OR"
|
44 |
+
},
|
45 |
+
"Other": {
|
46 |
+
"all": [],
|
47 |
+
"any": [],
|
48 |
+
"not": {},
|
49 |
+
"not_any": [],
|
50 |
+
"not_all": [],
|
51 |
+
"all_any_rel": "OR",
|
52 |
+
"not_all_any_rel": "OR"
|
53 |
+
},
|
54 |
+
"Explosions": {
|
55 |
+
"all": [],
|
56 |
+
"any": [
|
57 |
+
"This event involves explosive."
|
58 |
+
],
|
59 |
+
"not": {},
|
60 |
+
"not_any": [],
|
61 |
+
"not_all": [],
|
62 |
+
"all_any_rel": "OR",
|
63 |
+
"not_all_any_rel": "OR"
|
64 |
+
},
|
65 |
+
"Kidnapping": {
|
66 |
+
"all": [],
|
67 |
+
"any": [
|
68 |
+
"People were kidnapped.",
|
69 |
+
"This event involves kidnapping."
|
70 |
+
],
|
71 |
+
"not": {},
|
72 |
+
"not_any": [],
|
73 |
+
"not_all": [],
|
74 |
+
"all_any_rel": "OR",
|
75 |
+
"not_all_any_rel": "OR"
|
76 |
+
},
|
77 |
+
"Sexual Violence": {
|
78 |
+
"all": [],
|
79 |
+
"any": [
|
80 |
+
"This event involves rape.",
|
81 |
+
"People were abused."
|
82 |
+
],
|
83 |
+
"not": {},
|
84 |
+
"not_any": [],
|
85 |
+
"not_all": [],
|
86 |
+
"all_any_rel": "OR",
|
87 |
+
"not_all_any_rel": "OR"
|
88 |
+
},
|
89 |
+
"Protests": {
|
90 |
+
"all": [],
|
91 |
+
"any": [
|
92 |
+
"This event involves protest.",
|
93 |
+
"This event involves demonstration.",
|
94 |
+
"This event involves protester."
|
95 |
+
],
|
96 |
+
"not": {},
|
97 |
+
"not_any": [],
|
98 |
+
"not_all": [],
|
99 |
+
"all_any_rel": "OR",
|
100 |
+
"not_all_any_rel": "OR"
|
101 |
+
}
|
102 |
+
},
|
103 |
+
"templates": [
|
104 |
+
"This event involves [Z].",
|
105 |
+
"People were [Z]."
|
106 |
+
],
|
107 |
+
"add_words": []
|
108 |
+
}
|
data_demo.csv
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Event Descriptions
|
2 |
+
"An organization abducts a group of 50 women and children. 'They started loading our women and children into their vehicles, threatening to shoot whoever disobeyed them. Everybody was scared,' a witness stated. The date of the attack was not included in the report, but is likely to have occurred on 13/10 or 14/10."
|
3 |
+
"On 7 July 2019, unidentified gunmen abducted a man. No requests for ransom have been reported."
|
4 |
+
"On 2 March 2020, a driver was abducted by militiamen in the region. The person was later released in a village."
|
5 |
+
"On 23 August, a group attacked a village, abducting 6 people."
|
6 |
+
05 May 2021. An unidentified armed group abducted a girl in the town. [women targeted: girls]
|
7 |
+
Mass demonstrations honoring the bravery of Revolution martyrs.
|
8 |
+
"The national organization of Teachers, has organised a mass protest both against the substantive issue and against police brutality on December 11th, which is the international day of human rights."
|
9 |
+
"Answering the call of the religious organization following the Friday prayer, citizens held two protest sit-ins to express their support for discriminated people and denounce normalization of relations with the neighboring state. [size=no report]"
|
10 |
+
Families affected by the collapse of their building organized a protest gathering to demand from the local authorities to relocate them. [size=no report]
|
11 |
+
"On 15 September, gold miners held a protest sit-in to demand the closure of illegal gold refineries. [size=no report]"
|
12 |
+
"Around 14 August 2021 (between 13 - 18 August), police intelligence agents arrested the national party leader, a journalist and 8 other members of the party in the city. Reason for arrest not clear."
|
13 |
+
"5 arrested, houses ransacked, due to suspected links with terrorist organizations, then extradited "
|
14 |
+
"On 15 June 2020, Police Forces arrested 70 civilians, in an ongoing crackdown against religious minorities."
|
15 |
+
Arrests: Police detain a journalist who recently published a politically disagreeable article.
|
16 |
+
About 23 suspected political thugs were arrested by the men of the state police command during the just concluded national assembly poll.
|
17 |
+
"Around 13 June 2019 (as reported), local militiamen set a settlement ablaze. Event connected to an earlier attack on residents by pastoralists."
|
18 |
+
"Around 4 March, suspected separatists destroyed a number of houses and property in the region."
|
19 |
+
"On 30 January, an unidentified armed group has razed an Ebola handwashing station."
|
20 |
+
"On 12 November 2019, unknown individuals set ablaze a half hectare of soja harvest."
|
21 |
+
"On 11 February 2020, suspected fighters stole about 30 cattle during a raid."
|
22 |
+
"On 3 December, at least 12 illegal checkpoints were erected by unknown troops. The soldiers reportedly extort money from travelers."
|
23 |
+
"Around 12 October 2021 (between 12 - 13 October), unknown gunmen seized livestock belonging to a councilor in the village."
|
24 |
+
"Pro-government militia have set up more than 20 checkpoints along the road, demanding fees from vehicles, robbing passengers, stealing from lorries & seizing vehicles"
|
25 |
+
"2 armed men, believed to be part of a vigilante militia, robbed a NGO camp, which has led to a reduction in the NGO operations in the area. No casualties were reported."
|
26 |
+
Members of a peacekeeping unit have been accused of raping a young girl. [women targeted: girls]
|
27 |
+
"On 12 April 2020, presumed militants raped a woman."
|
28 |
+
Peacekeepers allegedly carried out sexual abuse on civilians.
|
29 |
+
2 people were killed during an attack by 20 militiamen in the chiefdom of. 1 person was killed and 2 raped. The militiamen looted several house and fired into the air as they raided the areas.
|
30 |
+
A five-year-old girl was raped and then killed by unidentified men. [women targeted: girls]
|
31 |
+
"Grenade explodes, killing 2 - affiliated to ruling party"
|
32 |
+
"On 25 February 2021, terrorist militants threw a hand grenade at a house. Two soldiers were injured in the explosion."
|
33 |
+
"On 8 August 2020, air force targeted a Military Faculty. No casualties reported."
|
34 |
+
Air forces drop 6 bombs in the region
|
35 |
+
"On 13 November, a cart was struck by an IED about 3km north, one civilian was killed and another severely wounded, the two donkeys that pulled the cart were also killed. The IED was most likely planted by terrorist militants."
|
helpers.py
ADDED
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import string
|
3 |
+
from time import time
|
4 |
+
|
5 |
+
import en_core_web_lg
|
6 |
+
import inflect
|
7 |
+
import nltk
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import streamlit as st
|
11 |
+
from nltk.tokenize import sent_tokenize
|
12 |
+
from transformers import pipeline
|
13 |
+
|
14 |
+
# Set constant values
|
15 |
+
INFLECT_ENGINE = inflect.engine()
|
16 |
+
TOP_K = 30
|
17 |
+
NLI_LIMIT = 0.9
|
18 |
+
|
19 |
+
st.set_page_config(layout="wide")
|
20 |
+
|
21 |
+
|
22 |
+
def get_top_k():
|
23 |
+
return TOP_K
|
24 |
+
|
25 |
+
|
26 |
+
def get_nli_limit():
|
27 |
+
return NLI_LIMIT
|
28 |
+
|
29 |
+
|
30 |
+
### Streamlit specific
|
31 |
+
@st.cache(allow_output_mutation=True)
|
32 |
+
def load_model_prompting():
|
33 |
+
return pipeline("fill-mask", model="distilbert-base-uncased")
|
34 |
+
|
35 |
+
|
36 |
+
@st.cache(allow_output_mutation=True)
|
37 |
+
def load_model_nli():
|
38 |
+
try:
|
39 |
+
return pipeline(
|
40 |
+
task="sentiment-analysis", model="roberta-large-mnli", device="mps"
|
41 |
+
)
|
42 |
+
except:
|
43 |
+
return pipeline(task="sentiment-analysis", model="roberta-large-mnli")
|
44 |
+
|
45 |
+
|
46 |
+
@st.cache(allow_output_mutation=True)
|
47 |
+
def load_spacy_pipeline():
|
48 |
+
return en_core_web_lg.load()
|
49 |
+
|
50 |
+
|
51 |
+
@st.cache()
|
52 |
+
def download_punkt():
|
53 |
+
nltk.download("punkt")
|
54 |
+
|
55 |
+
|
56 |
+
download_punkt()
|
57 |
+
|
58 |
+
|
59 |
+
@st.experimental_memo(max_entries=1)
|
60 |
+
def read_json_from_web(uploaded_json):
|
61 |
+
return json.load(uploaded_json)
|
62 |
+
|
63 |
+
|
64 |
+
@st.experimental_memo(max_entries=1)
|
65 |
+
def read_csv_from_web(uploaded_file):
|
66 |
+
"""Read CSV from the streamlit interface
|
67 |
+
|
68 |
+
:param uploaded_file: File to read
|
69 |
+
:type uploaded_file: UploadedFile (BytesIO)
|
70 |
+
:return: Dataframe
|
71 |
+
:rtype: pandas DataFrame
|
72 |
+
"""
|
73 |
+
try:
|
74 |
+
# Try first to read comma separated and semicolon separated files
|
75 |
+
data = pd.read_csv(uploaded_file, sep=None, engine="python")
|
76 |
+
# If both are not correct, then it will error and go to the except
|
77 |
+
except pd.errors.ParserError:
|
78 |
+
# This should be the case when there is no separator (1 column csv)
|
79 |
+
# Reset the IO object due to the previous crash
|
80 |
+
uploaded_file.seek(0)
|
81 |
+
# Use standard reading of CSV (no separator)
|
82 |
+
data = pd.read_csv(uploaded_file)
|
83 |
+
return data
|
84 |
+
|
85 |
+
|
86 |
+
def apply_style():
|
87 |
+
# Avoid having ellipsis in the multi select options
|
88 |
+
styl = """
|
89 |
+
<style>
|
90 |
+
.stMultiSelect span{
|
91 |
+
max-width: none;
|
92 |
+
|
93 |
+
}
|
94 |
+
</style>
|
95 |
+
"""
|
96 |
+
st.markdown(styl, unsafe_allow_html=True)
|
97 |
+
|
98 |
+
# Set color of multiselect to red
|
99 |
+
st.markdown(
|
100 |
+
"""
|
101 |
+
<style>
|
102 |
+
span[data-baseweb="tag"] {
|
103 |
+
background-color: red !important;
|
104 |
+
}
|
105 |
+
</style>
|
106 |
+
""",
|
107 |
+
unsafe_allow_html=True,
|
108 |
+
)
|
109 |
+
|
110 |
+
hide_st_style = """
|
111 |
+
<style>
|
112 |
+
#MainMenu {visibility: hidden;}
|
113 |
+
footer {visibility: hidden;}
|
114 |
+
header {visibility: hidden;}
|
115 |
+
</style>
|
116 |
+
"""
|
117 |
+
st.markdown(hide_st_style, unsafe_allow_html=True)
|
118 |
+
|
119 |
+
|
120 |
+
def choose_text_menu(text):
|
121 |
+
if "text" not in st.session_state:
|
122 |
+
st.session_state.text = "Several demonstrators were injured."
|
123 |
+
text = st.text_area("Event description", st.session_state.text)
|
124 |
+
|
125 |
+
return text
|
126 |
+
|
127 |
+
|
128 |
+
def initiate_widget_st_state(widget_key, perm_key, default_value):
|
129 |
+
if perm_key not in st.session_state:
|
130 |
+
st.session_state[perm_key] = default_value
|
131 |
+
if widget_key not in st.session_state:
|
132 |
+
st.session_state[widget_key] = st.session_state[perm_key]
|
133 |
+
|
134 |
+
|
135 |
+
def get_idx_column(col_name, col_list):
|
136 |
+
if col_name in col_list:
|
137 |
+
return col_list.index(col_name)
|
138 |
+
else:
|
139 |
+
return 0
|
140 |
+
|
141 |
+
|
142 |
+
def callback_add_to_multiselect(str_to_add, multiselect_key, text_input_key, *keys):
|
143 |
+
if len(str_to_add) == 0:
|
144 |
+
st.warning("Word is empty, did you press Enter on the field text?")
|
145 |
+
return
|
146 |
+
current_dict = st.session_state
|
147 |
+
*dict_keys, item_keys = keys
|
148 |
+
try:
|
149 |
+
for key in dict_keys:
|
150 |
+
current_dict = current_dict[key]
|
151 |
+
current_dict[item_keys].append(str_to_add)
|
152 |
+
except KeyError as e:
|
153 |
+
raise KeyError(keys) from e
|
154 |
+
|
155 |
+
if multiselect_key in st.session_state:
|
156 |
+
st.session_state[multiselect_key].append(str_to_add)
|
157 |
+
else:
|
158 |
+
st.session_state[multiselect_key] = [str_to_add]
|
159 |
+
|
160 |
+
st.session_state[text_input_key] = ""
|
161 |
+
|
162 |
+
|
163 |
+
# Split the text into sentences. Necessary for NLI models
|
164 |
+
def split_sentences(text):
|
165 |
+
return sent_tokenize(text)
|
166 |
+
|
167 |
+
|
168 |
+
def get_num_sentences_in_list_text(list_texts):
|
169 |
+
num_sentences = 0
|
170 |
+
for text in list_texts:
|
171 |
+
num_sentences += len(split_sentences(text))
|
172 |
+
return num_sentences
|
173 |
+
|
174 |
+
|
175 |
+
###### Prompting
|
176 |
+
def query_model_prompting(model, text, prompt_with_mask, top_k, targets):
|
177 |
+
"""Query the prompting model
|
178 |
+
|
179 |
+
:param model: Prompting model object
|
180 |
+
:type model: Huggingface pipeline object
|
181 |
+
:param text: Event description (context)
|
182 |
+
:type text: str
|
183 |
+
:param prompt_with_mask: Prompt with a mask
|
184 |
+
:type prompt_with_mask: str
|
185 |
+
:param top_k: Number of tokens to output
|
186 |
+
:type top_k: integer
|
187 |
+
:param targets: Restrict the answer to these possible tokens
|
188 |
+
:type targets: list
|
189 |
+
:return: Results of the prompting model
|
190 |
+
:rtype: list of dict
|
191 |
+
"""
|
192 |
+
sequence = text + prompt_with_mask
|
193 |
+
output_tokens = model(sequence, top_k=top_k, targets=targets)
|
194 |
+
|
195 |
+
return output_tokens
|
196 |
+
|
197 |
+
|
198 |
+
def do_sentence_entailment(sentence, hypothesis, model):
|
199 |
+
"""Concatenate context and hypothesis then perform entailment
|
200 |
+
|
201 |
+
:param sentence: Event description (context), 1 sentence
|
202 |
+
:type sentence: str
|
203 |
+
:param hypothesis: Mask filled with a token
|
204 |
+
:type hypothesis: str
|
205 |
+
:param model: NLI Model
|
206 |
+
:type model: Huggingface pipeline
|
207 |
+
:return: DataFrame containing the result of the entailment
|
208 |
+
:rtype: pandas DataFrame
|
209 |
+
"""
|
210 |
+
text = sentence + "</s></s>" + hypothesis
|
211 |
+
res = model(text, return_all_scores=True)
|
212 |
+
df_res = pd.DataFrame(res[0])
|
213 |
+
df_res["label"] = df_res["label"].apply(lambda x: x.lower())
|
214 |
+
df_res.columns = ["Label", "Score"]
|
215 |
+
return df_res
|
216 |
+
|
217 |
+
|
218 |
+
def softmax(x):
|
219 |
+
"""Compute softmax values for each sets of scores in x."""
|
220 |
+
return np.exp(x) / np.sum(np.exp(x), axis=0)
|
221 |
+
|
222 |
+
|
223 |
+
def get_singular_form(word):
|
224 |
+
"""Get the singular form of a word
|
225 |
+
|
226 |
+
:param word: word
|
227 |
+
:type word: string
|
228 |
+
:return: singular form of the word
|
229 |
+
:rtype: string
|
230 |
+
"""
|
231 |
+
if INFLECT_ENGINE.singular_noun(word):
|
232 |
+
return INFLECT_ENGINE.singular_noun(word)
|
233 |
+
else:
|
234 |
+
return word
|
235 |
+
|
236 |
+
|
237 |
+
######### NLI + PROMPTING
|
238 |
+
def do_text_entailment(text, hypothesis, model):
|
239 |
+
"""
|
240 |
+
Do entailment for each sentence of the event description as
|
241 |
+
model was trained on sentence pair
|
242 |
+
|
243 |
+
:param text: Event Description (context)
|
244 |
+
:type text: str
|
245 |
+
:param hypothesis: Mask filled with a token
|
246 |
+
:type hypothesis: str
|
247 |
+
:param model: Model NLI
|
248 |
+
:type model: Huggingface pipeline
|
249 |
+
:return: List of entailment results for each sentence of the text
|
250 |
+
:rtype: list
|
251 |
+
"""
|
252 |
+
text_entailment_results = []
|
253 |
+
for i, sentence in enumerate(split_sentences(text)):
|
254 |
+
df_score = do_sentence_entailment(sentence, hypothesis, model)
|
255 |
+
text_entailment_results.append((sentence, hypothesis, df_score))
|
256 |
+
return text_entailment_results
|
257 |
+
|
258 |
+
|
259 |
+
def get_true_entailment(text_entailment_results, nli_limit):
|
260 |
+
"""
|
261 |
+
From the result of each sentence entailment, extract the maximum entailment score and
|
262 |
+
check if it's higher than the entailment threshold.
|
263 |
+
"""
|
264 |
+
true_hypothesis_list = []
|
265 |
+
max_score = 0
|
266 |
+
for sentence_entailment in text_entailment_results:
|
267 |
+
df_score = sentence_entailment[2]
|
268 |
+
score = df_score[df_score["Label"] == "entailment"]["Score"].values.max()
|
269 |
+
if score > max_score:
|
270 |
+
max_score = score
|
271 |
+
if max_score > nli_limit:
|
272 |
+
true_hypothesis_list.append((sentence_entailment[1], np.round(max_score, 2)))
|
273 |
+
return list(set(true_hypothesis_list))
|
274 |
+
|
275 |
+
|
276 |
+
def run_model_nli(data, batch_size, model_nli, use_tf=False):
|
277 |
+
if not use_tf:
|
278 |
+
return model_nli(data, top_k=3, batch_size=batch_size)
|
279 |
+
else:
|
280 |
+
raise NotImplementedError
|
281 |
+
# return run_pipeline_on_gpu(data, batch_size, model_nli["tokenizer"], model_nli["model"])
|
282 |
+
|
283 |
+
|
284 |
+
def prompt_to_nli_batching(
|
285 |
+
text,
|
286 |
+
prompt,
|
287 |
+
model_prompting,
|
288 |
+
nli_model,
|
289 |
+
nlp,
|
290 |
+
top_k=10,
|
291 |
+
nli_limit=0.5,
|
292 |
+
targets=None,
|
293 |
+
additional_words=None,
|
294 |
+
remove_lemma=False,
|
295 |
+
use_tf=False,
|
296 |
+
):
|
297 |
+
# Check if text has end ponctuation
|
298 |
+
if text[-1] not in string.punctuation:
|
299 |
+
text += "."
|
300 |
+
prompt_masked = prompt.format(model_prompting.tokenizer.mask_token)
|
301 |
+
output_prompting = query_model_prompting(
|
302 |
+
model_prompting, text, prompt_masked, top_k, targets=targets
|
303 |
+
)
|
304 |
+
if remove_lemma:
|
305 |
+
output_prompting = filter_prompt_output_by_lemma(prompt, output_prompting, nlp)
|
306 |
+
full_batch_concat = []
|
307 |
+
prompt_tokens = []
|
308 |
+
for token in output_prompting:
|
309 |
+
hypothesis = prompt.format(token["token_str"])
|
310 |
+
for i, sentence in enumerate(split_sentences(text)):
|
311 |
+
full_batch_concat.append(sentence + "</s></s>" + hypothesis)
|
312 |
+
prompt_tokens.append((token["token_str"], token["score"]))
|
313 |
+
|
314 |
+
# Add words that must be tried for entailment
|
315 |
+
# Also increase batch_size
|
316 |
+
if additional_words:
|
317 |
+
for i, sentence in enumerate(split_sentences(text)):
|
318 |
+
for token in additional_words:
|
319 |
+
hypothesis = prompt.format(token)
|
320 |
+
full_batch_concat.append(sentence + "</s></s>" + hypothesis)
|
321 |
+
prompt_tokens.append((token, 1))
|
322 |
+
top_k = top_k + 1
|
323 |
+
results_nli = run_model_nli(full_batch_concat, top_k, nli_model, use_tf)
|
324 |
+
# Get entailed tokens
|
325 |
+
entailed_tokens = []
|
326 |
+
for i, res in enumerate(results_nli):
|
327 |
+
entailed_tokens.extend(
|
328 |
+
[
|
329 |
+
(get_singular_form(prompt_tokens[i][0]), x["score"])
|
330 |
+
for x in res
|
331 |
+
if ((x["label"] == "ENTAILMENT") & (x["score"] > nli_limit))
|
332 |
+
]
|
333 |
+
)
|
334 |
+
if entailed_tokens:
|
335 |
+
entailed_tokens = list(
|
336 |
+
pd.DataFrame(entailed_tokens).groupby(0).max()[1].items()
|
337 |
+
)
|
338 |
+
|
339 |
+
return entailed_tokens, list(set(prompt_tokens))
|
340 |
+
|
341 |
+
|
342 |
+
def remove_similar_lemma_from_list(prompt, list_words, nlp):
|
343 |
+
## Compute a dictionnary with the lemma for all tokens
|
344 |
+
## If there is a duplicate lemma then the dictionnary value will be a list of the corresponding tokens
|
345 |
+
lemma_dict = {}
|
346 |
+
for each in list_words:
|
347 |
+
mask_filled = nlp(prompt.strip(".").format(each))
|
348 |
+
lemma_dict.setdefault([x.lemma_ for x in mask_filled][-1], []).append(each)
|
349 |
+
|
350 |
+
## Get back the list of tokens
|
351 |
+
## If multiple tokens available then take the shortest one
|
352 |
+
new_token_list = []
|
353 |
+
for key in lemma_dict.keys():
|
354 |
+
if len(lemma_dict[key]) >= 1:
|
355 |
+
new_token_list.append(min(lemma_dict[key], key=len))
|
356 |
+
else:
|
357 |
+
raise ValueError("Lemma dict has 0 corresponding words")
|
358 |
+
return new_token_list
|
359 |
+
|
360 |
+
|
361 |
+
def filter_prompt_output_by_lemma(prompt, output_prompting, nlp):
|
362 |
+
"""
|
363 |
+
Remove all similar lemmas from the prompt output (e.g. "protest", "protests")
|
364 |
+
"""
|
365 |
+
list_words = [x["token_str"] for x in output_prompting]
|
366 |
+
new_token_list = remove_similar_lemma_from_list(prompt, list_words, nlp)
|
367 |
+
return [x for x in output_prompting if x["token_str"] in new_token_list]
|
368 |
+
|
369 |
+
|
370 |
+
# Streamlit specific run functions
|
371 |
+
@st.experimental_memo(max_entries=1024)
|
372 |
+
def do_prent(text, template, top_k, nli_limit, additional_words=None):
|
373 |
+
"""Function used to execute PRENT model
|
374 |
+
|
375 |
+
:param text: Event text
|
376 |
+
:type text: string
|
377 |
+
:param template: Template with mask
|
378 |
+
:type template: string
|
379 |
+
:param top_k: Maximum tokens to output from prompting model
|
380 |
+
:type top_k: int
|
381 |
+
:param nli_limit: Threshold of entailment for NLI [0,1]
|
382 |
+
:type nli_limit: float
|
383 |
+
:param additional_words: List of words that bypass prompting and goes directly to NLI, defaults to None
|
384 |
+
:type additional_words: list, optional
|
385 |
+
:return: (Results Entailment, Results Prompting)
|
386 |
+
:rtype: tuple
|
387 |
+
"""
|
388 |
+
results_nli, results_pr = prompt_to_nli_batching(
|
389 |
+
text,
|
390 |
+
template,
|
391 |
+
load_model_prompting(),
|
392 |
+
load_model_nli(),
|
393 |
+
load_spacy_pipeline(),
|
394 |
+
top_k=top_k,
|
395 |
+
nli_limit=nli_limit,
|
396 |
+
targets=None,
|
397 |
+
additional_words=additional_words,
|
398 |
+
remove_lemma=True,
|
399 |
+
)
|
400 |
+
return results_nli, results_pr
|
401 |
+
|
402 |
+
|
403 |
+
def get_additional_words():
|
404 |
+
"""Extract the additional words from the codebook
|
405 |
+
|
406 |
+
:return: list of additional words
|
407 |
+
:rtype: list
|
408 |
+
"""
|
409 |
+
if "add_words" in st.session_state.codebook:
|
410 |
+
additional_words = st.session_state.codebook["add_words"]
|
411 |
+
else:
|
412 |
+
additional_words = None
|
413 |
+
return additional_words
|
414 |
+
|
415 |
+
|
416 |
+
def run_prent(
|
417 |
+
text="", templates=[], additional_words=None, progress=True, display_text=True
|
418 |
+
):
|
419 |
+
"""Execute PRENT over a list of templates and display streamlit widgets
|
420 |
+
|
421 |
+
:param text: Event description, defaults to ""
|
422 |
+
:type text: str, optional
|
423 |
+
:param templates: Templates with a mask, defaults to []
|
424 |
+
:type templates: list, optional
|
425 |
+
:param additional_words: List of words to bypass prompting, defaults to None
|
426 |
+
:type additional_words: list, optional
|
427 |
+
:param progress: Display or not the progress bar, defaults to True
|
428 |
+
:type progress: bool, optional
|
429 |
+
:return: (results of prent, computation time)
|
430 |
+
:rtype: tuple
|
431 |
+
"""
|
432 |
+
# Check if there is any template and event description available
|
433 |
+
if not templates:
|
434 |
+
st.warning("Template list is empty. Please add one.")
|
435 |
+
return None, None
|
436 |
+
if not text:
|
437 |
+
st.warning("Event description is empty.")
|
438 |
+
return None, None
|
439 |
+
|
440 |
+
# Display text only when computing
|
441 |
+
if display_text:
|
442 |
+
temp_text = st.empty()
|
443 |
+
temp_text.markdown("**Event Descriptions:** {}".format(text))
|
444 |
+
|
445 |
+
# Start progress bar
|
446 |
+
if progress:
|
447 |
+
progress_bar = st.progress(0)
|
448 |
+
num_prent_call = len(templates)
|
449 |
+
num_sentences = get_num_sentences_in_list_text([text])
|
450 |
+
iter = 0
|
451 |
+
t0 = time()
|
452 |
+
|
453 |
+
# We set the radio choice of streamlit to Ignore at first
|
454 |
+
if "accept_reject_text_perm" in st.session_state:
|
455 |
+
st.session_state["accept_reject_text_perm"] = "Ignore"
|
456 |
+
|
457 |
+
res = {}
|
458 |
+
for template in templates:
|
459 |
+
template = template.replace("[Z]", "{}")
|
460 |
+
results_nli, results_pr = do_prent(
|
461 |
+
text,
|
462 |
+
template,
|
463 |
+
top_k=TOP_K,
|
464 |
+
nli_limit=NLI_LIMIT,
|
465 |
+
additional_words=additional_words,
|
466 |
+
)
|
467 |
+
# Results_nli contains % of entailment, we only care about the tokens string
|
468 |
+
res[template] = [x[0] for x in results_nli]
|
469 |
+
|
470 |
+
# Update progress bar
|
471 |
+
iter += 1
|
472 |
+
if progress:
|
473 |
+
progress_bar.progress((1 / num_prent_call) * (iter))
|
474 |
+
if display_text:
|
475 |
+
temp_text.markdown("")
|
476 |
+
time_comput = (time() - t0) / num_sentences
|
477 |
+
# This check is done otherwise the time of computation is replaced by the
|
478 |
+
# time of computation when using cached value
|
479 |
+
if not time_comput < st.session_state.time_comput / 5:
|
480 |
+
st.session_state.time_comput = int(time_comput)
|
481 |
+
|
482 |
+
# Store some results
|
483 |
+
res["templates_used"] = templates
|
484 |
+
res["additional_words_used"] = additional_words
|
485 |
+
return res, time_comput
|
486 |
+
|
487 |
+
|
488 |
+
####### Find event types based on codebook and PRENT results
|
489 |
+
def check_any_conds(cond_any, list_res):
|
490 |
+
"""Function that evaluates the "OR" conditions of the codebook versus the list of filled templates
|
491 |
+
|
492 |
+
:param cond_any: List of groundtruth filled templates
|
493 |
+
:type cond_any: list
|
494 |
+
:param list_res: A list of the filled templates given by PRENT
|
495 |
+
:type list_res: list
|
496 |
+
:return: True if any groundtruth template is inside the list given by PRENT
|
497 |
+
:rtype: bool
|
498 |
+
"""
|
499 |
+
cond_any = list(cond_any)
|
500 |
+
condition = False
|
501 |
+
# Return False if there is no any condition
|
502 |
+
if not cond_any:
|
503 |
+
return False
|
504 |
+
for cond in cond_any:
|
505 |
+
# With the current codebook design, this should never be true.
|
506 |
+
# Before it was possible to have recursion to check AND conditions inside an OR condition
|
507 |
+
if isinstance(cond, dict):
|
508 |
+
condition = check_all_conds(cond["all"], list_res)
|
509 |
+
else:
|
510 |
+
# Check lowercase version of templates
|
511 |
+
if cond.lower() in [x.lower() for x in list_res]:
|
512 |
+
condition = True
|
513 |
+
# Exit function as the other templates won't change the outcome
|
514 |
+
return condition
|
515 |
+
return condition
|
516 |
+
|
517 |
+
|
518 |
+
def check_all_conds(cond_all, list_res):
|
519 |
+
"""Function that evaluates the "AND" conditions of the codebook versus the list of filled templates
|
520 |
+
|
521 |
+
:param cond_all: List of groundtruth filled templates
|
522 |
+
:type cond_all: list
|
523 |
+
:param list_res: A list of the filled templates given by PRENT
|
524 |
+
:type list_res: list
|
525 |
+
:return: True if all groundtruth template are inside the list given by PRENT
|
526 |
+
:rtype: bool
|
527 |
+
"""
|
528 |
+
cond_all = list(cond_all)
|
529 |
+
# Return False if there is no all condition
|
530 |
+
if not cond_all:
|
531 |
+
return False
|
532 |
+
# Start bool on True, and put it to false if any template is missing
|
533 |
+
condition = True
|
534 |
+
for cond in cond_all:
|
535 |
+
# With the current codebook design, this should never be true.
|
536 |
+
# Before it was possible to have recursion to check OR conditions inside an AND condition
|
537 |
+
if isinstance(cond, dict):
|
538 |
+
condition = check_any_conds(cond["any"])
|
539 |
+
else:
|
540 |
+
# Check lowercase version of templates
|
541 |
+
if not (cond.lower() in [x.lower() for x in list_res]):
|
542 |
+
condition = False
|
543 |
+
# Exit function as the other templates won't change the outcome
|
544 |
+
return condition
|
545 |
+
return condition
|
546 |
+
|
547 |
+
|
548 |
+
def find_event_types(codebook, list_res):
|
549 |
+
"""This function evaluates the codebook and then outputs a list of events types corresponding to the given results of PRENT (list of filled templates).
|
550 |
+
|
551 |
+
:param codebook: A codebook in the format given by the dashboard
|
552 |
+
:type codebook: dict
|
553 |
+
:param list_res: A list of the filled templates given by PRENT
|
554 |
+
:type list_res: list
|
555 |
+
:return: List of event type
|
556 |
+
:rtype: list
|
557 |
+
"""
|
558 |
+
list_event_type = []
|
559 |
+
# Iterate over all defined event types
|
560 |
+
for event_type in codebook["events"]:
|
561 |
+
code_event = codebook["events"][event_type]
|
562 |
+
|
563 |
+
is_not_all_event, is_not_any_event, is_not_event = False, False, False
|
564 |
+
is_all_event, is_any_event, is_event = False, False, False
|
565 |
+
|
566 |
+
# First check if NOT conditions are met
|
567 |
+
# e.g. a filled template that is contrary to the event is present
|
568 |
+
if "not_all" in code_event:
|
569 |
+
cond_all = code_event["not_all"]
|
570 |
+
if check_all_conds(cond_all, list_res):
|
571 |
+
is_not_all_event = True
|
572 |
+
if "not_any" in code_event:
|
573 |
+
cond_any = code_event["not_any"]
|
574 |
+
if check_any_conds(cond_any, list_res):
|
575 |
+
is_not_any_event = True
|
576 |
+
|
577 |
+
# Next we need to check if the "not_all" and "not_any" are related
|
578 |
+
# by an "OR" or "AND".
|
579 |
+
# This latest case needs special care because one of two list can
|
580 |
+
# be empty so False
|
581 |
+
if code_event["not_all_any_rel"] == "AND":
|
582 |
+
if is_not_all_event and (not code_event["not_any"]):
|
583 |
+
# If all TRUE and ANY is empty (so false)
|
584 |
+
is_not_event = True
|
585 |
+
elif is_not_any_event and (not code_event["not_all"]):
|
586 |
+
# If any TRUE and ALL is empty (so false)
|
587 |
+
is_not_event = True
|
588 |
+
if is_not_all_event and is_not_any_event:
|
589 |
+
is_not_event = True
|
590 |
+
elif code_event["not_all_any_rel"] == "OR":
|
591 |
+
if is_not_all_event or is_not_any_event:
|
592 |
+
is_not_event = True
|
593 |
+
|
594 |
+
# The other checks are not necessary if this is true, so we go
|
595 |
+
# to the next iteration
|
596 |
+
if is_not_event:
|
597 |
+
continue
|
598 |
+
|
599 |
+
# Similar to the previous checks but this time we look for templates that should be present
|
600 |
+
if "all" in code_event:
|
601 |
+
cond_all = code_event["all"]
|
602 |
+
## Then check if All conditions are met, if not exit
|
603 |
+
if check_all_conds(cond_all, list_res):
|
604 |
+
is_all_event = True
|
605 |
+
if "any" in code_event:
|
606 |
+
## Finally check if Any conditions is met, if not exit
|
607 |
+
cond_any = code_event["any"]
|
608 |
+
if check_any_conds(cond_any, list_res):
|
609 |
+
is_any_event = True
|
610 |
+
|
611 |
+
# This case needs special care because one of two list can
|
612 |
+
# be empty so False
|
613 |
+
if code_event["all_any_rel"] == "AND":
|
614 |
+
if is_all_event and (not code_event["any"]):
|
615 |
+
# If all TRUE and ANY is empty (so false)
|
616 |
+
is_event = True
|
617 |
+
elif is_any_event and (not code_event["all"]):
|
618 |
+
# If any TRUE and ALL is empty (so false)
|
619 |
+
is_event = True
|
620 |
+
elif is_all_event and is_any_event:
|
621 |
+
is_event = True
|
622 |
+
elif code_event["all_any_rel"] == "OR":
|
623 |
+
if is_all_event or is_any_event:
|
624 |
+
is_event = True
|
625 |
+
|
626 |
+
# If all checks are correct, then we can add the event type to the output list
|
627 |
+
if is_event:
|
628 |
+
list_event_type.append(event_type)
|
629 |
+
|
630 |
+
return list_event_type
|
pages/1_Codebook_Design.py
ADDED
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime as datetime
|
2 |
+
import hashlib
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
import streamlit as st
|
9 |
+
|
10 |
+
current = os.path.dirname(os.path.realpath(__file__))
|
11 |
+
parent = os.path.dirname(current)
|
12 |
+
sys.path.append(parent)
|
13 |
+
from helpers import (
|
14 |
+
apply_style,
|
15 |
+
callback_add_to_multiselect,
|
16 |
+
choose_text_menu,
|
17 |
+
do_prent,
|
18 |
+
find_event_types,
|
19 |
+
get_additional_words,
|
20 |
+
get_idx_column,
|
21 |
+
get_nli_limit,
|
22 |
+
get_num_sentences_in_list_text,
|
23 |
+
get_top_k,
|
24 |
+
initiate_widget_st_state,
|
25 |
+
run_prent,
|
26 |
+
)
|
27 |
+
|
28 |
+
# Set constant values
|
29 |
+
TOP_K = get_top_k()
|
30 |
+
NLI_LIMIT = get_nli_limit()
|
31 |
+
|
32 |
+
### Styling
|
33 |
+
# Needs to be done first
|
34 |
+
apply_style()
|
35 |
+
|
36 |
+
# Avoid having ellipsis in the multi select options
|
37 |
+
styl = """
|
38 |
+
<style>
|
39 |
+
.stMultiSelect span{
|
40 |
+
max-width: none;
|
41 |
+
|
42 |
+
}
|
43 |
+
</style>
|
44 |
+
"""
|
45 |
+
st.markdown(styl, unsafe_allow_html=True)
|
46 |
+
|
47 |
+
# Set color of multiselect to red
|
48 |
+
st.markdown(
|
49 |
+
"""
|
50 |
+
<style>
|
51 |
+
span[data-baseweb="tag"] {
|
52 |
+
background-color: red !important;
|
53 |
+
}
|
54 |
+
</style>
|
55 |
+
""",
|
56 |
+
unsafe_allow_html=True,
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def validated_metric_per_event_types(validated_dataset):
|
61 |
+
"""Compute the accuracy metrics of the validated dataset
|
62 |
+
for each event type. Compute True Positive, False Negative,
|
63 |
+
True Negative, False Positive.
|
64 |
+
|
65 |
+
:param validated_dataset: Dictionary containing results of PRENT validated by the user
|
66 |
+
:type validated_dataset: dict
|
67 |
+
:return: Dictionnary containing accuracy metric for all event types
|
68 |
+
:rtype: dict
|
69 |
+
"""
|
70 |
+
dict_acc = {}
|
71 |
+
|
72 |
+
for key, val in validated_dataset.items():
|
73 |
+
# Compute the event types based on the computed templates of PRENT
|
74 |
+
pred_event_types = find_event_types(
|
75 |
+
st.session_state.codebook, val["filled_templates"]
|
76 |
+
)
|
77 |
+
true_event_types = val["event_types"]
|
78 |
+
# Compute only accuracy for accepted samples
|
79 |
+
if val["decision"] == "Accept":
|
80 |
+
# Iterate over all possible event types
|
81 |
+
for event_type in st.session_state.codebook["events"].keys():
|
82 |
+
dict_acc.setdefault(event_type, {})
|
83 |
+
dict_acc[event_type].setdefault("TP", 0)
|
84 |
+
dict_acc[event_type].setdefault("FN", 0)
|
85 |
+
dict_acc[event_type].setdefault("FP", 0)
|
86 |
+
dict_acc[event_type].setdefault("TN", 0)
|
87 |
+
if (event_type in true_event_types) and (
|
88 |
+
event_type in pred_event_types
|
89 |
+
):
|
90 |
+
dict_acc[event_type]["TP"] += 1
|
91 |
+
elif (event_type in true_event_types) and not (
|
92 |
+
event_type in pred_event_types
|
93 |
+
):
|
94 |
+
dict_acc[event_type]["FN"] += 1
|
95 |
+
elif not (event_type in true_event_types) and (
|
96 |
+
event_type in pred_event_types
|
97 |
+
):
|
98 |
+
dict_acc[event_type]["FP"] += 1
|
99 |
+
else:
|
100 |
+
dict_acc[event_type]["TN"] += 1
|
101 |
+
|
102 |
+
# Normalize metrics
|
103 |
+
if dict_acc:
|
104 |
+
for event_type in st.session_state.codebook["events"].keys():
|
105 |
+
dict_acc[event_type]["Accuracy"] = (
|
106 |
+
dict_acc[event_type]["TP"] + dict_acc[event_type]["TN"]
|
107 |
+
) / (
|
108 |
+
dict_acc[event_type]["TP"]
|
109 |
+
+ dict_acc[event_type]["TN"]
|
110 |
+
+ dict_acc[event_type]["FP"]
|
111 |
+
+ dict_acc[event_type]["FN"]
|
112 |
+
)
|
113 |
+
|
114 |
+
return dict_acc
|
115 |
+
|
116 |
+
|
117 |
+
def store_validated_data(
|
118 |
+
text,
|
119 |
+
decision,
|
120 |
+
text_idx,
|
121 |
+
templates,
|
122 |
+
additional_words,
|
123 |
+
list_event_type,
|
124 |
+
prent_params=(TOP_K, NLI_LIMIT),
|
125 |
+
):
|
126 |
+
"""Function used to store the results of PRENT in a DataFrame and in the
|
127 |
+
session state of Streamlit.
|
128 |
+
|
129 |
+
:param text: Event description
|
130 |
+
:type text: string
|
131 |
+
:param decision: Decision of the user (Accept/Reject/Ignore)
|
132 |
+
:type decision: string
|
133 |
+
:param text_idx: Index of the event
|
134 |
+
:type text_idx: int
|
135 |
+
:param templates: List of template used
|
136 |
+
:type templates: list
|
137 |
+
:param additional_words: List of additional words used
|
138 |
+
:type additional_words: list
|
139 |
+
:param list_event_type: List of event type found by PRENT and Codebook
|
140 |
+
:type list_event_type: list
|
141 |
+
:param prent_params: Parameters of PRENT, defaults to (TOP_K, NLI_LIMIT)
|
142 |
+
:type prent_params: tuple, optional
|
143 |
+
"""
|
144 |
+
if "validated_data" not in st.session_state:
|
145 |
+
st.session_state["validated_data"] = {}
|
146 |
+
|
147 |
+
# Generate an index if the text is not coming from a csv
|
148 |
+
if not text_idx:
|
149 |
+
# Create a hash of 8 digits of the text to put as index
|
150 |
+
data_idx = str(
|
151 |
+
"manual_{}".format(
|
152 |
+
int(
|
153 |
+
hashlib.sha256(text.encode("utf-8")).hexdigest(),
|
154 |
+
16,
|
155 |
+
)
|
156 |
+
% 10**8
|
157 |
+
)
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
data_idx = str(text_idx)
|
161 |
+
|
162 |
+
if data_idx not in st.session_state["validated_data"]:
|
163 |
+
st.session_state["validated_data"][data_idx] = {}
|
164 |
+
st.session_state["validated_data"][data_idx]["text"] = text
|
165 |
+
st.session_state["validated_data"][data_idx]["templates"] = [
|
166 |
+
template.replace("{}", "[Z]") for template in templates
|
167 |
+
]
|
168 |
+
st.session_state["validated_data"][data_idx]["additional_words"] = additional_words
|
169 |
+
st.session_state["validated_data"][data_idx]["event_types"] = list_event_type
|
170 |
+
st.session_state["validated_data"][data_idx][
|
171 |
+
"filled_templates"
|
172 |
+
] = list_filled_templates
|
173 |
+
st.session_state["validated_data"][data_idx]["decision"] = decision
|
174 |
+
st.session_state["validated_data"][data_idx]["prent_params"] = prent_params
|
175 |
+
|
176 |
+
|
177 |
+
### Initialize session state variables
|
178 |
+
if "codebook" not in st.session_state:
|
179 |
+
st.session_state.codebook = {}
|
180 |
+
st.session_state.codebook.setdefault("events", {})
|
181 |
+
st.session_state.codebook["templates"] = []
|
182 |
+
if "text" not in st.session_state:
|
183 |
+
st.session_state.text = ""
|
184 |
+
if "res" not in st.session_state:
|
185 |
+
st.session_state.res = None
|
186 |
+
if "accept_reject_text_perm" not in st.session_state:
|
187 |
+
st.session_state.accept_reject_text_perm = None
|
188 |
+
if "validated_data" not in st.session_state:
|
189 |
+
st.session_state["validated_data"] = {}
|
190 |
+
if "time_comput" not in st.session_state:
|
191 |
+
st.session_state.time_comput = 20
|
192 |
+
if "rerun" not in st.session_state:
|
193 |
+
st.session_state.rerun = False
|
194 |
+
if "recompute_all_templates" not in st.session_state:
|
195 |
+
st.session_state.recompute_all_templates = False
|
196 |
+
|
197 |
+
|
198 |
+
def reset_computation_results():
|
199 |
+
"""Reset cached values in session state related to computations"""
|
200 |
+
st.session_state.res = {}
|
201 |
+
st.session_state.recompute_all_templates = True
|
202 |
+
st.session_state["accept_reject_text_perm"] = "Ignore"
|
203 |
+
st.session_state.rerun = True
|
204 |
+
|
205 |
+
|
206 |
+
def get_all_filled_templates(results):
|
207 |
+
"""Create the filled templates from PRENT results. Merging template with mask
|
208 |
+
with the entailed tokens.
|
209 |
+
|
210 |
+
:param results: Dictionary containing PRENT results
|
211 |
+
:type results: dict
|
212 |
+
:return: List of all entailed templates
|
213 |
+
:rtype: list
|
214 |
+
"""
|
215 |
+
filled_templates = []
|
216 |
+
templates_used = [x.replace("[Z]", "{}") for x in results["templates_used"]]
|
217 |
+
for template in templates_used:
|
218 |
+
filled_template = [template.format(x) for x in results[template]]
|
219 |
+
filled_templates.extend(filled_template)
|
220 |
+
|
221 |
+
return filled_templates
|
222 |
+
|
223 |
+
|
224 |
+
# Split streamlit dashboard
|
225 |
+
col_intro_left, col_intro_righter = st.columns([8, 8])
|
226 |
+
with col_intro_left:
|
227 |
+
st.markdown(
|
228 |
+
""" # Codebook Design
|
229 |
+
"""
|
230 |
+
)
|
231 |
+
|
232 |
+
|
233 |
+
def load_demo(
|
234 |
+
codebook_path="codebook_demo.json",
|
235 |
+
validated_data_path="validated_data_demo.json",
|
236 |
+
csv_data_path="data_demo.csv",
|
237 |
+
):
|
238 |
+
"""Load demonstration files from disk
|
239 |
+
|
240 |
+
:param codebook_path: path to codebook, defaults to "codebook_demo.json"
|
241 |
+
:type codebook_path: str, optional
|
242 |
+
:param validated_data_path: path to validated dataset, defaults to "validated_data_demo.json"
|
243 |
+
:type validated_data_path: str, optional
|
244 |
+
:param csv_data_path: path to raw data, defaults to "data_demo.csv"
|
245 |
+
:type csv_data_path: str, optional
|
246 |
+
"""
|
247 |
+
st.session_state.codebook = json.load(open(codebook_path))
|
248 |
+
st.session_state.validated_data = json.load(open(validated_data_path))
|
249 |
+
st.session_state.data = pd.read_csv(csv_data_path, delimiter=";")
|
250 |
+
st.session_state.filtered_df = st.session_state.data
|
251 |
+
st.session_state.text_column_design_perm = "Event Descriptions"
|
252 |
+
st.session_state["multiselect_classes"] = list(
|
253 |
+
st.session_state.codebook["events"].keys()
|
254 |
+
)
|
255 |
+
st.session_state.text_idx = 0
|
256 |
+
st.session_state.text = (
|
257 |
+
"On 23 August, a group attacked a village, abducting 6 people."
|
258 |
+
)
|
259 |
+
st.session_state.text_display = (
|
260 |
+
"On 23 August, a group attacked a village, abducting 6 people."
|
261 |
+
)
|
262 |
+
st.session_state["text_options_valid_perm"] = "From CSV"
|
263 |
+
st.session_state["text_options_valid"] = "From CSV"
|
264 |
+
|
265 |
+
|
266 |
+
def clear_all():
|
267 |
+
"""Cleare session state"""
|
268 |
+
for each in st.session_state:
|
269 |
+
del st.session_state[each]
|
270 |
+
st.experimental_rerun()
|
271 |
+
|
272 |
+
|
273 |
+
# Add two buttons in the sidebar to load and clear the demo
|
274 |
+
with st.sidebar:
|
275 |
+
if st.button("Load Demo"):
|
276 |
+
load_demo()
|
277 |
+
|
278 |
+
if st.button("Clear Demo"):
|
279 |
+
clear_all()
|
280 |
+
|
281 |
+
st.write("********")
|
282 |
+
|
283 |
+
|
284 |
+
with st.sidebar:
|
285 |
+
# Next function used for callback when download
|
286 |
+
def update_codebook_save_time():
|
287 |
+
st.session_state.save_codebook_time = (
|
288 |
+
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M:%S %z")
|
289 |
+
)
|
290 |
+
|
291 |
+
if st.download_button(
|
292 |
+
label="Download codebook as JSON",
|
293 |
+
data=json.dumps(st.session_state.codebook, indent=3).encode("ASCII"),
|
294 |
+
file_name="codebook.json",
|
295 |
+
mime="application/json",
|
296 |
+
):
|
297 |
+
update_codebook_save_time()
|
298 |
+
if "save_codebook_time" in st.session_state:
|
299 |
+
st.write("Saved on: " + st.session_state.save_codebook_time)
|
300 |
+
|
301 |
+
|
302 |
+
with st.sidebar:
|
303 |
+
# Next function used for callback when download
|
304 |
+
def update_validated_save_time():
|
305 |
+
st.session_state.save_validated_time = (
|
306 |
+
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M:%S %z")
|
307 |
+
)
|
308 |
+
|
309 |
+
if st.download_button(
|
310 |
+
label="Download labeled data",
|
311 |
+
data=json.dumps(st.session_state["validated_data"], indent=3).encode("ASCII"),
|
312 |
+
file_name="validated_data.json",
|
313 |
+
mime="application/json",
|
314 |
+
):
|
315 |
+
update_validated_save_time()
|
316 |
+
if "save_validated_time" in st.session_state:
|
317 |
+
st.write("Saved on: " + st.session_state.save_validated_time)
|
318 |
+
|
319 |
+
# Add text to sidebar
|
320 |
+
with st.sidebar:
|
321 |
+
st.write("********")
|
322 |
+
st.markdown(
|
323 |
+
"""
|
324 |
+
#### Manual:
|
325 |
+
|
326 |
+
1. Set the list of possible event types
|
327 |
+
2. Select the input mode of the data (Manual or CSV)
|
328 |
+
3. If the codebook is empty, write a default template
|
329 |
+
- `This event involves [Z].` is a good starting point
|
330 |
+
4. Write/Select an event description
|
331 |
+
5. Run PR-ENT
|
332 |
+
6. Check the event type classification
|
333 |
+
- If it is correct then select Accept and return to step 4.
|
334 |
+
- If it is wrong then select Reject and populate the codebook with the appropriate filled templates. The classification is updated for each change, when it is correct, click Accept.
|
335 |
+
7. Return to step 4
|
336 |
+
|
337 |
+
#### Tips & Tricks:
|
338 |
+
|
339 |
+
- If you start a codebook from scratch, it may be easier to pass a manual text example for each event type to get a first codebook draft
|
340 |
+
- Current codebook accuracy based on labeled data can be found in the top right
|
341 |
+
- The approach does not aim for perfect accuracy and some failures can happen, e.g. some event descriptions can produce filled templates that are not satisfactory.
|
342 |
+
"""
|
343 |
+
)
|
344 |
+
|
345 |
+
# Add accuracy table
|
346 |
+
with col_intro_righter:
|
347 |
+
accuracy = st.empty()
|
348 |
+
# We fill the table with the last acc to avoid having it disappearing each time
|
349 |
+
if "acc_df" in st.session_state:
|
350 |
+
accuracy.table(
|
351 |
+
st.session_state.acc_df.loc["Accuracy":"Accuracy"].style.format("{:.2}")
|
352 |
+
)
|
353 |
+
performance_container = st.expander("Detailed Performances")
|
354 |
+
|
355 |
+
|
356 |
+
st.write("*********")
|
357 |
+
col_left, col_right = st.columns(2)
|
358 |
+
|
359 |
+
# Add widgets to add event type and choose text input
|
360 |
+
with col_intro_left:
|
361 |
+
with st.expander("Event Types List"):
|
362 |
+
st.markdown(
|
363 |
+
"""
|
364 |
+
## Select Event Types.
|
365 |
+
"""
|
366 |
+
)
|
367 |
+
|
368 |
+
if "class_list_perm" not in st.session_state:
|
369 |
+
st.session_state["class_list_perm"] = []
|
370 |
+
|
371 |
+
# Text field + button to add new event types to multiselect
|
372 |
+
new_class = st.text_input(
|
373 |
+
"Add a new event type", "", key="new_class_text_input"
|
374 |
+
)
|
375 |
+
st.button(
|
376 |
+
"Add Class",
|
377 |
+
on_click=callback_add_to_multiselect,
|
378 |
+
args=(
|
379 |
+
new_class,
|
380 |
+
"multiselect_classes",
|
381 |
+
"new_class_text_input",
|
382 |
+
"class_list_perm",
|
383 |
+
),
|
384 |
+
)
|
385 |
+
# Multiselect to choose event types
|
386 |
+
if "multiselect_classes" not in st.session_state:
|
387 |
+
st.session_state["multiselect_classes"] = list(
|
388 |
+
st.session_state.codebook["events"].keys()
|
389 |
+
)
|
390 |
+
class_list = st.multiselect(
|
391 |
+
"Event Type List",
|
392 |
+
set(
|
393 |
+
st.session_state["class_list_perm"]
|
394 |
+
+ list(st.session_state.codebook["events"].keys())
|
395 |
+
),
|
396 |
+
st.session_state["multiselect_classes"],
|
397 |
+
key="multiselect_classes",
|
398 |
+
)
|
399 |
+
st.session_state["class_list_perm"] = class_list
|
400 |
+
|
401 |
+
with st.expander("Select Text Input Mode (Manual, CSV)"):
|
402 |
+
st.write(
|
403 |
+
"""
|
404 |
+
Choose the text input of the event descriptions. Three choices:
|
405 |
+
- Manual: One event description can be manually input
|
406 |
+
- From CSV: If a CSV of event descriptions was provided
|
407 |
+
"""
|
408 |
+
)
|
409 |
+
|
410 |
+
def callback_radio_text_choice():
|
411 |
+
st.session_state.text = ""
|
412 |
+
st.session_state.text_display = ""
|
413 |
+
|
414 |
+
initiate_widget_st_state(
|
415 |
+
"text_options_valid", "text_options_valid_perm", "Manual"
|
416 |
+
)
|
417 |
+
st.session_state["text_options_valid_perm"] = st.radio(
|
418 |
+
"Choose text input",
|
419 |
+
["Manual", "From CSV"],
|
420 |
+
index=get_idx_column(
|
421 |
+
st.session_state["text_options_valid"], ["Manual", "From CSV"]
|
422 |
+
),
|
423 |
+
key="text_options_valid",
|
424 |
+
on_change=callback_radio_text_choice,
|
425 |
+
horizontal=True,
|
426 |
+
)
|
427 |
+
|
428 |
+
|
429 |
+
with col_left:
|
430 |
+
if st.session_state["text_options_valid_perm"] == "Manual":
|
431 |
+
text = choose_text_menu("")
|
432 |
+
# Reset all computations if text has changed
|
433 |
+
if text != st.session_state.text:
|
434 |
+
reset_computation_results()
|
435 |
+
st.session_state.text_idx = None
|
436 |
+
st.session_state.text = text
|
437 |
+
st.session_state.text_display = text
|
438 |
+
elif st.session_state["text_options_valid_perm"] == "From CSV":
|
439 |
+
if st.button("Select Random Text"):
|
440 |
+
sample = st.session_state.filtered_df.sample(n=1).iloc[0]
|
441 |
+
text = sample[st.session_state["text_column_design_perm"]]
|
442 |
+
idx = sample.name
|
443 |
+
if text != st.session_state.text:
|
444 |
+
reset_computation_results()
|
445 |
+
st.session_state.text = text
|
446 |
+
st.session_state.text_idx = idx
|
447 |
+
st.session_state.text_display = st.session_state.text
|
448 |
+
|
449 |
+
expected_time = st.session_state.time_comput * get_num_sentences_in_list_text(
|
450 |
+
[st.session_state.text]
|
451 |
+
)
|
452 |
+
if st.button("Run PR-ENT / Expected time: {}sec".format(expected_time)):
|
453 |
+
if "templates" in st.session_state.codebook:
|
454 |
+
templates = st.session_state.codebook["templates"]
|
455 |
+
else:
|
456 |
+
templates = []
|
457 |
+
st.warning("No template in codebook. Please add one.")
|
458 |
+
|
459 |
+
additional_words = get_additional_words()
|
460 |
+
st.session_state.res = {}
|
461 |
+
res, time_comput = run_prent(st.session_state.text, templates, additional_words)
|
462 |
+
st.session_state.res = res
|
463 |
+
|
464 |
+
st.write("**Event Descriptions:** {}".format(st.session_state.text_display))
|
465 |
+
ev_desc = st.empty()
|
466 |
+
radio_empty = st.empty()
|
467 |
+
|
468 |
+
if st.session_state.res:
|
469 |
+
list_filled_templates = get_all_filled_templates(st.session_state.res)
|
470 |
+
|
471 |
+
list_event_type = find_event_types(
|
472 |
+
st.session_state.codebook, list_filled_templates
|
473 |
+
)
|
474 |
+
event_type_text = ev_desc.markdown(
|
475 |
+
"**Current Event Types Classification**: {}".format(
|
476 |
+
"; ".join(list_event_type)
|
477 |
+
)
|
478 |
+
)
|
479 |
+
|
480 |
+
if "accept_reject_text_perm" not in st.session_state:
|
481 |
+
st.session_state["accept_reject_text_perm"] = "Ignore"
|
482 |
+
|
483 |
+
def callback_function(mod, key):
|
484 |
+
st.session_state[mod] = st.session_state[key]
|
485 |
+
|
486 |
+
radio_empty.radio(
|
487 |
+
"Accept or Reject Coding",
|
488 |
+
["Ignore", "Accept", "Reject"],
|
489 |
+
key="accept_reject_text",
|
490 |
+
on_change=callback_function,
|
491 |
+
args=(
|
492 |
+
"accept_reject_text_perm",
|
493 |
+
"accept_reject_text",
|
494 |
+
),
|
495 |
+
index=get_idx_column(
|
496 |
+
st.session_state["accept_reject_text_perm"],
|
497 |
+
["Ignore", "Accept", "Reject"],
|
498 |
+
),
|
499 |
+
horizontal=True,
|
500 |
+
)
|
501 |
+
|
502 |
+
decision = st.session_state["accept_reject_text_perm"]
|
503 |
+
text_idx = st.session_state.text_idx
|
504 |
+
text = st.session_state.text
|
505 |
+
store_validated_data(
|
506 |
+
text,
|
507 |
+
decision,
|
508 |
+
text_idx,
|
509 |
+
st.session_state.res["templates_used"],
|
510 |
+
st.session_state.res["additional_words_used"],
|
511 |
+
list_event_type,
|
512 |
+
prent_params=(TOP_K, NLI_LIMIT),
|
513 |
+
)
|
514 |
+
|
515 |
+
|
516 |
+
with col_right:
|
517 |
+
|
518 |
+
if (
|
519 |
+
st.session_state["accept_reject_text_perm"] == "Reject"
|
520 |
+
) or not st.session_state.codebook["templates"]:
|
521 |
+
with st.expander("Add Templates + Explanation"):
|
522 |
+
st.markdown(
|
523 |
+
"""
|
524 |
+
## Add Templates
|
525 |
+
"""
|
526 |
+
)
|
527 |
+
st.markdown(
|
528 |
+
"""
|
529 |
+
For each template added. PR-ENT will be run on the selected text.
|
530 |
+
"""
|
531 |
+
)
|
532 |
+
|
533 |
+
if "templates" not in st.session_state.codebook:
|
534 |
+
st.session_state.codebook["templates"] = []
|
535 |
+
|
536 |
+
template = st.text_input(
|
537 |
+
"Template with a mask [Z].", "This event involves [Z]."
|
538 |
+
)
|
539 |
+
|
540 |
+
if st.button("Add template"):
|
541 |
+
if template not in st.session_state.codebook["templates"]:
|
542 |
+
## Add template to codebook
|
543 |
+
st.session_state.codebook["templates"].append(template)
|
544 |
+
|
545 |
+
additional_words = get_additional_words()
|
546 |
+
prompt = template.replace("[Z]", "{}")
|
547 |
+
results_nli, _ = do_prent(
|
548 |
+
st.session_state.text,
|
549 |
+
prompt,
|
550 |
+
TOP_K,
|
551 |
+
NLI_LIMIT,
|
552 |
+
additional_words,
|
553 |
+
)
|
554 |
+
tokens_nli = [x[0] for x in results_nli]
|
555 |
+
|
556 |
+
# Update result table with new template
|
557 |
+
if not st.session_state["res"]:
|
558 |
+
st.session_state.res = {}
|
559 |
+
st.session_state.res["additional_words_used"] = additional_words
|
560 |
+
st.session_state.res["templates_used"] = []
|
561 |
+
st.session_state.res[prompt] = tokens_nli
|
562 |
+
st.session_state.res["templates_used"].append(template)
|
563 |
+
st.write("Template '{}' added.".format(template))
|
564 |
+
else:
|
565 |
+
st.write("Template '{}' already added.".format(template))
|
566 |
+
|
567 |
+
if st.session_state.codebook["templates"]:
|
568 |
+
with st.expander("Populate Codebook Explanation"):
|
569 |
+
st.markdown(
|
570 |
+
"""
|
571 |
+
## Set the filled template to each class.
|
572 |
+
For each class you can select one or more filled templates. When the evaluation will
|
573 |
+
be made, these templates will be compared with the results of PR-ENT. There are 4 options:
|
574 |
+
- ALL: If **ALL** of these filled templates are present in the results of PR-ENT then this event type is correct
|
575 |
+
- ANY: If **ANY** of these filled templates is present in the results of PR-ENT then this event type is correct
|
576 |
+
- NOT ALL: If **ALL** of these filled templates are present in the results of PR-ENT, then this event type is **not** correct
|
577 |
+
- e.g. You may want to remove all *explosions* events from a class *Killings*.
|
578 |
+
- NOT ANY: If **ANY** of these filled templates is present in the results of PR-ENT, then this event type is **not** correct
|
579 |
+
|
580 |
+
Moreover, **ANY/ALL** and **NOT ANY/ NOT ALL** can be made in relation by a **AND / OR** condition.
|
581 |
+
"""
|
582 |
+
)
|
583 |
+
|
584 |
+
st.write("***************")
|
585 |
+
st.write("### Populate Codebook")
|
586 |
+
if not class_list:
|
587 |
+
st.warning("No event type in codebook.")
|
588 |
+
|
589 |
+
tokens_list = get_all_filled_templates(st.session_state.res)
|
590 |
+
|
591 |
+
for event_type in class_list:
|
592 |
+
st.session_state.codebook["events"].setdefault(event_type, {})
|
593 |
+
event_type_chosen = event_type
|
594 |
+
with st.expander(event_type):
|
595 |
+
|
596 |
+
def declare_ms_event_templates(
|
597 |
+
widget_key, widget_display, codebook_key
|
598 |
+
):
|
599 |
+
if widget_key not in st.session_state:
|
600 |
+
st.session_state[widget_key] = st.session_state.codebook[
|
601 |
+
"events"
|
602 |
+
][event_type_chosen].setdefault(codebook_key, [])
|
603 |
+
|
604 |
+
tokens_all = st.multiselect(
|
605 |
+
widget_display,
|
606 |
+
set(
|
607 |
+
list(
|
608 |
+
tokens_list
|
609 |
+
+ st.session_state.codebook["events"][
|
610 |
+
event_type_chosen
|
611 |
+
].setdefault(codebook_key, [])
|
612 |
+
)
|
613 |
+
),
|
614 |
+
st.session_state[widget_key],
|
615 |
+
key=widget_key,
|
616 |
+
)
|
617 |
+
st.session_state.codebook["events"][event_type_chosen][
|
618 |
+
codebook_key
|
619 |
+
] = tokens_all
|
620 |
+
|
621 |
+
declare_ms_event_templates(
|
622 |
+
"ms_all_{}".format(event_type_chosen), "ALL", "all"
|
623 |
+
)
|
624 |
+
|
625 |
+
st.session_state.codebook["events"][event_type_chosen][
|
626 |
+
"all_any_rel"
|
627 |
+
] = st.selectbox(
|
628 |
+
"Relation",
|
629 |
+
["AND", "OR"],
|
630 |
+
index=get_idx_column(
|
631 |
+
st.session_state.codebook["events"][
|
632 |
+
event_type_chosen
|
633 |
+
].setdefault("all_any_rel", "OR"),
|
634 |
+
["AND", "OR"],
|
635 |
+
),
|
636 |
+
key="select_relation_any_all_{}".format(event_type_chosen),
|
637 |
+
)
|
638 |
+
|
639 |
+
declare_ms_event_templates(
|
640 |
+
"ms_any_{}".format(event_type_chosen), "ANY", "any"
|
641 |
+
)
|
642 |
+
|
643 |
+
declare_ms_event_templates(
|
644 |
+
"ms_not_all_{}".format(event_type_chosen), "NOT ALL", "not_all"
|
645 |
+
)
|
646 |
+
|
647 |
+
st.session_state.codebook["events"][event_type_chosen][
|
648 |
+
"not_all_any_rel"
|
649 |
+
] = st.selectbox(
|
650 |
+
"Relation",
|
651 |
+
["AND", "OR"],
|
652 |
+
index=get_idx_column(
|
653 |
+
st.session_state.codebook["events"][
|
654 |
+
event_type_chosen
|
655 |
+
].setdefault("not_all_any_rel", "OR"),
|
656 |
+
["AND", "OR"],
|
657 |
+
),
|
658 |
+
key="select_relation_not_any_all_{}".format(event_type_chosen),
|
659 |
+
)
|
660 |
+
|
661 |
+
declare_ms_event_templates(
|
662 |
+
"ms_not_any_{}".format(event_type_chosen), "NOT ANY", "not_any"
|
663 |
+
)
|
664 |
+
|
665 |
+
# Workaround to avoid the expanders closing after first modification
|
666 |
+
# I have no explanation for the bug
|
667 |
+
if st.session_state.rerun:
|
668 |
+
st.session_state.rerun = False
|
669 |
+
st.experimental_rerun()
|
670 |
+
|
671 |
+
|
672 |
+
if "validated_data" in st.session_state:
|
673 |
+
recompute = False
|
674 |
+
performance_container.markdown(
|
675 |
+
"If a new template is added, the previous labeled samples needs to be recomputed with it. The next button allows that, however it can take some time depending on the number of samples."
|
676 |
+
)
|
677 |
+
if performance_container.button(
|
678 |
+
"Recompute Missing Templates", key="recompute_temp"
|
679 |
+
):
|
680 |
+
prog_bar = performance_container.progress(0)
|
681 |
+
for i, datapoint in enumerate(st.session_state["validated_data"].values()):
|
682 |
+
if not set(st.session_state.codebook["templates"]).issubset(
|
683 |
+
set(datapoint["templates"])
|
684 |
+
):
|
685 |
+
# Get templates that are missing from results but present in codebook
|
686 |
+
# These happens if templates are added a posteriori
|
687 |
+
missing_templates = list(
|
688 |
+
set(st.session_state.codebook["templates"])
|
689 |
+
- set(set(datapoint["templates"]))
|
690 |
+
)
|
691 |
+
recompute = True
|
692 |
+
# For now additional words are not recomputed
|
693 |
+
if not set(st.session_state.codebook["add_words"]).issubset(
|
694 |
+
set(datapoint["additional_words"])
|
695 |
+
):
|
696 |
+
missing_add_words = list(
|
697 |
+
set(st.session_state.codebook["add_words"])
|
698 |
+
- set(set(datapoint["additional_words"]))
|
699 |
+
)
|
700 |
+
recompute = True
|
701 |
+
else:
|
702 |
+
missing_add_words = None
|
703 |
+
|
704 |
+
if recompute:
|
705 |
+
res, _ = run_prent(
|
706 |
+
datapoint["text"],
|
707 |
+
missing_templates,
|
708 |
+
missing_add_words,
|
709 |
+
progress=False,
|
710 |
+
)
|
711 |
+
datapoint["filled_templates"].extend(get_all_filled_templates(res))
|
712 |
+
datapoint["templates"].extend(missing_templates)
|
713 |
+
prog_bar.progress(
|
714 |
+
(1 / len(st.session_state["validated_data"].values())) * (i + 1)
|
715 |
+
)
|
716 |
+
|
717 |
+
st.session_state.acc_df = pd.DataFrame(
|
718 |
+
validated_metric_per_event_types(st.session_state["validated_data"])
|
719 |
+
)
|
720 |
+
accuracy.table(
|
721 |
+
st.session_state.acc_df.loc["Accuracy":"Accuracy"].style.format("{:.2}")
|
722 |
+
)
|
723 |
+
performance_container.markdown("### Performances on labeled dataset")
|
724 |
+
performance_container.dataframe(st.session_state.acc_df.style.format("{:.3}"))
|
725 |
+
|
726 |
+
if st.session_state.res:
|
727 |
+
list_filled_templates = get_all_filled_templates(st.session_state.res)
|
728 |
+
list_event_type = find_event_types(st.session_state.codebook, list_filled_templates)
|
729 |
+
ev_desc.markdown(
|
730 |
+
"**Current Event Types Classification**: {}".format("; ".join(list_event_type))
|
731 |
+
)
|
pages/2_Codebook_Advanced_Edit.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
from helpers import apply_style, callback_add_to_multiselect, get_idx_column
|
6 |
+
|
7 |
+
apply_style()
|
8 |
+
|
9 |
+
|
10 |
+
# Avoid having ellipsis in the multi select options
|
11 |
+
styl = """
|
12 |
+
<style>
|
13 |
+
.stMultiSelect span{
|
14 |
+
max-width: none;
|
15 |
+
|
16 |
+
}
|
17 |
+
</style>
|
18 |
+
"""
|
19 |
+
st.markdown(styl, unsafe_allow_html=True)
|
20 |
+
|
21 |
+
st.write("# Codebook Edit")
|
22 |
+
|
23 |
+
st.write(
|
24 |
+
"""In this tab you can:
|
25 |
+
- Add or remove templates
|
26 |
+
- Add or remove additional answer candidates
|
27 |
+
- Modify the filled templates by adding new ones manually"""
|
28 |
+
)
|
29 |
+
|
30 |
+
if "templates" not in st.session_state.codebook:
|
31 |
+
|
32 |
+
st.warning("No codebook loaded")
|
33 |
+
st.stop()
|
34 |
+
|
35 |
+
st.write("## Codebook: Template")
|
36 |
+
|
37 |
+
|
38 |
+
with st.expander("Templates"):
|
39 |
+
|
40 |
+
template = st.text_input(
|
41 |
+
"Template with a mask [Z].",
|
42 |
+
"This event involves [Z].",
|
43 |
+
key="add_template_text_input",
|
44 |
+
)
|
45 |
+
st.button(
|
46 |
+
"Add template",
|
47 |
+
on_click=callback_add_to_multiselect,
|
48 |
+
args=(
|
49 |
+
template,
|
50 |
+
"multiselect_templates",
|
51 |
+
"add_template_text_input",
|
52 |
+
"codebook",
|
53 |
+
"templates",
|
54 |
+
),
|
55 |
+
)
|
56 |
+
|
57 |
+
if "multiselect_templates" not in st.session_state:
|
58 |
+
st.session_state["multiselect_templates"] = st.session_state.codebook[
|
59 |
+
"templates"
|
60 |
+
]
|
61 |
+
|
62 |
+
st.write("Removed templates will be removed from the codebook.")
|
63 |
+
templates = st.multiselect(
|
64 |
+
"Templates",
|
65 |
+
set(st.session_state.codebook["templates"]),
|
66 |
+
st.session_state["multiselect_templates"],
|
67 |
+
key="multiselect_templates",
|
68 |
+
)
|
69 |
+
st.session_state.codebook["templates"] = templates
|
70 |
+
|
71 |
+
st.write("## Codebook: Additional Answer Candidates")
|
72 |
+
st.write(
|
73 |
+
"""
|
74 |
+
You can manually add answer candidates. Then they will be tested for entailment on every event
|
75 |
+
description and every template even if they are not present in the prompting results.
|
76 |
+
This is intended for case when the event that you try to describe is quite rare (e.g. shelling, missiles).
|
77 |
+
|
78 |
+
**Caution**: Each word added will increase the computation time (about +3%).
|
79 |
+
|
80 |
+
**Caution**: The PR-ENT model will always try to output the singular form of the word.
|
81 |
+
"""
|
82 |
+
)
|
83 |
+
with st.expander("Add answer candidates"):
|
84 |
+
|
85 |
+
new_word = st.text_input(
|
86 |
+
"Answer Candidate (1 word)", "", key="add_words_text_input"
|
87 |
+
)
|
88 |
+
st.button(
|
89 |
+
"Add Word",
|
90 |
+
on_click=callback_add_to_multiselect,
|
91 |
+
args=(
|
92 |
+
new_word,
|
93 |
+
"multiselect_addwords",
|
94 |
+
"add_words_text_input",
|
95 |
+
"codebook",
|
96 |
+
"add_words",
|
97 |
+
),
|
98 |
+
)
|
99 |
+
|
100 |
+
if "add_words" not in st.session_state.codebook:
|
101 |
+
st.session_state.codebook["add_words"] = []
|
102 |
+
|
103 |
+
if "multiselect_addwords" not in st.session_state:
|
104 |
+
st.session_state["multiselect_addwords"] = st.session_state.codebook[
|
105 |
+
"add_words"
|
106 |
+
]
|
107 |
+
|
108 |
+
templates = st.multiselect(
|
109 |
+
"Add Words",
|
110 |
+
set(st.session_state.codebook["add_words"]),
|
111 |
+
st.session_state["multiselect_addwords"],
|
112 |
+
key="multiselect_addwords",
|
113 |
+
)
|
114 |
+
st.session_state.codebook["add_words"] = templates
|
115 |
+
|
116 |
+
# TODO: Change by giving a list of templates and allow only filling a word.
|
117 |
+
st.write("## Codebook: Additional Filled Templates")
|
118 |
+
st.write(
|
119 |
+
"""
|
120 |
+
You can also manually add filled templates to the codebook. This is for the case when you know that a
|
121 |
+
filled template could appear but you don't find corresponding events. This does not increase much the
|
122 |
+
computation time. For example you could add `This event involves kidnapping.` if you have no kidnapping
|
123 |
+
event in your dataset but you know it could happen.
|
124 |
+
|
125 |
+
**Caution**: The PR-ENT model will always try to output the singular form of the word. (e.g. "Protests" -> "Protest")
|
126 |
+
"""
|
127 |
+
)
|
128 |
+
class_list = list(st.session_state.codebook["events"].keys())
|
129 |
+
|
130 |
+
|
131 |
+
if "filled_templates" not in st.session_state:
|
132 |
+
st.session_state["filled_templates"] = []
|
133 |
+
|
134 |
+
|
135 |
+
with st.expander("Add Filled Template"):
|
136 |
+
|
137 |
+
template_chosen = st.selectbox(
|
138 |
+
"Choose a template:",
|
139 |
+
st.session_state.codebook["templates"],
|
140 |
+
# index=get_idx_column(template, st.session_state.codebook["templates"]),
|
141 |
+
key="template_sct",
|
142 |
+
)
|
143 |
+
|
144 |
+
def add_template_with_word(template_chosen, new_word, key_text_input):
|
145 |
+
if len(new_word) == 0:
|
146 |
+
st.warning("Word is empty, did you press Enter on the field text?")
|
147 |
+
else:
|
148 |
+
st.session_state["filled_templates"].append(
|
149 |
+
template_chosen.replace("[Z]", new_word)
|
150 |
+
)
|
151 |
+
st.session_state[key_text_input] = ""
|
152 |
+
|
153 |
+
new_word = st.text_input("1 Word Mask", "", key="filled_template_text_input")
|
154 |
+
if st.button(
|
155 |
+
"Add Filled Template",
|
156 |
+
on_click=add_template_with_word,
|
157 |
+
args=(template_chosen, new_word, "filled_template_text_input"),
|
158 |
+
):
|
159 |
+
st.write("Filled template added.")
|
160 |
+
st.write("The template can then be selected for each class below.")
|
161 |
+
|
162 |
+
|
163 |
+
st.write("## Codebook: Event Types")
|
164 |
+
|
165 |
+
st.write(
|
166 |
+
"""
|
167 |
+
Here you have access to all filled templates independently of the template. You can add/remove some of them for
|
168 |
+
each event type.
|
169 |
+
"""
|
170 |
+
)
|
171 |
+
|
172 |
+
|
173 |
+
for event_type in st.session_state.codebook["events"].keys():
|
174 |
+
for any_not_all in st.session_state.codebook["events"][event_type].keys():
|
175 |
+
if (any_not_all == "all_any_rel") or (any_not_all == "not_all_any_rel"):
|
176 |
+
pass
|
177 |
+
else:
|
178 |
+
st.session_state["filled_templates"].extend(
|
179 |
+
st.session_state.codebook["events"][event_type][any_not_all]
|
180 |
+
)
|
181 |
+
|
182 |
+
for event_type in class_list:
|
183 |
+
st.session_state.codebook["events"].setdefault(event_type, {})
|
184 |
+
event_type_chosen = event_type
|
185 |
+
with st.expander(event_type):
|
186 |
+
|
187 |
+
def declare_ms_codebook_edit(widget_key, codebook_key, widget_display):
|
188 |
+
if widget_key not in st.session_state:
|
189 |
+
st.session_state[widget_key] = st.session_state.codebook["events"][
|
190 |
+
event_type_chosen
|
191 |
+
].setdefault(codebook_key, [])
|
192 |
+
|
193 |
+
tokens_all = st.multiselect(
|
194 |
+
widget_display,
|
195 |
+
set(st.session_state["filled_templates"]),
|
196 |
+
st.session_state[widget_key],
|
197 |
+
key=widget_key,
|
198 |
+
)
|
199 |
+
st.session_state.codebook["events"][event_type_chosen][
|
200 |
+
codebook_key
|
201 |
+
] = tokens_all
|
202 |
+
|
203 |
+
declare_ms_codebook_edit("ms_all_{}".format(event_type_chosen), "all", "ALL")
|
204 |
+
|
205 |
+
st.session_state.codebook["events"][event_type_chosen][
|
206 |
+
"all_any_rel"
|
207 |
+
] = st.selectbox(
|
208 |
+
"Relation",
|
209 |
+
["AND", "OR"],
|
210 |
+
index=get_idx_column(
|
211 |
+
st.session_state.codebook["events"][event_type_chosen].setdefault(
|
212 |
+
"all_any_rel", "OR"
|
213 |
+
),
|
214 |
+
["AND", "OR"],
|
215 |
+
),
|
216 |
+
key="select_relation_any_all_{}".format(event_type_chosen),
|
217 |
+
)
|
218 |
+
|
219 |
+
declare_ms_codebook_edit("ms_any_{}".format(event_type_chosen), "any", "ANY")
|
220 |
+
declare_ms_codebook_edit(
|
221 |
+
"ms_not_all_{}".format(event_type_chosen), "not_all", "NOT ALL"
|
222 |
+
)
|
223 |
+
|
224 |
+
st.session_state.codebook["events"][event_type_chosen][
|
225 |
+
"not_all_any_rel"
|
226 |
+
] = st.selectbox(
|
227 |
+
"Relation",
|
228 |
+
["AND", "OR"],
|
229 |
+
index=get_idx_column(
|
230 |
+
st.session_state.codebook["events"][event_type_chosen].setdefault(
|
231 |
+
"not_all_any_rel", "OR"
|
232 |
+
),
|
233 |
+
["AND", "OR"],
|
234 |
+
),
|
235 |
+
key="select_relation_not_any_all_{}".format(event_type_chosen),
|
236 |
+
)
|
237 |
+
declare_ms_codebook_edit(
|
238 |
+
"ms_not_any_{}".format(event_type_chosen), "not_any", "NOT ANY"
|
239 |
+
)
|
240 |
+
|
241 |
+
if st.button("Remove Class", key="remove_class_{}".format(event_type_chosen)):
|
242 |
+
del st.session_state.codebook["events"][event_type_chosen]
|
243 |
+
st.write("## Codebook: Download")
|
244 |
+
|
245 |
+
|
246 |
+
st.download_button(
|
247 |
+
label="Download codebook as JSON",
|
248 |
+
data=json.dumps(st.session_state.codebook, indent=3).encode("ASCII"),
|
249 |
+
file_name="codebook.json",
|
250 |
+
mime="application/json",
|
251 |
+
)
|
pages/3_Apply_Codebook.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import streamlit as st
|
7 |
+
|
8 |
+
current = os.path.dirname(os.path.realpath(__file__))
|
9 |
+
parent = os.path.dirname(current)
|
10 |
+
sys.path.append(parent)
|
11 |
+
from helpers import (
|
12 |
+
apply_style,
|
13 |
+
find_event_types,
|
14 |
+
get_additional_words,
|
15 |
+
get_nli_limit,
|
16 |
+
get_num_sentences_in_list_text,
|
17 |
+
get_top_k,
|
18 |
+
run_prent,
|
19 |
+
)
|
20 |
+
|
21 |
+
### Styling
|
22 |
+
apply_style()
|
23 |
+
|
24 |
+
|
25 |
+
TOP_K = get_top_k()
|
26 |
+
NLI_LIMIT = get_nli_limit()
|
27 |
+
|
28 |
+
|
29 |
+
### Initialize session state variables
|
30 |
+
if "codebook" not in st.session_state:
|
31 |
+
st.session_state.codebook = {}
|
32 |
+
st.session_state.codebook.setdefault("events", {})
|
33 |
+
|
34 |
+
if "text" not in st.session_state:
|
35 |
+
st.session_state.text = ""
|
36 |
+
|
37 |
+
if "res" not in st.session_state:
|
38 |
+
st.session_state.res = None
|
39 |
+
|
40 |
+
if "accept_reject_text_perm" not in st.session_state:
|
41 |
+
st.session_state.accept_reject_text_perm = None
|
42 |
+
|
43 |
+
if "validated_data" not in st.session_state:
|
44 |
+
st.session_state["validated_data"] = {}
|
45 |
+
|
46 |
+
if "time_comput" not in st.session_state:
|
47 |
+
st.session_state.time_comput = 20
|
48 |
+
|
49 |
+
if "rerun" not in st.session_state:
|
50 |
+
st.session_state.rerun = False
|
51 |
+
|
52 |
+
if "label_res" not in st.session_state:
|
53 |
+
st.session_state.label_res = {}
|
54 |
+
|
55 |
+
if "filtered_df" not in st.session_state:
|
56 |
+
st.session_state["filtered_df"] = pd.DataFrame()
|
57 |
+
|
58 |
+
if len(st.session_state["filtered_df"]) == 0:
|
59 |
+
st.warning("No data loaded.")
|
60 |
+
|
61 |
+
|
62 |
+
def reset_computation_results():
|
63 |
+
st.session_state.res = {}
|
64 |
+
st.session_state.recompute_all_templates = True
|
65 |
+
st.session_state["accept_reject_text_perm"] = "Ignore"
|
66 |
+
st.session_state.rerun = True
|
67 |
+
|
68 |
+
|
69 |
+
with st.sidebar:
|
70 |
+
st.markdown(
|
71 |
+
"Clicking any of these button during labeling will pause the process and download the latest version."
|
72 |
+
)
|
73 |
+
dl_labeled_button = st.empty()
|
74 |
+
dl_labeled_button.download_button(
|
75 |
+
label="Download Labeled Data",
|
76 |
+
data=st.session_state["filtered_df"].to_csv(sep=";").encode("utf-8"),
|
77 |
+
file_name="labeled_data.csv",
|
78 |
+
mime="text/csv",
|
79 |
+
)
|
80 |
+
|
81 |
+
dl_prent_button = st.empty()
|
82 |
+
dl_prent_button.download_button(
|
83 |
+
label="Download PR-ENT results",
|
84 |
+
data=json.dumps(st.session_state["label_res"], indent=3).encode("ASCII"),
|
85 |
+
file_name="prent_results.json",
|
86 |
+
mime="application/json",
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
st.markdown(
|
91 |
+
"""# Apply codebook to the dataset
|
92 |
+
The currently loaded codebook will be used to find the event types of all event description in the currently loaded dataset. This can take some time (minutes to hours) depending on the size of the dataset (number of events, length of text).
|
93 |
+
|
94 |
+
|
95 |
+
"""
|
96 |
+
)
|
97 |
+
|
98 |
+
markdown_num_events = st.empty()
|
99 |
+
|
100 |
+
label_button = st.empty()
|
101 |
+
st.markdown("#### Main progress bar")
|
102 |
+
main_progress_bar = st.empty()
|
103 |
+
main_progress_bar = main_progress_bar.progress(0)
|
104 |
+
|
105 |
+
st.markdown("#### Last labeled event")
|
106 |
+
temp_text = st.empty()
|
107 |
+
temp_class = st.empty()
|
108 |
+
temp_text.markdown("**Event Descriptions:** {}".format(""))
|
109 |
+
temp_class.markdown("**Event Types Classification**: {}".format(""))
|
110 |
+
st.markdown(
|
111 |
+
"""#### Pause/Stop the event coding
|
112 |
+
Pressing the button once will stop the process at the next iteration."""
|
113 |
+
)
|
114 |
+
stop_button = st.button("Stop")
|
115 |
+
|
116 |
+
for event_type in st.session_state.codebook["events"]:
|
117 |
+
if event_type not in st.session_state.filtered_df.columns:
|
118 |
+
st.session_state.filtered_df[event_type] = 0
|
119 |
+
|
120 |
+
expected_time = 0
|
121 |
+
num_sentences = 0
|
122 |
+
for idx in st.session_state.filtered_df.index:
|
123 |
+
subsampled_data = st.session_state.filtered_df.loc[idx:idx]
|
124 |
+
list_text = subsampled_data[st.session_state["text_column_design_perm"]].values[:1]
|
125 |
+
list_index = subsampled_data.index[:1]
|
126 |
+
if list_text[0] != st.session_state.text:
|
127 |
+
reset_computation_results()
|
128 |
+
st.session_state.text = list_text[0]
|
129 |
+
num_sentences += get_num_sentences_in_list_text([st.session_state.text])
|
130 |
+
expected_time += st.session_state.time_comput * get_num_sentences_in_list_text(
|
131 |
+
[st.session_state.text]
|
132 |
+
)
|
133 |
+
|
134 |
+
markdown_num_events.markdown(
|
135 |
+
"Number of events: {} ¦ Number of sentences: {}".format(
|
136 |
+
len(st.session_state.filtered_df.index), num_sentences
|
137 |
+
)
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
if label_button.button(
|
142 |
+
"Label Data", disabled=len(st.session_state["filtered_df"]) == 0
|
143 |
+
):
|
144 |
+
num_text = 0
|
145 |
+
main_progress_bar.progress(num_text)
|
146 |
+
temp_text.markdown("")
|
147 |
+
temp_class.markdown("")
|
148 |
+
tot_num_text = len(st.session_state.filtered_df.index)
|
149 |
+
|
150 |
+
for idx in st.session_state.filtered_df.index:
|
151 |
+
subsampled_data = st.session_state.filtered_df.loc[idx:idx]
|
152 |
+
list_text = subsampled_data[st.session_state["text_column_design_perm"]].values[
|
153 |
+
:1
|
154 |
+
]
|
155 |
+
list_index = subsampled_data.index[:1]
|
156 |
+
if list_text[0] != st.session_state.text:
|
157 |
+
reset_computation_results()
|
158 |
+
st.session_state.text = list_text[0]
|
159 |
+
st.session_state.text_idx = list_index[0]
|
160 |
+
st.session_state.template_list = []
|
161 |
+
st.session_state.text_display = st.session_state.text
|
162 |
+
|
163 |
+
st.session_state.res = {}
|
164 |
+
res, time_comput = run_prent(
|
165 |
+
st.session_state.text,
|
166 |
+
st.session_state.codebook["templates"],
|
167 |
+
get_additional_words(),
|
168 |
+
progress=False,
|
169 |
+
display_text=False,
|
170 |
+
)
|
171 |
+
st.session_state.res = res
|
172 |
+
|
173 |
+
list_filled_templates = []
|
174 |
+
for template in st.session_state.res:
|
175 |
+
tmp = template.replace("[Z]", "{}")
|
176 |
+
list_filled_templates.extend(
|
177 |
+
[tmp.format(x) for x in st.session_state.res[template]]
|
178 |
+
)
|
179 |
+
list_event_type = find_event_types(
|
180 |
+
st.session_state.codebook, list_filled_templates
|
181 |
+
)
|
182 |
+
for event_type in list_event_type:
|
183 |
+
st.session_state.filtered_df.loc[idx, event_type] = 1
|
184 |
+
temp_text.markdown(
|
185 |
+
"**Event Descriptions:** {}".format(st.session_state.text_display)
|
186 |
+
)
|
187 |
+
temp_class.markdown(
|
188 |
+
"**Event Types Classification**: {}".format("; ".join(list_event_type))
|
189 |
+
)
|
190 |
+
|
191 |
+
# Save results
|
192 |
+
st.session_state.label_res[st.session_state.text_display] = {}
|
193 |
+
st.session_state.label_res[st.session_state.text_display][
|
194 |
+
"prent_results"
|
195 |
+
] = st.session_state.res
|
196 |
+
st.session_state.label_res[st.session_state.text_display]["prent_params"] = (
|
197 |
+
TOP_K,
|
198 |
+
NLI_LIMIT,
|
199 |
+
)
|
200 |
+
st.session_state.label_res[st.session_state.text_display][
|
201 |
+
"event_types"
|
202 |
+
] = list_event_type
|
203 |
+
|
204 |
+
num_text += 1
|
205 |
+
main_progress_bar.progress(num_text / tot_num_text)
|
206 |
+
|
207 |
+
# Need to update the buttons otherwise it doesn't update the downloaded file
|
208 |
+
# and the user would need to click two times
|
209 |
+
dl_labeled_button.download_button(
|
210 |
+
label="Download Labeled Data",
|
211 |
+
data=st.session_state["filtered_df"].to_csv(sep=";").encode("utf-8"),
|
212 |
+
file_name="labeled_data.csv",
|
213 |
+
mime="text/csv",
|
214 |
+
key="tmp",
|
215 |
+
)
|
216 |
+
|
217 |
+
dl_prent_button.download_button(
|
218 |
+
label="Download PR-ENT results",
|
219 |
+
data=json.dumps(st.session_state["label_res"], indent=3).encode("ASCII"),
|
220 |
+
file_name="prent_results.json",
|
221 |
+
mime="application/json",
|
222 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# see environments.yml
|
2 |
+
numpy==1.23.2
|
3 |
+
pandas==1.4.2
|
4 |
+
spacy==3.2.3
|
5 |
+
https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.2.0/en_core_web_lg-3.2.0-py3-none-any.whl
|
6 |
+
transformers[torch]==4.22.1
|
7 |
+
nltk==3.7
|
8 |
+
streamlit==1.10.0
|
9 |
+
streamlit-aggrid==0.2.3.post2
|
10 |
+
inflect==6.0.0
|