sociofillmore_public / sociofillmore /webapp /query_frame_samples.py
Gosse Minnema
Add sociofillmore code, load dataset via private dataset repo
b11ac48
raw
history blame
No virus
3.54 kB
import sys
import requests
import json
import pandas as pd
SOCIOFILLMORE_API = "http://127.0.0.1:5000"
AUTH_KEY = "3TrJ397oh#^"
def get_sample(s, dataset, n_samples, frame, construction, role, dependency):
s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset})
r_q = s.get(
SOCIOFILLMORE_API + "/sample_frame",
params={
"auth_key": AUTH_KEY,
"frame": frame,
"construction": construction,
"role": role,
"dependency": dependency,
"model": "lome_0shot",
"n": n_samples,
},
)
data = json.loads(r_q.text)
rows_out = []
for sent in data:
for fns in sent["fn_structures"]:
if fns["frame"] == frame:
target_roles = [r for r in fns["roles"] if r[0] == role]
if target_roles:
target_role = target_roles[0]
else:
continue
rows_out.append(
{
"dataset": dataset,
"sentence": " ".join(sent["sentence"]),
"frame": frame,
"target": " ".join(fns["target"]["tokens_str"]),
"role_label": role,
"role_span": " ".join(target_role[1]["tokens_str"]),
"dependency": dependency,
}
)
return rows_out
def get_labels(s, dataset, frame):
s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset})
r_q = s.get(
SOCIOFILLMORE_API + "/frame_freq",
params={
"auth_key": AUTH_KEY,
"model": "lome_0shot",
"frames": frame,
"constructions": "",
"group_by_cat": "n",
"group_by_constr": "n",
"group_by_role_expr": 2,
"relative": "y",
"plot_over_days_post": "n",
},
)
data = json.loads(r_q.text)
return {l.split("::")[2] for l in data["relevant_frame_counts"]["x"]}
def main(language):
s = requests.Session()
if language == "it":
print("Finding IT labels...")
labels_it = get_labels(s, "femicides/rai", "Killing")
sample_rows_it = []
for label in sorted(labels_it):
if label == "_UNK_DEP":
continue
print(f"Label (IT): {label}")
sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Killer", label))
sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Victim", label))
df_samples_it = pd.DataFrame(sample_rows_it)
df_samples_it.to_csv("output/common/query_frame_samples/it_dep_samples.csv")
if language == "nl":
print("Finding NL labels...")
labels_nl = get_labels(s, "crashes/thecrashes", "Cause_harm")
sample_rows_nl = []
for label in sorted(labels_nl):
if label == "_UNK_DEP":
continue
print(f"Label (NL): {label}")
sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Agent", label))
sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Victim", label))
df_samples_nl = pd.DataFrame(sample_rows_nl)
df_samples_nl.to_csv("output/common/query_frame_samples/nl_dep_samples.csv")
if __name__ == "__main__":
main(language=sys.argv[1])