bjorn-hommel commited on
Commit
4832fb3
1 Parent(s): b7cee2e

init commit

Browse files
Files changed (9) hide show
  1. .env +1 -0
  2. .gitignore +1 -0
  3. app.py +186 -0
  4. init.json +3 -0
  5. logo-130x130.svg +35 -0
  6. modeling.py +139 -0
  7. requirements.txt +10 -0
  8. sample_input.yaml +12 -0
  9. utils.py +27 -0
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ model_path="/nlp/models/published/surveybot3000"
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import logging
4
+ import json
5
+ import yaml
6
+ import pandas as pd
7
+ import numpy as np
8
+ from dotenv import load_dotenv
9
+
10
+ import modeling
11
+
12
+ def show_launch(placeholder):
13
+ with placeholder.container():
14
+ st.divider()
15
+ st.markdown("""
16
+ ## Before Using the App
17
+ ### Disclaimer
18
+ This application is provided as-is, without any warranty or guarantee of any kind, expressed or implied. It is intended for educational, non-commercial use only.
19
+ The developers of this app shall not be held liable for any damages or losses incurred from its use. By using this application, you agree to the terms and conditions
20
+ outlined herein and acknowledge that any commercial use or reliance on its functionality is strictly prohibited.
21
+
22
+ Furthermore, by using this application, you consent to the collection of anonymous usage data. This data will be used for research purposes and to improve the
23
+ application's functionality. No personal information will be recorded or stored.
24
+ """, unsafe_allow_html=True)
25
+
26
+ button_placeholder = st.empty()
27
+
28
+ if button_placeholder.button(label='Accept Disclaimer', type='primary', use_container_width=True):
29
+ st.session_state.show_launch = False
30
+ placeholder.empty()
31
+ button_placeholder.empty()
32
+
33
+ def show_demo(placeholder):
34
+
35
+ with placeholder:
36
+ with st.container():
37
+ st.divider()
38
+ st.markdown("""
39
+ ## Try it yourself!
40
+ Our recent research shows that sentence transformer ("AI" models)
41
+ can predict respondent patterns in survey data! The model accurately
42
+ infers item-correlation with *r* = **.71** 🧨, and shows even higher
43
+ precision for scale correlations (*r* = **.89** 💥) and reliability
44
+ coefficients (*r* = **.86** 💣)!
45
+
46
+ Try it yourself by defining a scale structure using the input field
47
+ below and let the **SurveyBot3000** predict the expected response
48
+ pattern. Use the [YAML](https://yaml.org/) format or follow the structure
49
+ outlined by the preset example.
50
+ """)
51
+
52
+ with st.form("my_form"):
53
+
54
+ input_yaml = st.text_area(
55
+ label="Questionnaire Structure (YAML-Formatted)",
56
+ value=st.session_state['input_yaml'],
57
+ height=250
58
+ )
59
+
60
+ st.session_state.results_as_matrix = st.checkbox(
61
+ label="Result as matrix",
62
+ help="Results will be list-formated (long) by default. Enable to get (wide-format) matrices."
63
+ )
64
+
65
+ submitted = st.form_submit_button(
66
+ label="Get Synthetic Estimates",
67
+ type="primary",
68
+ use_container_width=True
69
+ )
70
+ if submitted:
71
+
72
+ try:
73
+ yaml_dict = yaml.safe_load(input_yaml)
74
+ except yaml.YAMLError as e:
75
+ st.error(f"Yikes, you better get your YAML straight! Check https://yaml.org/ for help!")
76
+ return(None)
77
+
78
+ try:
79
+ modeling.load_model()
80
+ except Exception as error:
81
+ st.error(f"Error while loading model: {error}")
82
+ st.json(yaml_dict)
83
+ return(None)
84
+
85
+ try:
86
+ st.session_state.input_data = modeling.process_yaml_input(yaml_dict)
87
+ except Exception as error:
88
+ st.error(error)
89
+ st.json(yaml_dict)
90
+ return(None)
91
+
92
+ try:
93
+ st.session_state.input_data = modeling.encode_input_data()
94
+ except Exception as error:
95
+ st.error(error)
96
+ st.json(yaml_dict)
97
+ return(None)
98
+
99
+ if 'input_data' in st.session_state:
100
+
101
+ tab1, tab2, tab3 = st.tabs(["Item Correlations", "Scale Correlations", "Scale Reliabilities"])
102
+
103
+ with tab1:
104
+ st.markdown("Θ = Synthetic Item Correlation")
105
+ df = modeling.synthetic_item_correlations()
106
+ st.dataframe(df, use_container_width=True)
107
+
108
+ with tab2:
109
+ st.markdown("Θ = Synthetic Scale Correlation")
110
+ df = modeling.synthetic_scale_correlations()
111
+ st.dataframe(df, use_container_width=True)
112
+
113
+ with tab3:
114
+ st.markdown("alpha (Θ) = Synthetic Reliability Estimate")
115
+ if np.min(modeling.get_items_per_scale()) < 3:
116
+ st.error("Please make sure that each scale consits of at least 3 items!")
117
+ else:
118
+ df = modeling.synthetic_reliabilities()
119
+ st.dataframe(df, use_container_width=True)
120
+
121
+ if 'yaml_dict' in locals():
122
+ st.markdown("### Input Structure:")
123
+ st.json(yaml_dict)
124
+
125
+ def handle_checkbox_change():
126
+ # Update session state
127
+ st.session_state.checkbox_state = not st.session_state.checkbox_state
128
+ # You can also add additional actions to be triggered by the checkbox here
129
+ def initialize():
130
+ load_dotenv()
131
+ logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
132
+
133
+ if 'state_loaded' not in st.session_state:
134
+
135
+ st.session_state['state_loaded'] = True
136
+ with open('init.json') as json_data:
137
+ st.session_state.update(json.load(json_data))
138
+
139
+ def main():
140
+ st.set_page_config(page_title='Synthetic Correlations')
141
+
142
+ col1, col2 = st.columns([2, 5])
143
+ with col1:
144
+ st.image('logo-130x130.svg')
145
+
146
+ with col2:
147
+ st.markdown("# Synthetic Correlations")
148
+ st.markdown("#### Estimate Item and Scale Correlations, as well as Reliability Coefficients based on nothing but Text!")
149
+
150
+ st.markdown("""
151
+
152
+ 📖 **Preprint (Open Access)**: https://osf.io/preprints/psyarxiv/kjuce
153
+
154
+ 🖊️ **Cite**: *Hommel, B. E., & Arslan, R. C. (2024). Language models accurately infer correlations between psychological items and scales from text alone. https://doi.org/10.31234/osf.io/kjuce*
155
+
156
+ 🌐 **Project website**: https://synth-science.github.io/surveybot3000/
157
+
158
+ 💾 **Data**: https://osf.io/z47qs/
159
+
160
+ #️⃣ **Social Media**:
161
+ - [Björn Hommel on X/Twitter](https://twitter.com/BjoernHommel)
162
+ - [Ruben Arslan on X/Twitter](https://twitter.com/rubenarslan/)
163
+
164
+ The web application is maintained by [magnolia psychometrics](https://www.magnolia-psychometrics.com/).
165
+ """, unsafe_allow_html=True)
166
+
167
+ placeholder_launch = st.empty()
168
+ placeholder_demo = st.empty()
169
+
170
+ if 'input_yaml' not in st.session_state:
171
+
172
+ with open('sample_input.yaml', 'r') as file:
173
+ try:
174
+ st.session_state['input_yaml'] = file.read()
175
+ except Exception as error:
176
+ print(error)
177
+
178
+ if 'disclaimer' not in st.session_state:
179
+ show_launch(placeholder_launch)
180
+ st.session_state['disclaimer'] = True
181
+ else:
182
+ show_demo(placeholder_demo)
183
+
184
+ if __name__ == '__main__':
185
+ initialize()
186
+ main()
init.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "results_as_matrix" : false
3
+ }
logo-130x130.svg ADDED
modeling.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import pingouin as pg
6
+ from sentence_transformers import SentenceTransformer, util
7
+
8
+ def load_model():
9
+
10
+ if st.session_state.get('model') is None:
11
+ with st.spinner('Loading the model might take a couple of seconds...'):
12
+
13
+ if os.environ.get('remote_model_path'):
14
+ model_path = os.environ.get('remote_model_path')
15
+ else:
16
+ model_path = os.getenv('model_path')
17
+
18
+ st.session_state.model = SentenceTransformer(
19
+ model_name_or_path=model_path#,
20
+ #use_auth_token=
21
+ )
22
+
23
+ logging.info('Loaded SurveyBot3000!')
24
+
25
+ def process_yaml_input(yaml_dict):
26
+
27
+ input_data = pd.DataFrame({k: pd.Series(v) for k, v in yaml_dict.items()})
28
+ df = (
29
+ input_data
30
+ .stack()
31
+ .reset_index()
32
+ .drop('level_0', axis=1)
33
+ .rename(columns={'level_1': 'scale', 0: "item"})
34
+ )
35
+ return df
36
+
37
+ def get_items_per_scale():
38
+ input_data = st.session_state.input_data
39
+ items_per_scale = input_data.groupby('scale').size().tolist()
40
+ return(items_per_scale)
41
+
42
+ def encode_input_data():
43
+
44
+ with st.spinner('Encoding items...'):
45
+ input_data = st.session_state.input_data
46
+ input_data['embeddings'] = input_data.item.apply(lambda x: st.session_state.model.encode(
47
+ sentences=x,
48
+ convert_to_numpy=True
49
+ ))
50
+
51
+ return(input_data)
52
+
53
+ def synthetic_item_correlations():
54
+
55
+ df = pd.DataFrame(
56
+ data = util.cos_sim(
57
+ a=st.session_state.input_data.embeddings,
58
+ b=st.session_state.input_data.embeddings
59
+ ),
60
+ columns=st.session_state.input_data.item,
61
+ index=st.session_state.input_data.item
62
+ ).round(2)
63
+
64
+ if st.session_state.results_as_matrix is False:
65
+ df = (
66
+ df
67
+ .reset_index()
68
+ .melt(id_vars=['item'], var_name='item_b', value_name='Θ')
69
+ .rename(columns={'item': 'item_a'})
70
+ .query('item_a < item_b')
71
+ )
72
+
73
+ return(df)
74
+
75
+
76
+ def synthetic_scale_correlations():
77
+
78
+ scales = st.session_state.input_data.scale
79
+ embeddings = st.session_state.input_data.embeddings.apply(pd.Series)
80
+
81
+ def func(group_data):
82
+ return(group_data.T.iloc[1:,:].mean(axis=1))
83
+ x = pd.concat([scales, embeddings], axis=1).groupby('scale').apply(lambda group: func(group))
84
+
85
+ print(x.T.corr())
86
+
87
+ data = (
88
+ pd
89
+ .concat([scales, embeddings], axis=1)
90
+ .groupby('scale')
91
+ .mean()
92
+ .reset_index()
93
+ )
94
+
95
+ mean_embeddings = data.apply(lambda row: [row[col] for col in data.columns if col != 'scale'], axis=1)
96
+ matrix = util.cos_sim(a=mean_embeddings, b=mean_embeddings)
97
+ df = pd.DataFrame(
98
+ data=matrix,
99
+ columns = data.scale.tolist(),
100
+ index=data.scale.tolist()
101
+ ).round(2)
102
+
103
+
104
+ if st.session_state.results_as_matrix is False:
105
+ df = (
106
+ df
107
+ .reset_index()
108
+ .melt(id_vars='index', var_name='scale_b', value_name='Θ')
109
+ .rename(columns={'index': 'scale_a'})
110
+ .query('scale_a < scale_b')
111
+ )
112
+
113
+ return(df)
114
+
115
+ def synthetic_reliabilities():
116
+
117
+ def reliability(group_data):
118
+ group_data = group_data.drop('scale', axis=1).T
119
+ alpha = pg.cronbach_alpha(data=group_data)
120
+ x = [alpha[0], alpha[1][0], alpha[1][1]]
121
+
122
+ return(x)
123
+
124
+ scales = st.session_state.input_data.scale
125
+ embeddings = st.session_state.input_data.embeddings.apply(pd.Series)
126
+
127
+ data = (
128
+ pd
129
+ .concat([scales, embeddings], axis=1)
130
+ .groupby('scale')
131
+ .apply(lambda group: reliability(group))
132
+ )
133
+
134
+ df = pd.DataFrame(
135
+ data=[[v] + data.tolist()[k] for k, v in enumerate(data.index.tolist())],
136
+ columns=['scale', 'alpha (Θ)', 'ci_lower', 'ci_upper']
137
+ ).round(2)
138
+
139
+ return(df)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ yaml
3
+ json
4
+ pandas==2.2.0
5
+ numpy==1.26.3
6
+ sentence_transformers==2.2.2
7
+ sentencepiece==0.1.99
8
+ altair==4.2.2
9
+ pingouin==0.5.4
10
+ python-dotenv
sample_input.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Buttox-Fixation:
2
+ - I like big butts.
3
+ - But that butt you got makes me so horny.
4
+ - Baby got back.
5
+ - Shake that healthy butt.
6
+ - My anaconda don´t want want none, unless you got buns, hun.
7
+ - So ladies if tha butt is round
8
+ - But please don´ lose that butt
9
+ Delinquent behavior:
10
+ - I cannot lie.
11
+ - I am hooked an´ I cannot stop starin´
12
+ - I am actin´ like an animal.
utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code by Martijn Pieters https://stackoverflow.com/a/23499088/1114975
2
+ from functools import singledispatch, wraps
3
+
4
+ @singledispatch
5
+ def depth(_, _level=1, _memo=None):
6
+ return _level
7
+
8
+ def _protect(f):
9
+ """Protect against circular references"""
10
+ @wraps(f)
11
+ def wrapper(o, _level=1, _memo=None, **kwargs):
12
+ _memo, id_ = _memo or set(), id(o)
13
+ if id_ in _memo: return _level
14
+ _memo.add(id_)
15
+ return f(o, _level=_level, _memo=_memo, **kwargs)
16
+ return wrapper
17
+
18
+ def _protected_register(cls, func=None, _orig=depth.register):
19
+ """Include the _protect decorator when registering"""
20
+ if func is None and isinstance(cls, type):
21
+ return lambda f: _orig(cls, _protect(f))
22
+ return _orig(cls, _protect(func)) if func is not None else _orig(_protect(cls))
23
+ depth.register = _protected_register
24
+
25
+ @depth.register
26
+ def _dict_depth(d: dict, _level=1, **kw):
27
+ return max(depth(v, _level=_level + 1, **kw) for v in d.values())