Spaces:
Runtime error
Runtime error
| '''Template for the two classes hosts should customize for each competition.''' | |
| import abc | |
| import os | |
| import pathlib | |
| import polars as pl | |
| import time | |
| import sys | |
| import traceback | |
| import warnings | |
| from typing import Callable, Generator, Tuple | |
| import kaggle_evaluation.core.base_gateway | |
| import kaggle_evaluation.core.relay | |
| _initial_import_time = time.time() | |
| _issued_startup_time_warning = False | |
| class Gateway(kaggle_evaluation.core.base_gateway.BaseGateway, abc.ABC): | |
| ''' | |
| Template to start with when writing a new gateway. | |
| In most cases, hosts should only need to write get_all_predictions. | |
| There are two main methods for sending data to the inference_server hosts should understand: | |
| - Small datasets: use `self.predict`. Competitors will receive the data passed to self.predict as | |
| Python objects in memory. This is just a wrapper for self.client.send(); you can write additional | |
| wrappers if necessary. | |
| - Large datasets: it's much faster to send data via self.share_files, which is equivalent to making | |
| files available via symlink. See base_gateway.BaseGateway.share_files for the full details. | |
| ''' | |
| def generate_data_batches(self) -> Generator: | |
| ''' Used by the default implementation of `get_all_predictions` so we can | |
| ensure `validate_prediction_batch` is run every time `predict` is called. | |
| This method must yield both the batch of data to be sent to `predict` and a series | |
| of row IDs to be sent to `validate_prediction_batch`. | |
| ''' | |
| raise NotImplementedError | |
| def get_all_predictions(self): | |
| all_predictions = [] | |
| all_row_ids = [] | |
| for data_batch, row_ids in self.generate_data_batches(): | |
| predictions = self.predict(*data_batch) | |
| predictions = pl.Series(self.target_column_name, predictions) | |
| self.validate_prediction_batch(predictions, row_ids) | |
| all_predictions.append(predictions) | |
| all_row_ids.append(row_ids) | |
| return all_predictions, all_row_ids | |
| def predict(self, *args, **kwargs): | |
| ''' self.predict will send all data in args and kwargs to the user container, and | |
| instruct the user container to generate a `predict` response. | |
| ''' | |
| try: | |
| return self.client.send('predict', *args, **kwargs) | |
| except Exception as e: | |
| self.handle_server_error(e, 'predict') | |
| def set_response_timeout_seconds(self, timeout_seconds: float): | |
| # Also store timeout_seconds in an easy place for for competitor to access. | |
| self.timeout_seconds = timeout_seconds | |
| # Set a response deadline that will apply after the very first repsonse | |
| self.client.endpoint_deadline_seconds = timeout_seconds | |
| def run(self) -> pathlib.Path: | |
| error = None | |
| submission_path = None | |
| try: | |
| predictions, row_ids = self.get_all_predictions() | |
| submission_path = self.write_submission(predictions, row_ids) | |
| except kaggle_evaluation.core.base_gateway.GatewayRuntimeError as gre: | |
| error = gre | |
| except Exception: | |
| # Get the full stack trace | |
| exc_type, exc_value, exc_traceback = sys.exc_info() | |
| error_str = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) | |
| error = kaggle_evaluation.core.base_gateway.GatewayRuntimeError( | |
| kaggle_evaluation.core.base_gateway.GatewayRuntimeErrorType.GATEWAY_RAISED_EXCEPTION, | |
| error_str | |
| ) | |
| self.client.close() | |
| if self.server: | |
| self.server.stop(0) | |
| if kaggle_evaluation.core.base_gateway.IS_RERUN: | |
| self.write_result(error) | |
| elif error: | |
| # For local testing | |
| raise error | |
| return submission_path | |
| class InferenceServer(abc.ABC): | |
| ''' | |
| Base class for competition participants to inherit from when writing their submission. In most cases, users should | |
| only need to implement a `predict` function or other endpoints to pass to this class's constructor, and hosts will | |
| provide a mock Gateway for testing. | |
| ''' | |
| def __init__(self, endpoint_listeners: Tuple[Callable]): | |
| self.server = kaggle_evaluation.core.relay.define_server(endpoint_listeners) | |
| self.client = None # The inference_server can have a client but it isn't typically necessary. | |
| def serve(self): | |
| self.server.start() | |
| if os.getenv('KAGGLE_IS_COMPETITION_RERUN') is not None: | |
| self.server.wait_for_termination() # This will block all other code | |