Spaces:
Build error
Build error
# | |
# Code for managing session state, which is needed for multi-input forms | |
# See https://github.com/streamlit/streamlit/issues/1557 | |
# | |
# This code is taken from | |
# https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662 | |
# | |
from streamlit.hashing import _CodeHasher | |
from streamlit.report_thread import get_report_ctx | |
from streamlit.server.server import Server | |
class _SessionState: | |
def __init__(self, session, hash_funcs): | |
"""Initialize SessionState instance.""" | |
self.__dict__["_state"] = { | |
"data": {}, | |
"hash": None, | |
"hasher": _CodeHasher(hash_funcs), | |
"is_rerun": False, | |
"session": session, | |
} | |
def __call__(self, **kwargs): | |
"""Initialize state data once.""" | |
for item, value in kwargs.items(): | |
if item not in self._state["data"]: | |
self._state["data"][item] = value | |
def __getitem__(self, item): | |
"""Return a saved state value, None if item is undefined.""" | |
return self._state["data"].get(item, None) | |
def __getattr__(self, item): | |
"""Return a saved state value, None if item is undefined.""" | |
return self._state["data"].get(item, None) | |
def __setitem__(self, item, value): | |
"""Set state value.""" | |
self._state["data"][item] = value | |
def __setattr__(self, item, value): | |
"""Set state value.""" | |
self._state["data"][item] = value | |
def clear(self): | |
"""Clear session state and request a rerun.""" | |
self._state["data"].clear() | |
self._state["session"].request_rerun(None) | |
def sync(self): | |
""" | |
Rerun the app with all state values up to date from the beginning to | |
fix rollbacks. | |
""" | |
data_to_bytes = self._state["hasher"].to_bytes(self._state["data"], None) | |
# Ensure to rerun only once to avoid infinite loops | |
# caused by a constantly changing state value at each run. | |
# | |
# Example: state.value += 1 | |
if self._state["is_rerun"]: | |
self._state["is_rerun"] = False | |
elif self._state["hash"] is not None: | |
if self._state["hash"] != data_to_bytes: | |
self._state["is_rerun"] = True | |
self._state["session"].request_rerun(None) | |
self._state["hash"] = data_to_bytes | |
def _get_session(): | |
session_id = get_report_ctx().session_id | |
session_info = Server.get_current()._get_session_info(session_id) | |
if session_info is None: | |
raise RuntimeError("Couldn't get your Streamlit Session object.") | |
return session_info.session | |
def _get_state(hash_funcs=None): | |
session = _get_session() | |
if not hasattr(session, "_custom_session_state"): | |
session._custom_session_state = _SessionState(session, hash_funcs) | |
return session._custom_session_state | |