|
|
import datetime |
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import subprocess |
|
|
import uuid |
|
|
|
|
|
import psutil |
|
|
from backend.globals import FREE_PORTS_POOL, ORG, SPACE, STORAGE_PATH |
|
|
from backend.session import UserSession |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def start_are_process_and_session_lite( |
|
|
model: str, |
|
|
provider: str, |
|
|
username: str, |
|
|
bearer_token: str | None, |
|
|
user_token: str | None, |
|
|
app_path: str | None, |
|
|
) -> UserSession: |
|
|
if not user_token: |
|
|
error_msg = ( |
|
|
f"HF_TOKEN (user_token) is None for user {username}. " |
|
|
"Cannot start ARE session without Hugging Face token." |
|
|
) |
|
|
raise ValueError(error_msg) |
|
|
|
|
|
global FREE_PORTS_POOL |
|
|
port = random.sample(FREE_PORTS_POOL, k=1)[0] |
|
|
|
|
|
log_path = f"{STORAGE_PATH}/log_{port}.log" |
|
|
env_vars = dict(os.environ) |
|
|
env_vars["ARE_SERVER_HOSTNAME"] = "0.0.0.0" |
|
|
env_vars["ARE_SIMULATION_SERVER_HOSTNAME"] = "0.0.0.0" |
|
|
env_vars["ARE_SERVER_PORT"] = str(port) |
|
|
env_vars["ARE_SIMULATION_SERVER_PORT"] = str(port) |
|
|
env_vars["HF_TOKEN"] = os.environ.get("HF_DATASET_TOKEN", user_token) |
|
|
env_vars["HF_INFERENCE_TOKEN"] = os.environ.get("HF_INFERENCE_TOKEN", user_token) |
|
|
env_vars["HF_DEMO_UNIVERSE"] = "universe_hf_0" |
|
|
bill_to = os.environ.get("HF_BILL_TO") |
|
|
if bill_to: |
|
|
env_vars["HF_BILL_TO"] = bill_to |
|
|
llama_key = os.environ.get("LLAMA_API_KEY") |
|
|
if llama_key: |
|
|
env_vars["LLAMA_API_KEY"] = llama_key |
|
|
env_vars["INTERACTIVE_SCENARIOS_TREE"] = "/app/mcp_demo_prompts.json" |
|
|
if app_path: |
|
|
env_vars["MCP_APPS_JSON_PATH"] = app_path |
|
|
|
|
|
p = subprocess.Popen( |
|
|
" ".join( |
|
|
["python", "-u", "-m", "are.simulation.gui.cli", "-a", "default"] |
|
|
+ ["-m", model, "--provider", provider] |
|
|
+ [ |
|
|
"-s", |
|
|
"scenario_hf_demo_mcp", |
|
|
|
|
|
"--ui_view", |
|
|
"playground", |
|
|
] |
|
|
+ ["2>&1", "|", "tee", log_path] |
|
|
), |
|
|
env=env_vars, |
|
|
shell=True, |
|
|
executable="/bin/bash", |
|
|
) |
|
|
|
|
|
FREE_PORTS_POOL = [p for p in FREE_PORTS_POOL if p != port] |
|
|
user_session = UserSession( |
|
|
port=int(port), |
|
|
pid=p.pid, |
|
|
sid=str(uuid.uuid4()), |
|
|
model=model, |
|
|
provider=provider, |
|
|
log_path=log_path, |
|
|
start_time=str(datetime.datetime.now()), |
|
|
user=username, |
|
|
sign=bearer_token, |
|
|
) |
|
|
|
|
|
return user_session |
|
|
|
|
|
|
|
|
def kill_are_process(session: UserSession) -> None: |
|
|
|
|
|
global FREE_PORTS_POOL |
|
|
|
|
|
try: |
|
|
|
|
|
main_process = psutil.Process(session.pid) |
|
|
|
|
|
|
|
|
children = main_process.children(recursive=True) |
|
|
|
|
|
|
|
|
for child in children: |
|
|
try: |
|
|
child.kill() |
|
|
logger.info(f"Killed child process PID {child.pid}") |
|
|
except psutil.NoSuchProcess: |
|
|
logger.info(f"Child process PID {child.pid} already terminated") |
|
|
except OSError: |
|
|
logger.warning(f"Child process PID {child.pid} not found") |
|
|
|
|
|
|
|
|
for child in children: |
|
|
try: |
|
|
child.wait(timeout=5) |
|
|
except psutil.TimeoutExpired: |
|
|
logger.warning( |
|
|
f"Child process PID {child.pid} did not terminate within timeout" |
|
|
) |
|
|
except psutil.NoSuchProcess: |
|
|
pass |
|
|
|
|
|
|
|
|
main_process.kill() |
|
|
logger.info(f"Sent SIGKILL to main PID {session.pid}") |
|
|
|
|
|
|
|
|
try: |
|
|
main_process.wait(timeout=5) |
|
|
except psutil.TimeoutExpired: |
|
|
logger.warning( |
|
|
f"Main process PID {session.pid} did not terminate within timeout" |
|
|
) |
|
|
except psutil.NoSuchProcess: |
|
|
pass |
|
|
|
|
|
FREE_PORTS_POOL.append(session.port) |
|
|
logger.info( |
|
|
f"Killed session {session.sid} PID {session.pid} on port {session.port} for user {session.user}" |
|
|
) |
|
|
|
|
|
except psutil.NoSuchProcess: |
|
|
logger.info(f"Process PID {session.pid} not found - may already be terminated") |
|
|
FREE_PORTS_POOL.append(session.port) |
|
|
except OSError: |
|
|
logger.error( |
|
|
f"COULD NOT KILL ARE on port {session.port} for user {session.user}", |
|
|
exc_info=True, |
|
|
) |
|
|
|
|
|
|
|
|
def get_are_url(session: UserSession, server: str) -> str: |
|
|
"""Generates the are url |
|
|
|
|
|
Args: |
|
|
port (str): Port on which the app is running |
|
|
session_id (str): Session id in ARE |
|
|
sign (str): Auth key provided by the query |
|
|
server (str): Must be either "are" or "graphql" |
|
|
|
|
|
Returns: |
|
|
str: The url to look at |
|
|
""" |
|
|
|
|
|
flask_env = os.environ.get("FLASK_ENV", "production") |
|
|
|
|
|
if flask_env == "development": |
|
|
|
|
|
return f"http://localhost:{session.port}/{server}?sid={session.sid}&__sign={session.sign}" |
|
|
else: |
|
|
|
|
|
return f"https://{ORG.lower()}-{SPACE.lower()}--{session.port}.hf.space/{server}?sid={session.sid}&__sign={session.sign}" |
|
|
|