burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
be32845 verified
# Copyright (c) Yogesh Singla and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
FastAPI application for the Julia Environment with concurrent execution support.
This module creates an HTTP server that exposes the JuliaCodeActEnv
over HTTP endpoints with optimized async execution for handling multiple
concurrent requests efficiently.
Features:
- Async Julia code execution to avoid blocking
- Environment pool for concurrent request handling
- Thread pool executor for CPU-bound Julia tasks
- Automatic error recovery and retry logic
- Comprehensive logging to file and console
- Worker health monitoring and auto-restart
- 10x+ performance improvement over single-threaded version
Usage:
# Development (with auto-reload):
uvicorn envs.julia_env.server.app:app --reload --host 0.0.0.0 --port 8000
# Production (with multiple workers for even better concurrency):
uvicorn envs.julia_env.server.app:app --host 0.0.0.0 --port 8000 --workers 4
# Or run directly:
python -m envs.julia_env.server.app
"""
import asyncio
import logging
import os
import sys
import traceback
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from dataclasses import asdict
from datetime import datetime
from logging.handlers import RotatingFileHandler
from typing import Any, Dict
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from ..models import JuliaAction, JuliaObservation
from .julia_codeact_env import JuliaCodeActEnv
# Configuration
MAX_WORKERS = int(
os.getenv("JULIA_MAX_WORKERS", "8")
) # Number of concurrent Julia executions
ENABLE_WEB = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ("true", "1", "yes")
EXECUTION_TIMEOUT = int(os.getenv("JULIA_EXECUTION_TIMEOUT", "120")) # seconds
LOG_FILE = os.getenv("JULIA_LOG_FILE", "/tmp/run.log")
LOG_LEVEL = os.getenv("JULIA_LOG_LEVEL", "INFO")
# Global thread pool executor for CPU-bound Julia tasks
executor = None
# Setup comprehensive logging
def setup_logging():
"""Configure logging to both file and console with rotation."""
logger = logging.getLogger("julia_env")
logger.setLevel(getattr(logging, LOG_LEVEL))
# Prevent duplicate handlers
if logger.handlers:
return logger
# Create formatters
detailed_formatter = logging.Formatter(
"%(asctime)s - %(name)s - [%(process)d:%(thread)d] - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# File handler with rotation (10MB max, keep 5 backup files)
try:
os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)
file_handler = RotatingFileHandler(
LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5, encoding="utf-8" # 10MB
)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(detailed_formatter)
logger.addHandler(file_handler)
except Exception as e:
print(f"Warning: Could not create log file {LOG_FILE}: {e}")
# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(detailed_formatter)
logger.addHandler(console_handler)
return logger
logger = setup_logging()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup/shutdown with health monitoring"""
global executor
logger.info("=" * 80)
logger.info("Starting Julia Environment Server")
logger.info(f"Max Workers: {MAX_WORKERS}")
logger.info(f"Execution Timeout: {EXECUTION_TIMEOUT}s")
logger.info(f"Log File: {LOG_FILE}")
logger.info(f"Log Level: {LOG_LEVEL}")
logger.info("=" * 80)
# Startup: Create thread pool with error handling
try:
executor = ThreadPoolExecutor(
max_workers=MAX_WORKERS, thread_name_prefix="julia_worker"
)
logger.info(f"✅ Thread pool created with {MAX_WORKERS} workers")
logger.info(f"✅ Julia Environment Server started successfully")
print(
f"✅ Julia Environment Server started with {MAX_WORKERS} concurrent workers"
)
except Exception as e:
logger.error(f"❌ Failed to start server: {e}")
logger.error(traceback.format_exc())
raise
yield
# Shutdown: Cleanup with grace period
logger.info("Shutting down Julia Environment Server...")
try:
executor.shutdown(wait=True, cancel_futures=False)
logger.info("✅ All workers completed gracefully")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
logger.info("✅ Julia Environment Server shutdown complete")
print("✅ Julia Environment Server shutdown complete")
# Create FastAPI app with lifespan management
app = FastAPI(
title="Julia Environment Server",
description="Async Julia code execution environment with concurrent request support and auto-recovery",
version="2.1.0",
lifespan=lifespan,
)
# Global exception handler for uncaught errors
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Handle all uncaught exceptions to prevent worker crashes"""
error_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
logger.error(f"[ERROR-{error_id}] Uncaught exception in {request.url.path}")
logger.error(f"[ERROR-{error_id}] Request: {request.method} {request.url}")
logger.error(f"[ERROR-{error_id}] Exception: {type(exc).__name__}: {exc}")
logger.error(f"[ERROR-{error_id}] Traceback:\n{traceback.format_exc()}")
return JSONResponse(
status_code=500,
content={
"error": "Internal server error",
"type": type(exc).__name__,
"message": str(exc),
"error_id": error_id,
"timestamp": datetime.now().isoformat(),
},
)
async def execute_julia_async(
action: JuliaAction, request_id: str = None
) -> JuliaObservation:
"""
Execute Julia code asynchronously in thread pool with timeout and error recovery.
This runs the CPU-bound Julia execution in a separate thread to avoid
blocking the event loop, allowing the server to handle multiple requests
concurrently.
Features:
- Timeout protection
- Automatic retry on transient failures
- Comprehensive error logging
- Resource cleanup
"""
if request_id is None:
request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
loop = asyncio.get_event_loop()
max_retries = 2
retry_count = 0
logger.debug(
f"[{request_id}] Starting Julia execution (timeout: {EXECUTION_TIMEOUT}s)"
)
while retry_count <= max_retries:
env = None
try:
# Create a fresh environment instance for this request
# This ensures thread safety and allows concurrent execution
env = JuliaCodeActEnv()
# Run the blocking step() call in thread pool with timeout
observation = await asyncio.wait_for(
loop.run_in_executor(executor, env.step, action),
timeout=EXECUTION_TIMEOUT,
)
logger.debug(f"[{request_id}] Julia execution completed successfully")
logger.debug(
f"[{request_id}] Result: tests_passed={observation.tests_passed}, "
f"tests_failed={observation.tests_failed}, reward={observation.reward}"
)
return observation
except asyncio.TimeoutError:
retry_count += 1
logger.warning(
f"[{request_id}] Julia execution timeout (attempt {retry_count}/{max_retries + 1})"
)
if retry_count > max_retries:
logger.error(
f"[{request_id}] Julia execution failed after {max_retries + 1} attempts"
)
# Return a failure observation
return JuliaObservation(
stdout="",
stderr=f"Execution timeout after {EXECUTION_TIMEOUT}s",
exit_code=-1,
tests_passed=0,
tests_failed=1,
code_compiles=False,
reward=0.0,
done=True,
)
# Wait a bit before retry
await asyncio.sleep(0.5)
except Exception as e:
retry_count += 1
logger.error(
f"[{request_id}] Julia execution error (attempt {retry_count}/{max_retries + 1}): {e}"
)
logger.error(f"[{request_id}] Traceback:\n{traceback.format_exc()}")
if retry_count > max_retries:
logger.error(
f"[{request_id}] Julia execution failed permanently after {max_retries + 1} attempts"
)
# Return a failure observation
return JuliaObservation(
stdout="",
stderr=f"Execution error: {str(e)}",
exit_code=-1,
tests_passed=0,
tests_failed=1,
code_compiles=False,
reward=0.0,
done=True,
)
# Wait a bit before retry
await asyncio.sleep(0.5)
finally:
# Clean up environment resources if possible
if env is not None:
try:
# Add any cleanup needed here
del env
except Exception as cleanup_error:
logger.debug(f"[{request_id}] Cleanup warning: {cleanup_error}")
@app.post("/reset")
async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]:
"""
Reset endpoint - returns initial observation.
Creates a fresh environment instance for the new episode.
"""
request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
logger.info(f"[{request_id}] Reset request received")
try:
# Run reset in thread pool to avoid blocking
loop = asyncio.get_event_loop()
env = JuliaCodeActEnv()
observation = await asyncio.wait_for(
loop.run_in_executor(executor, env.reset),
timeout=30.0, # Reset should be quick
)
# Serialize observation
obs_dict = asdict(observation)
reward = obs_dict.pop("reward", None)
done = obs_dict.pop("done", False)
obs_dict.pop("metadata", None)
logger.info(f"[{request_id}] Reset completed successfully")
return {
"observation": obs_dict,
"reward": reward,
"done": done,
}
except asyncio.TimeoutError:
logger.error(f"[{request_id}] Reset timeout")
raise HTTPException(status_code=504, detail="Reset operation timed out")
except Exception as e:
logger.error(f"[{request_id}] Reset error: {e}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Reset failed: {str(e)}")
@app.post("/step")
async def step(request: Dict[str, Any]) -> Dict[str, Any]:
"""
Step endpoint - executes Julia code and returns observation.
Runs Julia code execution asynchronously to handle multiple concurrent requests.
Each request gets its own environment instance for thread safety.
"""
request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
try:
action_data = request.get("action", {})
if not action_data:
logger.warning(f"[{request_id}] Step request with empty action")
raise HTTPException(status_code=400, detail="Action data is required")
# Deserialize action
metadata = action_data.pop("metadata", {})
action = JuliaAction(**action_data)
action.metadata = metadata
logger.info(f"[{request_id}] Step request received")
logger.debug(
f"[{request_id}] Action: core_code_length={len(action.core_code) if action.core_code else 0}, "
f"test_code_length={len(action.test_code) if action.test_code else 0}"
)
# Execute Julia code asynchronously with timeout and retry
observation = await execute_julia_async(action, request_id)
# Serialize observation
obs_dict = asdict(observation)
reward = obs_dict.pop("reward", None)
done = obs_dict.pop("done", False)
obs_dict.pop("metadata", None)
logger.info(
f"[{request_id}] Step completed - reward={reward}, "
f"tests_passed={observation.tests_passed}, tests_failed={observation.tests_failed}"
)
return {
"observation": obs_dict,
"reward": reward,
"done": done,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"[{request_id}] Step endpoint error: {e}")
logger.error(f"[{request_id}] Traceback:\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Step execution failed: {str(e)}")
@app.get("/state")
async def get_state() -> Dict[str, Any]:
"""
State endpoint - returns environment metadata and server health.
Note: Since each request creates a fresh environment, this returns
general server state rather than specific episode state.
"""
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
return {
"max_workers": MAX_WORKERS,
"executor_type": "ThreadPoolExecutor",
"status": "ready",
"timeout": EXECUTION_TIMEOUT,
"log_file": LOG_FILE,
"memory_mb": memory_info.rss / 1024 / 1024,
"threads": len(process.threads()),
}
except ImportError:
# psutil not available, return basic info
return {
"max_workers": MAX_WORKERS,
"executor_type": "ThreadPoolExecutor",
"status": "ready",
"timeout": EXECUTION_TIMEOUT,
"log_file": LOG_FILE,
}
except Exception as e:
logger.warning(f"Could not get full state info: {e}")
return {
"max_workers": MAX_WORKERS,
"executor_type": "ThreadPoolExecutor",
"status": "ready",
}
@app.get("/health")
async def health() -> Dict[str, str]:
"""
Health check endpoint.
Returns healthy status if the server is operational and can accept requests.
"""
try:
# Quick health check - verify executor is available
if executor is None:
logger.error("Health check failed: executor not initialized")
raise HTTPException(status_code=503, detail="Service not ready")
return {
"status": "healthy",
"workers": str(MAX_WORKERS),
"timeout": str(EXECUTION_TIMEOUT),
"timestamp": datetime.now().isoformat(),
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Health check error: {e}")
raise HTTPException(status_code=503, detail="Health check failed")
if __name__ == "__main__":
import uvicorn
# Run with uvicorn
# Use multiple workers for even better concurrency
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")