|
"""Adds pre-session state to StreamLit. |
|
|
|
This file is borrowed from |
|
https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92 |
|
""" |
|
|
|
|
|
|
|
try: |
|
import streamlit.ReportThread as ReportThread |
|
from streamlit.server.Server import Server |
|
except ModuleNotFoundError: |
|
|
|
import streamlit.report_thread as ReportThread |
|
from streamlit.server.server import Server |
|
|
|
|
|
class SessionState(object): |
|
"""Hack to add per-session state to Streamlit. |
|
|
|
Usage |
|
----- |
|
|
|
>>> import SessionState |
|
>>> |
|
>>> session_state = SessionState.get(user_name='', favorite_color='black') |
|
>>> session_state.user_name |
|
'' |
|
>>> session_state.user_name = 'Mary' |
|
>>> session_state.favorite_color |
|
'black' |
|
|
|
Since you set user_name above, next time your script runs this will be the |
|
result: |
|
>>> session_state = get(user_name='', favorite_color='black') |
|
>>> session_state.user_name |
|
'Mary' |
|
|
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
"""A new SessionState object. |
|
|
|
Parameters |
|
---------- |
|
**kwargs : any |
|
Default values for the session state. |
|
|
|
Example |
|
------- |
|
>>> session_state = SessionState(user_name='', favorite_color='black') |
|
>>> session_state.user_name = 'Mary' |
|
'' |
|
>>> session_state.favorite_color |
|
'black' |
|
|
|
""" |
|
for key, val in kwargs.items(): |
|
setattr(self, key, val) |
|
|
|
|
|
def get(**kwargs): |
|
"""Gets a SessionState object for the current session. |
|
|
|
Creates a new object if necessary. |
|
|
|
Parameters |
|
---------- |
|
**kwargs : any |
|
Default values you want to add to the session state, if we're creating a |
|
new one. |
|
|
|
Example |
|
------- |
|
>>> session_state = get(user_name='', favorite_color='black') |
|
>>> session_state.user_name |
|
'' |
|
>>> session_state.user_name = 'Mary' |
|
>>> session_state.favorite_color |
|
'black' |
|
|
|
Since you set user_name above, next time your script runs this will be the |
|
result: |
|
>>> session_state = get(user_name='', favorite_color='black') |
|
>>> session_state.user_name |
|
'Mary' |
|
|
|
""" |
|
|
|
|
|
ctx = ReportThread.get_report_ctx() |
|
|
|
this_session = None |
|
|
|
current_server = Server.get_current() |
|
if hasattr(current_server, '_session_infos'): |
|
|
|
session_infos = Server.get_current()._session_infos.values() |
|
else: |
|
session_infos = Server.get_current()._session_info_by_id.values() |
|
|
|
for session_info in session_infos: |
|
s = session_info.session |
|
if ( |
|
|
|
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg) |
|
or |
|
|
|
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue) |
|
or |
|
|
|
(not hasattr(s, '_main_dg') and |
|
s._uploaded_file_mgr == ctx.uploaded_file_mgr) |
|
): |
|
this_session = s |
|
|
|
if this_session is None: |
|
raise RuntimeError( |
|
"Oh noes. Couldn't get your Streamlit Session object. " |
|
'Are you doing something fancy with threads?') |
|
|
|
|
|
|
|
if not hasattr(this_session, '_custom_session_state'): |
|
this_session._custom_session_state = SessionState(**kwargs) |
|
|
|
return this_session._custom_session_state |
|
|
|
|
|
|