bjorn-hommel
commited on
Commit
•
4832fb3
1
Parent(s):
b7cee2e
init commit
Browse files- .env +1 -0
- .gitignore +1 -0
- app.py +186 -0
- init.json +3 -0
- logo-130x130.svg +35 -0
- modeling.py +139 -0
- requirements.txt +10 -0
- sample_input.yaml +12 -0
- utils.py +27 -0
.env
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
model_path="/nlp/models/published/surveybot3000"
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import logging
|
4 |
+
import json
|
5 |
+
import yaml
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
|
10 |
+
import modeling
|
11 |
+
|
12 |
+
def show_launch(placeholder):
|
13 |
+
with placeholder.container():
|
14 |
+
st.divider()
|
15 |
+
st.markdown("""
|
16 |
+
## Before Using the App
|
17 |
+
### Disclaimer
|
18 |
+
This application is provided as-is, without any warranty or guarantee of any kind, expressed or implied. It is intended for educational, non-commercial use only.
|
19 |
+
The developers of this app shall not be held liable for any damages or losses incurred from its use. By using this application, you agree to the terms and conditions
|
20 |
+
outlined herein and acknowledge that any commercial use or reliance on its functionality is strictly prohibited.
|
21 |
+
|
22 |
+
Furthermore, by using this application, you consent to the collection of anonymous usage data. This data will be used for research purposes and to improve the
|
23 |
+
application's functionality. No personal information will be recorded or stored.
|
24 |
+
""", unsafe_allow_html=True)
|
25 |
+
|
26 |
+
button_placeholder = st.empty()
|
27 |
+
|
28 |
+
if button_placeholder.button(label='Accept Disclaimer', type='primary', use_container_width=True):
|
29 |
+
st.session_state.show_launch = False
|
30 |
+
placeholder.empty()
|
31 |
+
button_placeholder.empty()
|
32 |
+
|
33 |
+
def show_demo(placeholder):
|
34 |
+
|
35 |
+
with placeholder:
|
36 |
+
with st.container():
|
37 |
+
st.divider()
|
38 |
+
st.markdown("""
|
39 |
+
## Try it yourself!
|
40 |
+
Our recent research shows that sentence transformer ("AI" models)
|
41 |
+
can predict respondent patterns in survey data! The model accurately
|
42 |
+
infers item-correlation with *r* = **.71** 🧨, and shows even higher
|
43 |
+
precision for scale correlations (*r* = **.89** 💥) and reliability
|
44 |
+
coefficients (*r* = **.86** 💣)!
|
45 |
+
|
46 |
+
Try it yourself by defining a scale structure using the input field
|
47 |
+
below and let the **SurveyBot3000** predict the expected response
|
48 |
+
pattern. Use the [YAML](https://yaml.org/) format or follow the structure
|
49 |
+
outlined by the preset example.
|
50 |
+
""")
|
51 |
+
|
52 |
+
with st.form("my_form"):
|
53 |
+
|
54 |
+
input_yaml = st.text_area(
|
55 |
+
label="Questionnaire Structure (YAML-Formatted)",
|
56 |
+
value=st.session_state['input_yaml'],
|
57 |
+
height=250
|
58 |
+
)
|
59 |
+
|
60 |
+
st.session_state.results_as_matrix = st.checkbox(
|
61 |
+
label="Result as matrix",
|
62 |
+
help="Results will be list-formated (long) by default. Enable to get (wide-format) matrices."
|
63 |
+
)
|
64 |
+
|
65 |
+
submitted = st.form_submit_button(
|
66 |
+
label="Get Synthetic Estimates",
|
67 |
+
type="primary",
|
68 |
+
use_container_width=True
|
69 |
+
)
|
70 |
+
if submitted:
|
71 |
+
|
72 |
+
try:
|
73 |
+
yaml_dict = yaml.safe_load(input_yaml)
|
74 |
+
except yaml.YAMLError as e:
|
75 |
+
st.error(f"Yikes, you better get your YAML straight! Check https://yaml.org/ for help!")
|
76 |
+
return(None)
|
77 |
+
|
78 |
+
try:
|
79 |
+
modeling.load_model()
|
80 |
+
except Exception as error:
|
81 |
+
st.error(f"Error while loading model: {error}")
|
82 |
+
st.json(yaml_dict)
|
83 |
+
return(None)
|
84 |
+
|
85 |
+
try:
|
86 |
+
st.session_state.input_data = modeling.process_yaml_input(yaml_dict)
|
87 |
+
except Exception as error:
|
88 |
+
st.error(error)
|
89 |
+
st.json(yaml_dict)
|
90 |
+
return(None)
|
91 |
+
|
92 |
+
try:
|
93 |
+
st.session_state.input_data = modeling.encode_input_data()
|
94 |
+
except Exception as error:
|
95 |
+
st.error(error)
|
96 |
+
st.json(yaml_dict)
|
97 |
+
return(None)
|
98 |
+
|
99 |
+
if 'input_data' in st.session_state:
|
100 |
+
|
101 |
+
tab1, tab2, tab3 = st.tabs(["Item Correlations", "Scale Correlations", "Scale Reliabilities"])
|
102 |
+
|
103 |
+
with tab1:
|
104 |
+
st.markdown("Θ = Synthetic Item Correlation")
|
105 |
+
df = modeling.synthetic_item_correlations()
|
106 |
+
st.dataframe(df, use_container_width=True)
|
107 |
+
|
108 |
+
with tab2:
|
109 |
+
st.markdown("Θ = Synthetic Scale Correlation")
|
110 |
+
df = modeling.synthetic_scale_correlations()
|
111 |
+
st.dataframe(df, use_container_width=True)
|
112 |
+
|
113 |
+
with tab3:
|
114 |
+
st.markdown("alpha (Θ) = Synthetic Reliability Estimate")
|
115 |
+
if np.min(modeling.get_items_per_scale()) < 3:
|
116 |
+
st.error("Please make sure that each scale consits of at least 3 items!")
|
117 |
+
else:
|
118 |
+
df = modeling.synthetic_reliabilities()
|
119 |
+
st.dataframe(df, use_container_width=True)
|
120 |
+
|
121 |
+
if 'yaml_dict' in locals():
|
122 |
+
st.markdown("### Input Structure:")
|
123 |
+
st.json(yaml_dict)
|
124 |
+
|
125 |
+
def handle_checkbox_change():
|
126 |
+
# Update session state
|
127 |
+
st.session_state.checkbox_state = not st.session_state.checkbox_state
|
128 |
+
# You can also add additional actions to be triggered by the checkbox here
|
129 |
+
def initialize():
|
130 |
+
load_dotenv()
|
131 |
+
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
132 |
+
|
133 |
+
if 'state_loaded' not in st.session_state:
|
134 |
+
|
135 |
+
st.session_state['state_loaded'] = True
|
136 |
+
with open('init.json') as json_data:
|
137 |
+
st.session_state.update(json.load(json_data))
|
138 |
+
|
139 |
+
def main():
|
140 |
+
st.set_page_config(page_title='Synthetic Correlations')
|
141 |
+
|
142 |
+
col1, col2 = st.columns([2, 5])
|
143 |
+
with col1:
|
144 |
+
st.image('logo-130x130.svg')
|
145 |
+
|
146 |
+
with col2:
|
147 |
+
st.markdown("# Synthetic Correlations")
|
148 |
+
st.markdown("#### Estimate Item and Scale Correlations, as well as Reliability Coefficients based on nothing but Text!")
|
149 |
+
|
150 |
+
st.markdown("""
|
151 |
+
|
152 |
+
📖 **Preprint (Open Access)**: https://osf.io/preprints/psyarxiv/kjuce
|
153 |
+
|
154 |
+
🖊️ **Cite**: *Hommel, B. E., & Arslan, R. C. (2024). Language models accurately infer correlations between psychological items and scales from text alone. https://doi.org/10.31234/osf.io/kjuce*
|
155 |
+
|
156 |
+
🌐 **Project website**: https://synth-science.github.io/surveybot3000/
|
157 |
+
|
158 |
+
💾 **Data**: https://osf.io/z47qs/
|
159 |
+
|
160 |
+
#️⃣ **Social Media**:
|
161 |
+
- [Björn Hommel on X/Twitter](https://twitter.com/BjoernHommel)
|
162 |
+
- [Ruben Arslan on X/Twitter](https://twitter.com/rubenarslan/)
|
163 |
+
|
164 |
+
The web application is maintained by [magnolia psychometrics](https://www.magnolia-psychometrics.com/).
|
165 |
+
""", unsafe_allow_html=True)
|
166 |
+
|
167 |
+
placeholder_launch = st.empty()
|
168 |
+
placeholder_demo = st.empty()
|
169 |
+
|
170 |
+
if 'input_yaml' not in st.session_state:
|
171 |
+
|
172 |
+
with open('sample_input.yaml', 'r') as file:
|
173 |
+
try:
|
174 |
+
st.session_state['input_yaml'] = file.read()
|
175 |
+
except Exception as error:
|
176 |
+
print(error)
|
177 |
+
|
178 |
+
if 'disclaimer' not in st.session_state:
|
179 |
+
show_launch(placeholder_launch)
|
180 |
+
st.session_state['disclaimer'] = True
|
181 |
+
else:
|
182 |
+
show_demo(placeholder_demo)
|
183 |
+
|
184 |
+
if __name__ == '__main__':
|
185 |
+
initialize()
|
186 |
+
main()
|
init.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"results_as_matrix" : false
|
3 |
+
}
|
logo-130x130.svg
ADDED
modeling.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import streamlit as st
|
4 |
+
import pandas as pd
|
5 |
+
import pingouin as pg
|
6 |
+
from sentence_transformers import SentenceTransformer, util
|
7 |
+
|
8 |
+
def load_model():
|
9 |
+
|
10 |
+
if st.session_state.get('model') is None:
|
11 |
+
with st.spinner('Loading the model might take a couple of seconds...'):
|
12 |
+
|
13 |
+
if os.environ.get('remote_model_path'):
|
14 |
+
model_path = os.environ.get('remote_model_path')
|
15 |
+
else:
|
16 |
+
model_path = os.getenv('model_path')
|
17 |
+
|
18 |
+
st.session_state.model = SentenceTransformer(
|
19 |
+
model_name_or_path=model_path#,
|
20 |
+
#use_auth_token=
|
21 |
+
)
|
22 |
+
|
23 |
+
logging.info('Loaded SurveyBot3000!')
|
24 |
+
|
25 |
+
def process_yaml_input(yaml_dict):
|
26 |
+
|
27 |
+
input_data = pd.DataFrame({k: pd.Series(v) for k, v in yaml_dict.items()})
|
28 |
+
df = (
|
29 |
+
input_data
|
30 |
+
.stack()
|
31 |
+
.reset_index()
|
32 |
+
.drop('level_0', axis=1)
|
33 |
+
.rename(columns={'level_1': 'scale', 0: "item"})
|
34 |
+
)
|
35 |
+
return df
|
36 |
+
|
37 |
+
def get_items_per_scale():
|
38 |
+
input_data = st.session_state.input_data
|
39 |
+
items_per_scale = input_data.groupby('scale').size().tolist()
|
40 |
+
return(items_per_scale)
|
41 |
+
|
42 |
+
def encode_input_data():
|
43 |
+
|
44 |
+
with st.spinner('Encoding items...'):
|
45 |
+
input_data = st.session_state.input_data
|
46 |
+
input_data['embeddings'] = input_data.item.apply(lambda x: st.session_state.model.encode(
|
47 |
+
sentences=x,
|
48 |
+
convert_to_numpy=True
|
49 |
+
))
|
50 |
+
|
51 |
+
return(input_data)
|
52 |
+
|
53 |
+
def synthetic_item_correlations():
|
54 |
+
|
55 |
+
df = pd.DataFrame(
|
56 |
+
data = util.cos_sim(
|
57 |
+
a=st.session_state.input_data.embeddings,
|
58 |
+
b=st.session_state.input_data.embeddings
|
59 |
+
),
|
60 |
+
columns=st.session_state.input_data.item,
|
61 |
+
index=st.session_state.input_data.item
|
62 |
+
).round(2)
|
63 |
+
|
64 |
+
if st.session_state.results_as_matrix is False:
|
65 |
+
df = (
|
66 |
+
df
|
67 |
+
.reset_index()
|
68 |
+
.melt(id_vars=['item'], var_name='item_b', value_name='Θ')
|
69 |
+
.rename(columns={'item': 'item_a'})
|
70 |
+
.query('item_a < item_b')
|
71 |
+
)
|
72 |
+
|
73 |
+
return(df)
|
74 |
+
|
75 |
+
|
76 |
+
def synthetic_scale_correlations():
|
77 |
+
|
78 |
+
scales = st.session_state.input_data.scale
|
79 |
+
embeddings = st.session_state.input_data.embeddings.apply(pd.Series)
|
80 |
+
|
81 |
+
def func(group_data):
|
82 |
+
return(group_data.T.iloc[1:,:].mean(axis=1))
|
83 |
+
x = pd.concat([scales, embeddings], axis=1).groupby('scale').apply(lambda group: func(group))
|
84 |
+
|
85 |
+
print(x.T.corr())
|
86 |
+
|
87 |
+
data = (
|
88 |
+
pd
|
89 |
+
.concat([scales, embeddings], axis=1)
|
90 |
+
.groupby('scale')
|
91 |
+
.mean()
|
92 |
+
.reset_index()
|
93 |
+
)
|
94 |
+
|
95 |
+
mean_embeddings = data.apply(lambda row: [row[col] for col in data.columns if col != 'scale'], axis=1)
|
96 |
+
matrix = util.cos_sim(a=mean_embeddings, b=mean_embeddings)
|
97 |
+
df = pd.DataFrame(
|
98 |
+
data=matrix,
|
99 |
+
columns = data.scale.tolist(),
|
100 |
+
index=data.scale.tolist()
|
101 |
+
).round(2)
|
102 |
+
|
103 |
+
|
104 |
+
if st.session_state.results_as_matrix is False:
|
105 |
+
df = (
|
106 |
+
df
|
107 |
+
.reset_index()
|
108 |
+
.melt(id_vars='index', var_name='scale_b', value_name='Θ')
|
109 |
+
.rename(columns={'index': 'scale_a'})
|
110 |
+
.query('scale_a < scale_b')
|
111 |
+
)
|
112 |
+
|
113 |
+
return(df)
|
114 |
+
|
115 |
+
def synthetic_reliabilities():
|
116 |
+
|
117 |
+
def reliability(group_data):
|
118 |
+
group_data = group_data.drop('scale', axis=1).T
|
119 |
+
alpha = pg.cronbach_alpha(data=group_data)
|
120 |
+
x = [alpha[0], alpha[1][0], alpha[1][1]]
|
121 |
+
|
122 |
+
return(x)
|
123 |
+
|
124 |
+
scales = st.session_state.input_data.scale
|
125 |
+
embeddings = st.session_state.input_data.embeddings.apply(pd.Series)
|
126 |
+
|
127 |
+
data = (
|
128 |
+
pd
|
129 |
+
.concat([scales, embeddings], axis=1)
|
130 |
+
.groupby('scale')
|
131 |
+
.apply(lambda group: reliability(group))
|
132 |
+
)
|
133 |
+
|
134 |
+
df = pd.DataFrame(
|
135 |
+
data=[[v] + data.tolist()[k] for k, v in enumerate(data.index.tolist())],
|
136 |
+
columns=['scale', 'alpha (Θ)', 'ci_lower', 'ci_upper']
|
137 |
+
).round(2)
|
138 |
+
|
139 |
+
return(df)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
yaml
|
3 |
+
json
|
4 |
+
pandas==2.2.0
|
5 |
+
numpy==1.26.3
|
6 |
+
sentence_transformers==2.2.2
|
7 |
+
sentencepiece==0.1.99
|
8 |
+
altair==4.2.2
|
9 |
+
pingouin==0.5.4
|
10 |
+
python-dotenv
|
sample_input.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Buttox-Fixation:
|
2 |
+
- I like big butts.
|
3 |
+
- But that butt you got makes me so horny.
|
4 |
+
- Baby got back.
|
5 |
+
- Shake that healthy butt.
|
6 |
+
- My anaconda don´t want want none, unless you got buns, hun.
|
7 |
+
- So ladies if tha butt is round
|
8 |
+
- But please don´ lose that butt
|
9 |
+
Delinquent behavior:
|
10 |
+
- I cannot lie.
|
11 |
+
- I am hooked an´ I cannot stop starin´
|
12 |
+
- I am actin´ like an animal.
|
utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code by Martijn Pieters https://stackoverflow.com/a/23499088/1114975
|
2 |
+
from functools import singledispatch, wraps
|
3 |
+
|
4 |
+
@singledispatch
|
5 |
+
def depth(_, _level=1, _memo=None):
|
6 |
+
return _level
|
7 |
+
|
8 |
+
def _protect(f):
|
9 |
+
"""Protect against circular references"""
|
10 |
+
@wraps(f)
|
11 |
+
def wrapper(o, _level=1, _memo=None, **kwargs):
|
12 |
+
_memo, id_ = _memo or set(), id(o)
|
13 |
+
if id_ in _memo: return _level
|
14 |
+
_memo.add(id_)
|
15 |
+
return f(o, _level=_level, _memo=_memo, **kwargs)
|
16 |
+
return wrapper
|
17 |
+
|
18 |
+
def _protected_register(cls, func=None, _orig=depth.register):
|
19 |
+
"""Include the _protect decorator when registering"""
|
20 |
+
if func is None and isinstance(cls, type):
|
21 |
+
return lambda f: _orig(cls, _protect(f))
|
22 |
+
return _orig(cls, _protect(func)) if func is not None else _orig(_protect(cls))
|
23 |
+
depth.register = _protected_register
|
24 |
+
|
25 |
+
@depth.register
|
26 |
+
def _dict_depth(d: dict, _level=1, **kw):
|
27 |
+
return max(depth(v, _level=_level + 1, **kw) for v in d.values())
|