# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import asyncio import time import traceback from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Awaitable, Dict, NamedTuple, Optional, Tuple, Type from typing_extensions import Final from streamlit import config from streamlit.logger import get_logger from streamlit.proto.BackMsg_pb2 import BackMsg from streamlit.proto.ForwardMsg_pb2 import ForwardMsg from streamlit.runtime.app_session import AppSession from streamlit.runtime.caching import ( get_data_cache_stats_provider, get_resource_cache_stats_provider, ) from streamlit.runtime.caching.storage.local_disk_cache_storage import ( LocalDiskCacheStorageManager, ) from streamlit.runtime.forward_msg_cache import ( ForwardMsgCache, create_reference_msg, populate_hash_if_needed, ) from streamlit.runtime.legacy_caching.caching import _mem_caches from streamlit.runtime.media_file_manager import MediaFileManager from streamlit.runtime.media_file_storage import MediaFileStorage from streamlit.runtime.memory_session_storage import MemorySessionStorage from streamlit.runtime.runtime_util import is_cacheable_msg from streamlit.runtime.script_data import ScriptData from streamlit.runtime.scriptrunner.script_cache import ScriptCache from streamlit.runtime.session_manager import ( ActiveSessionInfo, SessionClient, SessionClientDisconnectedError, SessionManager, SessionStorage, ) from streamlit.runtime.state import ( SCRIPT_RUN_WITHOUT_ERRORS_KEY, SessionStateStatProvider, ) from streamlit.runtime.stats import StatsManager from streamlit.runtime.uploaded_file_manager import UploadedFileManager from streamlit.runtime.websocket_session_manager import WebsocketSessionManager from streamlit.watcher import LocalSourcesWatcher if TYPE_CHECKING: from streamlit.runtime.caching.storage import CacheStorageManager # Wait for the script run result for 60s and if no result is available give up SCRIPT_RUN_CHECK_TIMEOUT: Final = 60 LOGGER: Final = get_logger(__name__) class RuntimeStoppedError(Exception): """Raised by operations on a Runtime instance that is stopped.""" @dataclass(frozen=True) class RuntimeConfig: """Config options for StreamlitRuntime.""" # The filesystem path of the Streamlit script to run. script_path: str # The (optional) command line that Streamlit was started with # (e.g. "streamlit run app.py") command_line: Optional[str] # The storage backend for Streamlit's MediaFileManager. media_file_storage: MediaFileStorage # The upload file manager uploaded_file_manager: UploadedFileManager # The cache storage backend for Streamlit's st.cache_data. cache_storage_manager: CacheStorageManager = field( default_factory=LocalDiskCacheStorageManager ) # The SessionManager class to be used. session_manager_class: Type[SessionManager] = WebsocketSessionManager # The SessionStorage instance for the SessionManager to use. session_storage: SessionStorage = field(default_factory=MemorySessionStorage) class RuntimeState(Enum): INITIAL = "INITIAL" NO_SESSIONS_CONNECTED = "NO_SESSIONS_CONNECTED" ONE_OR_MORE_SESSIONS_CONNECTED = "ONE_OR_MORE_SESSIONS_CONNECTED" STOPPING = "STOPPING" STOPPED = "STOPPED" class AsyncObjects(NamedTuple): """Container for all asyncio objects that Runtime manages. These cannot be initialized until the Runtime's eventloop is assigned. """ # The eventloop that Runtime is running on. eventloop: asyncio.AbstractEventLoop # Set after Runtime.stop() is called. Never cleared. must_stop: asyncio.Event # Set when a client connects; cleared when we have no connected clients. has_connection: asyncio.Event # Set after a ForwardMsg is enqueued; cleared when we flush ForwardMsgs. need_send_data: asyncio.Event # Completed when the Runtime has started. started: asyncio.Future[None] # Completed when the Runtime has stopped. stopped: asyncio.Future[None] class Runtime: _instance: Optional[Runtime] = None @classmethod def instance(cls) -> Runtime: """Return the singleton Runtime instance. Raise an Error if the Runtime hasn't been created yet. """ if cls._instance is None: raise RuntimeError("Runtime hasn't been created!") return cls._instance @classmethod def exists(cls) -> bool: """True if the singleton Runtime instance has been created. When a Streamlit app is running in "raw mode" - that is, when the app is run via `python app.py` instead of `streamlit run app.py` - the Runtime will not exist, and various Streamlit functions need to adapt. """ return cls._instance is not None def __init__(self, config: RuntimeConfig): """Create a Runtime instance. It won't be started yet. Runtime is *not* thread-safe. Its public methods are generally safe to call only on the same thread that its event loop runs on. Parameters ---------- config Config options. """ if Runtime._instance is not None: raise RuntimeError("Runtime instance already exists!") Runtime._instance = self # Will be created when we start. self._async_objs: Optional[AsyncObjects] = None # The task that runs our main loop. We need to save a reference # to it so that it doesn't get garbage collected while running. self._loop_coroutine_task: Optional[asyncio.Task[None]] = None self._main_script_path = config.script_path self._command_line = config.command_line or "" self._state = RuntimeState.INITIAL # Initialize managers self._message_cache = ForwardMsgCache() self._uploaded_file_mgr = config.uploaded_file_manager self._media_file_mgr = MediaFileManager(storage=config.media_file_storage) self._cache_storage_manager = config.cache_storage_manager self._script_cache = ScriptCache() self._session_mgr = config.session_manager_class( session_storage=config.session_storage, uploaded_file_manager=self._uploaded_file_mgr, script_cache=self._script_cache, message_enqueued_callback=self._enqueued_some_message, ) self._stats_mgr = StatsManager() self._stats_mgr.register_provider(get_data_cache_stats_provider()) self._stats_mgr.register_provider(get_resource_cache_stats_provider()) self._stats_mgr.register_provider(_mem_caches) self._stats_mgr.register_provider(self._message_cache) self._stats_mgr.register_provider(self._uploaded_file_mgr) self._stats_mgr.register_provider(SessionStateStatProvider(self._session_mgr)) @property def state(self) -> RuntimeState: return self._state @property def message_cache(self) -> ForwardMsgCache: return self._message_cache @property def uploaded_file_mgr(self) -> UploadedFileManager: return self._uploaded_file_mgr @property def cache_storage_manager(self) -> CacheStorageManager: return self._cache_storage_manager @property def media_file_mgr(self) -> MediaFileManager: return self._media_file_mgr @property def stats_mgr(self) -> StatsManager: return self._stats_mgr @property def stopped(self) -> Awaitable[None]: """A Future that completes when the Runtime's run loop has exited.""" return self._get_async_objs().stopped # NOTE: A few Runtime methods listed as threadsafe (get_client and # is_active_session) currently rely on the implementation detail that # WebsocketSessionManager's get_active_session_info and is_active_session methods # happen to be threadsafe. This may change with future SessionManager implementations, # at which point we'll need to formalize our thread safety rules for each # SessionManager method. def get_client(self, session_id: str) -> Optional[SessionClient]: """Get the SessionClient for the given session_id, or None if no such session exists. Notes ----- Threading: SAFE. May be called on any thread. """ session_info = self._session_mgr.get_active_session_info(session_id) if session_info is None: return None return session_info.client async def start(self) -> None: """Start the runtime. This must be called only once, before any other functions are called. When this coroutine returns, Streamlit is ready to accept new sessions. Notes ----- Threading: UNSAFE. Must be called on the eventloop thread. """ # Create our AsyncObjects. We need to have a running eventloop to # instantiate our various synchronization primitives. async_objs = AsyncObjects( eventloop=asyncio.get_running_loop(), must_stop=asyncio.Event(), has_connection=asyncio.Event(), need_send_data=asyncio.Event(), started=asyncio.Future(), stopped=asyncio.Future(), ) self._async_objs = async_objs self._loop_coroutine_task = asyncio.create_task( self._loop_coroutine(), name="Runtime.loop_coroutine" ) await async_objs.started def stop(self) -> None: """Request that Streamlit close all sessions and stop running. Note that Streamlit won't stop running immediately. Notes ----- Threading: SAFE. May be called from any thread. """ async_objs = self._get_async_objs() def stop_on_eventloop(): if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED): return LOGGER.debug("Runtime stopping...") self._set_state(RuntimeState.STOPPING) async_objs.must_stop.set() async_objs.eventloop.call_soon_threadsafe(stop_on_eventloop) def is_active_session(self, session_id: str) -> bool: """True if the session_id belongs to an active session. Notes ----- Threading: SAFE. May be called on any thread. """ return self._session_mgr.is_active_session(session_id) def connect_session( self, client: SessionClient, user_info: Dict[str, Optional[str]], existing_session_id: Optional[str] = None, session_id_override: Optional[str] = None, ) -> str: """Create a new session (or connect to an existing one) and return its unique ID. Parameters ---------- client A concrete SessionClient implementation for communicating with the session's client. user_info A dict that contains information about the session's user. For now, it only (optionally) contains the user's email address. { "email": "example@example.com" } existing_session_id The ID of an existing session to reconnect to. If one is not provided, a new session is created. Note that whether the Runtime's SessionManager supports reconnecting to an existing session depends on the SessionManager that this runtime is configured with. session_id_override The ID to assign to a new session being created with this method. Setting this can be useful when the service that a Streamlit Runtime is running in wants to tie the lifecycle of a Streamlit session to some other session-like object that it manages. Only one of existing_session_id and session_id_override should be set. Returns ------- str The session's unique string ID. Notes ----- Threading: UNSAFE. Must be called on the eventloop thread. """ assert not ( existing_session_id and session_id_override ), "Only one of existing_session_id and session_id_override should be set!" if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED): raise RuntimeStoppedError(f"Can't connect_session (state={self._state})") session_id = self._session_mgr.connect_session( client=client, script_data=ScriptData(self._main_script_path, self._command_line or ""), user_info=user_info, existing_session_id=existing_session_id, session_id_override=session_id_override, ) self._set_state(RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED) self._get_async_objs().has_connection.set() return session_id def create_session( self, client: SessionClient, user_info: Dict[str, Optional[str]], existing_session_id: Optional[str] = None, session_id_override: Optional[str] = None, ) -> str: """Create a new session (or connect to an existing one) and return its unique ID. Notes ----- This method is simply an alias for connect_session added for backwards compatibility. """ LOGGER.warning("create_session is deprecated! Use connect_session instead.") return self.connect_session( client=client, user_info=user_info, existing_session_id=existing_session_id, session_id_override=session_id_override, ) def close_session(self, session_id: str) -> None: """Close and completely shut down a session. This differs from disconnect_session in that it always completely shuts down a session, permanently losing any associated state (session state, uploaded files, etc.). This function may be called multiple times for the same session, which is not an error. (Subsequent calls just no-op.) Parameters ---------- session_id The session's unique ID. Notes ----- Threading: UNSAFE. Must be called on the eventloop thread. """ session_info = self._session_mgr.get_session_info(session_id) if session_info: self._message_cache.remove_refs_for_session(session_info.session) self._session_mgr.close_session(session_id) self._on_session_disconnected() def disconnect_session(self, session_id: str) -> None: """Disconnect a session. It will stop producing ForwardMsgs. Differs from close_session because disconnected sessions can be reconnected to for a brief window (depending on the SessionManager/SessionStorage implementations used by the runtime). This function may be called multiple times for the same session, which is not an error. (Subsequent calls just no-op.) Parameters ---------- session_id The session's unique ID. Notes ----- Threading: UNSAFE. Must be called on the eventloop thread. """ session_info = self._session_mgr.get_active_session_info(session_id) if session_info: # NOTE: Ideally, we'd like to keep ForwardMsgCache refs for a session around # when a session is disconnected (and defer their cleanup until the session # is garbage collected), but this would be difficult to do as the # ForwardMsgCache is not thread safe, and we have no guarantee that the # garbage collector will only run on the eventloop thread. Because of this, # we clean up refs now and accept the risk that we're deleting cache entries # that will be useful once the browser tab reconnects. self._message_cache.remove_refs_for_session(session_info.session) self._session_mgr.disconnect_session(session_id) self._on_session_disconnected() def handle_backmsg(self, session_id: str, msg: BackMsg) -> None: """Send a BackMsg to an active session. Parameters ---------- session_id The session's unique ID. msg The BackMsg to deliver to the session. Notes ----- Threading: UNSAFE. Must be called on the eventloop thread. """ if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED): raise RuntimeStoppedError(f"Can't handle_backmsg (state={self._state})") session_info = self._session_mgr.get_active_session_info(session_id) if session_info is None: LOGGER.debug( "Discarding BackMsg for disconnected session (id=%s)", session_id ) return session_info.session.handle_backmsg(msg) def handle_backmsg_deserialization_exception( self, session_id: str, exc: BaseException ) -> None: """Handle an Exception raised during deserialization of a BackMsg. Parameters ---------- session_id The session's unique ID. exc The Exception. Notes ----- Threading: UNSAFE. Must be called on the eventloop thread. """ if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED): raise RuntimeStoppedError( f"Can't handle_backmsg_deserialization_exception (state={self._state})" ) session_info = self._session_mgr.get_active_session_info(session_id) if session_info is None: LOGGER.debug( "Discarding BackMsg Exception for disconnected session (id=%s)", session_id, ) return session_info.session.handle_backmsg_exception(exc) @property async def is_ready_for_browser_connection(self) -> Tuple[bool, str]: if self._state not in ( RuntimeState.INITIAL, RuntimeState.STOPPING, RuntimeState.STOPPED, ): return True, "ok" return False, "unavailable" async def does_script_run_without_error(self) -> Tuple[bool, str]: """Load and execute the app's script to verify it runs without an error. Returns ------- (True, "ok") if the script completes without error, or (False, err_msg) if the script raises an exception. Notes ----- Threading: UNSAFE. Must be called on the eventloop thread. """ # NOTE: We create an AppSession directly here instead of using the # SessionManager intentionally. This isn't a "real" session and is only being # used to test that the script runs without error. session = AppSession( script_data=ScriptData(self._main_script_path, self._command_line), uploaded_file_manager=self._uploaded_file_mgr, script_cache=self._script_cache, message_enqueued_callback=self._enqueued_some_message, local_sources_watcher=LocalSourcesWatcher(self._main_script_path), user_info={"email": "test@test.com"}, ) try: session.request_rerun(None) now = time.perf_counter() while ( SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state and (time.perf_counter() - now) < SCRIPT_RUN_CHECK_TIMEOUT ): await asyncio.sleep(0.1) if SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state: return False, "timeout" ok = session.session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] msg = "ok" if ok else "error" return ok, msg finally: session.shutdown() def _set_state(self, new_state: RuntimeState) -> None: LOGGER.debug("Runtime state: %s -> %s", self._state, new_state) self._state = new_state async def _loop_coroutine(self) -> None: """The main Runtime loop. This function won't exit until `stop` is called. Notes ----- Threading: UNSAFE. Must be called on the eventloop thread. """ async_objs = self._get_async_objs() try: if self._state == RuntimeState.INITIAL: self._set_state(RuntimeState.NO_SESSIONS_CONNECTED) elif self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED: pass else: raise RuntimeError(f"Bad Runtime state at start: {self._state}") # Signal that we're started and ready to accept sessions async_objs.started.set_result(None) while not async_objs.must_stop.is_set(): if self._state == RuntimeState.NO_SESSIONS_CONNECTED: # type: ignore[comparison-overlap] # mypy 1.4 incorrectly thinks this if-clause is unreachable, # because it thinks self._state must be INITIAL | ONE_OR_MORE_SESSIONS_CONNECTED. await asyncio.wait( # type: ignore[unreachable] ( asyncio.create_task(async_objs.must_stop.wait()), asyncio.create_task(async_objs.has_connection.wait()), ), return_when=asyncio.FIRST_COMPLETED, ) elif self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED: async_objs.need_send_data.clear() for active_session_info in self._session_mgr.list_active_sessions(): msg_list = active_session_info.session.flush_browser_queue() for msg in msg_list: try: self._send_message(active_session_info, msg) except SessionClientDisconnectedError: self._session_mgr.disconnect_session( active_session_info.session.id ) # Yield for a tick after sending a message. await asyncio.sleep(0) # Yield for a few milliseconds between session message # flushing. await asyncio.sleep(0.01) else: # Break out of the thread loop if we encounter any other state. break await asyncio.wait( ( asyncio.create_task(async_objs.must_stop.wait()), asyncio.create_task(async_objs.need_send_data.wait()), ), return_when=asyncio.FIRST_COMPLETED, ) # Shut down all AppSessions. for session_info in self._session_mgr.list_sessions(): # NOTE: We want to fully shut down sessions when the runtime stops for # now, but this may change in the future if/when our notion of a session # is no longer so tightly coupled to a browser tab. self._session_mgr.close_session(session_info.session.id) self._set_state(RuntimeState.STOPPED) async_objs.stopped.set_result(None) except Exception as e: async_objs.stopped.set_exception(e) traceback.print_exc() LOGGER.info( """ Please report this bug at https://github.com/streamlit/streamlit/issues. """ ) def _send_message(self, session_info: ActiveSessionInfo, msg: ForwardMsg) -> None: """Send a message to a client. If the client is likely to have already cached the message, we may instead send a "reference" message that contains only the hash of the message. Parameters ---------- session_info : ActiveSessionInfo The ActiveSessionInfo associated with websocket msg : ForwardMsg The message to send to the client Notes ----- Threading: UNSAFE. Must be called on the eventloop thread. """ msg.metadata.cacheable = is_cacheable_msg(msg) msg_to_send = msg if msg.metadata.cacheable: populate_hash_if_needed(msg) if self._message_cache.has_message_reference( msg, session_info.session, session_info.script_run_count ): # This session has probably cached this message. Send # a reference instead. LOGGER.debug("Sending cached message ref (hash=%s)", msg.hash) msg_to_send = create_reference_msg(msg) # Cache the message so it can be referenced in the future. # If the message is already cached, this will reset its # age. LOGGER.debug("Caching message (hash=%s)", msg.hash) self._message_cache.add_message( msg, session_info.session, session_info.script_run_count ) # If this was a `script_finished` message, we increment the # script_run_count for this session, and update the cache if ( msg.WhichOneof("type") == "script_finished" and msg.script_finished == ForwardMsg.FINISHED_SUCCESSFULLY ): LOGGER.debug( "Script run finished successfully; " "removing expired entries from MessageCache " "(max_age=%s)", config.get_option("global.maxCachedMessageAge"), ) session_info.script_run_count += 1 self._message_cache.remove_expired_entries_for_session( session_info.session, session_info.script_run_count ) # Ship it off! session_info.client.write_forward_msg(msg_to_send) def _enqueued_some_message(self) -> None: """Callback called by AppSession after the AppSession has enqueued a message. Sets the "needs_send_data" event, which causes our core loop to wake up and flush client message queues. Notes ----- Threading: SAFE. May be called on any thread. """ async_objs = self._get_async_objs() async_objs.eventloop.call_soon_threadsafe(async_objs.need_send_data.set) def _get_async_objs(self) -> AsyncObjects: """Return our AsyncObjects instance. If the Runtime hasn't been started, this will raise an error. """ if self._async_objs is None: raise RuntimeError("Runtime hasn't started yet!") return self._async_objs def _on_session_disconnected(self) -> None: """Set the runtime state to NO_SESSIONS_CONNECTED if the last active session was disconnected. """ if ( self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED and self._session_mgr.num_active_sessions() == 0 ): self._get_async_objs().has_connection.clear() self._set_state(RuntimeState.NO_SESSIONS_CONNECTED)