from flask import Flask, request, jsonify import uuid import time import docker import requests import atexit import socket import argparse import logging from pydantic import BaseModel, Field, ValidationError app = Flask(__name__) app.logger.setLevel(logging.INFO) # CLI function to parse arguments def parse_args(): parser = argparse.ArgumentParser(description="Jupyter server.") parser.add_argument('--n_instances', type=int, help="Number of Jupyter instances.") parser.add_argument('--n_cpus', type=int, default=2, help="Number of CPUs per Jupyter instance.") parser.add_argument('--mem', type=str, default="2g", help="Amount of memory per Jupyter instance.") parser.add_argument('--execution_timeout', type=int, default=10, help="Timeout period for a code execution.") parser.add_argument('--port', type=int, default=5001, help="Port of main server") return parser.parse_args() def get_unused_port(start=50000, end=65535, exclusion=[]): for port in range(start, end + 1): if port in exclusion: continue try: sock = socket.socket() sock.bind(("", port)) sock.listen(1) sock.close() return port except OSError: continue raise IOError("No free ports available in range {}-{}".format(start, end)) def create_kernel_containers(n_instances, n_cpus=2, mem="2g", execution_timeout=10): docker_client = docker.from_env() app.logger.info("Buidling docker image...") image, logs = docker_client.images.build(path='./', tag='jupyter-kernel:latest') app.logger.info("Building docker image complete.") containers = [] port_exclusion = [] for i in range(n_instances): free_port = get_unused_port(exclusion=port_exclusion) port_exclusion.append(free_port) # it takes a while to startup so we don't use the same port twice app.logger.info(f"Starting container {i} on port {free_port}...") container = docker_client.containers.run( "jupyter-kernel:latest", detach=True, mem_limit=mem, cpuset_cpus=f"{i*n_cpus}-{(i+1)*n_cpus-1}", # Limit to CPU cores 0 and 1 remove=True, ports={'5000/tcp': free_port}, environment={"EXECUTION_TIMEOUT": execution_timeout}, ) containers.append({"container": container, "port": free_port}) start_time = time.time() containers_ready = [] while len(containers_ready) < n_instances: app.logger.info("Pinging Jupyter containers to check readiness.") if time.time() - start_time > 60: raise TimeoutError("Container took too long to startup.") for i in range(n_instances): if i in containers_ready: continue url = f"http://localhost:{containers[i]['port']}/health" try: # TODO: dedicated health endpoint response = requests.get(url) if response.status_code == 200: containers_ready.append(i) except Exception as e: # Catch any other errors that might occur pass time.sleep(0.5) app.logger.info("Containers ready!") return containers def shutdown_cleanup(): app.logger.info("Shutting down. Stopping and removing all containers...") for instance in app.containers: try: instance['container'].stop() instance['container'].remove() except Exception as e: app.logger.info(f"Error stopping/removing container: {str(e)}") app.logger.info("All containers stopped and removed.") class ServerRequest(BaseModel): code: str = Field(..., example="print('Hello World!')") instance_id: int = Field(0, example=0) restart: bool = Field(False, example=False) @app.route('/execute', methods=['POST']) def execute_code(): try: input = ServerRequest(**request.json) except ValidationError as e: return jsonify(e.errors()), 400 port = app.containers[input.instance_id]["port"] app.logger.info(f"Received request for instance {input.instance_id} (port={port}).") try: if input.restart: response = requests.post(f'http://localhost:{port}/restart', json={}) if response.status_code==200: app.logger.info(f"Kernel for instance {input.instance_id} restarted.") else: app.logger.info(f"Error when restarting kernel of instance {input.instance_id}: {response.json()}.") response = requests.post(f'http://localhost:{port}/execute', json={'code': input.code}) result = response.json() return result except Exception as e: app.logger.info(f"Error in execute_code: {str(e)}") return jsonify({ 'result': 'error', 'output': str(e) }), 500 atexit.register(shutdown_cleanup) if __name__ == '__main__': args = parse_args() app.containers = create_kernel_containers( args.n_instances, n_cpus=args.n_cpus, mem=args.mem, execution_timeout=args.execution_timeout ) # don't use debug=True --> it will run main twice and thus start double the containers app.run(debug=False, host='0.0.0.0', port=args.port) # TODO: # how to mount data at runtime into the container? idea: mount a (read only) # folder into the container at startup and copy the data in there. before starting # the kernel we could cp the necessary data into the pwd.