demo / backend /are.py
Pierre Andrews
use custom token for inference
6ac5bb0
import datetime
import logging
import os
import random
import subprocess
import uuid
import psutil
from backend.globals import FREE_PORTS_POOL, ORG, SPACE, STORAGE_PATH
from backend.session import UserSession
logger = logging.getLogger(__name__)
def start_are_process_and_session_lite(
model: str,
provider: str,
username: str,
bearer_token: str | None,
user_token: str | None,
app_path: str | None,
) -> UserSession:
if not user_token:
error_msg = (
f"HF_TOKEN (user_token) is None for user {username}. "
"Cannot start ARE session without Hugging Face token."
)
raise ValueError(error_msg)
global FREE_PORTS_POOL
port = random.sample(FREE_PORTS_POOL, k=1)[0]
log_path = f"{STORAGE_PATH}/log_{port}.log"
env_vars = dict(os.environ)
env_vars["ARE_SERVER_HOSTNAME"] = "0.0.0.0"
env_vars["ARE_SIMULATION_SERVER_HOSTNAME"] = "0.0.0.0"
env_vars["ARE_SERVER_PORT"] = str(port)
env_vars["ARE_SIMULATION_SERVER_PORT"] = str(port)
env_vars["HF_TOKEN"] = os.environ.get("HF_DATASET_TOKEN", user_token)
env_vars["HF_INFERENCE_TOKEN"] = os.environ.get("HF_INFERENCE_TOKEN", user_token)
env_vars["HF_DEMO_UNIVERSE"] = "universe_hf_0" # universe_hf"
bill_to = os.environ.get("HF_BILL_TO")
if bill_to:
env_vars["HF_BILL_TO"] = bill_to
llama_key = os.environ.get("LLAMA_API_KEY")
if llama_key:
env_vars["LLAMA_API_KEY"] = llama_key
env_vars["INTERACTIVE_SCENARIOS_TREE"] = "/app/mcp_demo_prompts.json"
if app_path:
env_vars["MCP_APPS_JSON_PATH"] = app_path
p = subprocess.Popen(
" ".join(
["python", "-u", "-m", "are.simulation.gui.cli", "-a", "default"]
+ ["-m", model, "--provider", provider]
+ [
"-s",
"scenario_hf_demo_mcp",
# "hf://datasets/meta-agents-research-environments/gaia2/demo/validation/universe_hf",
"--ui_view",
"playground",
] # scenario_universe_hf_0 or "scenario_hf_0" or "universe_hf_0"
+ ["2>&1", "|", "tee", log_path]
),
env=env_vars,
shell=True,
executable="/bin/bash",
)
FREE_PORTS_POOL = [p for p in FREE_PORTS_POOL if p != port]
user_session = UserSession(
port=int(port),
pid=p.pid,
sid=str(uuid.uuid4()),
model=model,
provider=provider,
log_path=log_path,
start_time=str(datetime.datetime.now()),
user=username,
sign=bearer_token,
)
return user_session
def kill_are_process(session: UserSession) -> None:
# Automatically kills the are processes and all their children
global FREE_PORTS_POOL
try:
# Get the main process
main_process = psutil.Process(session.pid)
# Get all child processes recursively
children = main_process.children(recursive=True)
# Kill all child processes first
for child in children:
try:
child.kill()
logger.info(f"Killed child process PID {child.pid}")
except psutil.NoSuchProcess:
logger.info(f"Child process PID {child.pid} already terminated")
except OSError:
logger.warning(f"Child process PID {child.pid} not found")
# Wait for child processes to terminate
for child in children:
try:
child.wait(timeout=5)
except psutil.TimeoutExpired:
logger.warning(
f"Child process PID {child.pid} did not terminate within timeout"
)
except psutil.NoSuchProcess:
pass
# Kill the main process
main_process.kill()
logger.info(f"Sent SIGKILL to main PID {session.pid}")
# Wait for main process to terminate
try:
main_process.wait(timeout=5)
except psutil.TimeoutExpired:
logger.warning(
f"Main process PID {session.pid} did not terminate within timeout"
)
except psutil.NoSuchProcess:
pass
FREE_PORTS_POOL.append(session.port)
logger.info(
f"Killed session {session.sid} PID {session.pid} on port {session.port} for user {session.user}"
)
except psutil.NoSuchProcess:
logger.info(f"Process PID {session.pid} not found - may already be terminated")
FREE_PORTS_POOL.append(session.port)
except OSError:
logger.error(
f"COULD NOT KILL ARE on port {session.port} for user {session.user}",
exc_info=True,
)
def get_are_url(session: UserSession, server: str) -> str:
"""Generates the are url
Args:
port (str): Port on which the app is running
session_id (str): Session id in ARE
sign (str): Auth key provided by the query
server (str): Must be either "are" or "graphql"
Returns:
str: The url to look at
"""
# Check if we're in development mode
flask_env = os.environ.get("FLASK_ENV", "production")
if flask_env == "development":
# In development mode, use localhost with the actual ARE port
return f"http://localhost:{session.port}/{server}?sid={session.sid}&__sign={session.sign}"
else:
# In production mode, use Hugging Face Space URL
return f"https://{ORG.lower()}-{SPACE.lower()}--{session.port}.hf.space/{server}?sid={session.sid}&__sign={session.sign}"