diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..f9e8e050c0ba4079f786018081a7d03d8c6a8273 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,21 @@ +.git +.github +.pytest_cache +__pycache__ +*.pyc +*.pyo +*.egg-info +.eggs +dist +build +.mypy_cache +.ruff_cache +.venv +venv +documents/ +tests/ +docs/ +*.pdf +.claude/ +# OpenRA submodule (cloned from GitHub during Docker build) +OpenRA/ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..f13f04a111e42da4cbb6817cd7aa285df5c7f9cb --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,32 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: pip install -e ".[dev]" + + - name: Run tests + run: pytest tests/ -v + + - name: Lint + run: ruff check openra_env/ diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..e22bcd698e406f45c299e22dba6b7aa79bb242c2 --- /dev/null +++ b/.github/workflows/docker-publish.yml @@ -0,0 +1,52 @@ +name: Docker Publish + +on: + push: + tags: ["v*"] + release: + types: [published] + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + build-and-push: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - uses: docker/setup-qemu-action@v3 + + - uses: docker/setup-buildx-action@v3 + + - uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - uses: docker/metadata-action@v5 + id: meta + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,value=latest,enable={{is_default_branch}} + + - uses: docker/build-push-action@v6 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..850e680d1e068b43b0079021cf8159c412bdc0fe --- /dev/null +++ b/.github/workflows/pypi-publish.yml @@ -0,0 +1,26 @@ +name: PyPI Publish + +on: + release: + types: [published] + +permissions: + id-token: write + +jobs: + publish: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Build package + run: | + pip install build + python -m build + + - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/sync-to-hf.yml b/.github/workflows/sync-to-hf.yml new file mode 100644 index 0000000000000000000000000000000000000000..4c9fb9b4985fe661c48c316992529cf9faf833b0 --- /dev/null +++ b/.github/workflows/sync-to-hf.yml @@ -0,0 +1,25 @@ +name: Sync to Hugging Face Space +on: + push: + branches: [main] + workflow_dispatch: + +jobs: + sync-to-hub: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 1 + lfs: true + - name: Push to Hugging Face Space + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + # Create an orphan branch with just the current tree (no history) + git checkout --orphan hf-sync + git commit -m "Sync from GitHub ${GITHUB_SHA::7}" + git remote add hf https://openra-rl:$HF_TOKEN@huggingface.co/spaces/openra-rl/OpenRA-RL + git push hf hf-sync:main --force diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..15f55d822926adde7af420f69d735f593bd8e7cc --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +__pycache__/ +*.py[cod] +*$py.class +*.egg-info/ +dist/ +build/ +.eggs/ +*.egg +.venv/ +venv/ +.env +*.log +.pytest_cache/ +.ruff_cache/ +.mypy_cache/ +.DS_Store +replays/ +documents/ +*.orarep diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..bdecf93e2bca92efa30537a231609ab10a3f918d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "OpenRA"] + path = OpenRA + url = https://github.com/yxc20089/OpenRA.git diff --git a/.openenvignore b/.openenvignore new file mode 100644 index 0000000000000000000000000000000000000000..8b13e7a4138f628e5c477698b3cf7f64fdda5128 --- /dev/null +++ b/.openenvignore @@ -0,0 +1,28 @@ +# Build artifacts (Dockerfile builds fresh from source) +OpenRA/bin/ +OpenRA/obj/ + +# Replay files +*.orarep +replays/ + +# Log files +*.log + +# Documents +documents/ + +# Dev/test artifacts +.pytest_cache/ +.ruff_cache/ +.mypy_cache/ +.venv/ +venv/ +.env +.eggs/ +*.egg-info/ +dist/ +build/ + +# IDE +.DS_Store diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..382d550b42c7f77d51c636248fe899efcc76e73b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,149 @@ +# ============================================================================== +# Stage 1: Build OpenRA from source (C#/.NET 8.0) +# ============================================================================== +FROM mcr.microsoft.com/dotnet/sdk:8.0-bookworm-slim AS openra-build + +RUN apt-get update && apt-get install -y --no-install-recommends \ + make \ + git \ + libsdl2-dev \ + libopenal-dev \ + libfreetype-dev \ + liblua5.1-0-dev \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Clone OpenRA source from GitHub (works on HF Spaces where submodules aren't initialized) +ARG OPENRA_REPO=https://github.com/yxc20089/OpenRA.git +RUN git clone --depth=1 "$OPENRA_REPO" /src/openra +WORKDIR /src/openra + +# Fix Windows CRLF line endings in shell scripts (git autocrlf on Windows adds \r) +RUN find . -name '*.sh' -exec sed -i 's/\r$//' {} + && \ + find . -name '*.sh' -exec chmod +x {} + + +# Build with system libraries (unix-generic avoids bundled native binaries) +# SKIP_PROTOC=true uses pre-generated protobuf C# files (avoids protoc arm64 crash in Docker) +ENV SKIP_PROTOC=true +RUN make TARGETPLATFORM=unix-generic CONFIGURATION=Release + +# Verify critical output (includes Null platform for headless RL operation) +RUN test -f bin/OpenRA.dll && \ + test -f bin/OpenRA.Game.dll && \ + test -f bin/OpenRA.Mods.Common.dll && \ + test -f bin/OpenRA.Platforms.Null.dll + +# ============================================================================== +# Stage 2: Install Python dependencies +# ============================================================================== +FROM python:3.11-slim-bookworm AS python-build + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app +COPY pyproject.toml /app/ +COPY openra_env/ /app/openra_env/ +COPY proto/ /app/proto/ +COPY README.md /app/ + +RUN pip install --upgrade pip && \ + pip install --no-cache-dir . + +# ============================================================================== +# Stage 3: Runtime image +# ============================================================================== +FROM mcr.microsoft.com/dotnet/aspnet:8.0-bookworm-slim AS dotnet-runtime + +FROM python:3.11-slim-bookworm + +LABEL maintainer="OpenRA-RL" +LABEL description="OpenRA RL Environment - headless game engine with gRPC bridge + OpenEnv API" + +# Copy ASP.NET Core runtime from official Microsoft image +COPY --from=dotnet-runtime /usr/share/dotnet /usr/share/dotnet +RUN ln -s /usr/share/dotnet/dotnet /usr/bin/dotnet + +# Install runtime dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + xvfb \ + libgl1-mesa-dri \ + libgl1-mesa-glx \ + libegl-mesa0 \ + mesa-vulkan-drivers \ + libvulkan1 \ + libsdl2-2.0-0 \ + libopenal1 \ + libfreetype6 \ + liblua5.1-0 \ + libicu72 \ + curl procps \ + x11vnc novnc websockify \ + && rm -rf /var/lib/apt/lists/* + +# Copy Python packages from builder +COPY --from=python-build /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages +COPY --from=python-build /usr/local/bin /usr/local/bin + +# Copy built OpenRA (bin, mods, glsl shaders, and global mix database for content resolution) +COPY --from=openra-build /src/openra/bin /opt/openra/bin +COPY --from=openra-build /src/openra/mods /opt/openra/mods +COPY --from=openra-build /src/openra/glsl /opt/openra/glsl +COPY --from=openra-build ["/src/openra/global mix database.dat", "/opt/openra/global mix database.dat"] + +# Create native library symlinks that OpenRA expects +# (configure-system-libraries.sh points these to system lib paths) +RUN LIBDIR=$( [ "$(dpkg --print-architecture)" = "arm64" ] && echo "/usr/lib/aarch64-linux-gnu" || echo "/usr/lib/x86_64-linux-gnu" ) && \ + ln -sf "$LIBDIR/libSDL2-2.0.so.0" /opt/openra/bin/SDL2.so && \ + ln -sf "$LIBDIR/libopenal.so.1" /opt/openra/bin/soft_oal.so && \ + ln -sf "$LIBDIR/libfreetype.so.6" /opt/openra/bin/freetype6.so && \ + ln -sf "$LIBDIR/liblua5.1.so.0" /opt/openra/bin/lua51.so + +# Copy Python application code +COPY openra_env/ /app/openra_env/ +COPY proto/ /app/proto/ +COPY pyproject.toml /app/ + +# Create OpenRA support directory and pre-install RA game content (best-effort). +# Only needed for the replay viewer (Game.Platform=Default with full UI). +# The RL environment works without this content (headless mode). +RUN mkdir -p /root/.config/openra/Content/ra/v2/expand /root/.config/openra/Content/ra/v2/cnc && \ + ( curl -sfL --max-time 30 -o /tmp/ra-quickinstall.zip \ + https://openra.baxxster.no/openra/ra-quickinstall.zip && \ + apt-get update && apt-get install -y --no-install-recommends unzip && \ + unzip -o /tmp/ra-quickinstall.zip -d /tmp/ra-content && \ + cp /tmp/ra-content/*.mix /root/.config/openra/Content/ra/v2/ && \ + cp /tmp/ra-content/expand/* /root/.config/openra/Content/ra/v2/expand/ && \ + cp /tmp/ra-content/cnc/* /root/.config/openra/Content/ra/v2/cnc/ && \ + rm -rf /tmp/ra-quickinstall.zip /tmp/ra-content && \ + apt-get purge -y unzip && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* \ + ) || echo "WARNING: RA content download failed (replay viewer will be unavailable)" + +# Copy entrypoints (fix Windows CRLF line endings) +COPY docker/entrypoint.sh /entrypoint.sh +COPY docker/replay-viewer.sh /replay-viewer.sh +RUN sed -i 's/\r$//' /entrypoint.sh /replay-viewer.sh && \ + chmod +x /entrypoint.sh /replay-viewer.sh + +# Environment +ENV OPENRA_PATH=/opt/openra +ENV PYTHONPATH=/app +ENV PYTHONUNBUFFERED=1 +ENV DISPLAY=:99 +ENV DOTNET_CLI_TELEMETRY_OPTOUT=1 +ENV DOTNET_ROLL_FORWARD=LatestMajor +ENV LIBGL_ALWAYS_SOFTWARE=1 +ENV MESA_GL_VERSION_OVERRIDE=3.3 +# Game configuration (override at runtime with -e) +ENV AI_SLOT=Multi0 +ENV BOT_TYPE=normal +ENV RECORD_REPLAYS=true + +EXPOSE 8000 + +HEALTHCHECK --interval=30s --timeout=5s --start-period=30s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +ENTRYPOINT ["/entrypoint.sh"] +CMD ["python", "-m", "openra_env.server.app"] diff --git a/Dockerfile.agent b/Dockerfile.agent new file mode 100644 index 0000000000000000000000000000000000000000..c92f64b45bf11776fbc982279fe0b1f1584f77b9 --- /dev/null +++ b/Dockerfile.agent @@ -0,0 +1,32 @@ +# ============================================================================== +# Lightweight agent container for OpenRA-RL +# +# Runs the LLM agent (or MCP bot) that connects to the OpenRA-RL game server. +# Does NOT include the game engine — only the Python client and agent code. +# +# Usage: +# docker build -f Dockerfile.agent -t openra-rl-agent . +# docker run -e OPENROUTER_API_KEY=sk-or-... openra-rl-agent +# ============================================================================== +FROM python:3.11-slim-bookworm + +LABEL description="OpenRA-RL Agent - LLM/MCP bot that plays Red Alert" + +WORKDIR /app + +# Install Python dependencies +COPY pyproject.toml README.md /app/ +COPY openra_env/ /app/openra_env/ +COPY proto/ /app/proto/ + +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir . httpx + +# Copy agent scripts +COPY examples/ /app/examples/ + +ENV PYTHONPATH=/app +ENV PYTHONUNBUFFERED=1 + +# Default: run LLM agent +CMD ["python", "examples/llm_agent.py"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f288702d2fa16d3cdf0035b15a9fcbc552cd88e7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/OpenRA b/OpenRA new file mode 160000 index 0000000000000000000000000000000000000000..de92f675141c8ceff6621417ce74f82497765698 --- /dev/null +++ b/OpenRA @@ -0,0 +1 @@ +Subproject commit de92f675141c8ceff6621417ce74f82497765698 diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fdbf00a68211054017616d9fa1cc5dce3a772b28 --- /dev/null +++ b/README.md @@ -0,0 +1,479 @@ +--- +title: OpenRA-RL +emoji: 🎮 +colorFrom: red +colorTo: blue +sdk: docker +app_port: 8000 +tags: + - openenv + - reinforcement-learning + - rts +models: [] +datasets: [] +pinned: false +--- + +# OpenRA-RL + +Play [Red Alert](https://www.openra.net/) with AI agents. LLMs, scripted bots, or RL — your agent commands armies in the classic RTS through a Python API. + +``` +┌──────────────────┐ HTTP / WS :8000 ┌──────────────────────────────┐ +│ Your Agent │ ◄────────────────────────► │ OpenRA-RL Server (Docker) │ +│ │ gRPC :9999 │ FastAPI + gRPC bridge │ +│ LLM / Bot / RL │ ◄────────────────────────► │ OpenRA engine (headless) │ +└──────────────────┘ └──────────────────────────────┘ +``` + +## Quick Start + +```bash +pip install openra-rl +openra-rl play +``` + +On first run, an interactive wizard helps you configure your LLM provider (OpenRouter, Ollama, or LM Studio). The CLI pulls the game server Docker image and starts everything automatically. + +### Skip the wizard + +```bash +# Cloud (OpenRouter) +openra-rl play --provider openrouter --api-key sk-or-... --model anthropic/claude-sonnet-4-20250514 + +# Local (Ollama — free, no API key) +openra-rl play --provider ollama --model qwen3:32b + +# Developer mode (skip Docker, run server locally) +openra-rl play --local --provider ollama --model qwen3:32b + +# Reconfigure later +openra-rl config +``` + +### Prerequisites + +- **Docker** — the game server runs in a container +- **Python 3.10+** +- An LLM endpoint (cloud API key or local model server) + +## CLI Reference + +``` +openra-rl play Run the LLM agent (wizard on first use) +openra-rl config Re-run the setup wizard +openra-rl server start | stop | status | logs +openra-rl replay watch | list | copy | stop +openra-rl bench submit Upload results to the leaderboard +openra-rl mcp-server Start MCP stdio server (for OpenClaw / Claude Desktop) +openra-rl doctor Check system prerequisites +openra-rl version Print version +``` + +## MCP Server (OpenClaw / Claude Desktop) + +OpenRA-RL exposes all 48 game tools as a standard MCP server: + +```bash +openra-rl mcp-server +``` + +Add to your MCP client config (e.g. `~/.openclaw/openclaw.json`): + +```json +{ + "mcpServers": { + "openra-rl": { + "command": "openra-rl", + "args": ["mcp-server"] + } + } +} +``` + +Then chat: _"Start a game of Red Alert on easy difficulty, build a base, and defeat the enemy."_ + +## Architecture + +| Component | Language | Role | +|-----------|----------|------| +| **OpenRA-RL** | Python | Environment wrapper, agents, HTTP/WebSocket API | +| **OpenRA** (submodule) | C# | Modified game engine with embedded gRPC server | +| **OpenEnv** (pip dep) | Python | Standardized Gymnasium-style environment interface | + +**Data flow:** Agent <-> FastAPI (port 8000) <-> gRPC bridge (port 9999) <-> OpenRA game engine + +The game runs at ~25 ticks/sec independent of agent speed. Observations use a DropOldest channel so the agent always sees the latest game state, even if it's slower than real time. + +## Example Agents + +### Scripted Bot + +A hardcoded state-machine bot that demonstrates all action types. Deploys MCV, builds a base, trains infantry, and attacks. + +```bash +python examples/scripted_bot.py --url http://localhost:8000 --verbose --max-steps 2000 +``` + +### MCP Bot + +A planning-aware bot that uses game knowledge tools (tech tree lookups, faction briefings, map analysis) to formulate strategy before playing. + +```bash +python examples/mcp_bot.py --url http://localhost:8000 --verbose --max-turns 3000 +``` + +### LLM Agent + +An AI agent powered by any OpenAI-compatible model. Supports cloud APIs (OpenRouter, OpenAI) and local model servers (Ollama, LM Studio). + +```bash +python examples/llm_agent.py \ + --config examples/config-openrouter.yaml \ + --api-key sk-or-... \ + --verbose \ + --log-file game.log +``` + +CLI flags override config file values. See `python examples/llm_agent.py --help` for all options. + +## Configuration + +OpenRA-RL uses a unified YAML config system. Settings are resolved with this precedence: + +**CLI flags > Environment variables > Config file > Built-in defaults** + +### Config file + +Copy and edit the default config: + +```bash +cp config.yaml my-config.yaml +# Edit my-config.yaml, then: +python examples/llm_agent.py --config my-config.yaml +``` + +Key sections: + +```yaml +game: + openra_path: "/opt/openra" # Path to OpenRA installation + map_name: "singles.oramap" # Map to play + headless: true # No GPU rendering + record_replays: false # Save .orarep replay files + +opponent: + bot_type: "normal" # AI difficulty: easy, normal, hard + ai_slot: "Multi0" # AI player slot + +planning: + enabled: true # Pre-game planning phase + max_turns: 10 # Max planning turns + max_time_s: 60.0 # Planning time limit + +llm: + base_url: "https://openrouter.ai/api/v1/chat/completions" + model: "qwen/qwen3-coder-next" + max_tokens: 1500 + temperature: null # null = provider default + +tools: + categories: # Toggle tool groups on/off + read: true + knowledge: true + movement: true + production: true + # ... see config.yaml for all categories + disabled: [] # Disable specific tools by name + +alerts: + under_attack: true + low_power: true + idle_production: true + no_scouting: true + # ... see config.yaml for all alerts +``` + +### Example configs + +| File | Use case | +|------|----------| +| `examples/config-openrouter.yaml` | Cloud LLM via OpenRouter (Claude, GPT, etc.) | +| `examples/config-ollama.yaml` | Local LLM via Ollama | +| `examples/config-lmstudio.yaml` | Local LLM via LM Studio | +| `examples/config-minimal.yaml` | Reduced tool set for limited-context models | + +### Environment variables + +| Variable | Config path | Description | +|----------|-------------|-------------| +| `OPENROUTER_API_KEY` | `llm.api_key` | API key for OpenRouter | +| `LLM_API_KEY` | `llm.api_key` | Generic LLM API key (overrides OpenRouter key) | +| `LLM_BASE_URL` | `llm.base_url` | LLM endpoint URL | +| `LLM_MODEL` | `llm.model` | Model identifier | +| `BOT_TYPE` | `opponent.bot_type` | AI difficulty: easy, normal, hard | +| `OPENRA_PATH` | `game.openra_path` | Path to OpenRA installation | +| `RECORD_REPLAYS` | `game.record_replays` | Save replay files (true/false) | +| `PLANNING_ENABLED` | `planning.enabled` | Enable planning phase (true/false) | + +## Using Local Models + +### Ollama + +```bash +# Pull a model with tool-calling support +ollama pull qwen3:32b + +# For models that need more context (default is often 2048-4096 tokens): +cat > /tmp/Modelfile < **Note:** Not all Ollama models support tool calling. Check with `ollama show ` — the template must include a `tools` block. Models known to work: `qwen3:32b`, `qwen3:4b`. + +### LM Studio + +1. Load a model in LM Studio and start the local server (default port 1234) +2. Run: + +```bash +openra-rl play --provider lmstudio --model +``` + +## Docker + +### Server management + +```bash +openra-rl server start # Start game server container +openra-rl server start --port 9000 # Custom port +openra-rl server status # Check if running +openra-rl server logs --follow # Tail logs +openra-rl server stop # Stop container +``` + +### Docker Compose (development) + +| Service | Command | Description | +|---------|---------|-------------| +| `openra-rl` | `docker compose up openra-rl` | Headless game server (ports 8000, 9999) | +| `agent` | `docker compose up agent` | LLM agent (requires `OPENROUTER_API_KEY`) | +| `mcp-bot` | `docker compose run mcp-bot` | MCP bot | + +```bash +# LLM agent via Docker Compose +OPENROUTER_API_KEY=sk-or-... docker compose up agent +``` + +### Replays + +After each game, replays are automatically copied to `~/.openra-rl/replays/`. Watch them in your browser: + +```bash +openra-rl replay watch # Watch the latest replay (opens browser via VNC) +openra-rl replay watch # Watch a specific .orarep file +openra-rl replay list # List replays (Docker + local) +openra-rl replay copy # Copy replays from Docker to local +openra-rl replay stop # Stop the replay viewer +``` + +The replay viewer runs inside Docker using the same engine that recorded the game, so replays always play back correctly. The browser connects via noVNC — no local game install needed. + +> **Version tracking:** Each replay records which Docker image version was used. When you upgrade, old replays are still viewable using their original engine version. + +## Local Development (without Docker) + +For running the game server natively (macOS/Linux): + +### Install dependencies + +```bash +# Python +pip install -e ".[dev]" + +# .NET 8.0 SDK +# macOS: brew install dotnet@8 +# Ubuntu: sudo apt install dotnet-sdk-8.0 + +# Native libraries (macOS arm64) +brew install sdl2 openal-soft freetype luajit +cp $(brew --prefix sdl2)/lib/libSDL2.dylib OpenRA/bin/SDL2.dylib +cp $(brew --prefix openal-soft)/lib/libopenal.dylib OpenRA/bin/soft_oal.dylib +cp $(brew --prefix freetype)/lib/libfreetype.dylib OpenRA/bin/freetype6.dylib +cp $(brew --prefix luajit)/lib/libluajit-5.1.dylib OpenRA/bin/lua51.dylib +``` + +### Build OpenRA + +```bash +cd OpenRA && make && cd .. +``` + +### Start the server + +```bash +python openra_env/server/app.py +``` + +### Run tests + +```bash +pytest +``` + +## Observation Space + +Each tick, the agent receives structured game state: + +| Field | Description | +|-------|-------------| +| `tick` | Current game tick | +| `cash`, `ore`, `power_provided`, `power_drained` | Economy | +| `units` | Own units with position, health, type, facing, stance, speed, attack range | +| `buildings` | Own buildings with production queues, power, rally points | +| `visible_enemies`, `visible_enemy_buildings` | Fog-of-war limited enemy intel | +| `spatial_map` | 9-channel spatial tensor (terrain, height, resources, passability, fog, own buildings, own units, enemy buildings, enemy units) | +| `military` | Kill/death costs, asset value, experience, order count | +| `available_production` | What can currently be built | + +## Action Space + +18 action types available through the command API: + +| Category | Actions | +|----------|---------| +| **Movement** | `move`, `attack_move`, `attack`, `stop` | +| **Production** | `produce`, `cancel_production` | +| **Building** | `place_building`, `sell`, `repair`, `power_down`, `set_rally_point`, `set_primary` | +| **Unit control** | `deploy`, `guard`, `set_stance`, `enter_transport`, `unload`, `harvest` | + +## MCP Tools + +The LLM agent interacts through 48 MCP (Model Context Protocol) tools organized into categories: + +| Category | Tools | Purpose | +|----------|-------|---------| +| **Read** | `get_game_state`, `get_economy`, `get_units`, `get_buildings`, `get_enemies`, `get_production`, `get_map_info`, `get_exploration_status` | Query current game state | +| **Knowledge** | `lookup_unit`, `lookup_building`, `lookup_tech_tree`, `lookup_faction` | Static game data reference | +| **Bulk Knowledge** | `get_faction_briefing`, `get_map_analysis`, `batch_lookup` | Efficient batch queries | +| **Planning** | `start_planning_phase`, `end_planning_phase`, `get_opponent_intel`, `get_planning_status` | Pre-game strategy planning | +| **Game Control** | `advance` | Advance game ticks | +| **Movement** | `move_units`, `attack_move`, `attack_target`, `stop_units` | Unit movement commands | +| **Production** | `build_unit`, `build_structure`, `build_and_place` | Build units and structures | +| **Building Actions** | `place_building`, `cancel_production`, `deploy_unit`, `sell_building`, `repair_building`, `set_rally_point`, `guard_target`, `set_stance`, `harvest`, `power_down`, `set_primary` | Building and unit management | +| **Placement** | `get_valid_placements` | Query valid building locations | +| **Unit Groups** | `assign_group`, `add_to_group`, `get_groups`, `command_group` | Group management | +| **Compound** | `batch`, `plan` | Multi-action sequences | +| **Utility** | `get_replay_path`, `surrender` | Misc | +| **Terrain** | `get_terrain_at` | Terrain queries | + +Tools can be toggled per-category or individually via `config.yaml`. + +## Benchmark & Leaderboard + +Game results are automatically submitted to the [OpenRA-Bench leaderboard](https://huggingface.co/spaces/openra-rl/OpenRA-Bench) after each game. Disable with `BENCH_UPLOAD=false` or `bench_upload: false` in config. + +### Agent identity + +Customize how your agent appears on the leaderboard: + +```bash +# Environment variables +AGENT_NAME="DeathBot-9000" AGENT_TYPE="RL" openra-rl play + +# Or in config.yaml +agent: + agent_name: "DeathBot-9000" + agent_type: "RL" + agent_url: "https://github.com/user/deathbot" # shown as link on leaderboard +``` + +| Variable | Config path | Description | +|----------|-------------|-------------| +| `AGENT_NAME` | `agent.agent_name` | Display name (default: model name) | +| `AGENT_TYPE` | `agent.agent_type` | Scripted / LLM / RL (default: auto-detect) | +| `AGENT_URL` | `agent.agent_url` | GitHub/project URL shown on leaderboard | +| `BENCH_UPLOAD` | `agent.bench_upload` | Auto-upload after each game (default: true) | +| `BENCH_URL` | `agent.bench_url` | Leaderboard URL | + +### Manual submission + +Upload a saved result (with optional replay file): + +```bash +openra-rl bench submit result.json +openra-rl bench submit result.json --replay game.orarep --agent-name "MyBot" +``` + +### Custom agents + +If you're building your own agent (RL, CNN, multi-agent, etc.) that doesn't use the built-in LLM agent, use `build_bench_export()` to create a leaderboard submission from a final observation: + +```python +from openra_env.bench_export import build_bench_export + +# obs = final observation from env.step() +export = build_bench_export( + obs, + agent_name="DeathBot-9000", + agent_type="RL", + opponent="Normal", + agent_url="https://github.com/user/deathbot", + replay_path="/path/to/replay.orarep", +) +# Saves JSON to ~/.openra-rl/bench-exports/ and returns dict with "path" key +``` + +Then submit: + +```bash +openra-rl bench submit ~/.openra-rl/bench-exports/bench-DeathBot-9000-*.json --replay game.orarep +``` + +## Project Structure + +``` +OpenRA-RL/ +├── OpenRA/ # Game engine (git submodule, C#) +├── openra_env/ # Python package +│ ├── cli/ # CLI entry point (openra-rl command) +│ ├── mcp_server.py # Standard MCP server (stdio transport) +│ ├── client.py # WebSocket client +│ ├── config.py # Unified YAML configuration +│ ├── models.py # Pydantic data models +│ ├── game_data.py # Unit/building stats, tech tree +│ ├── reward.py # Multi-component reward function +│ ├── bench_export.py # Build leaderboard submissions from observations +│ ├── bench_submit.py # Upload results to OpenRA-Bench leaderboard +│ ├── opponent_intel.py # AI opponent profiles +│ ├── mcp_ws_client.py # MCP WebSocket client +│ ├── server/ +│ │ ├── app.py # FastAPI application +│ │ ├── openra_environment.py # OpenEnv environment (reset/step/state) +│ │ ├── bridge_client.py # Async gRPC client +│ │ └── openra_process.py # OpenRA subprocess manager +│ └── generated/ # Auto-generated protobuf stubs +├── examples/ +│ ├── scripted_bot.py # Hardcoded strategy bot +│ ├── mcp_bot.py # MCP tool-based bot +│ ├── llm_agent.py # LLM-powered agent +│ └── config-*.yaml # Example configs (ollama, lmstudio, openrouter, minimal) +├── skill/ # OpenClaw skill definition +├── proto/ # Protobuf definitions (rl_bridge.proto) +├── tests/ # Test suite +├── .github/workflows/ # CI, Docker publish, PyPI publish +├── config.yaml # Default configuration +├── docker-compose.yaml # Service orchestration +├── Dockerfile # Game server image +└── Dockerfile.agent # Lightweight agent image +``` + +## License + +[GPL-3.0](LICENSE) diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7502d602a3228a0a4e5eaf9aae7f84a11fc9b3 --- /dev/null +++ b/__init__.py @@ -0,0 +1,4 @@ +"""OpenRA-RL: OpenEnv environment for Red Alert.""" + +from openra_env.client import OpenRAEnv # noqa: F401 +from openra_env.models import OpenRAAction, OpenRAObservation, OpenRAState # noqa: F401 diff --git a/client.py b/client.py new file mode 100644 index 0000000000000000000000000000000000000000..b279dd61253f217e8bff13f68f84ee9a44df7d40 --- /dev/null +++ b/client.py @@ -0,0 +1,3 @@ +"""OpenEnv client re-export.""" + +from openra_env.client import OpenRAEnv # noqa: F401 diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fdab1178644a6036dace4cdde45d47453d53b4eb --- /dev/null +++ b/config.yaml @@ -0,0 +1,142 @@ +# OpenRA-RL Configuration +# ======================== +# All values below show the built-in defaults (commented out). +# Uncomment and change any value to override. +# Environment variables always take highest priority (see docs for mapping). +# +# Precedence: env vars > CLI args > constructor args > this file > defaults + +# ── Game Engine ─────────────────────────────────────────────────────── +game: + openra_path: "/Users/berta/Projects/OpenRA-RL/OpenRA" # Path to OpenRA installation ($OPENRA_PATH) +# mod: "ra" # Game mod (ra, cnc, d2k) +# map_name: "singles.oramap" # Map to play +# grpc_port: 9999 # gRPC bridge port +# headless: true # Use Null renderer (no GPU) + record_replays: true # Save .orarep replay files ($RECORD_REPLAYS) +# seed: null # RNG seed for reproducibility (null = random) +# max_ticks: 0 # End game after N ticks (0 = unlimited) +# max_wall_time_s: 0 # End game after N seconds (0 = unlimited) + +# ── Opponent ────────────────────────────────────────────────────────── +# Enemy bot always spawns by default. Set ai_slot to "" to disable. +# Difficulty tiers: beginner / easy / medium / hard / brutal +# Raw play styles also accepted: rush / normal / turtle / naval +opponent: + bot_type: "beginner" # Difficulty tier ($BOT_TYPE) + ai_slot: "Multi0" # AI player slot; "" to disable enemy ($AI_SLOT) + +# ── Planning Phase ──────────────────────────────────────────────────── +# planning: +# enabled: true # Enable pre-game planning phase ($PLANNING_ENABLED) +# max_turns: 10 # Max planning turns ($PLANNING_MAX_TURNS) +# max_time_s: 60.0 # Max planning seconds ($PLANNING_MAX_TIME) + +# ── Reward Function ─────────────────────────────────────────────────── +# reward: +# survival: 0.001 # Per-tick survival bonus +# economic_efficiency: 0.01 # Cash delta reward +# aggression: 0.1 # Kill reward multiplier +# defense: 0.05 # Loss penalty multiplier +# victory: 1.0 # Terminal win reward +# defeat: -1.0 # Terminal loss penalty + +# ── Reward Vector ──────────────────────────────────────────────────── +# 8-dimensional skill signal computed per step alongside the scalar reward. +# Dimensions: combat, economy, infrastructure, intelligence, composition, +# tempo, disruption, outcome +# reward_vector: +# enabled: true # Enabled by default +# weights: # Per-dimension weights (for weighted sum) +# combat: 0.30 +# economy: 0.15 +# infrastructure: 0.10 +# intelligence: 0.10 +# composition: 0.10 +# tempo: 0.10 +# disruption: 0.15 +# outcome: 1.00 + +# ── MCP Tools ───────────────────────────────────────────────────────── +# tools: +# categories: # Toggle tool groups (true/false) +# read: true # get_game_state, get_economy, get_units, etc. +# knowledge: true # lookup_unit, lookup_building, etc. +# bulk_knowledge: true # get_faction_briefing, get_map_analysis, batch_lookup +# planning: true # start/end_planning_phase, get_opponent_intel, etc. +# game_control: true # advance +# movement: true # move_units, attack_move, attack_target, stop_units +# production: true # build_unit, build_structure, build_and_place +# building_actions: true # place, cancel, deploy, sell, repair, rally, etc. +# placement: true # get_valid_placements +# unit_groups: true # assign_group, command_group, etc. +# compound: true # batch, plan +# utility: true # get_replay_path, surrender +# terrain: true # get_terrain_at +# disabled: [] # Disable specific tools by name + +# ── Alerts ──────────────────────────────────────────────────────────── +# alerts: +# under_attack: true +# damaged_building: true +# low_power: true +# idle_funds: true +# ore_full: true +# idle_production: true +# production_stalled: true +# building_ready: true +# stance_warning: true +# idle_army: true +# no_defenses: true +# no_scouting: true +# loss_tracking: true +# minimap: true # Show ASCII minimap in turn briefing +# max_alerts: 0 # Max alerts per turn (0 = unlimited) + +# ── LLM Model ──────────────────────────────────────────────────────── +# llm: +# base_url: "https://openrouter.ai/api/v1/chat/completions" +# api_key: "" # Empty = not required (local models) +# model: "qwen/qwen3-coder-next" +# max_tokens: 1500 +# temperature: null # null = provider default +# top_p: null # null = provider default +# keep_last_messages: 40 # Messages to keep after compression +# compression_strategy: "sliding_window" # "sliding_window" or "none" +# compression_trigger: 0 # Compress at this count (0 = keep_last * 2) +# max_retries: 4 # Retry on transient errors +# retry_backoff_s: 10 # Base backoff (multiplied by attempt) +# request_timeout_s: 120.0 # HTTP timeout per request +# extra_headers: # Custom headers (OpenRouter-specific) +# HTTP-Referer: "https://github.com/openra-rl" +# X-Title: "OpenRA-RL Agent" + +# ── Agent Runtime ───────────────────────────────────────────────────── +# agent: +# server_url: "http://localhost:8000" # OpenRA-RL server ($OPENRA_URL) +# max_turns: 0 # 0 = unlimited +# max_time_s: 1800 # 30 minutes ($MAX_TIME) +# verbose: false +# log_file: "" # Log file path ($LLM_AGENT_LOG) +# agent_name: "" # Leaderboard display name ($AGENT_NAME); empty = model name +# agent_type: "" # Scripted/LLM/RL ($AGENT_TYPE); empty = auto-detect +# agent_url: "" # GitHub/project URL shown on leaderboard ($AGENT_URL) +# bench_upload: true # Auto-upload results after each game ($BENCH_UPLOAD) +# bench_url: "https://openra-rl-openra-bench.hf.space" # Leaderboard URL ($BENCH_URL) + +# ── Prompts ────────────────────────────────────────────────────────── +# All LLM-facing text. Override individual fields here, or point +# prompts_file to a separate YAML (copy openra_env/prompts/default_prompts.yaml). +# Templates use Python str.format() placeholders: {variable_name} +# prompts: +# system_prompt: "" # Inline system prompt (overrides built-in) +# system_prompt_file: "" # Path to .txt system prompt ($SYSTEM_PROMPT_FILE) +# prompts_file: "" # Path to prompts YAML ($PROMPTS_FILE) +# planning_nudge: "Call end_planning_phase(strategy='...') when ready to start." +# planning_complete: "Planning complete. Game is now live." +# no_tool_nudge: "No tool was called. A tool call is required each turn." +# continue_nudge: "The game is still in progress." +# alerts: # Alert message templates +# low_power: "LOW POWER: {balance} — production runs at 1/3 speed" +# idle_army: "IDLE ARMY: {count} combat units idle" +# # ... see openra_env/prompts/default_prompts.yaml for all fields diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..159ad96faf19286e2d94aaa0bac1a67dae6daf49 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,71 @@ +# Docker Compose for OpenRA-RL development +# +# Usage: +# Game server only: docker compose up openra-rl +# With LLM agent: docker compose up agent +# With MCP bot: docker compose run mcp-bot +# +# Build: +# docker compose build + +services: + openra-rl: + image: ${OPENRA_RL_IMAGE:-ghcr.io/yxc20089/openra-rl:latest} + build: + context: . + dockerfile: Dockerfile + ports: + - "8000:8000" # OpenEnv HTTP API + - "9999:9999" # gRPC bridge (direct access) + environment: + - OPENRA_PATH=/opt/openra + - DISPLAY=:99 + - LIBGL_ALWAYS_SOFTWARE=1 + - MESA_GL_VERSION_OVERRIDE=3.3 + deploy: + resources: + limits: + cpus: "4" + memory: 4G + shm_size: 256m + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 5s + start_period: 60s + retries: 3 + + agent: + build: + context: . + dockerfile: Dockerfile.agent + environment: + - OPENROUTER_API_KEY=${OPENROUTER_API_KEY} + - OPENROUTER_MODEL=${OPENROUTER_MODEL:-anthropic/claude-sonnet-4-20250514} + - OPENRA_URL=http://openra-rl:8000 + command: > + python examples/llm_agent.py + --url http://openra-rl:8000 + --max-turns ${MAX_TURNS:-200} + --verbose + depends_on: + openra-rl: + condition: service_healthy + + mcp-bot: + build: + context: . + dockerfile: Dockerfile.agent + environment: + - OPENRA_URL=http://openra-rl:8000 + command: > + python examples/mcp_bot.py + --url http://openra-rl:8000 + --max-turns ${MAX_TURNS:-3000} + --verbose + depends_on: + openra-rl: + condition: service_healthy + profiles: + - bot diff --git a/docker/build.sh b/docker/build.sh new file mode 100755 index 0000000000000000000000000000000000000000..e9b8d89ee3c19fc3f10df30e72ce9e886b5cc9e3 --- /dev/null +++ b/docker/build.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Build the OpenRA-RL Docker image. +# +# This script assembles the build context by copying the OpenRA source +# into the OpenRA-RL directory (Docker can't access files outside context). +# +# Usage: +# ./docker/build.sh # Auto-detect ../OpenRA +# OPENRA_DIR=/path/to/OpenRA ./docker/build.sh # Specify OpenRA path + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +OPENRA_DIR="${OPENRA_DIR:-$PROJECT_DIR/OpenRA}" + +if [ ! -d "$OPENRA_DIR" ]; then + echo "ERROR: OpenRA source not found at $OPENRA_DIR" + echo "Run: git submodule update --init" + exit 1 +fi + +if [ ! -f "$OPENRA_DIR/OpenRA.sln" ]; then + echo "ERROR: $OPENRA_DIR doesn't look like an OpenRA repo (no OpenRA.sln)" + exit 1 +fi + +echo "=== OpenRA-RL Docker Build ===" +echo "OpenRA source: $OPENRA_DIR" +echo "Project dir: $PROJECT_DIR" +echo "" + +# If OpenRA source is external (not the submodule), copy it into build context +REAL_OPENRA="$(cd "$OPENRA_DIR" && pwd)" +REAL_SUBMODULE="$(cd "$PROJECT_DIR/OpenRA" 2>/dev/null && pwd || echo "")" +if [ "$REAL_OPENRA" != "$REAL_SUBMODULE" ]; then + echo "Copying OpenRA source into build context..." + rsync -a --delete \ + --exclude='.git' \ + --exclude='bin/' \ + --exclude='*/obj/' \ + --exclude='*.user' \ + "$OPENRA_DIR/" "$PROJECT_DIR/OpenRA/" +fi + +echo "Building Docker image..." +docker build -t openra-rl "$PROJECT_DIR" "$@" + +echo "" +echo "=== Build complete ===" +echo "Run with: docker run -p 8000:8000 openra-rl" diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh new file mode 100755 index 0000000000000000000000000000000000000000..ec146eefe04576fc19ce9c364109c543a0170bee --- /dev/null +++ b/docker/entrypoint.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e + +# Start Xvfb (virtual framebuffer) for headless display +echo "Starting Xvfb on display :99..." +Xvfb :99 -screen 0 1024x768x24 -ac +extension GLX +render -noreset & +XVFB_PID=$! + +# Wait for Xvfb to be ready +sleep 2 +if ! kill -0 $XVFB_PID 2>/dev/null; then + echo "ERROR: Xvfb failed to start" + exit 1 +fi +echo "Xvfb started (PID: $XVFB_PID)" + +export DISPLAY=:99 + +# Clean shutdown on signals +cleanup() { + echo "Shutting down..." + kill $XVFB_PID 2>/dev/null || true + wait $XVFB_PID 2>/dev/null || true + exit 0 +} +trap cleanup SIGTERM SIGINT + +# Execute the main command (uvicorn by default) +echo "Starting OpenRA-RL environment server..." +exec "$@" diff --git a/docker/replay-viewer.sh b/docker/replay-viewer.sh new file mode 100644 index 0000000000000000000000000000000000000000..0efaad62a3cdb8850ed2783b3785d5104a12babd --- /dev/null +++ b/docker/replay-viewer.sh @@ -0,0 +1,89 @@ +#!/bin/bash +set -e + +# The base image sets LIBGL_ALWAYS_SOFTWARE=1 for the headless game server. +# The replay viewer needs GPU rendering, so unset it. +unset LIBGL_ALWAYS_SOFTWARE + +REPLAY_FILE="$1" +if [ -z "$REPLAY_FILE" ]; then + echo "Usage: /replay-viewer.sh " + exit 1 +fi + +if [ ! -f "$REPLAY_FILE" ]; then + echo "ERROR: Replay file not found: $REPLAY_FILE" + exit 1 +fi + +# Tunable settings via environment variables (set by docker_manager.py) +REPLAY_RESOLUTION="${OPENRA_RL_REPLAY_RESOLUTION:-1280x960}" +REPLAY_WIDTH="${REPLAY_RESOLUTION%x*}" +REPLAY_HEIGHT="${REPLAY_RESOLUTION#*x}" +REPLAY_UI_SCALE="${OPENRA_RL_REPLAY_UI_SCALE:-1}" +REPLAY_VIEWPORT="${OPENRA_RL_REPLAY_VIEWPORT_DISTANCE:-Medium}" +REPLAY_MUTE="${OPENRA_RL_REPLAY_MUTE:-True}" + +# Copy replay to the expected directory structure so OpenRA can read metadata +REPLAY_DIR="/root/.config/openra/Replays/ra/{DEV_VERSION}" +mkdir -p "$REPLAY_DIR" +REPLAY_BASENAME=$(basename "$REPLAY_FILE") +cp "$REPLAY_FILE" "$REPLAY_DIR/$REPLAY_BASENAME" +REPLAY_PATH="$REPLAY_DIR/$REPLAY_BASENAME" +echo "Replay copied to: $REPLAY_PATH" + +# Start Xvfb at configured resolution +echo "Starting Xvfb on display :99 (${REPLAY_WIDTH}x${REPLAY_HEIGHT})..." +Xvfb :99 -screen 0 ${REPLAY_WIDTH}x${REPLAY_HEIGHT}x24 -ac +extension GLX +render -noreset & +XVFB_PID=$! +sleep 2 +if ! kill -0 $XVFB_PID 2>/dev/null; then + echo "ERROR: Xvfb failed to start" + exit 1 +fi +export DISPLAY=:99 + +# Start x11vnc with performance optimizations +echo "Starting VNC server on port 5900..." +x11vnc -display :99 -forever -nopw -shared -rfbport 5900 \ + -noxdamage -wait 50 -defer 50 -quiet & +VNC_PID=$! +sleep 1 + +# Start noVNC (websockify proxy) +echo "Starting noVNC on port 6080..." +websockify --web /usr/share/novnc 6080 localhost:5900 & +NOVNC_PID=$! +sleep 1 + +echo "" +echo "=== Replay viewer ready ===" +echo "Open in browser: http://localhost:6080/vnc.html" +echo "Press Ctrl+C to stop" +echo "" + +# Clean shutdown on signals +cleanup() { + echo "Shutting down replay viewer..." + kill $NOVNC_PID 2>/dev/null || true + kill $VNC_PID 2>/dev/null || true + kill $XVFB_PID 2>/dev/null || true + wait 2>/dev/null || true + exit 0 +} +trap cleanup SIGTERM SIGINT + +# Launch OpenRA with rendering settings tuned for VNC replay viewing. +# CPU is managed by Docker --cpus limit (set in docker_manager.py). +exec dotnet /opt/openra/bin/OpenRA.dll \ + Engine.EngineDir=/opt/openra \ + Game.Mod=ra \ + Game.Platform=Default \ + Graphics.Mode=Windowed \ + Graphics.WindowedSize=${REPLAY_WIDTH},${REPLAY_HEIGHT} \ + Graphics.UIScale=${REPLAY_UI_SCALE} \ + Graphics.VSync=False \ + Graphics.DisableGLDebugMessageCallback=True \ + Graphics.ViewportDistance=${REPLAY_VIEWPORT} \ + Sound.Mute=${REPLAY_MUTE} \ + "Launch.Replay=$REPLAY_PATH" diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..007ab55b635f8db67d1e83b8109fd150583aa841 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,50 @@ +# OpenRA-RL Examples + +## Scripted Bot + +A hardcoded Red Alert bot that plays a full game through the OpenEnv client API. + +**Strategy:** Deploy MCV → Build Power Plant → Build Barracks → Train 5 Rifle Infantry → Attack-move toward enemy. + +### Prerequisites + +```bash +# Install the project +pip install -e . + +# Start the OpenRA-RL server (Docker) +docker run -p 8000:8000 openra-rl + +# Or build from source first: +OPENRA_DIR=/path/to/OpenRA ./docker/build.sh +docker run -p 8000:8000 openra-rl +``` + +### Run + +```bash +# Basic run +python examples/scripted_bot.py + +# Custom server URL +python examples/scripted_bot.py --url http://localhost:8000 + +# Verbose mode (prints every bot decision) +python examples/scripted_bot.py --verbose + +# Limit episode length +python examples/scripted_bot.py --max-steps 2000 +``` + +### Output + +``` +Connecting to http://localhost:8000... +Game started! Map: singles +Step 0 | Tick 0 | $ 5000 | Units: 2 (combat: 0) | Buildings: [none] | Phase: deploy_mcv +Step 100 | Tick 100 | $ 4700 | Units: 1 (combat: 0) | Buildings: [fact] | Phase: build_base +Step 200 | Tick 200 | $ 4100 | Units: 1 (combat: 0) | Buildings: [fact, powr] | Phase: build_base +... +Game over: win after 3421 steps (tick 3421) +Total reward: 2.150 +``` diff --git a/examples/config-lmstudio.yaml b/examples/config-lmstudio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3cdccca7d5e4c1ab38764a54c35dd5a70912b88c --- /dev/null +++ b/examples/config-lmstudio.yaml @@ -0,0 +1,14 @@ +# OpenRA-RL config for LM Studio (local) +# Usage: python examples/llm_agent.py --config examples/config-lmstudio.yaml + +llm: + base_url: "http://localhost:1234/v1/chat/completions" + model: "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF" + api_key: "" # No key needed for LM Studio + max_tokens: 2000 + extra_headers: {} + request_timeout_s: 180.0 + +agent: + max_time_s: 3600 + verbose: true diff --git a/examples/config-minimal.yaml b/examples/config-minimal.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a700d49bb3a6712d03b3db605f37017ce3b56b6 --- /dev/null +++ b/examples/config-minimal.yaml @@ -0,0 +1,21 @@ +# OpenRA-RL config with minimal tool set +# Reduces tool count for models with limited context or tool-calling ability. +# Usage: python examples/llm_agent.py --config examples/config-minimal.yaml + +planning: + enabled: false + +tools: + categories: + knowledge: false # Disable lookup_unit, lookup_building, etc. + bulk_knowledge: false # Disable get_faction_briefing, get_map_analysis, etc. + planning: false # Disabled automatically when planning.enabled=false + unit_groups: false # Disable assign_group, command_group, etc. + terrain: false # Disable get_terrain_at + compound: false # Disable batch, plan + +alerts: + stance_warning: false + idle_army: false + no_scouting: false + no_defenses: false diff --git a/examples/config-ollama.yaml b/examples/config-ollama.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ed3d35db4dd367c2b5b9beb32b4959e9c2cf1010 --- /dev/null +++ b/examples/config-ollama.yaml @@ -0,0 +1,14 @@ +# OpenRA-RL config for Ollama (local) +# Usage: python examples/llm_agent.py --config examples/config-ollama.yaml + +llm: + base_url: "http://localhost:11434/v1/chat/completions" + model: "qwen3:32b" + api_key: "" # No key needed for Ollama + max_tokens: 2000 + extra_headers: {} + request_timeout_s: 300.0 # Local models need more time (auto-set if <= 120) + +agent: + max_time_s: 3600 # 1 hour (local models are slower) + verbose: true diff --git a/examples/config-openrouter.yaml b/examples/config-openrouter.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b78eb35524f2688dccefb59c61c77a2af4dc32a0 --- /dev/null +++ b/examples/config-openrouter.yaml @@ -0,0 +1,13 @@ +# OpenRA-RL config for OpenRouter (cloud) +# Usage: OPENROUTER_API_KEY=sk-or-... python examples/llm_agent.py --config examples/config-openrouter.yaml + +llm: + base_url: "https://openrouter.ai/api/v1/chat/completions" + model: "anthropic/claude-sonnet-4-20250514" + # api_key: set via OPENROUTER_API_KEY env var + extra_headers: + HTTP-Referer: "https://github.com/openra-rl" + X-Title: "OpenRA-RL Agent" + +agent: + max_time_s: 1800 # 30 minutes diff --git a/examples/llm_agent.py b/examples/llm_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..05ecee596e6ecbd8eaa96a1a0788bfcc89f633a6 --- /dev/null +++ b/examples/llm_agent.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +"""LLM agent that plays Red Alert using any OpenAI-compatible model. + +Supports OpenRouter, Ollama, LM Studio, or any local/remote endpoint +that implements the OpenAI Chat Completions API with tool calling. + +Usage: + # With OpenRouter (cloud) + export OPENROUTER_API_KEY=sk-or-... + python examples/llm_agent.py --verbose + + # With a YAML config file + python examples/llm_agent.py --config examples/config-ollama.yaml + + # With LM Studio (local, no API key needed) + python examples/llm_agent.py --base-url http://localhost:1234/v1/chat/completions --model my-model +""" + +import argparse +import asyncio +import sys + +from dotenv import load_dotenv +load_dotenv() + +from openra_env.config import load_config +from openra_env.agent import run_agent + +# Re-export for backwards compatibility +from openra_env.agent import ( # noqa: F401 + SYSTEM_PROMPT, + load_system_prompt, + compose_pregame_briefing, + format_state_briefing, + mcp_tools_to_openai, + _sanitize_messages, + chat_completion, + compress_history, +) + +# Line-buffered stdout so output is observable in real time +sys.stdout.reconfigure(line_buffering=True) + + +def main(): + parser = argparse.ArgumentParser( + description="LLM agent that plays Red Alert via any OpenAI-compatible model", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Examples:\n" + " %(prog)s --config examples/config-ollama.yaml --verbose\n" + " %(prog)s --api-key sk-or-... --verbose\n" + " %(prog)s --base-url http://localhost:1234/v1/chat/completions --model my-model\n" + ), + ) + parser.add_argument( + "--config", "-c", + default=None, + help="Path to YAML config file (default: auto-discover config.yaml)", + ) + parser.add_argument( + "--url", + default=None, + help="OpenRA-RL server URL (overrides config agent.server_url)", + ) + parser.add_argument( + "--base-url", + default=None, + help="LLM API endpoint URL (overrides config llm.base_url)", + ) + parser.add_argument( + "--model", + default=None, + help="Model ID (overrides config llm.model)", + ) + parser.add_argument( + "--api-key", + default=None, + help="API key for LLM endpoint (overrides config llm.api_key)", + ) + parser.add_argument( + "--max-turns", + type=int, + default=None, + help="Maximum LLM turns, 0 = unlimited (overrides config agent.max_turns)", + ) + parser.add_argument( + "--max-time", + type=int, + default=None, + help="Maximum wall-clock time in seconds (overrides config agent.max_time_s)", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print detailed LLM reasoning and tool calls", + ) + parser.add_argument( + "--log-file", + default=None, + help="Write all output to this log file in addition to stdout", + ) + parser.add_argument( + "--system-prompt", + default=None, + help="Path to a custom system prompt .txt file (overrides built-in default)", + ) + args = parser.parse_args() + + # Build config: YAML file + env vars + CLI overrides (CLI wins over .env) + cli: dict = {} + if args.url is not None: + cli.setdefault("agent", {})["server_url"] = args.url + if args.base_url is not None: + cli.setdefault("llm", {})["base_url"] = args.base_url + if args.model is not None: + cli.setdefault("llm", {})["model"] = args.model + if args.api_key is not None: + cli.setdefault("llm", {})["api_key"] = args.api_key + if args.max_turns is not None: + cli.setdefault("agent", {})["max_turns"] = args.max_turns + if args.max_time is not None: + cli.setdefault("agent", {})["max_time_s"] = args.max_time + if args.verbose: + cli.setdefault("agent", {})["verbose"] = True + if args.log_file is not None: + cli.setdefault("agent", {})["log_file"] = args.log_file + if args.system_prompt is not None: + cli.setdefault("agent", {})["system_prompt_file"] = args.system_prompt + + config = load_config(config_path=args.config, cli_overrides=cli) + verbose = config.agent.verbose + + # Set up logging to file if requested — tee all print() to both stdout and file + if config.agent.log_file: + import builtins + _builtin_print = builtins.print + _log_fh = open(config.agent.log_file, "w", encoding="utf-8") + + def _tee_print(*pargs, **kwargs): + _builtin_print(*pargs, **kwargs) + kwargs.pop("file", None) + _builtin_print(*pargs, file=_log_fh, **kwargs) + _log_fh.flush() + + builtins.print = _tee_print + + # API key validation: only required for remote endpoints + is_local = any(h in config.llm.base_url for h in ("localhost", "127.0.0.1", "0.0.0.0")) + if not config.llm.api_key and not is_local: + print("Error: API key required for remote LLM endpoints.") + print(" Set OPENROUTER_API_KEY or LLM_API_KEY environment variable, use --api-key,") + print(" or use a config file with llm.api_key set.") + print(" For local models (Ollama, LM Studio), use --base-url http://localhost:...") + sys.exit(1) + + try: + asyncio.run(run_agent(config, verbose)) + except KeyboardInterrupt: + print("\nInterrupted by user") + sys.exit(0) + except ConnectionRefusedError: + print(f"\nCould not connect to {config.agent.server_url}") + print("Is the OpenRA-RL server running?") + print(" docker run -p 8000:8000 openra-rl") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/mcp_bot.py b/examples/mcp_bot.py new file mode 100644 index 0000000000000000000000000000000000000000..92d79f76ad8198a4a73763621647a0b7769ec3f4 --- /dev/null +++ b/examples/mcp_bot.py @@ -0,0 +1,619 @@ +#!/usr/bin/env python3 +"""MCP tool-based Red Alert bot that plays entirely through MCP tools. + +Validates the full MCP integration path: tool discovery, game knowledge +lookups, read tools for state, and action tools for commands. Uses +OpenRAMCPClient to interact with the OpenRA-RL server via WebSocket. + +Exercises ALL 30 MCP tools: + - Read tools: get_game_state, get_economy, get_units, get_buildings, + get_enemies, get_production, get_map_info + - Knowledge tools: lookup_unit, lookup_building, lookup_tech_tree, lookup_faction, + get_faction_briefing, get_map_analysis, batch_lookup + - Action tools: advance, deploy_unit, build_structure, place_building, + build_unit, move_units, attack_move, attack_target, stop_units, + set_rally_point, guard_target, set_stance, sell_building, repair_building, + harvest, power_down, set_primary + - Replay tool: get_replay_path + +Usage: + docker run -p 8000:8000 openra-rl + python examples/mcp_bot.py --verbose +""" + +import argparse +import asyncio +import json +import sys +from typing import Any, Optional + +# Line-buffered stdout so output is observable in real time +sys.stdout.reconfigure(line_buffering=True) + +from openra_env.mcp_ws_client import OpenRAMCPClient + + +class MCPBot: + """State-machine bot that plays Red Alert using MCP tool calls. + + Phases: + startup - Look up tech tree and faction info + deploy_mcv - Find and deploy MCV + build_base - Build power/barracks/refinery/war factory + train_army - Train infantry + vehicles, set rally points + attack - Attack-move toward enemy + sustain - Repair, sell damaged, power management + """ + + BARRACKS_TYPES = {"tent", "barr"} + WAR_FACTORY_TYPES = {"weap"} + BUILD_ORDER = ["powr", "barracks", "proc", "weap", "powr"] + INFANTRY_TARGET = 6 + GUARD_COUNT = 2 + COMBAT_TYPES = {"e1", "e2", "e3", "e4", "1tnk", "2tnk", "3tnk", "arty", "jeep", "apc"} + INFANTRY_TYPES = {"e1", "e2", "e3", "e4"} + + def __init__(self, env: OpenRAMCPClient, verbose: bool = False, no_planning: bool = False): + self.env = env + self.verbose = verbose + self.no_planning = no_planning + self.phase = "startup" + self.build_index = 0 + self.placement_count = 0 + self.deploy_issued = False + self._guards_assigned: set[int] = set() + self._stances_set: set[int] = set() + self._rally_set: set[int] = set() + self._repair_issued: set[int] = set() + self._sold: set[int] = set() + self._powered_down: set[int] = set() + self._primary_set: set[int] = set() + self._apc_trained = False + self._tools_exercised: set[str] = set() + + async def call(self, tool_name: str, **kwargs: Any) -> Any: + """Call an MCP tool and track which tools have been exercised.""" + self._tools_exercised.add(tool_name) + result = await self.env.call_tool(tool_name, **kwargs) + return result + + def _log(self, msg: str): + if self.verbose: + print(f" [MCPBot] {msg}") + + # ── Main loop ───────────────────────────────────────────────── + + async def run(self, max_turns: int) -> dict: + """Run the bot for up to max_turns.""" + # Phase: startup — exercise knowledge tools + await self._startup() + + turn = 0 + while turn < max_turns: + state = await self.call("get_game_state") + if state.get("done"): + self._log(f"Game over: {state.get('result', '?')}") + break + + turn += 1 + await self._tick(state, turn) + + if turn % 100 == 0: + self._print_status(turn, state) + + # End-of-game report + final_state = await self.call("get_game_state") + replay = await self.call("get_replay_path") + self._log(f"Replay: {replay}") + + return { + "turns": turn, + "final_state": final_state, + "replay": replay, + "tools_exercised": sorted(self._tools_exercised), + "tools_count": len(self._tools_exercised), + "planning_strategy": getattr(self, "_planning_strategy", ""), + } + + # ── Startup: knowledge tools ────────────────────────────────── + + async def _startup(self): + """Run planning phase and look up game knowledge at game start.""" + if self.no_planning: + self._log("=== Startup: Planning DISABLED ===") + # Use bulk knowledge tool instead of individual lookups + briefing = await self.call("get_faction_briefing") + self._log(f"Faction briefing: {briefing.get('side', '?')}, " + f"{len(briefing.get('units', {}))} units, " + f"{len(briefing.get('buildings', {}))} buildings") + else: + self._log("=== Startup: Planning Phase ===") + + # Try the planning phase + planning = await self.call("start_planning_phase") + if planning.get("planning_active"): + self._log(f"Planning active — opponent: {planning.get('opponent_summary', '')[:120]}") + + # Use bulk tools for efficient research + briefing = await self.call("get_faction_briefing") + self._log(f"Faction briefing: {briefing.get('side', '?')}, " + f"{len(briefing.get('units', {}))} units, " + f"{len(briefing.get('buildings', {}))} buildings") + + map_analysis = await self.call("get_map_analysis") + self._log(f"Map analysis: {map_analysis.get('map_type', '?')}, " + f"{len(map_analysis.get('resource_patches', []))} resource patches") + + intel = await self.call("get_opponent_intel") + aggressiveness = intel.get("aggressiveness", "unknown") + self._log(f"Opponent aggressiveness: {aggressiveness}") + + # Formulate strategy based on opponent profile + if aggressiveness in ("high", "very_high"): + strategy = ( + "Defensive opening: power plant, barracks, turrets at base entrance, " + "then ore refinery for economy. Build war factory for tanks once stable. " + "Scout early to find and deny enemy expansion." + ) + else: + strategy = ( + "Rush opening: power plant, barracks, infantry rush while building " + "ore refinery. Transition to tanks from war factory." + ) + + result = await self.call("end_planning_phase", strategy=strategy) + self._planning_strategy = strategy + self._log(f"Planning complete: {result.get('planning_duration_seconds', '?')}s, strategy: {strategy[:80]}") + else: + # Planning disabled server-side + self._log(f"Planning: {planning.get('message', 'disabled')}") + briefing = await self.call("get_faction_briefing") + self._log(f"Faction briefing: {briefing.get('side', '?')}, " + f"{len(briefing.get('units', {}))} units, " + f"{len(briefing.get('buildings', {}))} buildings") + + map_info = await self.call("get_map_info") + self._log(f"Map: {map_info.get('map_name', '?')} ({map_info.get('width')}x{map_info.get('height')})") + + self.phase = "deploy_mcv" + self._log("Phase → deploy_mcv") + + # ── Per-tick decision ───────────────────────────────────────── + + async def _tick(self, state: dict, turn: int): + """Make decisions for one game tick.""" + # Update phase based on state + await self._update_phase() + + if self.phase == "deploy_mcv": + await self._do_deploy() + elif self.phase == "build_base": + await self._do_build() + elif self.phase == "train_army": + await self._do_build() + await self._do_train() + elif self.phase == "attack": + await self._do_build() + await self._do_train() + await self._do_combat() + await self._do_sustain() + + # Advance game + await self.call("advance", ticks=1) + + async def _update_phase(self): + """Transition phases based on game state.""" + buildings = await self.call("get_buildings") + units = await self.call("get_units") + + has_cy = any(b["type"] == "fact" for b in buildings) + has_barracks = any(b["type"] in self.BARRACKS_TYPES for b in buildings) + combat_units = [u for u in units if u["type"] in self.COMBAT_TYPES] + non_guard = [u for u in combat_units if u["actor_id"] not in self._guards_assigned] + + if self.phase == "deploy_mcv" and has_cy: + self.phase = "build_base" + self._log("Phase → build_base") + elif self.phase == "build_base" and self.build_index >= len(self.BUILD_ORDER): + self.phase = "train_army" + self._log("Phase → train_army") + elif self.phase == "train_army" and len(non_guard) >= self.INFANTRY_TARGET: + self.phase = "attack" + self._log(f"Phase → attack ({len(non_guard)} combat units)") + + # ── Deploy MCV ──────────────────────────────────────────────── + + async def _do_deploy(self): + """Find and deploy MCV.""" + if self.deploy_issued: + return + + units = await self.call("get_units") + mcv = next((u for u in units if u["type"] == "mcv"), None) + if mcv: + self._log(f"Deploying MCV (actor {mcv['actor_id']})") + await self.call("deploy_unit", unit_id=mcv["actor_id"]) + self.deploy_issued = True + + # ── Build base ──────────────────────────────────────────────── + + async def _do_build(self): + """Handle building construction and placement.""" + # Check for completed buildings to place + production = await self.call("get_production") + buildings = await self.call("get_buildings") + + for p in production.get("queue", []): + if p["queue_type"] == "Building" and p["progress"] >= 0.99: + cy = next((b for b in buildings if b["type"] == "fact"), None) + if cy: + x, y = self._placement_offset(cy) + self._log(f"Placing {p['item']} at ({x}, {y})") + await self.call("place_building", building_type=p["item"], cell_x=x, cell_y=y) + self.placement_count += 1 + + # Start new building if nothing in queue + if self.build_index >= len(self.BUILD_ORDER): + return + + building_in_queue = any(p["queue_type"] == "Building" for p in production.get("queue", [])) + if building_in_queue: + return + + item = self.BUILD_ORDER[self.build_index] + # Resolve faction-agnostic barracks + if item == "barracks": + available = production.get("available", []) + if "tent" in available: + item = "tent" + elif "barr" in available: + item = "barr" + else: + return + + # Check if already built + already = sum(1 for b in buildings if b["type"] == item) + if already > 0 and self.build_index < len(self.BUILD_ORDER) - 1: + # Skip if not a duplicate in build order + count_in_order = sum(1 for x in self.BUILD_ORDER[:self.build_index + 1] + if x == item or (x == "barracks" and item in self.BARRACKS_TYPES)) + if already >= count_in_order: + self.build_index += 1 + return + + available = production.get("available", []) + if item in available: + economy = await self.call("get_economy") + building_stats = await self.call("lookup_building", building_type=item) + cost = building_stats.get("cost", 0) + if economy.get("cash", 0) >= cost: + self._log(f"Building {item} (#{self.build_index + 1}/{len(self.BUILD_ORDER)}, cost=${cost})") + await self.call("build_structure", building_type=item) + self.build_index += 1 + + # Set rally points on production buildings + await self._do_rally_points(buildings) + + async def _do_rally_points(self, buildings: list[dict]): + """Set rally points on barracks and war factories.""" + cy = next((b for b in buildings if b["type"] == "fact"), None) + if not cy: + return + + for b in buildings: + if b["type"] in ("tent", "barr", "weap") and b["actor_id"] not in self._rally_set: + rally_x = cy["cell_x"] if cy["cell_x"] > 0 else cy.get("pos_x", 0) // 1024 + rally_y = cy["cell_y"] if cy["cell_y"] > 0 else cy.get("pos_y", 0) // 1024 + self._log(f"Setting rally on {b['type']} (actor {b['actor_id']}) → ({rally_x}, {rally_y})") + await self.call("set_rally_point", building_id=b["actor_id"], cell_x=rally_x, cell_y=rally_y) + self._rally_set.add(b["actor_id"]) + + def _placement_offset(self, cy: dict) -> tuple[int, int]: + """Calculate placement position relative to CY.""" + cx = cy.get("pos_x", 0) // 1024 if cy.get("cell_x", 0) == 0 else cy["cell_x"] + cy_y = cy.get("pos_y", 0) // 1024 if cy.get("cell_y", 0) == 0 else cy["cell_y"] + offsets = [ + (3, 0), (-3, 0), (0, 3), (0, -3), + (3, 3), (-3, 3), (3, -3), (-3, -3), + (6, 0), (-6, 0), (0, 6), (0, -6), + ] + idx = self.placement_count % len(offsets) + dx, dy = offsets[idx] + return cx + dx, cy_y + dy + + # ── Train army ──────────────────────────────────────────────── + + async def _do_train(self): + """Train infantry and vehicles.""" + production = await self.call("get_production") + buildings = await self.call("get_buildings") + units = await self.call("get_units") + economy = await self.call("get_economy") + + has_barracks = any(b["type"] in self.BARRACKS_TYPES for b in buildings) + infantry_training = any( + p["queue_type"] == "Infantry" and p["progress"] < 0.99 + for p in production.get("queue", []) + ) + infantry = [u for u in units if u["type"] in self.INFANTRY_TYPES] + total_target = self.INFANTRY_TARGET + self.GUARD_COUNT + + # Train infantry + if has_barracks and not infantry_training and len(infantry) < total_target: + available = production.get("available", []) + if "e1" in available and economy.get("cash", 0) >= 100: + self._log(f"Training e1 ({len(infantry)}/{total_target})") + await self.call("build_unit", unit_type="e1") + + # Train APC from war factory + has_weap = any(b["type"] == "weap" for b in buildings) + vehicle_training = any( + p["queue_type"] == "Vehicle" and p["progress"] < 0.99 + for p in production.get("queue", []) + ) + if has_weap and not vehicle_training and not self._apc_trained: + available = production.get("available", []) + if "apc" in available and economy.get("cash", 0) >= 800: + self._log("Training APC") + await self.call("build_unit", unit_type="apc") + self._apc_trained = True + + # Continuous vehicle production in attack phase + if self.phase == "attack" and has_weap and not vehicle_training: + available = production.get("available", []) + if "1tnk" in available and economy.get("cash", 0) >= 700: + self._log("Training 1tnk (continuous)") + await self.call("build_unit", unit_type="1tnk") + + # Set stances on new units + for u in units: + if u["actor_id"] in self._stances_set: + continue + if u["type"] not in self.COMBAT_TYPES: + continue + stance = "defend" if u["actor_id"] in self._guards_assigned else "attack_anything" + await self.call("set_stance", unit_ids=str(u["actor_id"]), stance=stance) + self._stances_set.add(u["actor_id"]) + + # Assign guards to CY + if len(self._guards_assigned) < self.GUARD_COUNT: + cy = next((b for b in buildings if b["type"] == "fact"), None) + if cy: + for u in units: + if len(self._guards_assigned) >= self.GUARD_COUNT: + break + if (u["type"] in self.INFANTRY_TYPES + and u["is_idle"] + and u["actor_id"] not in self._guards_assigned): + self._log(f"Assigning {u['type']} (actor {u['actor_id']}) to guard CY") + await self.call("guard_target", unit_ids=str(u["actor_id"]), target_actor_id=cy["actor_id"]) + self._guards_assigned.add(u["actor_id"]) + + # Set primary on multiple production buildings + for btype_set in [self.BARRACKS_TYPES, self.WAR_FACTORY_TYPES]: + bldgs_of_type = [b for b in buildings if b["type"] in btype_set] + if len(bldgs_of_type) >= 2: + newest = max(bldgs_of_type, key=lambda b: b["actor_id"]) + if newest["actor_id"] not in self._primary_set: + self._log(f"Setting primary: {newest['type']} (actor {newest['actor_id']})") + await self.call("set_primary", building_id=newest["actor_id"]) + self._primary_set.add(newest["actor_id"]) + + # ── Combat ──────────────────────────────────────────────────── + + async def _do_combat(self): + """Attack-move idle combat units toward enemies.""" + units = await self.call("get_units") + enemies = await self.call("get_enemies") + + idle_fighters = [ + u for u in units + if (u["type"] in self.COMBAT_TYPES + and u["is_idle"] + and u["actor_id"] not in self._guards_assigned) + ] + + if len(idle_fighters) < 2: + return + + # Find attack target + target_x, target_y = self._find_attack_target(enemies, units) + + unit_id_list = [u["actor_id"] for u in idle_fighters] + unit_ids = ",".join(str(i) for i in unit_id_list) + self._log(f"Attacking with {len(unit_id_list)} units toward ({target_x}, {target_y})") + await self.call("attack_move", unit_ids=unit_ids, target_x=target_x, target_y=target_y) + + # Attack specific visible enemy if close + if enemies.get("units"): + enemy = enemies["units"][0] + nearby = [u for u in idle_fighters[:3] if u["can_attack"]] + if nearby: + nearby_ids = ",".join(str(u["actor_id"]) for u in nearby) + await self.call( + "attack_target", + unit_ids=nearby_ids, + target_actor_id=enemy["actor_id"], + ) + + def _find_attack_target(self, enemies: dict, units: list[dict]) -> tuple[int, int]: + """Find best attack target: enemy buildings > units > map center.""" + if enemies.get("buildings"): + b = enemies["buildings"][0] + return b["cell_x"], b["cell_y"] + if enemies.get("units"): + u = enemies["units"][0] + return u["cell_x"], u["cell_y"] + return 64, 64 # fallback: map center + + # ── Sustain ─────────────────────────────────────────────────── + + async def _do_sustain(self): + """Repair, sell, and manage power.""" + buildings = await self.call("get_buildings") + economy = await self.call("get_economy") + + for b in buildings: + # Repair damaged buildings + if (b["hp_percent"] < 0.7 + and not b.get("is_repairing", False) + and b["actor_id"] not in self._repair_issued + and economy.get("cash", 0) >= 500): + self._log(f"Repairing {b['type']} (actor {b['actor_id']}, hp={b['hp_percent']:.0%})") + await self.call("repair_building", building_id=b["actor_id"]) + self._repair_issued.add(b["actor_id"]) + + # Sell heavily damaged buildings + if (b["hp_percent"] < 0.2 + and b["type"] != "fact" + and b["actor_id"] not in self._sold): + self._log(f"Selling {b['type']} (actor {b['actor_id']}, hp={b['hp_percent']:.0%})") + await self.call("sell_building", building_id=b["actor_id"]) + self._sold.add(b["actor_id"]) + + # Power management + power_balance = economy.get("power_provided", 0) - economy.get("power_drained", 0) + if power_balance < 0: + power_down_priority = ["dome", "spen", "syrd", "hpad", "afld", "fix"] + for btype in power_down_priority: + for b in buildings: + if (b["type"] == btype + and b.get("is_powered", True) + and b["actor_id"] not in self._powered_down): + self._log(f"Powering down {b['type']} (actor {b['actor_id']}) — power: {power_balance}") + await self.call("power_down", building_id=b["actor_id"]) + self._powered_down.add(b["actor_id"]) + return # one at a time + + # Send idle harvesters to harvest + units = await self.call("get_units") + for u in units: + if u["type"] == "harv" and u["is_idle"]: + self._log(f"Sending harvester {u['actor_id']} to harvest") + await self.call("harvest", unit_id=u["actor_id"]) + break # one at a time + + # Stop fleeing units + fleeing = [u for u in units if u["type"] in self.COMBAT_TYPES + and u.get("current_activity") == "Flee"] + if fleeing: + await self.call("stop_units", unit_ids=",".join(str(u["actor_id"]) for u in fleeing[:3])) + + # Move scouts + idle_scouts = [u for u in units + if u["type"] in ("jeep", "e1") and u["is_idle"] + and u["actor_id"] not in self._guards_assigned] + if idle_scouts and len(idle_scouts) > 3: + scout = idle_scouts[0] + await self.call("move_units", unit_ids=str(scout["actor_id"]), target_x=64, target_y=64) + + # ── Status display ──────────────────────────────────────────── + + def _print_status(self, turn: int, state: dict): + eco = state.get("economy", {}) + power = eco.get("power_provided", 0) - eco.get("power_drained", 0) + print( + f"Turn {turn:4d} | Tick {state.get('tick', 0):5d} | " + f"${eco.get('cash', 0):5d} | Pwr:{power:+d} | " + f"Units:{state.get('own_units', 0)} | " + f"Enemy:{state.get('visible_enemies', 0)} | " + f"Bldgs:{state.get('own_buildings', 0)} | {self.phase}" + ) + + +# ── Main ────────────────────────────────────────────────────────── + + +async def run_mcp_bot(url: str, max_turns: int, verbose: bool, no_planning: bool = False): + """Connect to the OpenRA-RL server and play using MCP tools.""" + print(f"Connecting to {url}...") + + async with OpenRAMCPClient(base_url=url, message_timeout_s=300.0) as env: + print("Resetting environment (launching OpenRA)...") + await env.reset() + + # Discover available tools + tools = await env.list_tools() + tool_names = sorted(t.name for t in tools) + print(f"Discovered {len(tools)} MCP tools: {tool_names}") + + # Run bot + bot = MCPBot(env, verbose=verbose, no_planning=no_planning) + result = bot.run(max_turns) + if asyncio.iscoroutine(result): + result = await result + + # Final report + print() + print("=" * 70) + final = result["final_state"] + print(f"Game finished after {result['turns']} turns") + if final.get("done"): + print(f"Result: {final.get('result', '?').upper()}") + + # Score card + mil = final.get("military", {}) + eco = final.get("economy", {}) + planning = result.get("planning_strategy", "") + print() + print("--- SCORECARD ---") + print(f" Planning: {'ON — ' + planning if planning else 'OFF'}") + print(f" Ticks played: {final.get('tick', '?')}") + print(f" Units killed: {mil.get('units_killed', 0)} (value: ${mil.get('kills_cost', 0)})") + print(f" Units lost: {mil.get('units_lost', 0)} (value: ${mil.get('deaths_cost', 0)})") + print(f" Buildings killed: {mil.get('buildings_killed', 0)}") + print(f" Buildings lost: {mil.get('buildings_lost', 0)}") + print(f" Army value: ${mil.get('army_value', 0)}") + print(f" Assets value: ${mil.get('assets_value', 0)}") + print(f" Experience: {mil.get('experience', 0)}") + print(f" Orders issued: {mil.get('order_count', 0)}") + print(f" Cash remaining: ${eco.get('cash', 0)}") + print(f" K/D cost ratio: {mil.get('kills_cost', 0) / max(mil.get('deaths_cost', 1), 1):.2f}") + print() + + print(f"Tools exercised: {result['tools_count']}/{len(tools)}") + print(f" {result['tools_exercised']}") + if result.get("replay", {}).get("path"): + print(f"Replay: {result['replay']['path']}") + print("=" * 70) + + +def main(): + parser = argparse.ArgumentParser(description="MCP tool-based Red Alert bot") + parser.add_argument( + "--url", + default="http://localhost:8000", + help="OpenRA-RL server URL (default: http://localhost:8000)", + ) + parser.add_argument( + "--max-turns", + type=int, + default=3000, + help="Maximum turns before stopping (default: 3000)", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print detailed bot decisions", + ) + parser.add_argument( + "--no-planning", + action="store_true", + help="Disable planning phase (for comparison runs)", + ) + args = parser.parse_args() + + try: + asyncio.run(run_mcp_bot(args.url, args.max_turns, args.verbose, no_planning=args.no_planning)) + except KeyboardInterrupt: + print("\nInterrupted by user") + sys.exit(0) + except ConnectionRefusedError: + print(f"\nCould not connect to {args.url}") + print("Is the OpenRA-RL server running?") + print(" docker run -p 8000:8000 openra-rl") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/scripted_bot.py b/examples/scripted_bot.py new file mode 100644 index 0000000000000000000000000000000000000000..23d6c029387bbbc4bdb88e0e55961c4687edf896 --- /dev/null +++ b/examples/scripted_bot.py @@ -0,0 +1,831 @@ +#!/usr/bin/env python3 +"""Scripted Red Alert bot that plays a full game via the OpenEnv client API. + +Exercises ALL Sprint 4+5 observation fields and action types: + - Observations: spatial_map, visible_enemy_buildings, unit facing/stance/speed/ + attack_range/experience/passengers, building cell coords/can_produce/power/ + rally/repair/sell_value + - Actions: all 20 types including GUARD, SET_STANCE, ENTER_TRANSPORT, UNLOAD, + SET_RALLY_POINT, REPAIR, SELL, POWER_DOWN, SET_PRIMARY + +Usage: + docker run -p 8000:8000 openra-rl + python examples/scripted_bot.py --verbose +""" + +import argparse +import asyncio +import base64 +import sys +from typing import List, Optional, Tuple + +from openra_env.client import OpenRAEnv +from openra_env.models import ( + ActionType, + BuildingInfoModel, + CommandModel, + OpenRAAction, + OpenRAObservation, + UnitInfoModel, +) + +# Stance constants matching C# AutoTarget.UnitStance enum +STANCE_HOLD_FIRE = 0 +STANCE_RETURN_FIRE = 1 +STANCE_DEFEND = 2 +STANCE_ATTACK_ANYTHING = 3 + +STANCE_NAMES = {0: "HoldFire", 1: "ReturnFire", 2: "Defend", 3: "AttackAnything"} + + +class ScriptedBot: + """State-machine bot with a Red Alert build order exercising all actions. + + Phases: + deploy_mcv - Deploy MCV, set stance on starting units + build_base - Build power/barracks/war factory, set rally points + train_army - Train infantry + APC, guard CY, load transport + attack - Attack-move toward enemy buildings, unload APC + sustain - Continuous production, repair, sell damaged buildings + """ + + # Build order uses both faction names — bot picks whichever is available + BARRACKS_TYPES = {"tent", "barr"} # Allied / Soviet + WAR_FACTORY_TYPES = {"weap"} + BUILD_PRIORITY = [ + "powr", # Power Plant ($300) — shared + "barracks", # Placeholder: tent (Allied) or barr (Soviet) + "proc", # Ore Refinery ($2000) — needed before war factory + "weap", # War Factory ($2000) — shared + "powr", # Second Power Plant + ] + + INFANTRY_TRAIN_TARGET = 6 + GUARD_COUNT = 2 # infantry to guard CY + TRANSPORT_TYPE = "apc" + COMBAT_UNIT_TYPES = {"e1", "e2", "e3", "e4", "1tnk", "2tnk", "3tnk", "arty", "jeep", "apc"} + INFANTRY_TYPES = {"e1", "e2", "e3", "e4"} + VEHICLE_TYPES = {"1tnk", "2tnk", "3tnk", "arty", "jeep"} + + def __init__(self, verbose: bool = False): + self.phase = "deploy_mcv" + self.build_index = 0 + self.placement_count = 0 + self.deploy_issued = False + self.verbose = verbose + self._guards_assigned: set[int] = set() # actor IDs guarding CY + self._stances_set: set[int] = set() # actor IDs with stance already set + self._rally_set: set[int] = set() # building actor IDs with rally point set + self._apc_trained = False + self._apc_loaded = False + self._repair_issued: set[int] = set() # building actor IDs being repaired + self._sold: set[int] = set() # building actor IDs sold + self._powered_down: set[int] = set() # building actor IDs powered down + self._primary_set: set[int] = set() # building actor IDs set as primary + + def decide(self, obs: OpenRAObservation) -> OpenRAAction: + """Given current observation, return commands for this tick.""" + commands: List[CommandModel] = [] + + self._update_phase(obs) + + # Priority 1: Place completed buildings + commands.extend(self._handle_placement(obs)) + + # Priority 2: Deploy MCV + if self.phase == "deploy_mcv": + cmd = self._handle_deploy(obs) + if cmd: + commands.append(cmd) + + # Priority 3: Set rally points on production buildings + commands.extend(self._handle_rally_points(obs)) + + # Priority 4: Power management (power down buildings if power negative) + commands.extend(self._handle_power_management(obs)) + + # Priority 5: Set primary production buildings + commands.extend(self._handle_set_primary(obs)) + + # Priority 6: Repair damaged buildings + commands.extend(self._handle_repairs(obs)) + + # Priority 7: Queue production (buildings + units) + commands.extend(self._handle_production(obs)) + + # Priority 8: Set stances on new units + commands.extend(self._handle_stances(obs)) + + # Priority 9: Assign guards to CY + commands.extend(self._handle_guards(obs)) + + # Priority 10: Load infantry into APC + commands.extend(self._handle_transport(obs)) + + # Priority 11: Combat — attack + unload + commands.extend(self._handle_combat(obs)) + + # Priority 12: Sell heavily damaged buildings + commands.extend(self._handle_sell(obs)) + + if not commands: + commands.append(CommandModel(action=ActionType.NO_OP)) + + return OpenRAAction(commands=commands) + + # ── Phase transitions ────────────────────────────────────────── + + def _update_phase(self, obs: OpenRAObservation): + has_cy = any(b.type == "fact" for b in obs.buildings) + has_barracks = any(b.type in self.BARRACKS_TYPES for b in obs.buildings) + combat_units = [u for u in obs.units if u.type in self.COMBAT_UNIT_TYPES] + non_guard_combat = [u for u in combat_units if u.actor_id not in self._guards_assigned] + + if self.phase == "deploy_mcv" and has_cy: + self.phase = "build_base" + self._log("Phase → build_base") + elif self.phase == "build_base" and self.build_index >= len(self.BUILD_PRIORITY): + self.phase = "train_army" + self._log("Phase → train_army") + elif self.phase == "train_army" and len(non_guard_combat) >= self.INFANTRY_TRAIN_TARGET: + self.phase = "attack" + self._log(f"Phase → attack ({len(non_guard_combat)} combat units ready)") + elif self.phase == "attack" and has_barracks: + # Stay in attack but also sustain production + pass + + # ── Deploy MCV ───────────────────────────────────────────────── + + def _handle_deploy(self, obs: OpenRAObservation) -> Optional[CommandModel]: + if self.deploy_issued: + return None + mcv = next((u for u in obs.units if u.type == "mcv"), None) + if mcv: + self.deploy_issued = True + self._log(f"Deploying MCV (actor {mcv.actor_id}, facing={mcv.facing})") + return CommandModel(action=ActionType.DEPLOY, actor_id=mcv.actor_id) + return None + + # ── Building placement ───────────────────────────────────────── + + def _handle_placement(self, obs: OpenRAObservation) -> List[CommandModel]: + commands = [] + cy = self._find_building(obs, "fact") + if not cy: + return commands + + for prod in obs.production: + if prod.queue_type == "Building" and prod.progress >= 0.99: + x, y = self._placement_offset(cy) + self._log(f"Placing {prod.item} at cell ({x}, {y}) [attempt {self.placement_count}]") + commands.append(CommandModel( + action=ActionType.PLACE_BUILDING, + item_type=prod.item, + target_x=x, + target_y=y, + )) + self.placement_count += 1 + return commands + + def _placement_offset(self, cy: BuildingInfoModel) -> Tuple[int, int]: + """Calculate placement position relative to CY using cell coords.""" + # Use pos_x // 1024 as CenterPosition maps to cell more reliably + cx = cy.pos_x // 1024 + cy_y = cy.pos_y // 1024 + # Many offsets to maximize chance of finding valid terrain + offsets = [ + (3, 0), (-3, 0), (0, 3), (0, -3), + (3, 3), (-3, 3), (3, -3), (-3, -3), + (6, 0), (-6, 0), (0, 6), (0, -6), + (2, 0), (-2, 0), (0, 2), (0, -2), + (4, 0), (-4, 0), (0, 4), (0, -4), + ] + idx = self.placement_count % len(offsets) + dx, dy = offsets[idx] + return cx + dx, cy_y + dy + + # ── Rally points (Sprint 4 action) ───────────────────────────── + + def _handle_rally_points(self, obs: OpenRAObservation) -> List[CommandModel]: + commands = [] + cy = self._find_building(obs, "fact") + if not cy: + return commands + + # Set rally point on barracks and war factory toward CY + for b in obs.buildings: + if b.type in ("tent", "weap") and b.actor_id not in self._rally_set: + rally_x = cy.cell_x if cy.cell_x > 0 else cy.pos_x // 1024 + rally_y = cy.cell_y if cy.cell_y > 0 else cy.pos_y // 1024 + self._log(f"Setting rally on {b.type} (actor {b.actor_id}) → ({rally_x}, {rally_y})") + commands.append(CommandModel( + action=ActionType.SET_RALLY_POINT, + actor_id=b.actor_id, + target_x=rally_x, + target_y=rally_y, + )) + self._rally_set.add(b.actor_id) + return commands + + # ── Power management (Sprint 5 action) ───────────────────────── + + def _handle_power_management(self, obs: OpenRAObservation) -> List[CommandModel]: + """Power down non-essential buildings when power balance is negative.""" + commands = [] + power_balance = obs.economy.power_provided - obs.economy.power_drained + if power_balance >= 0: + return commands + + # Power down radar/tech buildings first (keep production running) + POWER_DOWN_PRIORITY = ["dome", "spen", "syrd", "hpad", "afld", "fix"] + for btype in POWER_DOWN_PRIORITY: + for b in obs.buildings: + if b.type == btype and b.is_powered and b.actor_id not in self._powered_down: + commands.append(CommandModel(action=ActionType.POWER_DOWN, actor_id=b.actor_id)) + self._powered_down.add(b.actor_id) + self._log(f"Powering down {b.type} (actor {b.actor_id}) — power balance: {power_balance}") + return commands # one at a time + return commands + + # ── Set primary building (Sprint 5 action) ─────────────────── + + def _handle_set_primary(self, obs: OpenRAObservation) -> List[CommandModel]: + """Set primary on newest production building of each type.""" + commands = [] + for btype_set in [self.BARRACKS_TYPES, self.WAR_FACTORY_TYPES]: + buildings_of_type = [b for b in obs.buildings if b.type in btype_set] + if len(buildings_of_type) >= 2: + newest = max(buildings_of_type, key=lambda b: b.actor_id) + if newest.actor_id not in self._primary_set: + commands.append(CommandModel(action=ActionType.SET_PRIMARY, actor_id=newest.actor_id)) + self._primary_set.add(newest.actor_id) + self._log(f"Setting primary: {newest.type} (actor {newest.actor_id})") + return commands + + # ── Repair damaged buildings (Sprint 4 observation + existing action) ── + + def _handle_repairs(self, obs: OpenRAObservation) -> List[CommandModel]: + commands = [] + for b in obs.buildings: + if (b.hp_percent < 0.7 + and not b.is_repairing + and b.actor_id not in self._repair_issued + and obs.economy.cash >= 500): + self._log(f"Repairing {b.type} (actor {b.actor_id}, hp={b.hp_percent:.0%})") + commands.append(CommandModel( + action=ActionType.REPAIR, + actor_id=b.actor_id, + )) + self._repair_issued.add(b.actor_id) + return commands + + # ── Production ───────────────────────────────────────────────── + + def _handle_production(self, obs: OpenRAObservation) -> List[CommandModel]: + commands = [] + + # Building construction — treat any Building queue item as "in progress" + # (includes completed-but-unplaced buildings that block the queue) + building_in_queue = any( + p.queue_type == "Building" + for p in obs.production + ) + if not building_in_queue and self.build_index < len(self.BUILD_PRIORITY): + item_type = self._resolve_build_item(obs, self.BUILD_PRIORITY[self.build_index]) + if item_type is None: + # Can't resolve this item yet, skip + pass + elif self._has_building_type(obs, item_type, self.build_index): + self.build_index += 1 + elif self._can_produce_item(obs, item_type): + self._log(f"Building {item_type} (#{self.build_index + 1}/{len(self.BUILD_PRIORITY)})") + commands.append(CommandModel(action=ActionType.BUILD, item_type=item_type)) + self.build_index += 1 + + # Infantry training + has_barracks = any(b.type in self.BARRACKS_TYPES for b in obs.buildings) + infantry_training = any( + p.queue_type == "Infantry" and p.progress < 0.99 + for p in obs.production + ) + infantry = [u for u in obs.units if u.type in self.INFANTRY_TYPES] + total_target = self.INFANTRY_TRAIN_TARGET + self.GUARD_COUNT + + if has_barracks and not infantry_training and len(infantry) < total_target: + if self._can_produce_item(obs, "e1") and obs.economy.cash >= 100: + self._log(f"Training e1 ({len(infantry)}/{total_target})") + commands.append(CommandModel(action=ActionType.TRAIN, item_type="e1")) + + # APC from war factory + has_weap = any(b.type == "weap" for b in obs.buildings) + vehicle_training = any( + p.queue_type == "Vehicle" and p.progress < 0.99 + for p in obs.production + ) + if (has_weap and not vehicle_training and not self._apc_trained + and self._can_produce_item(obs, self.TRANSPORT_TYPE) + and obs.economy.cash >= 800): + self._log("Training APC for transport ops") + commands.append(CommandModel(action=ActionType.TRAIN, item_type=self.TRANSPORT_TYPE)) + self._apc_trained = True + + # Continuous vehicle production in attack phase + if (self.phase == "attack" and has_weap and not vehicle_training + and obs.economy.cash >= 800): + # Build light tanks if available + if self._can_produce_item(obs, "1tnk"): + self._log("Training 1tnk (continuous production)") + commands.append(CommandModel(action=ActionType.TRAIN, item_type="1tnk")) + + return commands + + def _can_produce_item(self, obs: OpenRAObservation, item_type: str) -> bool: + """Check if item is buildable using per-building can_produce (Sprint 4).""" + # First check global available_production + if item_type in obs.available_production: + return True + # Also check per-building can_produce lists + for b in obs.buildings: + if item_type in b.can_produce: + return True + return False + + # ── Stances (Sprint 4 action) ────────────────────────────────── + + def _handle_stances(self, obs: OpenRAObservation) -> List[CommandModel]: + commands = [] + for u in obs.units: + if u.actor_id in self._stances_set: + continue + if u.type not in self.COMBAT_UNIT_TYPES: + continue + + # Guards get Defend stance, attackers get AttackAnything + if u.actor_id in self._guards_assigned: + desired = STANCE_DEFEND + else: + desired = STANCE_ATTACK_ANYTHING + + if u.stance != desired: + self._log( + f"Setting {u.type} (actor {u.actor_id}) stance: " + f"{STANCE_NAMES.get(u.stance, '?')} → {STANCE_NAMES[desired]}" + ) + commands.append(CommandModel( + action=ActionType.SET_STANCE, + actor_id=u.actor_id, + target_x=desired, + )) + self._stances_set.add(u.actor_id) + return commands + + # ── Guard CY (Sprint 4 action) ──────────────────────────────── + + def _handle_guards(self, obs: OpenRAObservation) -> List[CommandModel]: + commands = [] + if len(self._guards_assigned) >= self.GUARD_COUNT: + return commands + + cy = self._find_building(obs, "fact") + if not cy: + return commands + + # Find idle infantry not yet guarding + for u in obs.units: + if len(self._guards_assigned) >= self.GUARD_COUNT: + break + if (u.type in self.INFANTRY_TYPES + and u.is_idle + and u.actor_id not in self._guards_assigned): + self._log( + f"Assigning {u.type} (actor {u.actor_id}, " + f"range={u.attack_range}) to guard CY" + ) + commands.append(CommandModel( + action=ActionType.GUARD, + actor_id=u.actor_id, + target_actor_id=cy.actor_id, + )) + self._guards_assigned.add(u.actor_id) + return commands + + # ── Transport: load/unload (Sprint 4 actions) ───────────────── + + def _handle_transport(self, obs: OpenRAObservation) -> List[CommandModel]: + commands = [] + if self._apc_loaded: + return commands + + apc = next( + (u for u in obs.units + if u.type == self.TRANSPORT_TYPE and u.passenger_count == 0), + None, + ) + if not apc: + return commands + + # Load idle infantry (not guards) into the APC + loaded = 0 + for u in obs.units: + if loaded >= 4: # APC capacity + break + if (u.type in self.INFANTRY_TYPES + and u.is_idle + and u.actor_id not in self._guards_assigned): + self._log( + f"Loading {u.type} (actor {u.actor_id}, " + f"speed={u.speed}) into APC {apc.actor_id}" + ) + commands.append(CommandModel( + action=ActionType.ENTER_TRANSPORT, + actor_id=u.actor_id, + target_actor_id=apc.actor_id, + )) + loaded += 1 + + if loaded > 0: + self._apc_loaded = True + return commands + + # ── Combat ───────────────────────────────────────────────────── + + def _handle_combat(self, obs: OpenRAObservation) -> List[CommandModel]: + commands = [] + if self.phase != "attack": + return commands + + # Unload APC near enemy + commands.extend(self._handle_unload(obs)) + + # Attack-move idle fighters toward enemy + idle_fighters = [ + u for u in obs.units + if (u.type in self.COMBAT_UNIT_TYPES + and u.is_idle + and u.actor_id not in self._guards_assigned) + ] + + if len(idle_fighters) < 2: + return commands + + target_x, target_y = self._find_attack_target(obs) + + for unit in idle_fighters: + commands.append(CommandModel( + action=ActionType.ATTACK_MOVE, + actor_id=unit.actor_id, + target_x=target_x, + target_y=target_y, + )) + + if idle_fighters: + self._log( + f"Attacking with {len(idle_fighters)} units " + f"toward ({target_x}, {target_y})" + ) + return commands + + def _handle_unload(self, obs: OpenRAObservation) -> List[CommandModel]: + """Unload APC when near enemies.""" + commands = [] + for u in obs.units: + if u.type != self.TRANSPORT_TYPE or u.passenger_count <= 0: + continue + + # Check if any enemy is within ~15 cells + for enemy in obs.visible_enemies: + dx = abs(u.cell_x - enemy.cell_x) + dy = abs(u.cell_y - enemy.cell_y) + if dx + dy < 15: + self._log( + f"Unloading APC (actor {u.actor_id}, " + f"{u.passenger_count} passengers) near enemy" + ) + commands.append(CommandModel( + action=ActionType.UNLOAD, + actor_id=u.actor_id, + )) + break + + # Also unload near enemy buildings + for eb in obs.visible_enemy_buildings: + dx = abs(u.cell_x - eb.cell_x) + dy = abs(u.cell_y - eb.cell_y) + if dx + dy < 15: + self._log( + f"Unloading APC near enemy building {eb.type} " + f"(hp={eb.hp_percent:.0%})" + ) + commands.append(CommandModel( + action=ActionType.UNLOAD, + actor_id=u.actor_id, + )) + break + return commands + + def _find_attack_target(self, obs: OpenRAObservation) -> Tuple[int, int]: + """Prioritize enemy buildings > enemy units > map center.""" + # Priority 1: visible enemy buildings (Sprint 4 field) + if obs.visible_enemy_buildings: + # Prefer production buildings + prod_buildings = [ + b for b in obs.visible_enemy_buildings + if b.type in ("fact", "tent", "weap", "hpad", "afld") + ] + target = prod_buildings[0] if prod_buildings else obs.visible_enemy_buildings[0] + return target.cell_x, target.cell_y + + # Priority 2: visible enemy units + if obs.visible_enemies: + enemy = obs.visible_enemies[0] + return enemy.cell_x, enemy.cell_y + + # Fallback: map center + if obs.map_info.width > 0: + return obs.map_info.width // 2, obs.map_info.height // 2 + return 64, 64 + + # ── Sell heavily damaged buildings ───────────────────────────── + + def _handle_sell(self, obs: OpenRAObservation) -> List[CommandModel]: + commands = [] + for b in obs.buildings: + if (b.hp_percent < 0.2 + and b.type != "fact" # never sell CY + and b.actor_id not in self._sold): + self._log( + f"Selling {b.type} (actor {b.actor_id}, hp={b.hp_percent:.0%}, " + f"refund=${b.sell_value})" + ) + commands.append(CommandModel( + action=ActionType.SELL, + actor_id=b.actor_id, + )) + self._sold.add(b.actor_id) + return commands + + # ── Helpers ──────────────────────────────────────────────────── + + def _resolve_build_item(self, obs: OpenRAObservation, placeholder: str) -> Optional[str]: + """Resolve faction-agnostic build item to actual producible type.""" + if placeholder == "barracks": + # Find which barracks type is available + for btype in self.BARRACKS_TYPES: + if self._can_produce_item(obs, btype): + return btype + return None + return placeholder + + def _has_building_type(self, obs: OpenRAObservation, item_type: str, build_index: int) -> bool: + """Check if we already have enough of this building type.""" + already_built = sum(1 for b in obs.buildings if b.type == item_type) + # Count how many times this item appears up to current index + resolved_order = [] + for i, p in enumerate(self.BUILD_PRIORITY[:build_index + 1]): + if p == "barracks": + resolved_order.append(item_type if item_type in self.BARRACKS_TYPES else p) + else: + resolved_order.append(p) + target_count = resolved_order.count(item_type) + return already_built >= target_count + + def _find_building(self, obs: OpenRAObservation, btype: str) -> Optional[BuildingInfoModel]: + return next((b for b in obs.buildings if b.type == btype), None) + + def _log(self, msg: str): + if self.verbose: + print(f" [Bot] {msg}") + + +# ── Status display ───────────────────────────────────────────────── + + +def print_status(step: int, obs: OpenRAObservation, bot: ScriptedBot): + """Print a rich status line using Sprint 4 observation fields.""" + combat = [u for u in obs.units if u.type in bot.COMBAT_UNIT_TYPES] + buildings = ", ".join(sorted(set(b.type for b in obs.buildings))) or "none" + power_balance = obs.economy.power_provided - obs.economy.power_drained + + # Count enemy intel + enemy_units = len(obs.visible_enemies) + enemy_buildings = len(obs.visible_enemy_buildings) + + print( + f"Step {step:4d} | Tick {obs.tick:5d} | " + f"${obs.economy.cash:5d} | Pwr:{power_balance:+d} | " + f"Units:{len(obs.units)} (combat:{len(combat)}) | " + f"Enemy:{enemy_units}u/{enemy_buildings}b | " + f"Bldgs:[{buildings}] | {bot.phase}" + ) + + +def print_detailed_status(obs: OpenRAObservation): + """Print full observation details using all Sprint 4 fields.""" + print("\n── Detailed Observation ──") + + # Spatial map + if obs.spatial_channels > 0 and obs.spatial_map: + raw_bytes = base64.b64decode(obs.spatial_map) + w, h = obs.map_info.width, obs.map_info.height + expected_bytes = w * h * obs.spatial_channels * 4 + print( + f" Spatial: {w}x{h} map, {obs.spatial_channels} channels, " + f"{len(raw_bytes)} bytes (expected {expected_bytes})" + ) + else: + print(" Spatial: not populated") + + # Economy + e = obs.economy + print( + f" Economy: ${e.cash} cash, {e.ore} ore, " + f"power {e.power_provided}/{e.power_drained} " + f"({e.power_provided - e.power_drained:+d}), " + f"{e.harvester_count} harvesters" + ) + + # Production queue + if obs.production: + print(f" Production queue ({len(obs.production)}):") + for p in obs.production: + print(f" {p.queue_type}: {p.item} @ {p.progress:.0%} (paused={p.paused})") + if obs.available_production: + print(f" Available production: {', '.join(obs.available_production[:15])}") + else: + print(" Available production: (none)") + + # Own buildings with Sprint 4 fields + print(f" Buildings ({len(obs.buildings)}):") + for b in obs.buildings: + extras = [] + if b.power_amount != 0: + extras.append(f"pwr={b.power_amount:+d}") + if b.is_producing: + extras.append(f"producing={b.producing_item}@{b.production_progress:.0%}") + if b.is_repairing: + extras.append("REPAIRING") + if b.rally_x >= 0: + extras.append(f"rally=({b.rally_x},{b.rally_y})") + if b.can_produce: + extras.append(f"can_produce=[{','.join(b.can_produce[:5])}{'...' if len(b.can_produce) > 5 else ''}]") + extra_str = f" ({', '.join(extras)})" if extras else "" + print( + f" {b.type:6s} #{b.actor_id:4d} " + f"cell=({b.cell_x},{b.cell_y}) " + f"hp={b.hp_percent:.0%} " + f"sell=${b.sell_value}{extra_str}" + ) + + # Own units with Sprint 4 fields + print(f" Units ({len(obs.units)}):") + for u in obs.units[:10]: # cap at 10 for readability + stance_name = STANCE_NAMES.get(u.stance, f"?{u.stance}") + extras = [] + if u.experience_level > 0: + extras.append(f"vet={u.experience_level}") + if u.passenger_count >= 0: + extras.append(f"cargo={u.passenger_count}") + extra_str = f" ({', '.join(extras)})" if extras else "" + print( + f" {u.type:6s} #{u.actor_id:4d} " + f"cell=({u.cell_x},{u.cell_y}) " + f"hp={u.hp_percent:.0%} " + f"face={u.facing:4d} spd={u.speed:3d} " + f"rng={u.attack_range:5d} " + f"stance={stance_name} " + f"{'IDLE' if u.is_idle else u.current_activity}{extra_str}" + ) + if len(obs.units) > 10: + print(f" ... and {len(obs.units) - 10} more") + + # Visible enemies + if obs.visible_enemies: + print(f" Visible enemy units ({len(obs.visible_enemies)}):") + for u in obs.visible_enemies[:5]: + print( + f" {u.type:6s} #{u.actor_id:4d} " + f"cell=({u.cell_x},{u.cell_y}) hp={u.hp_percent:.0%} " + f"spd={u.speed} rng={u.attack_range}" + ) + + # Visible enemy buildings (Sprint 4 field) + if obs.visible_enemy_buildings: + print(f" Visible enemy buildings ({len(obs.visible_enemy_buildings)}):") + for b in obs.visible_enemy_buildings[:5]: + print( + f" {b.type:6s} #{b.actor_id:4d} " + f"cell=({b.cell_x},{b.cell_y}) hp={b.hp_percent:.0%} " + f"pwr={b.power_amount:+d}" + ) + + +# ── Main loop ────────────────────────────────────────────────────── + + +async def run_bot(url: str, max_steps: int, verbose: bool): + """Connect to the OpenRA-RL server and play one full game.""" + print(f"Connecting to {url}...") + bot = ScriptedBot(verbose=verbose) + + async with OpenRAEnv(base_url=url, message_timeout_s=300.0) as env: + print("Resetting environment...") + result = await env.reset() + obs = result.observation + print(f"Game started! Map: {obs.map_info.map_name} ({obs.map_info.width}x{obs.map_info.height})") + + # Print initial detailed status + if verbose: + print_detailed_status(obs) + + print_status(0, obs, bot) + + step = 0 + total_reward = 0.0 + + while not result.done and step < max_steps: + action = bot.decide(result.observation) + result = await env.step(action) + step += 1 + total_reward += result.reward or 0.0 + obs = result.observation + + if step % 100 == 0: + print_status(step, obs, bot) + + # Detailed dump at key milestones + if verbose and step in (50, 200, 500, 1000): + print_detailed_status(obs) + + # Final report + print() + print("=" * 70) + obs = result.observation + if obs.done: + print(f"GAME OVER: {obs.result.upper()} after {step} steps (tick {obs.tick})") + else: + print(f"Reached max steps ({max_steps}) at tick {obs.tick}") + + print(f"Total reward: {total_reward:.3f}") + print(f"Final cash: ${obs.economy.cash}") + print(f"Power balance: {obs.economy.power_provided - obs.economy.power_drained:+d}") + print(f"Units killed: {obs.military.units_killed}") + print(f"Units lost: {obs.military.units_lost}") + print(f"Buildings killed: {obs.military.buildings_killed}") + print(f"Buildings lost: {obs.military.buildings_lost}") + print(f"Army value: ${obs.military.army_value}") + print(f"Own buildings: {len(obs.buildings)}") + print(f"Visible enemies: {len(obs.visible_enemies)} units, {len(obs.visible_enemy_buildings)} buildings") + + # Spatial map stats + if obs.spatial_channels > 0 and obs.spatial_map: + raw_bytes = base64.b64decode(obs.spatial_map) + n_floats = len(raw_bytes) // 4 + print(f"Spatial map: {n_floats} floats ({obs.spatial_channels} channels)") + else: + print("Spatial map: not populated") + + # Show veteran units + vets = [u for u in obs.units if u.experience_level > 0] + if vets: + print(f"Veterans: {', '.join(f'{u.type}#{u.actor_id}(lvl{u.experience_level})' for u in vets)}") + + if verbose: + print_detailed_status(obs) + + print("=" * 70) + + +def main(): + parser = argparse.ArgumentParser(description="Scripted Red Alert bot via OpenEnv") + parser.add_argument( + "--url", + default="http://localhost:8000", + help="OpenRA-RL server URL (default: http://localhost:8000)", + ) + parser.add_argument( + "--max-steps", + type=int, + default=5000, + help="Maximum steps before stopping (default: 5000)", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print detailed bot decisions and observation dumps", + ) + args = parser.parse_args() + + try: + asyncio.run(run_bot(args.url, args.max_steps, args.verbose)) + except KeyboardInterrupt: + print("\nInterrupted by user") + sys.exit(0) + except ConnectionRefusedError: + print(f"\nCould not connect to {args.url}") + print("Is the OpenRA-RL server running?") + print(" docker run -p 8000:8000 openra-rl") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..dd96532d5797c9f6cf89e5c7fd181ff3609e55c5 --- /dev/null +++ b/models.py @@ -0,0 +1,7 @@ +"""OpenEnv models re-export.""" + +from openra_env.models import ( # noqa: F401 + OpenRAAction, + OpenRAObservation, + OpenRAState, +) diff --git a/openenv.yaml b/openenv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ddeed7da09c1161028f991a3427181f2df10a2d --- /dev/null +++ b/openenv.yaml @@ -0,0 +1,6 @@ +spec_version: 1 +name: openra_env +type: space +runtime: fastapi +app: openra_env.server.app:app +port: 8000 diff --git a/openra_env/__init__.py b/openra_env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67820d468dabf65a27adf8ae7cf51fcfcd40c167 --- /dev/null +++ b/openra_env/__init__.py @@ -0,0 +1,6 @@ +"""OpenRA-RL: Reinforcement Learning Environment for the OpenRA RTS Engine.""" + +from openra_env.client import OpenRAEnv +from openra_env.models import OpenRAAction, OpenRAObservation, OpenRAState + +__all__ = ["OpenRAEnv", "OpenRAAction", "OpenRAObservation", "OpenRAState"] diff --git a/openra_env/agent.py b/openra_env/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..7a60a167bc013035b5d6d07f5556b6b057dfa408 --- /dev/null +++ b/openra_env/agent.py @@ -0,0 +1,1156 @@ +"""LLM agent that plays Red Alert using any OpenAI-compatible model. + +Supports OpenRouter, Ollama, LM Studio, or any local/remote endpoint +that implements the OpenAI Chat Completions API with tool calling. +""" + +import asyncio +import json +import logging +import time + +from collections import defaultdict + +import httpx +from openra_env.config import LLMConfig +from openra_env.game_data import get_building_stats, get_faction_info, get_tech_tree, get_unit_stats +from openra_env.mcp_ws_client import OpenRAMCPClient + +logger = logging.getLogger("llm_agent") + + +def _looks_like_tool_capability_error(error_text: str) -> bool: + """Best-effort detection of provider errors indicating no tool support.""" + text = error_text.lower() + # Only match phrases that unambiguously refer to tool-calling capability. + # "no endpoints found" is too generic on its own — guard it with "tool". + if "no endpoints found" in text and "tool" in text: + return True + markers = ( + "support tool use", + "does not support tool", + "tool calling", + "tools are not supported", + ) + return any(m in text for m in markers) + + +def _bench_export_policy(encountered_agent_error: bool) -> tuple[bool, bool, str]: + """Decide whether bench export and upload should run for this match. + + Returns: + (should_export, should_upload, reason) + Local export always happens (useful for debugging). + Upload is skipped when runtime errors occurred. + """ + if encountered_agent_error: + return True, False, "runtime [ERROR] occurred during the match" + return True, True, "" + + +def _format_llm_api_error(status_code: int, error_text: str, llm_config: LLMConfig) -> str: + """Map raw provider errors to clear, actionable runtime messages.""" + error_lower = error_text.lower() + + if status_code in (401, 403): + return ( + f"Authentication failed ({status_code}). " + "Check your API key: openra-rl config" + ) + + if status_code == 400 and "model" in error_lower: + return ( + f"Invalid model ID '{llm_config.model}'. " + "Update with: openra-rl config" + ) + + if status_code == 429: + return "Rate limited by LLM provider. Wait a minute and retry." + + if status_code == 404 and _looks_like_tool_capability_error(error_text): + is_openrouter = "openrouter.ai" in llm_config.base_url.lower() + if is_openrouter: + return ( + f"Model '{llm_config.model}' has no OpenRouter route that supports tool calling. " + "OpenRA-RL requires tool-calling models. " + "Use a tool-capable model/route (often not ':free'), or use Ollama " + "with qwen3:32b or qwen3:4b." + ) + return ( + f"Model '{llm_config.model}' does not support tool calling on this endpoint. " + "OpenRA-RL requires tool-calling models." + ) + + return f"LLM API error {status_code}: {error_text}" + + +async def _preflight_tool_calling_support(llm_config: LLMConfig) -> tuple[bool, str]: + """Check OpenRouter model route support for tool calling before game start. + + Returns: + (True, "") when preflight passes or does not apply. + (False, reason) when preflight confirms tools are unsupported. + """ + if "openrouter.ai" not in llm_config.base_url.lower(): + return True, "" + + preflight_cfg = llm_config.model_copy( + update={ + "max_tokens": 1, + "request_timeout_s": min(llm_config.request_timeout_s, 30.0), + } + ) + preflight_messages = [ + {"role": "user", "content": "Tool-calling preflight check. Reply briefly."}, + ] + preflight_tools = [ + { + "type": "function", + "function": { + "name": "preflight_ping", + "description": "Preflight-only tool for capability check.", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + try: + await chat_completion(preflight_messages, preflight_tools, preflight_cfg, verbose=False, prompts=None) + return True, "" + except RuntimeError as e: + msg = str(e) + if _looks_like_tool_capability_error(msg): + return False, msg + raise + + +def _load_default_prompt() -> str: + """Load the default system prompt shipped with the package.""" + from openra_env.prompts import load_default_prompt + return load_default_prompt() + + +# Public constant for backward compatibility (lazy-loaded on first access) +SYSTEM_PROMPT = _load_default_prompt() + + +def load_system_prompt(config) -> str: + """Resolve system prompt from config: inline > file > default. + + Priority: + 1. config.prompts.system_prompt (inline string) + 2. config.prompts.system_prompt_file (path to .txt file) + 3. config.agent.system_prompt (deprecated, backward compat) + 4. config.agent.system_prompt_file (deprecated, backward compat) + 5. Built-in default (openra_env/prompts/default.txt) + """ + from pathlib import Path + + # Check prompts.* first (canonical location) + prompts_cfg = getattr(config, "prompts", None) + if prompts_cfg: + if getattr(prompts_cfg, "system_prompt", ""): + return prompts_cfg.system_prompt + prompt_file = getattr(prompts_cfg, "system_prompt_file", "") + if prompt_file: + p = Path(prompt_file).expanduser() + if p.is_file(): + return p.read_text(encoding="utf-8").strip() + raise FileNotFoundError(f"system_prompt_file not found: {p}") + + # Backward compat: check agent.* (deprecated) + agent_cfg = config.agent if hasattr(config, "agent") else config + if getattr(agent_cfg, "system_prompt", ""): + return agent_cfg.system_prompt + prompt_file = getattr(agent_cfg, "system_prompt_file", "") + if prompt_file: + p = Path(prompt_file).expanduser() + if p.is_file(): + return p.read_text(encoding="utf-8").strip() + raise FileNotFoundError(f"system_prompt_file not found: {p}") + + # Default + return SYSTEM_PROMPT + + +def compose_pregame_briefing(state: dict) -> str: + """Compose a strategic briefing from initial game state + static game data. + + Sent once at game start so the LLM knows map, base position, faction, tech tree, + and available units/buildings without needing extra tool calls. + """ + map_info = state.get("map", {}) + map_w = map_info.get("width", 0) + map_h = map_info.get("height", 0) + map_name = map_info.get("map_name", "?") + + # Determine base position from buildings/units + buildings = state.get("buildings_summary", []) + units = state.get("units_summary", []) + all_positions = [(b["cell_x"], b["cell_y"]) for b in buildings] + \ + [(u["cell_x"], u["cell_y"]) for u in units] + if all_positions: + base_x = sum(p[0] for p in all_positions) // len(all_positions) + base_y = sum(p[1] for p in all_positions) // len(all_positions) + else: + base_x, base_y = map_w // 2, map_h // 2 + + # Estimate enemy spawn — opposite side of map + enemy_x = max(2, min(map_w - 2, map_w - base_x)) + enemy_y = max(2, min(map_h - 2, map_h - base_y)) + + # Determine faction and side + faction = state.get("faction", "") + allied_factions = {"england", "france", "germany"} + soviet_factions = {"russia", "ukraine"} + if faction in allied_factions: + side = "Allied" + barracks = "tent" + elif faction in soviet_factions: + side = "Soviet" + barracks = "barr" + else: + # Infer from available production or buildings + avail = state.get("available_production", []) + bldg_types = state.get("building_types", []) + if "tent" in avail or "tent" in bldg_types: + side, barracks = "Allied", "tent" + else: + side, barracks = "Soviet", "barr" + + # Get tech tree — returns {side: [order]} dict + tech = get_tech_tree(side.lower()) + tech_order = tech.get(side.lower(), tech.get("build_order", [])) + + # Get faction info for available units/buildings + faction_info = get_faction_info(faction) if faction else get_faction_info(side.lower()) + avail_units = faction_info.get("available_units", []) if faction_info else [] + avail_buildings = faction_info.get("available_buildings", []) if faction_info else [] + + # Format key units with costs + unit_lines = [] + for utype in avail_units[:12]: # Cap at 12 to keep concise + stats = get_unit_stats(utype) + if stats: + unit_lines.append(f" {utype}: {stats['name']} — ${stats['cost']}, {stats.get('category', '?')}") + + # Format key buildings with costs and power + bldg_lines = [] + for btype in avail_buildings[:10]: + stats = get_building_stats(btype) + if stats: + power = stats.get("power", 0) + power_str = f", {power:+d} power" if power else "" + bldg_lines.append(f" {btype}: {stats['name']} — ${stats['cost']}{power_str}") + + # Calculate defense direction + dx = enemy_x - base_x + dy = enemy_y - base_y + dir_parts = [] + if dy < -map_h // 6: + dir_parts.append("North") + elif dy > map_h // 6: + dir_parts.append("South") + if dx > map_w // 6: + dir_parts.append("East") + elif dx < -map_w // 6: + dir_parts.append("West") + defense_direction = "".join(dir_parts) if dir_parts else "Center" + + parts = [ + "## Strategic Briefing", + f"Map: {map_name} ({map_w}x{map_h})", + f"Your faction: {faction or side} ({side})", + f"Your base: ({base_x}, {base_y})", + f"Enemy likely near: ({enemy_x}, {enemy_y})", + f"Enemy approach direction: {defense_direction}", + "", + f"Tech tree: {' → '.join(tech_order[:8])}{'...' if len(tech_order) > 8 else ''}", + f"Barracks type: {barracks}", + "", + "Available units:", + *unit_lines, + "", + "Available buildings:", + *bldg_lines, + ] + return "\n".join(parts) + + +def format_state_briefing(state: dict) -> str: + """Format game state (from get_game_state tool) into a compact turn briefing with positions.""" + if not isinstance(state, dict) or "tick" not in state: + return "" + + eco = state.get("economy", {}) + tick = state["tick"] + cash = eco.get("cash", 0) + ore = eco.get("ore", 0) + funds = cash + ore + + parts = [ + f"--- TURN BRIEFING (tick {tick}, ~{tick // 25}s game time) ---", + f"Funds: ${funds} (cash=${cash} + ore=${ore}) | Power: {state.get('power_balance', 0):+d} | Harvesters: {eco.get('harvester_count', 0)} | Explored: {state.get('explored_percent', 0)}%", + ] + + # Minimap (ASCII spatial overview) + minimap = state.get("minimap", "") + if minimap: + parts.append(minimap) + + # Base center from buildings + buildings = state.get("buildings_summary", []) + if buildings: + base_x = sum(b["cell_x"] for b in buildings) // len(buildings) + base_y = sum(b["cell_y"] for b in buildings) // len(buildings) + parts.append(f"Base center: ({base_x},{base_y})") + + # Compact unit summary grouped by type, with IDs, positions, and activity + units = state.get("units_summary", []) + if units: + by_type = defaultdict(list) + idle_ids = [] + for u in units: + by_type[u["type"]].append(u) + if u.get("idle") and u.get("can_attack"): + idle_ids.append(u["id"]) + unit_parts = [] + for utype, us in by_type.items(): + entries = [] + for u in us: + pos = f"{u['id']}@({u['cell_x']},{u['cell_y']})" + if u.get("target_x") is not None: + pos += f"→({u['target_x']},{u['target_y']})" + elif not u.get("idle"): + # Show short activity tag for non-idle units without tracked target + act = u.get("activity", "") + if act and act not in ("Idle", "Unknown", "Wait"): + tag = act[:3].lower() + pos += f"→{tag}" + entries.append(pos) + unit_parts.append(f"{len(us)}x{utype}[{','.join(entries)}]") + line = f"Units: {' '.join(unit_parts)}" + if idle_ids: + line += f" | Idle: [{','.join(str(i) for i in idle_ids)}]" + parts.append(line) + else: + parts.append(f"Units: {state.get('own_units', '?')}") + + # Compact building summary with IDs, positions, and production category + _BLDG_CATEGORY = {"tent": "infantry", "barr": "infantry", "weap": "vehicle", + "hpad": "aircraft", "afld": "aircraft", "syrd": "ship", "spen": "ship", + "gun": "defense", "ftur": "defense", "tsla": "defense", + "sam": "defense", "agun": "defense", "pbox": "defense", "hbox": "defense"} + if buildings: + bldg_parts = [] + for b in buildings: + cat = _BLDG_CATEGORY.get(b["type"], "") + cat_str = f"[{cat}]" if cat else "" + bldg_parts.append(f"{b['type']}({b['id']})@({b['cell_x']},{b['cell_y']}){cat_str}") + parts.append(f"Buildings: {' '.join(bldg_parts)}") + else: + parts.append(f"Buildings: {state.get('own_buildings', '?')} ({', '.join(state.get('building_types', []))})") + + # Enemy summary with IDs and positions (units + buildings) + enemies = state.get("enemy_summary", []) + enemy_bldgs = state.get("enemy_buildings_summary", []) + if enemies or enemy_bldgs: + enemy_parts = [] + if enemies: + eby_type = defaultdict(list) + for e in enemies: + eby_type[e["type"]].append(e) + for etype, es in eby_type.items(): + entries = ",".join(f"{e['id']}@({e['cell_x']},{e['cell_y']})" for e in es) + enemy_parts.append(f"{len(es)}x{etype}[{entries}]") + if enemy_bldgs: + ebby_type = defaultdict(list) + for b in enemy_bldgs: + ebby_type[b["type"]].append(b) + for btype, bs in ebby_type.items(): + entries = ",".join(f"{b['id']}@({b['cell_x']},{b['cell_y']})" for b in bs) + enemy_parts.append(f"{len(bs)}x{btype}[{entries}]") + # Average position of all visible enemies + all_enemy_pos = ( + [(e["cell_x"], e["cell_y"]) for e in enemies] + + [(b["cell_x"], b["cell_y"]) for b in enemy_bldgs] + ) + avg_x = sum(p[0] for p in all_enemy_pos) // len(all_enemy_pos) + avg_y = sum(p[1] for p in all_enemy_pos) // len(all_enemy_pos) + parts.append(f"Enemies: {' '.join(enemy_parts)} center ({avg_x},{avg_y})") + else: + n_enemy = state.get("visible_enemy_units", 0) + parts.append(f"Enemies: {'none visible' if n_enemy == 0 else f'{n_enemy} visible'}") + + prod = state.get("production_items", []) + if prod: + active = [p for p in prod if "@100%" not in p] + ready = [p.split("@")[0] for p in prod if "@100%" in p] + parts_prod = [] + if active: + parts_prod.append(", ".join(active)) + if ready: + parts_prod.append(f"READY TO PLACE: {', '.join(ready)}") + parts.append(f"Production: {' | '.join(parts_prod)}") + else: + parts.append("Production: IDLE") + + available = state.get("available_production", []) + if available: + parts.append(f"Can build: {', '.join(available)}") + + alerts = state.get("alerts", []) + if alerts: + parts.append("ALERTS:") + for a in alerts: + parts.append(f" ** {a}") + + parts.append("---") + + if state.get("done"): + parts.append(f"GAME OVER: {state.get('result', '?')}") + + return "\n".join(parts) + + +def mcp_tools_to_openai(tools: list) -> list[dict]: + """Convert MCP Tool schemas to OpenAI function calling format.""" + result = [] + for tool in tools: + schema = tool.input_schema if hasattr(tool, 'input_schema') else {} + # Clean up schema — remove 'title' which confuses some models + params = dict(schema) if schema else {} + params.pop("title", None) + if "properties" not in params: + params["properties"] = {} + params["type"] = "object" + + result.append({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": params, + }, + }) + return result + + +def _sanitize_messages(messages: list[dict], prompts=None) -> list[dict]: + """Merge consecutive same-role messages for strict-alternation models (e.g. Mistral). + + Some models require strict user/assistant alternation and reject sequences + like ``user → user`` or ``tool → user``. This helper: + 1. Merges consecutive ``user`` messages by joining their content with newlines. + 2. Inserts a bridge ``assistant`` message when a ``tool`` result is followed + by a ``user`` message (Mistral requires tool → assistant → user). + """ + if not messages: + return messages + + bridge = prompts.sanitize_bridge if prompts else "Acknowledged. Continuing." + merged: list[dict] = [dict(messages[0])] + for msg in messages[1:]: + prev = merged[-1] + # Merge consecutive user messages + if msg["role"] == "user" and prev["role"] == "user": + merged[-1] = {**prev, "content": prev["content"] + "\n\n" + msg["content"]} + continue + # Bridge: tool → user needs an assistant message in between + if msg["role"] == "user" and prev["role"] == "tool": + merged.append({"role": "assistant", "content": bridge}) + merged.append(msg) + return merged + + +async def chat_completion( + messages: list[dict], + tools: list[dict], + llm_config: LLMConfig, + verbose: bool = False, + prompts=None, +) -> dict: + """Call an OpenAI-compatible chat completions API. + + Works with OpenRouter, Ollama, LM Studio, or any endpoint + implementing the OpenAI Chat Completions spec with tool calling. + """ + clean_messages = _sanitize_messages(messages, prompts=prompts) + payload = { + "model": llm_config.model, + "messages": clean_messages, + "max_tokens": llm_config.max_tokens, + } + if tools: + payload["tools"] = tools + payload["tool_choice"] = "auto" + if llm_config.temperature is not None: + payload["temperature"] = llm_config.temperature + if llm_config.top_p is not None: + payload["top_p"] = llm_config.top_p + if llm_config.reasoning_effort is not None: + payload["reasoning"] = {"effort": llm_config.reasoning_effort} + + headers = dict(llm_config.extra_headers) + if llm_config.api_key: + headers["Authorization"] = f"Bearer {llm_config.api_key}" + + async with httpx.AsyncClient() as client: + if verbose: + n_msgs = len(clean_messages) + roles = [m.get("role", "?") for m in clean_messages] + print(f" [LLM] Sending {n_msgs} messages to {llm_config.model}...") + print(f" [LLM] Roles: {' → '.join(roles)}") + + response = await client.post( + llm_config.base_url, + headers=headers, + json=payload, + timeout=llm_config.request_timeout_s, + ) + + if response.status_code != 200: + error_text = response.text[:2000] + raise RuntimeError( + _format_llm_api_error(response.status_code, error_text, llm_config) + ) + + try: + data = response.json() + except (json.JSONDecodeError, ValueError) as e: + raise RuntimeError(f"LLM API error 502: invalid JSON response ({e})") + + if "error" in data: + raise RuntimeError(f"LLM API error 500: {data['error']}") + + if verbose: + usage = data.get("usage", {}) + print( + f" [LLM] Response: {usage.get('prompt_tokens', '?')} prompt + " + f"{usage.get('completion_tokens', '?')} completion tokens" + ) + + return data + + +def compress_history(messages: list[dict], keep_last: int = 40, + trigger: int = 0, prompts=None, compression=None) -> list[dict]: + """Compress message history to stay within context limits. + + Keeps the system prompt and the last ``keep_last`` messages, replacing + earlier messages with a state-aware summary that preserves critical + game context (buildings, economy, strategy, military, errors). + + Args: + keep_last: Number of recent messages to keep after compression. + trigger: Compress when total messages exceed this threshold. + 0 (default) means ``keep_last * 2``. + prompts: PromptsConfig for customizable text. + compression: CompressionConfig controlling what to include in summary. + """ + threshold = trigger if trigger > 0 else keep_last * 2 + if len(messages) <= threshold: + return messages + + system = messages[0] + # Find a clean cut point: recent must not start with tool role + cut = len(messages) - keep_last + while cut < len(messages) and messages[cut].get("role") == "tool": + cut += 1 # move cut forward to skip orphaned tool results + if cut >= len(messages) - 2: + return messages # can't compress safely + + old_messages = messages[1:cut] + recent = messages[cut:] + + # Compression config defaults + inc_strategy = compression.include_strategy if compression else True + inc_military = compression.include_military if compression else True + inc_production = compression.include_production if compression else True + + # Extract game state context from old messages + last_state = {} + building_types = set() + unit_types_produced = set() + strategy_text = "" + errors = [] + + for msg in old_messages: + # Extract planning strategy from early user messages + if inc_strategy and msg.get("role") == "user" and not strategy_text: + content_str = msg.get("content", "") + if isinstance(content_str, str): + for line in content_str.split("\n"): + if line.strip().startswith("Strategy:"): + strategy_text = line.strip() + break + + if msg.get("role") != "tool": + continue + try: + content = json.loads(msg["content"]) if isinstance(msg["content"], str) else msg["content"] + if not isinstance(content, dict): + continue + + # Track latest state snapshot + if "tick" in content and "economy" in content: + last_state = content + + # Track buildings built + for bt in content.get("building_types", []): + building_types.add(bt) + + # Track units produced (from build_unit notes) + if inc_production and "note" in content: + note = content["note"] + if isinstance(note, str) and "queued" in note: + # Extract unit/building name from "'name' ... queued" + import re + m = re.search(r"'(\w+)'.*queued", note) + if m: + name = m.group(1) + # Distinguish units from buildings + if "per unit" in note or "each" in note: + unit_types_produced.add(name) + else: + building_types.add(name) + + # Track placement failures and errors + if content.get("placement_failed"): + errors.append("placement failed") + elif "error" in content and isinstance(content["error"], str): + err = content["error"] + if len(err) < 80: + errors.append(err) + except (json.JSONDecodeError, TypeError): + pass + + # Build summary + parts = [f"[History: {len(old_messages)} earlier messages removed]"] + + if last_state: + eco = last_state.get("economy", {}) + parts.append( + f"Last state at tick {last_state.get('tick', '?')}: " + f"${eco.get('cash', '?')} cash, " + f"{last_state.get('own_units', '?')} units, " + f"{last_state.get('own_buildings', '?')} buildings" + ) + + if inc_strategy and strategy_text: + parts.append(strategy_text) + + if building_types: + parts.append(f"Buildings built: {', '.join(sorted(building_types))}") + + if inc_production and unit_types_produced: + parts.append(f"Units produced: {', '.join(sorted(unit_types_produced))}") + + if inc_military and last_state: + mil = last_state.get("military", {}) + if mil: + parts.append( + f"Military: {mil.get('units_killed', 0)} kills, " + f"{mil.get('units_lost', 0)} losses" + ) + + if errors: + unique = list(dict.fromkeys(errors))[-3:] + parts.append(f"Recent issues: {'; '.join(unique)}") + + suffix = prompts.compression_suffix if prompts else "Game continues from current state." + parts.append(suffix) + + return [ + system, + {"role": "user", "content": "\n".join(parts)}, + *recent, + ] + + +async def run_agent(config, verbose: bool = False): + """Connect to OpenRA-RL and play a game using an LLM agent.""" + url = config.agent.server_url + llm_config = config.llm + max_turns = config.agent.max_turns + max_time = config.agent.max_time_s + + # Auto-increase timeout for local models (they're slower than cloud APIs) + is_local = any(h in llm_config.base_url for h in ("localhost", "127.0.0.1")) + if is_local and llm_config.request_timeout_s <= 120.0: + llm_config = llm_config.model_copy(update={"request_timeout_s": 300.0}) + + print(f"Connecting to {url}...") + print(f"Model: {llm_config.model} @ {llm_config.base_url}") + if is_local: + print(f"Timeout: {int(llm_config.request_timeout_s)}s (local model)") + + if "openrouter.ai" in llm_config.base_url.lower(): + print("Checking model route for tool-calling support...") + try: + preflight_ok, preflight_err = await _preflight_tool_calling_support(llm_config) + except Exception as e: + print(f" [ERROR] Preflight check failed: {e}") + print(" Aborting before game launch (no match started).") + return + if not preflight_ok: + print(f" [ERROR] Preflight check failed: {preflight_err}") + print(" Aborting before game launch (no match started).") + return + + async with OpenRAMCPClient(base_url=url, message_timeout_s=300.0) as env: + print("Resetting environment (launching OpenRA)...") + await env.reset() + + # Discover and convert tools + mcp_tools = await env.list_tools() + openai_tools = mcp_tools_to_openai(mcp_tools) + tool_names = {t["function"]["name"] for t in openai_tools} + print(f"Discovered {len(mcp_tools)} MCP tools") + + if verbose: + for t in mcp_tools: + print(f" - {t.name}: {t.description[:60]}...") + + # Initialize conversation + system_prompt = load_system_prompt(config) + messages = [{"role": "system", "content": system_prompt}] + + # ─── Pre-Game Planning Phase ────────────────────────────────── + planning_strategy = "" + planning_status = await env.call_tool("get_planning_status") + + if planning_status.get("planning_enabled", True) is not False: + print("Starting pre-game planning phase...") + planning_data = await env.call_tool("start_planning_phase") + + if planning_data.get("planning_active"): + max_planning_turns = planning_data.get("max_turns", 10) + opponent_summary = planning_data.get("opponent_summary", "") + + prompts = config.prompts + planning_prompt = prompts.planning_prompt.format( + max_turns=max_planning_turns, + map_name=planning_data.get("map", {}).get("map_name", "?"), + map_width=planning_data.get("map", {}).get("width", "?"), + map_height=planning_data.get("map", {}).get("height", "?"), + base_x=planning_data.get("base_position", {}).get("x", "?"), + base_y=planning_data.get("base_position", {}).get("y", "?"), + enemy_x=planning_data.get("enemy_estimated_position", {}).get("x", "?"), + enemy_y=planning_data.get("enemy_estimated_position", {}).get("y", "?"), + faction=planning_data.get("your_faction", "?"), + side=planning_data.get("your_side", "?"), + opponent_summary=opponent_summary, + planning_nudge=prompts.planning_nudge, + ) + messages.append({"role": "user", "content": planning_prompt}) + + # Planning loop (bounded by max_planning_turns + margin) + planning_done = False + for planning_turn in range(max_planning_turns + 2): + try: + response = await chat_completion(messages, openai_tools, llm_config, verbose, prompts=config.prompts) + except (RuntimeError, httpx.ReadTimeout, httpx.ConnectTimeout) as e: + print(f" [Planning] API error: {e}") + print(" Skipping planning phase.") + break + if response is None: + break + + choice = response["choices"][0] + assistant_msg = choice["message"] + messages.append(assistant_msg) + + if verbose and assistant_msg.get("content"): + print(f" [Planning] {assistant_msg['content'][:200]}") + + tool_calls = assistant_msg.get("tool_calls", []) + if not tool_calls: + messages.append({ + "role": "user", + "content": prompts.planning_nudge, + }) + continue + + for tc in tool_calls: + fn_name = tc["function"]["name"] + try: + fn_args = json.loads(tc["function"].get("arguments", "{}")) + except (json.JSONDecodeError, TypeError): + fn_args = {} + + if verbose: + args_str = json.dumps(fn_args) + if len(args_str) > 80: + args_str = args_str[:80] + "..." + print(f" [Planning Tool] {fn_name}({args_str})") + + try: + result = await env.call_tool(fn_name, **fn_args) + except Exception as e: + result = {"error": str(e)} + + messages.append({ + "role": "tool", + "tool_call_id": tc["id"], + "content": json.dumps(result) if not isinstance(result, str) else result, + }) + + # Check if planning ended + if isinstance(result, dict): + if result.get("planning_complete"): + planning_strategy = result.get("strategy", "") + planning_done = True + if verbose: + print(f" [Planning] Strategy: {planning_strategy[:150]}...") + elif result.get("planning_expired"): + planning_strategy = result.get("strategy", "") + planning_done = True + print(f" [Planning] Expired: {result.get('reason', '?')}") + + if planning_done: + break + + if not planning_done: + # Force end planning + try: + result = await env.call_tool( + "end_planning_phase", + strategy="(planning timed out, no explicit strategy)" + ) + planning_strategy = result.get("strategy", "") + except Exception: + pass + print(" Planning phase timed out, proceeding to gameplay.") + + print(f"Planning phase complete. Strategy recorded: {bool(planning_strategy)}") + else: + if verbose: + print(f" Planning: {planning_data.get('message', 'skipped')}") + + # ─── Game Start ─────────────────────────────────────────────── + # Reset messages to just system prompt — planning context is captured + # in the strategy text below. This avoids tool/user role alternation + # issues with models that enforce strict message ordering (e.g. Mistral). + messages = [messages[0]] # keep only system prompt + + state = await env.call_tool("get_game_state") + briefing = compose_pregame_briefing(state) + + strategy_section = "" + if planning_strategy: + strategy_section = f"\n\n## Your Pre-Game Strategy\n{planning_strategy}\n" + + # Find MCV unit ID and barracks type for context + mcv_id = None + for u in state.get("units_summary", []): + if u.get("type") == "mcv": + mcv_id = u["id"] + break + faction = state.get("faction", "") + barracks_type = "tent" if faction in {"england", "france", "germany"} else "barr" + + mcv_note = f" Your MCV is unit {mcv_id}." if mcv_id else "" + + game_start_prompts = config.prompts + messages.append({ + "role": "user", + "content": game_start_prompts.game_start.format( + strategy_section=strategy_section, + briefing=briefing, + barracks_type=barracks_type, + mcv_note=mcv_note, + ), + }) + + total_tool_calls = 0 + total_api_calls = 0 + start_time = time.time() + game_done = False + encountered_agent_error = False + consecutive_errors = 0 + MAX_CONSECUTIVE_ERRORS = 3 + + turn = 0 + while True: + # Check limits + elapsed = time.time() - start_time + if max_time and elapsed >= max_time: + print(f"\n TIME LIMIT reached ({max_time}s). Stopping.") + break + if max_turns and turn >= max_turns: + break + turn += 1 + + # Compress history periodically (unless disabled) + if llm_config.compression_strategy != "none": + messages = compress_history( + messages, keep_last=llm_config.keep_last_messages, + trigger=llm_config.compression_trigger, + prompts=config.prompts, + compression=config.prompts.compression) + + # Inject state briefing before LLM thinks (skip first turn — initial state already provided) + if total_api_calls > 0: + try: + briefing_state = await env.call_tool("get_game_state") + briefing = format_state_briefing(briefing_state) + if briefing: + messages.append({"role": "user", "content": briefing}) + if verbose: + # Print just the alerts + for a in briefing_state.get("alerts", []): + print(f" [ALERT] {a}") + # Check game over from briefing + if isinstance(briefing_state, dict) and briefing_state.get("done"): + game_done = True + print(f"\n GAME OVER: {briefing_state.get('result', '?').upper()} at tick {briefing_state.get('tick', '?')}") + break + except Exception: + pass + + # Call LLM with retry for rate limits + response = None + max_retries = llm_config.max_retries + is_local = any(h in llm_config.base_url for h in ("localhost", "127.0.0.1")) + for attempt in range(max_retries): + try: + response = await chat_completion(messages, openai_tools, llm_config, verbose, prompts=config.prompts) + break + except (httpx.ReadTimeout, httpx.ConnectTimeout): + timeout_s = int(llm_config.request_timeout_s) + print(f"\n [ERROR] Request timed out after {timeout_s}s.") + encountered_agent_error = True + if is_local: + print(" [HINT] Local models can be slow. Increase timeout in config.yaml:") + print(f" llm.request_timeout_s: {timeout_s * 2}") + break + except RuntimeError as e: + err_str = str(e) + retriable = any(code in err_str for code in ("429", "500", "502", "503", "504")) + if retriable and attempt < max_retries - 1: + wait = llm_config.retry_backoff_s * (attempt + 1) + print(f"\n [RETRY] Provider error, waiting {wait}s ({attempt + 1}/{max_retries})...") + print(f" {e}") + await asyncio.sleep(wait) + else: + print(f"\n [ERROR] API call failed: {e}") + encountered_agent_error = True + break + if response is None: + print(" [ERROR] Stopping agent.") + encountered_agent_error = True + break + + total_api_calls += 1 + choice = response["choices"][0] + assistant_msg = choice["message"] + + # Add assistant response to history + messages.append(assistant_msg) + + # Print assistant's reasoning + if assistant_msg.get("content") and verbose: + print(f"\n [LLM thinks] {assistant_msg['content'][:200]}") + + # Handle tool calls + tool_calls = assistant_msg.get("tool_calls", []) + if not tool_calls: + # No tool calls — prompt to act + if verbose: + content = assistant_msg.get("content", "(no content)") + print(f" [LLM] No tool calls. Response: {content[:100]}") + messages.append({ + "role": "user", + "content": config.prompts.no_tool_nudge, + }) + continue + + # Execute each tool call + for tc in tool_calls: + fn_name = tc["function"]["name"] + try: + fn_args = json.loads(tc["function"].get("arguments", "{}")) + except (json.JSONDecodeError, TypeError): + fn_args = {} + + total_tool_calls += 1 + + if verbose: + args_str = json.dumps(fn_args) + if len(args_str) > 80: + args_str = args_str[:80] + "..." + print(f" [Tool] {fn_name}({args_str})") + + try: + result = await env.call_tool(fn_name, **fn_args) + consecutive_errors = 0 + except Exception as e: + result = {"error": str(e)} + # Suggest similar tools for unknown tool errors + if fn_name not in tool_names: + import difflib + close = difflib.get_close_matches(fn_name, tool_names, n=3, cutoff=0.4) + # Always include canonical build tools for build-related names + build_keywords = {"build", "place", "train", "produce", "construct"} + if any(kw in fn_name.lower() for kw in build_keywords): + for bt in ("build_unit", "build_structure", "build_and_place"): + if bt in tool_names and bt not in close: + close.append(bt) + if close: + result["suggested_tools"] = close + + # Detect game connection lost + if isinstance(result, dict) and "connection lost" in str(result.get("error", "")).lower(): + consecutive_errors += 1 + if consecutive_errors >= MAX_CONSECUTIVE_ERRORS: + print(f"\n GAME CRASHED: {consecutive_errors} consecutive connection errors. Stopping.") + encountered_agent_error = True + game_done = True + + # Format result for message + result_str = json.dumps(result) if not isinstance(result, str) else result + + messages.append({ + "role": "tool", + "tool_call_id": tc["id"], + "content": result_str, + }) + + # Check for game over + if isinstance(result, dict) and result.get("done"): + game_done = True + print(f"\n GAME OVER: {result.get('result', '?').upper()} at tick {result.get('tick', '?')}") + + if verbose and isinstance(result, dict): + result_preview = json.dumps(result) + if len(result_preview) > 500: + result_preview = result_preview[:500] + "..." + print(f" [Result] {result_preview}") + + # Status update + if total_api_calls % 5 == 0 or game_done: + elapsed = time.time() - start_time + limit_str = f"/{max_turns}" if max_turns else "" + time_str = f"{elapsed:.0f}/{max_time}s" if max_time else f"{elapsed:.0f}s" + print( + f" Turn {turn}{limit_str} | " + f"API calls: {total_api_calls} | " + f"Tool calls: {total_tool_calls} | " + f"Time: {time_str}" + ) + + if game_done: + break + + # Check finish reason + if choice.get("finish_reason") == "stop" and not tool_calls: + messages.append({ + "role": "user", + "content": config.prompts.continue_nudge, + }) + + # Surrender so the replay has a proper ending + if not game_done: + try: + await env.call_tool("surrender") + print("\n Surrendered (replay will have proper ending)") + except Exception: + pass + + # Final report + elapsed = time.time() - start_time + print() + print("=" * 70) + print(f"Agent finished after {total_api_calls} API calls, {total_tool_calls} tool calls") + print(f"Time: {elapsed:.1f}s ({elapsed / max(total_api_calls, 1):.1f}s per API call)") + + # Get final state and scorecard + try: + final = await env.call_tool("get_game_state") + mil = final.get("military", {}) + eco = final.get("economy", {}) + print(f"Result: {final.get('result', 'ongoing').upper()}") + print() + print("--- SCORECARD ---") + print(f" Planning: {'ON — ' + planning_strategy[:100] if planning_strategy else 'OFF'}") + print(f" Ticks played: {final.get('tick', '?')}") + print(f" Units killed: {mil.get('units_killed', 0)} (value: ${mil.get('kills_cost', 0)})") + print(f" Units lost: {mil.get('units_lost', 0)} (value: ${mil.get('deaths_cost', 0)})") + print(f" Buildings killed: {mil.get('buildings_killed', 0)}") + print(f" Buildings lost: {mil.get('buildings_lost', 0)}") + print(f" Army value: ${mil.get('army_value', 0)}") + print(f" Assets value: ${mil.get('assets_value', 0)}") + print(f" Experience: {mil.get('experience', 0)}") + print(f" Orders issued: {mil.get('order_count', 0)}") + print(f" Cash remaining: ${eco.get('cash', 0)}") + print(f" K/D cost ratio: {mil.get('kills_cost', 0) / max(mil.get('deaths_cost', 1), 1):.2f}") + print(f" Own units: {final.get('own_units', '?')}") + print(f" Own buildings: {final.get('own_buildings', '?')}") + print(f" Explored: {final.get('explored_percent', 0)}%") + rv = final.get("reward_vector", {}) + if rv: + print(" Reward vector:") + for dim, val in rv.items(): + print(f" {dim:15s} {val:+.3f}") + print() + except Exception as e: + print(f" (could not get final state: {e})") + + # Get replay + replay = {} + try: + replay = await env.call_tool("get_replay_path") + if replay.get("path"): + print(f"Replay: {replay['path']}") + except Exception: + pass + + # Auto-export bench submission JSON (always local, upload gated on errors) + should_export, should_upload, skip_reason = _bench_export_policy(encountered_agent_error) + try: + from datetime import datetime, timezone + from pathlib import Path + + resolved_name = config.agent.agent_name or llm_config.model + sub = { + "agent_name": resolved_name, + "agent_type": config.agent.agent_type or "LLM", + "agent_url": config.agent.agent_url, + "opponent": config.opponent.bot_type.capitalize(), + "games": 1, + "result": final.get("result", ""), + "win": final.get("result") == "win", + "ticks": final.get("tick", 0), + "kills_cost": mil.get("kills_cost", 0), + "deaths_cost": mil.get("deaths_cost", 0), + "kd_ratio": round(mil.get("kills_cost", 0) / max(mil.get("deaths_cost", 1), 1), 2), + "assets_value": mil.get("assets_value", 0), + "explored_percent": final.get("explored_percent", 0), + "reward_vector": final.get("reward_vector", {}), + "replay_path": replay.get("path", ""), + "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + } + export_dir = Path.home() / ".openra-rl" / "bench-exports" + export_dir.mkdir(parents=True, exist_ok=True) + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + slug = resolved_name.replace("/", "_")[:40] + export_path = export_dir / f"bench-{slug}-{ts}.json" + export_path.write_text(json.dumps(sub, indent=2)) + print(f"Bench export: {export_path}") + + # Auto-upload to bench if enabled (skip when agent errors occurred) + bench_url = config.agent.bench_url + if config.agent.bench_upload and bench_url: + if not should_upload: + print(f"Skipping bench upload: {skip_reason}") + else: + try: + from openra_env.bench_submit import gradio_submit + msg = gradio_submit(bench_url, sub, replay_path=replay.get("path", "")) + print(f"Uploaded to bench: {msg}") + except Exception as e: + print(f" (bench upload failed: {e})") + except Exception as e: + print(f" (bench export failed: {e})") + + print("=" * 70) diff --git a/openra_env/bench_export.py b/openra_env/bench_export.py new file mode 100644 index 0000000000000000000000000000000000000000..68f9a9d901b9053d0a2b123f32c0bf3e99409148 --- /dev/null +++ b/openra_env/bench_export.py @@ -0,0 +1,95 @@ +"""Build bench export JSON from a final game observation. + +Custom agents that use OpenRAEnv directly (CNN, RL, multi-agent, etc.) +can call build_bench_export() after their game loop to produce a bench +submission JSON — the same format the built-in LLM agent auto-exports. + +Usage: + from openra_env.bench_export import build_bench_export + + obs = await env.step(action) # final observation (obs.done == True) + export = build_bench_export( + obs, + agent_name="DeathBot-9000", + agent_type="RL", + opponent="Normal", + ) + print(f"Saved to {export['path']}") + + # Then submit: + # openra-rl bench submit +""" + +import json +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Optional + + +def build_bench_export( + obs: Any, + agent_name: str, + agent_type: str = "RL", + opponent: str = "Normal", + agent_url: str = "", + replay_path: str = "", + export_dir: Optional[Path] = None, +) -> Dict[str, Any]: + """Build and save a bench export JSON from a final observation. + + Args: + obs: Final observation — either a dict or a Pydantic model with + .military, .economy, .tick, .result, .explored_percent attributes. + agent_name: Display name for the leaderboard. + agent_type: One of "Scripted", "LLM", "RL". + opponent: Difficulty tier (Beginner/Easy/Medium/Normal/Hard). + agent_url: Optional GitHub/project URL. + replay_path: Optional path to .orarep replay file. + export_dir: Where to save the JSON (default: ~/.openra-rl/bench-exports/). + + Returns: + Dict with all submission fields plus "path" pointing to the saved file. + """ + # Normalize obs to dict + if hasattr(obs, "model_dump"): + obs_dict = obs.model_dump() + elif hasattr(obs, "__dict__") and not isinstance(obs, dict): + obs_dict = vars(obs) + else: + obs_dict = dict(obs) + + mil = obs_dict.get("military") or {} + kills = mil.get("kills_cost", 0) + deaths = mil.get("deaths_cost", 0) + + sub = { + "agent_name": agent_name, + "agent_type": agent_type, + "agent_url": agent_url, + "opponent": opponent, + "games": 1, + "result": obs_dict.get("result", ""), + "win": obs_dict.get("result") == "win", + "ticks": obs_dict.get("tick", 0), + "kills_cost": kills, + "deaths_cost": deaths, + "kd_ratio": round(kills / max(deaths, 1), 2), + "assets_value": mil.get("assets_value", 0), + "explored_percent": obs_dict.get("explored_percent", 0), + "reward_vector": obs_dict.get("reward_vector", {}), + "replay_path": replay_path, + "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + } + + # Save to disk + if export_dir is None: + export_dir = Path.home() / ".openra-rl" / "bench-exports" + export_dir.mkdir(parents=True, exist_ok=True) + + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + slug = agent_name.replace("/", "_").replace(" ", "_")[:40] + export_path = export_dir / f"bench-{slug}-{ts}.json" + export_path.write_text(json.dumps(sub, indent=2)) + + sub["path"] = str(export_path) + return sub diff --git a/openra_env/bench_submit.py b/openra_env/bench_submit.py new file mode 100644 index 0000000000000000000000000000000000000000..4bf6213c275142fe4784f92feee88d8d7eef2e0f --- /dev/null +++ b/openra_env/bench_submit.py @@ -0,0 +1,167 @@ +"""CLI tool to upload bench export JSON to OpenRA-Bench leaderboard. + +Usage: + openra-rl bench submit result.json + openra-rl bench submit result.json --agent-name DeathBot-9000 --agent-type RL + openra-rl bench submit result.json --replay game.orarep + openra-rl bench submit result.json --bench-url http://localhost:7860 +""" + +import argparse +import json +import sys +from pathlib import Path + +import httpx + +DEFAULT_BENCH_URL = "https://openra-rl-openra-bench.hf.space" + + +def _gradio_call(bench_url: str, api_name: str, payload: dict, timeout: float = 30) -> str: + """Call a Gradio SSE endpoint (two-step protocol). + + 1. POST /gradio_api/call/ → {"event_id": "..."} + 2. GET /gradio_api/call// → SSE stream + """ + base = bench_url.rstrip("/") + + resp = httpx.post( + f"{base}/gradio_api/call/{api_name}", + json=payload, + timeout=timeout, + ) + if resp.status_code != 200: + raise RuntimeError(f"HTTP {resp.status_code}: {resp.text[:200]}") + + event_id = resp.json().get("event_id") + if not event_id: + raise RuntimeError(f"No event_id in response: {resp.text[:200]}") + + with httpx.stream( + "GET", + f"{base}/gradio_api/call/{api_name}/{event_id}", + timeout=timeout, + ) as stream: + for line in stream.iter_lines(): + if line.startswith("data: "): + result = json.loads(line[6:]) + if isinstance(result, list) and result: + return result[0] + return str(result) + + raise RuntimeError("No result received from SSE stream") + + +def gradio_upload_file(bench_url: str, file_path: str, timeout: float = 30) -> dict: + """Upload a file to a Gradio app and return the file reference. + + Returns a dict like {"path": "...", "orig_name": "...", "size": ...} + that can be passed as a file input in a Gradio API call. + """ + base = bench_url.rstrip("/") + path = Path(file_path) + + with open(path, "rb") as f: + resp = httpx.post( + f"{base}/gradio_api/upload", + files={"files": (path.name, f)}, + timeout=timeout, + ) + + if resp.status_code != 200: + raise RuntimeError(f"File upload failed: HTTP {resp.status_code}: {resp.text[:200]}") + + paths = resp.json() + if not paths: + raise RuntimeError("File upload returned empty response") + + return { + "path": paths[0], + "orig_name": path.name, + "size": path.stat().st_size, + "meta": {"_type": "gradio.FileData"}, + } + + +def gradio_submit( + bench_url: str, + data: dict, + replay_path: str = "", + timeout: float = 30, +) -> str: + """Submit bench results (and optional replay) to the Gradio leaderboard. + + If replay_path points to an existing file, uploads it and uses + the submit_with_replay endpoint. Otherwise uses the JSON-only submit. + """ + if replay_path and Path(replay_path).is_file(): + file_ref = gradio_upload_file(bench_url, replay_path, timeout=timeout) + return _gradio_call( + bench_url, + "submit_with_replay", + {"data": [json.dumps(data), file_ref]}, + timeout=timeout, + ) + + return _gradio_call( + bench_url, + "submit", + {"data": [json.dumps(data)]}, + timeout=timeout, + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Upload bench export JSON to OpenRA-Bench leaderboard" + ) + parser.add_argument( + "json_file", + type=Path, + help="Path to bench export JSON file", + ) + parser.add_argument("--agent-name", default=None, help="Override agent name in the submission") + parser.add_argument("--agent-type", default=None, help="Override agent type (Scripted/LLM/RL)") + parser.add_argument("--agent-url", default=None, help="GitHub/project URL for the agent") + parser.add_argument("--replay", default=None, help="Path to .orarep replay file") + parser.add_argument( + "--bench-url", + default=DEFAULT_BENCH_URL, + help=f"Bench leaderboard URL (default: {DEFAULT_BENCH_URL})", + ) + args = parser.parse_args() + + if not args.json_file.exists(): + print(f"Error: file not found: {args.json_file}") + sys.exit(1) + + try: + data = json.loads(args.json_file.read_text()) + except json.JSONDecodeError as e: + print(f"Error: invalid JSON: {e}") + sys.exit(1) + + # Apply CLI overrides + if args.agent_name: + data["agent_name"] = args.agent_name + if args.agent_type: + data["agent_type"] = args.agent_type + if args.agent_url: + data["agent_url"] = args.agent_url + + print(f"Submitting {data.get('agent_name', '?')} vs {data.get('opponent', '?')}...") + print(f" Bench: {args.bench_url}") + + try: + msg = gradio_submit(args.bench_url, data, replay_path=args.replay or "") + print(f" {msg}") + except httpx.ConnectError: + print(f"Error: could not connect to {args.bench_url}") + sys.exit(1) + except Exception as e: + print(f"Error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/openra_env/cli/__init__.py b/openra_env/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/openra_env/cli/commands.py b/openra_env/cli/commands.py new file mode 100644 index 0000000000000000000000000000000000000000..7aabf95cb38832e4afe3f2a47a51dd9a9269e7ac --- /dev/null +++ b/openra_env/cli/commands.py @@ -0,0 +1,464 @@ +"""Subcommand implementations for the openra-rl CLI.""" + +import shutil +import subprocess +import sys +import webbrowser +from pathlib import Path +from typing import Optional + +from openra_env.cli.console import dim, error, header, info, step, success, warn +from openra_env.cli import docker_manager as docker +from openra_env.cli.wizard import ( + CONFIG_PATH, + has_saved_config, + load_saved_config, + merge_cli_into_config, + run_wizard, +) + + +def cmd_play( + provider: Optional[str] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, + difficulty: str = "normal", + verbose: bool = False, + port: int = 8000, + server_url: Optional[str] = None, + local: bool = False, + image_version: Optional[str] = None, +) -> None: + """Run the LLM agent against the game server.""" + use_docker = server_url is None and not local + + # 1. Check Docker (unless --local or --server-url) + if use_docker and not docker.check_docker(): + sys.exit(1) + + # 1b. Version selection — let user pick if multiple versions exist locally + if use_docker and image_version is None: + versions = docker.list_local_versions() + # Filter out "latest" for display — only show concrete version tags + concrete = [v for v in versions if v != "latest"] + if len(concrete) > 1: + info(f"Multiple engine versions available: {', '.join(concrete)}") + try: + choice = input(f" Version to use [{concrete[0]}]: ").strip() + except (EOFError, KeyboardInterrupt): + choice = "" + if choice: + image_version = choice + else: + image_version = concrete[0] + + # 2. Load or create config + has_cli_overrides = any([provider, model, api_key]) + + if has_cli_overrides: + config = load_saved_config() or {} + config = merge_cli_into_config(config, provider=provider, model=model, api_key=api_key) + elif has_saved_config(): + config = load_saved_config() or {} + else: + config = run_wizard() + + # Validate we have enough config to proceed + llm_cfg = config.get("llm", {}) + base_url = llm_cfg.get("base_url", "") + is_local_llm = any(h in base_url for h in ("localhost", "127.0.0.1", "0.0.0.0")) + if not llm_cfg.get("api_key") and not is_local_llm: + error("No API key configured. Run `openra-rl config` or pass --api-key.") + sys.exit(1) + if not llm_cfg.get("model"): + error("No model configured. Run `openra-rl config` or pass --model.") + sys.exit(1) + + # 3. Start/reuse server + actual_url = server_url or f"http://localhost:{port}" + we_started_server = False + local_server_proc = None + + if local: + # Run the server locally (for developers with local OpenRA build) + header("Starting local server...") + local_server_proc = subprocess.Popen( + [sys.executable, "-m", "openra_env.server.app"], + stdout=sys.stdout, + stderr=sys.stderr, + ) + we_started_server = True + # Wait for it to be ready + import time + import urllib.request + import urllib.error + step(f"Waiting for local server on port {port}...") + start = time.time() + while time.time() - start < 60: + try: + req = urllib.request.urlopen(f"{actual_url}/health", timeout=3) + if req.status == 200: + success("Local server is ready!") + break + except (urllib.error.URLError, OSError): + pass + time.sleep(2) + else: + error("Local server did not become ready within 60s.") + local_server_proc.terminate() + sys.exit(1) + elif use_docker: + if docker.is_running(): + info(f"Server already running on port {port}.") + else: + if not docker.start_server(port=port, difficulty=difficulty, version=image_version): + sys.exit(1) + we_started_server = True + if not docker.wait_for_health(port=port): + sys.exit(1) + + # 4. Run the LLM agent + header("Starting LLM agent...") + provider_name = config.get("provider", "custom") + info(f"Model: {llm_cfg.get('model', '?')} via {provider_name}") + print() + + try: + _run_llm_agent(config, actual_url, verbose) + except KeyboardInterrupt: + print("\nInterrupted.") + except ConnectionRefusedError: + error(f"Could not connect to {actual_url}.") + info("Try: openra-rl server start") + info("Check: openra-rl doctor") + except Exception as e: + error(f"Agent error: {e}") + info("Run with --verbose for full details, or check: openra-rl doctor") + + # 5. Auto-copy replays from Docker + if use_docker and docker.is_running(): + new_replays = docker.copy_replays() + if new_replays: + print() + for f in new_replays: + success(f"Replay saved: {docker.LOCAL_REPLAY_DIR / f}") + info("Watch with: openra-rl replay watch") + + # 6. Cleanup + if we_started_server: + print() + if local_server_proc: + try: + answer = input(" Stop local server? [Y/n] ").strip().lower() + except (EOFError, KeyboardInterrupt): + answer = "y" + if answer in ("", "y", "yes"): + local_server_proc.terminate() + local_server_proc.wait(timeout=10) + success("Local server stopped.") + elif use_docker: + try: + answer = input(" Stop game server? [Y/n] ").strip().lower() + except (EOFError, KeyboardInterrupt): + answer = "y" + if answer in ("", "y", "yes"): + docker.stop_server() + + +def _run_llm_agent(config: dict, server_url: str, verbose: bool) -> None: + """Import and run the LLM agent with the given config.""" + import asyncio + + from openra_env.config import load_config + + # Build overrides from saved config + cli_overrides: dict = {} + llm_cfg = config.get("llm", {}) + if llm_cfg: + cli_overrides["llm"] = llm_cfg + cli_overrides.setdefault("agent", {})["server_url"] = server_url + if verbose: + cli_overrides.setdefault("agent", {})["verbose"] = True + + app_config = load_config(cli_overrides=cli_overrides) + + from openra_env.agent import run_agent + asyncio.run(run_agent(app_config, verbose)) + + +def cmd_config() -> None: + """Re-run the setup wizard.""" + run_wizard() + + +def cmd_server_start(port: int = 8000, difficulty: str = "normal", detach: bool = True) -> None: + """Start the game server.""" + if not docker.check_docker(): + sys.exit(1) + if not docker.start_server(port=port, difficulty=difficulty, detach=detach): + sys.exit(1) + if detach: + docker.wait_for_health(port=port) + + +def cmd_server_stop() -> None: + """Stop the game server.""" + docker.stop_server() + + +def cmd_server_status() -> None: + """Show game server status.""" + status = docker.server_status() + if status: + success(f"Server is running: {status['status']}") + if status.get("ports"): + dim(f" Ports: {status['ports']}") + else: + info("Server is not running.") + + +def cmd_server_logs(follow: bool = False) -> None: + """Show game server logs.""" + docker.get_logs(follow=follow) + + +def cmd_doctor() -> None: + """Check system prerequisites.""" + header("OpenRA-RL Doctor") + ok = True + + # Docker + if shutil.which("docker"): + success("Docker CLI: installed") + from openra_env.cli.docker_manager import _run + result = _run(["docker", "info"]) + if result.returncode == 0: + success("Docker daemon: running") + else: + warn("Docker daemon: not running") + ok = False + else: + error("Docker CLI: not found") + dim(" Install from https://docs.docker.com/get-docker/") + ok = False + + # Image + if docker.image_exists(): + success(f"Game image: available ({docker.IMAGE})") + else: + warn("Game image: not pulled yet (will be pulled on first `openra-rl play`)") + + # Server + if docker.is_running(): + success("Game server: running") + else: + dim("Game server: not running") + + # Python + py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + if sys.version_info >= (3, 10): + success(f"Python: {py_version}") + else: + error(f"Python: {py_version} (requires 3.10+)") + ok = False + + # Saved config + if has_saved_config(): + cfg = load_saved_config() or {} + provider = cfg.get("provider", "unknown") + model = cfg.get("llm", {}).get("model", "unknown") + success(f"Config: {CONFIG_PATH}") + dim(f" Provider: {provider}, Model: {model}") + else: + dim("Config: not yet configured (run `openra-rl play` or `openra-rl config`)") + + print() + if ok: + success("All checks passed!") + else: + warn("Some checks failed. Fix the issues above and try again.") + + +def cmd_version() -> None: + """Print version.""" + try: + from importlib.metadata import version + v = version("openra-rl") + except Exception: + v = "dev" + print(f"openra-rl {v}") + + +def cmd_mcp_server(server_url: Optional[str] = None, port: int = 8000) -> None: + """Start the MCP stdio server.""" + from openra_env.mcp_server import main as mcp_main + mcp_main(server_url=server_url or f"http://localhost:{port}") + + +# ── Replay commands ────────────────────────────────────────────────── + + +def cmd_replay_watch( + file: Optional[str] = None, + port: int = 6080, + resolution: Optional[str] = None, + render_mode: Optional[str] = None, + vnc_quality: Optional[int] = None, + vnc_compression: Optional[int] = None, + cpu_cores: Optional[int] = None, +) -> None: + """Watch a replay in the browser via VNC-in-Docker.""" + if not docker.check_docker(): + sys.exit(1) + + try: + viewer_settings = docker.load_replay_viewer_settings( + resolution=resolution, + render_mode=render_mode, + vnc_quality=vnc_quality, + vnc_compression=vnc_compression, + cpu_cores=cpu_cores, + ) + except ValueError as exc: + error(f"Invalid replay viewer setting: {exc}") + sys.exit(1) + + replay_path = file + + if replay_path is None: + # Check local replays first (most reliable — file is mounted directly) + local_replays = sorted(docker.LOCAL_REPLAY_DIR.glob("*.orarep")) + if local_replays: + replay_path = str(local_replays[-1]) + info(f"Latest local replay: {local_replays[-1].name}") + elif docker.is_running(): + # Fall back to container path (uses --volumes-from, less reliable) + replay_path = docker.get_latest_replay() + if replay_path: + info(f"Latest container replay: {Path(replay_path).name}") + if replay_path is None: + error("No replays found. Play a game first with: openra-rl play") + sys.exit(1) + + header("Starting replay viewer...") + info( + f"Settings: {viewer_settings.width}x{viewer_settings.height}, " + f"render={viewer_settings.render_mode}, " + f"vnc q/c={viewer_settings.vnc_quality}/{viewer_settings.vnc_compression}" + ) + + if not docker.start_replay_viewer(replay_path, port=port, settings=viewer_settings): + sys.exit(1) + + import time + import urllib.error + import urllib.request + + url = ( + f"http://localhost:{port}/vnc.html?autoconnect=1&resize=scale" + f"&quality={viewer_settings.vnc_quality}" + f"&compression={viewer_settings.vnc_compression}" + ) + step("Waiting for viewer to be ready...") + + ready = False + start_time = time.time() + timeout = 30 + while time.time() - start_time < timeout: + if not docker.is_replay_viewer_running(): + error("Replay viewer exited before it became ready.") + logs = docker.get_replay_viewer_logs() + if logs: + print() + info("Replay viewer logs:") + print(logs) + sys.exit(1) + try: + req = urllib.request.urlopen(f"http://localhost:{port}/vnc.html", timeout=2) + if 200 <= req.status < 500: + ready = True + break + except (urllib.error.URLError, OSError): + pass + time.sleep(1) + + if not ready: + error(f"Viewer did not become ready within {timeout}s.") + logs = docker.get_replay_viewer_logs() + if logs: + print() + info("Replay viewer logs:") + print(logs) + sys.exit(1) + + info(f"Opening {url}") + webbrowser.open(url) + print() + info("Tip: press F12 in the viewer for maximum replay speed.") + info("Tip: tune with --resolution, --render, --vnc-quality, --vnc-compression.") + info("Press Ctrl+C to stop the replay viewer") + print() + + try: + # Wait until container exits or user presses Ctrl+C + while docker.is_replay_viewer_running(): + time.sleep(2) + info("Replay viewer has stopped.") + except KeyboardInterrupt: + print() + docker.stop_replay_viewer() + + +def cmd_replay_list() -> None: + """List available replays from Docker and local.""" + header("Game Replays") + + # Docker replays + if docker.is_running(): + docker_replays = docker.list_replays() + if docker_replays: + info(f"In Docker container ({len(docker_replays)}):") + for r in docker_replays: + dim(f" {Path(r).name}") + else: + dim(" No replays in Docker container.") + else: + dim(" Docker server not running — cannot list container replays.") + + # Local replays + print() + local_dir = docker.LOCAL_REPLAY_DIR + if local_dir.exists(): + local_replays = sorted(local_dir.glob("*.orarep")) + if local_replays: + info(f"Local ({len(local_replays)}) — {local_dir}:") + for r in local_replays: + dim(f" {r.name}") + else: + dim(f" No local replays in {local_dir}") + else: + dim(f" No local replay directory ({local_dir})") + + +def cmd_replay_copy() -> None: + """Copy replays from Docker container to local directory.""" + if not docker.check_docker(): + sys.exit(1) + + if not docker.is_running(): + error("Game server is not running. Start it first or use: openra-rl server start") + sys.exit(1) + + header("Copying replays from Docker...") + new_files = docker.copy_replays() + if new_files: + for f in new_files: + success(f" Copied: {f}") + success(f"Copied {len(new_files)} new replay(s) to {docker.LOCAL_REPLAY_DIR}") + else: + info(f"No new replays to copy. Replays are in {docker.LOCAL_REPLAY_DIR}") + + +def cmd_replay_stop() -> None: + """Stop the replay viewer.""" + docker.stop_replay_viewer() diff --git a/openra_env/cli/console.py b/openra_env/cli/console.py new file mode 100644 index 0000000000000000000000000000000000000000..427c1105457463c32766cd6309a95a128362682d --- /dev/null +++ b/openra_env/cli/console.py @@ -0,0 +1,43 @@ +"""ANSI colored console output helpers (no external deps).""" + +import sys + +# ANSI codes — disabled when not a TTY +_IS_TTY = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + +_RESET = "\033[0m" if _IS_TTY else "" +_BOLD = "\033[1m" if _IS_TTY else "" +_GREEN = "\033[32m" if _IS_TTY else "" +_YELLOW = "\033[33m" if _IS_TTY else "" +_RED = "\033[31m" if _IS_TTY else "" +_CYAN = "\033[36m" if _IS_TTY else "" +_DIM = "\033[2m" if _IS_TTY else "" + + +def info(msg: str) -> None: + print(f" {msg}") + + +def success(msg: str) -> None: + print(f" {_GREEN}{msg}{_RESET}") + + +def error(msg: str) -> None: + print(f" {_RED}{msg}{_RESET}", file=sys.stderr) + + +def warn(msg: str) -> None: + print(f" {_YELLOW}{msg}{_RESET}") + + +def step(msg: str) -> None: + """Print a progress step (e.g. 'Pulling image...').""" + print(f" {_CYAN}{msg}{_RESET}") + + +def header(msg: str) -> None: + print(f"\n {_BOLD}{msg}{_RESET}") + + +def dim(msg: str) -> None: + print(f" {_DIM}{msg}{_RESET}") diff --git a/openra_env/cli/docker_manager.py b/openra_env/cli/docker_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..acb50261352f8004ecca5851564ba6c6caeeeee5 --- /dev/null +++ b/openra_env/cli/docker_manager.py @@ -0,0 +1,600 @@ +"""Docker orchestration for the OpenRA-RL game server.""" + +import json +import os +import shutil +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from openra_env.cli.console import error, info, step, success + +IMAGE_REPO = "ghcr.io/yxc20089/openra-rl" +IMAGE = f"{IMAGE_REPO}:latest" +CONTAINER_NAME = "openra-rl-server" +REPLAY_CONTAINER = "openra-rl-replay" +REPLAY_DIR_IN_CONTAINER = "/root/.config/openra/Replays/ra" +LOCAL_REPLAY_DIR = Path.home() / ".openra-rl" / "replays" +MANIFEST_PATH = LOCAL_REPLAY_DIR / "manifest.json" + + +def _run(args: list[str], capture: bool = True, **kwargs) -> subprocess.CompletedProcess: + """Run a subprocess command, capturing output by default.""" + return subprocess.run( + args, + capture_output=capture, + text=True, + encoding="utf-8", + **kwargs, + ) + + +def check_docker() -> bool: + """Verify docker CLI is available and daemon is running.""" + if not shutil.which("docker"): + error("Docker not found. Install it from https://docs.docker.com/get-docker/") + return False + result = _run(["docker", "info"]) + if result.returncode != 0: + error("Docker daemon is not running. Start Docker Desktop and try again.") + return False + return True + + +def _image_tag(version: Optional[str] = None) -> str: + """Return the full image tag for a given version (default: latest).""" + tag = version or "latest" + return f"{IMAGE_REPO}:{tag}" + + +def pull_image(version: Optional[str] = None, quiet: bool = False) -> bool: + """Pull the game server image from GHCR.""" + image = _image_tag(version) + if not quiet: + step(f"Pulling game server image ({image})...") + result = subprocess.run( + ["docker", "pull", image], + stdout=sys.stdout if not quiet else subprocess.DEVNULL, + stderr=sys.stderr if not quiet else subprocess.DEVNULL, + ) + if result.returncode != 0: + error(f"Failed to pull {image}") + return False + if not quiet: + success("Image pulled successfully.") + return True + + +def image_exists(version: Optional[str] = None) -> bool: + """Check if the game server image is available locally.""" + image = _image_tag(version) + result = _run(["docker", "images", "-q", image]) + return bool(result.stdout.strip()) + + +def list_local_versions() -> list[str]: + """List all locally available openra-rl image versions (tags), newest first.""" + result = _run([ + "docker", "images", IMAGE_REPO, + "--format", "{{.Tag}}", + ]) + if result.returncode != 0: + return [] + tags = [t.strip() for t in result.stdout.splitlines() if t.strip()] + # Put "latest" first, then sort the rest in reverse + versions = sorted([t for t in tags if t != "latest"], reverse=True) + if "latest" in tags: + versions.insert(0, "latest") + return versions + + +def get_running_image_tag() -> Optional[str]: + """Get the image tag of the currently running game server container.""" + if not is_running(): + return None + result = _run([ + "docker", "inspect", CONTAINER_NAME, + "--format", "{{.Config.Image}}", + ]) + if result.returncode != 0: + return None + image = result.stdout.strip() + # Extract tag from "ghcr.io/yxc20089/openra-rl:0.2.1" + if ":" in image: + return image.split(":")[-1] + return "latest" + + +# ── Replay manifest ────────────────────────────────────────────────── + + +def _load_manifest() -> dict: + """Load the replay manifest (replay filename → image tag).""" + if MANIFEST_PATH.exists(): + try: + return json.loads(MANIFEST_PATH.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError): + pass + return {} + + +def _save_manifest(manifest: dict) -> None: + """Save the replay manifest.""" + MANIFEST_PATH.parent.mkdir(parents=True, exist_ok=True) + MANIFEST_PATH.write_text(json.dumps(manifest, indent=2) + "\n", encoding="utf-8") + + +def get_replay_image_tag(replay_filename: str) -> Optional[str]: + """Look up which image tag was used to record a replay.""" + manifest = _load_manifest() + return manifest.get(replay_filename) + + +def _record_replays_in_manifest(filenames: list[str], image_tag: str) -> None: + """Record which image tag was used for newly copied replays.""" + if not filenames: + return + manifest = _load_manifest() + for f in filenames: + manifest[f] = image_tag + _save_manifest(manifest) + + +def is_running() -> bool: + """Check if the game server container is running.""" + result = _run([ + "docker", "ps", "--filter", f"name={CONTAINER_NAME}", + "--format", "{{.Names}}" + ]) + return CONTAINER_NAME in result.stdout + + +def start_server( + port: int = 8000, + difficulty: str = "normal", + detach: bool = True, + version: Optional[str] = None, +) -> bool: + """Start the game server container.""" + if is_running(): + info(f"Server already running on port {port}.") + return True + + image = _image_tag(version) + + # Ensure image exists + if not image_exists(version): + if not pull_image(version): + return False + + step(f"Starting game server on port {port} ({image})...") + cmd = [ + "docker", "run", "--rm", + "-d" if detach else "", + "-p", f"{port}:8000", + "--name", CONTAINER_NAME, + "-e", f"BOT_TYPE={difficulty}", + image, + ] + # Remove empty strings from cmd + cmd = [c for c in cmd if c] + + result = _run(cmd) + if result.returncode != 0: + error(f"Failed to start server: {result.stderr.strip()}") + return False + return True + + +def stop_server() -> bool: + """Stop and remove the game server container.""" + if not is_running(): + info("Server is not running.") + return True + step("Stopping game server...") + result = _run(["docker", "stop", CONTAINER_NAME]) + if result.returncode != 0: + error(f"Failed to stop server: {result.stderr.strip()}") + return False + success("Server stopped.") + return True + + +def wait_for_health(port: int = 8000, timeout: int = 120) -> bool: + """Poll the health endpoint until the server is ready.""" + import urllib.request + import urllib.error + + url = f"http://localhost:{port}/health" + step(f"Waiting for server to be ready (timeout {timeout}s)...") + start = time.time() + while time.time() - start < timeout: + try: + req = urllib.request.urlopen(url, timeout=3) + if req.status == 200: + success("Server is ready!") + return True + except (urllib.error.URLError, OSError): + pass + time.sleep(2) + error(f"Server did not become healthy within {timeout}s.") + return False + + +def get_logs(follow: bool = False) -> None: + """Print container logs.""" + if not is_running(): + # Try to get logs from stopped container too + pass + cmd = ["docker", "logs"] + if follow: + cmd.append("-f") + cmd.append(CONTAINER_NAME) + subprocess.run(cmd) + + +def server_status() -> Optional[dict]: + """Get server container status info.""" + if not is_running(): + return None + result = _run([ + "docker", "ps", "--filter", f"name={CONTAINER_NAME}", + "--format", "{{.Status}}\t{{.Ports}}" + ]) + if result.stdout.strip(): + parts = result.stdout.strip().split("\t") + return { + "status": parts[0] if parts else "unknown", + "ports": parts[1] if len(parts) > 1 else "", + } + return None + + +# ── Replay viewer settings ─────────────────────────────────────────── + + +@dataclass(frozen=True) +class ReplayViewerSettings: + """Tunable replay viewer settings for quality/performance tradeoffs.""" + + width: int = 1280 + height: int = 960 + ui_scale: float = 1.0 + viewport_distance: str = "Medium" + mute: bool = True + render_mode: str = "auto" # auto | gpu | cpu + vnc_quality: int = 8 + vnc_compression: int = 4 + cpu_cores: int = 4 # Docker --cpus limit for software rendering (0 = all available) + + +def _parse_resolution(value: str) -> tuple[int, int]: + """Parse a WxH resolution string.""" + raw = value.strip().lower().replace(" ", "") + for sep in ("x", ","): + if sep in raw: + left, right = raw.split(sep, 1) + try: + w, h = int(left), int(right) + except ValueError: + break + if w < 320 or h < 240 or w > 7680 or h > 4320: + raise ValueError(f"resolution out of range (320x240..7680x4320): {value}") + return w, h + raise ValueError(f"resolution must be WxH (e.g. 960x540), got: {value!r}") + + +def _normalize_render_mode(value: str) -> str: + """Validate and normalize render mode.""" + mode = value.strip().lower() + if mode not in ("auto", "gpu", "cpu"): + raise ValueError(f"render mode must be auto/gpu/cpu, got: {value!r}") + return mode + + +def _normalize_viewport(value: str) -> str: + """Validate and normalize viewport distance.""" + mapping = {"close": "Close", "medium": "Medium", "far": "Far"} + key = value.strip().lower() + if key not in mapping: + raise ValueError(f"viewport must be close/medium/far, got: {value!r}") + return mapping[key] + + +def load_replay_viewer_settings( + resolution: Optional[str] = None, + render_mode: Optional[str] = None, + vnc_quality: Optional[int] = None, + vnc_compression: Optional[int] = None, + cpu_cores: Optional[int] = None, +) -> ReplayViewerSettings: + """Load replay viewer settings from CLI overrides → env vars → defaults.""" + env = os.environ + + res = resolution or env.get("OPENRA_RL_REPLAY_RESOLUTION", "1280x960") + w, h = _parse_resolution(res) + + mode = _normalize_render_mode( + render_mode if render_mode is not None else env.get("OPENRA_RL_REPLAY_RENDER", "auto") + ) + + vq = vnc_quality if vnc_quality is not None else int(env.get("OPENRA_RL_REPLAY_VNC_QUALITY", "8")) + vc = vnc_compression if vnc_compression is not None else int(env.get("OPENRA_RL_REPLAY_VNC_COMPRESSION", "4")) + vq = max(0, min(9, vq)) + vc = max(0, min(9, vc)) + + cores = cpu_cores if cpu_cores is not None else int(env.get("OPENRA_RL_REPLAY_CPU_CORES", "4")) + if cores <= 0: + cores = os.cpu_count() or 4 + cores = max(1, min(32, cores)) + + ui_scale = float(env.get("OPENRA_RL_REPLAY_UI_SCALE", "1")) + viewport = _normalize_viewport(env.get("OPENRA_RL_REPLAY_VIEWPORT_DISTANCE", "medium")) + mute_raw = env.get("OPENRA_RL_REPLAY_MUTE", "true").strip().lower() + mute = mute_raw not in ("0", "false", "no", "off") + + return ReplayViewerSettings( + width=w, height=h, ui_scale=ui_scale, viewport_distance=viewport, + mute=mute, render_mode=mode, vnc_quality=vq, vnc_compression=vc, + cpu_cores=cores, + ) + + +def _settings_env_args(settings: ReplayViewerSettings) -> list[str]: + """Convert settings to docker -e KEY=VAL args.""" + return [ + "-e", f"OPENRA_RL_REPLAY_RESOLUTION={settings.width}x{settings.height}", + "-e", f"OPENRA_RL_REPLAY_UI_SCALE={settings.ui_scale}", + "-e", f"OPENRA_RL_REPLAY_VIEWPORT_DISTANCE={settings.viewport_distance}", + "-e", f"OPENRA_RL_REPLAY_MUTE={'True' if settings.mute else 'False'}", + "-e", "SDL_AUDIODRIVER=dummy", + "-e", "OPENRA_DISPLAY_SCALE=1", + ] + + +def _gpu_docker_args(mode: str, cpu_cores: int = 4) -> list[list[str]]: + """Return docker arg variants for GPU passthrough, in preference order. + + auto: try GPU variants first, fall back to CPU. + gpu: only try GPU variants (fail if none work). + cpu: only try CPU (software rendering). + cpu_cores: number of llvmpipe threads for software rendering. + """ + cpu = ["-e", "LIBGL_ALWAYS_SOFTWARE=1", "-e", f"LP_NUM_THREADS={cpu_cores}"] + gpu_variants = [ + ["--gpus", "all"], # NVIDIA + ["--device", "/dev/dxg:/dev/dxg", # WSL2 (AMD/NVIDIA/Intel) + "-v", "/usr/lib/wsl:/usr/lib/wsl:ro", + "-e", "LD_LIBRARY_PATH=/usr/lib/wsl/lib"], + ["--device", "/dev/kfd:/dev/kfd", # AMD ROCm (native Linux) + "--device", "/dev/dri:/dev/dri", + "--group-add", "video"], + ["--device", "/dev/dri:/dev/dri"], # Generic DRI (AMD/Intel) + ] + if mode == "cpu": + return [cpu] + if mode == "gpu": + return gpu_variants + # auto: try all GPU variants, then CPU fallback + return gpu_variants + [cpu] + + +# ── Replay viewer ──────────────────────────────────────────────────── + + +def list_replays() -> list[str]: + """List .orarep files inside the game server container.""" + if not is_running(): + return [] + result = _run([ + "docker", "exec", CONTAINER_NAME, + "find", REPLAY_DIR_IN_CONTAINER, "-name", "*.orarep", "-type", "f", + ]) + if result.returncode != 0: + return [] + files = [line.strip() for line in result.stdout.splitlines() if line.strip()] + files.sort() + return files + + +def get_latest_replay() -> Optional[str]: + """Return the path of the newest replay inside the game server container.""" + replays = list_replays() + return replays[-1] if replays else None + + +def copy_replays() -> list[str]: + """Copy all replays from the game server container to ~/.openra-rl/replays/. + + Returns list of newly copied filenames. + Also records the image tag in the manifest so replay watch uses the right version. + """ + if not is_running(): + error("Game server is not running — cannot copy replays.") + return [] + + LOCAL_REPLAY_DIR.mkdir(parents=True, exist_ok=True) + + # Get list of replays in container + replays = list_replays() + if not replays: + return [] + + # Get existing local files to detect new ones + existing = {f.name for f in LOCAL_REPLAY_DIR.iterdir() if f.suffix == ".orarep"} + + # Copy each replay individually (docker cp doesn't glob well) + for replay_path in replays: + filename = os.path.basename(replay_path) + result = _run([ + "docker", "cp", + f"{CONTAINER_NAME}:{replay_path}", + str(LOCAL_REPLAY_DIR / filename), + ]) + if result.returncode != 0: + error(f"Failed to copy {filename}: {result.stderr.strip()}") + + # Determine which files are new + after = {f.name for f in LOCAL_REPLAY_DIR.iterdir() if f.suffix == ".orarep"} + new_files = sorted(after - existing) + + # Record the image version that produced these replays + if new_files: + tag = get_running_image_tag() or "latest" + _record_replays_in_manifest(new_files, tag) + + return new_files + + +def is_replay_viewer_running() -> bool: + """Check if the replay viewer container is running.""" + result = _run([ + "docker", "ps", "--filter", f"name={REPLAY_CONTAINER}", + "--format", "{{.Names}}" + ]) + return REPLAY_CONTAINER in result.stdout + + +def replay_viewer_exists() -> bool: + """Check if the replay viewer container exists (running or exited).""" + result = _run([ + "docker", "ps", "-a", "--filter", f"name={REPLAY_CONTAINER}", + "--format", "{{.Names}}" + ]) + return REPLAY_CONTAINER in result.stdout + + +def get_replay_viewer_logs(tail: int = 200) -> str: + """Return recent replay viewer logs, or empty string if unavailable.""" + if not replay_viewer_exists(): + return "" + result = _run(["docker", "logs", "--tail", str(tail), REPLAY_CONTAINER]) + if result.returncode != 0: + return result.stderr.strip() or result.stdout.strip() + return result.stdout.strip() + + +def start_replay_viewer( + replay_path: str, + port: int = 6080, + version: Optional[str] = None, + settings: Optional[ReplayViewerSettings] = None, +) -> bool: + """Start the replay viewer container. + + Args: + replay_path: Path to .orarep file (container path or local path). + port: noVNC port to expose (default 6080). + version: Docker image version to use (default: auto-detect from manifest). + settings: Replay viewer tuning (resolution, render mode, etc.). + """ + if settings is None: + settings = load_replay_viewer_settings() + + if is_replay_viewer_running(): + error("Replay viewer is already running. Stop it first with: openra-rl replay stop") + return False + + # Clean up stale (exited) container if it exists + if replay_viewer_exists(): + _run(["docker", "rm", "-f", REPLAY_CONTAINER]) + + # Auto-detect version from manifest if not specified + if version is None: + filename = os.path.basename(replay_path) + version = get_replay_image_tag(filename) + if version: + info(f"Using image version '{version}' (from manifest)") + + image = _image_tag(version) + + if not image_exists(version): + step(f"Image {image} not found locally, pulling...") + if not pull_image(version): + return False + + # Determine if this is a local file or a container path. + local_file = None + container_replay_path = replay_path + local_path = Path(replay_path).resolve() + + if local_path.exists(): + local_file = str(local_path) + container_replay_path = f"/tmp/replay/{local_path.name}" + elif replay_path.startswith("/") and is_running(): + # Container path — copy locally first so we can mount it reliably + # (--volumes-from only shares Docker volumes, not the writable layer) + filename = os.path.basename(replay_path) + LOCAL_REPLAY_DIR.mkdir(parents=True, exist_ok=True) + local_dest = LOCAL_REPLAY_DIR / filename + cp_result = _run(["docker", "cp", f"{CONTAINER_NAME}:{replay_path}", str(local_dest)]) + if cp_result.returncode == 0 and local_dest.exists(): + local_file = str(local_dest) + container_replay_path = f"/tmp/replay/{filename}" + elif not replay_path.startswith("/"): + error(f"Replay file not found: {local_path}") + return False + + step(f"Starting replay viewer on port {port} ({image})...") + + # Build base docker command + base_cmd = [ + "docker", "run", "-d", + "-p", f"{port}:6080", + "--name", REPLAY_CONTAINER, + "--entrypoint", "/replay-viewer.sh", + ] + base_cmd.extend(_settings_env_args(settings)) + + if local_file: + base_cmd.extend(["-v", f"{local_file}:{container_replay_path}:ro"]) + elif is_running(): + base_cmd.extend(["--volumes-from", CONTAINER_NAME]) + + # Try GPU variants in order, fall back to CPU + last_stderr = "" + for gpu_args in _gpu_docker_args(settings.render_mode, cpu_cores=settings.cpu_cores): + is_gpu = "--gpus" in gpu_args or "--device" in gpu_args + # Limit CPU for software rendering to prevent runaway usage. + # llvmpipe busy-loops without GPU; --cpus caps Docker scheduler. + cpu_limit = [] if is_gpu else ["--cpus", str(settings.cpu_cores)] + cmd = base_cmd + cpu_limit + gpu_args + [image, container_replay_path] + result = _run(cmd) + if result.returncode == 0: + if is_gpu: + gpu_args_str = " ".join(gpu_args) + if "--gpus" in gpu_args_str: + info("Rendering mode: GPU (NVIDIA)") + elif "/dev/dxg" in gpu_args_str: + info("Rendering mode: GPU (WSL2 DirectX)") + elif "/dev/kfd" in gpu_args_str: + info("Rendering mode: GPU (AMD ROCm)") + else: + info("Rendering mode: GPU (DRI)") + else: + info(f"Rendering mode: CPU (software, {settings.cpu_cores} cores)") + success("Replay viewer started.") + return True + last_stderr = result.stderr.strip() + # Clean up the failed container before trying next variant + _run(["docker", "rm", "-f", REPLAY_CONTAINER]) + + error(f"Failed to start replay viewer: {last_stderr}") + return False + + +def stop_replay_viewer() -> bool: + """Stop and remove the replay viewer container.""" + if not replay_viewer_exists(): + info("Replay viewer is not running.") + return True + step("Stopping replay viewer...") + result = _run(["docker", "rm", "-f", REPLAY_CONTAINER]) + if result.returncode != 0: + error(f"Failed to stop replay viewer: {result.stderr.strip()}") + return False + success("Replay viewer stopped.") + return True diff --git a/openra_env/cli/main.py b/openra_env/cli/main.py new file mode 100644 index 0000000000000000000000000000000000000000..eb22e675565dc432482d3676b2c36aa152615ea9 --- /dev/null +++ b/openra_env/cli/main.py @@ -0,0 +1,212 @@ +"""CLI entry point for openra-rl.""" + +import argparse +import sys + + +def main() -> None: + parser = argparse.ArgumentParser( + prog="openra-rl", + description="Play Red Alert with AI agents", + ) + parser.add_argument( + "--version", action="store_true", + help="Print version and exit", + ) + subparsers = parser.add_subparsers(dest="command") + + # ── play ──────────────────────────────────────────────────────── + play_parser = subparsers.add_parser( + "play", help="Run the LLM agent against the game", + ) + play_parser.add_argument( + "--provider", choices=["openrouter", "ollama", "lmstudio"], + help="LLM provider (overrides saved config)", + ) + play_parser.add_argument("--model", help="Model ID") + play_parser.add_argument("--api-key", help="API key for LLM endpoint") + play_parser.add_argument( + "--difficulty", choices=["easy", "normal", "hard"], default="normal", + help="AI opponent difficulty (default: normal)", + ) + play_parser.add_argument("--verbose", action="store_true", help="Verbose output") + play_parser.add_argument("--port", type=int, default=8000, help="Game server port (default: 8000)") + play_parser.add_argument("--server-url", help="Connect to existing server URL (skip Docker)") + play_parser.add_argument("--local", action="store_true", help="Run server locally instead of Docker (for developers)") + play_parser.add_argument("--version", dest="image_version", default=None, help="Docker image version to use (default: latest)") + + # ── config ────────────────────────────────────────────────────── + subparsers.add_parser("config", help="Re-run the setup wizard") + + # ── server ────────────────────────────────────────────────────── + server_parser = subparsers.add_parser("server", help="Manage the game server") + server_sub = server_parser.add_subparsers(dest="server_command") + + start_parser = server_sub.add_parser("start", help="Start the game server") + start_parser.add_argument("--port", type=int, default=8000, help="Port (default: 8000)") + start_parser.add_argument( + "--difficulty", choices=["easy", "normal", "hard"], default="normal", + ) + start_parser.add_argument("--detach", action="store_true", default=True, help="Run in background (default)") + + server_sub.add_parser("stop", help="Stop the game server") + server_sub.add_parser("status", help="Show server status") + + logs_parser = server_sub.add_parser("logs", help="Show server logs") + logs_parser.add_argument("--follow", "-f", action="store_true", help="Follow log output") + + # ── mcp-server ────────────────────────────────────────────────── + mcp_parser = subparsers.add_parser("mcp-server", help="Start MCP stdio server") + mcp_parser.add_argument("--server-url", help="Game server URL") + mcp_parser.add_argument("--port", type=int, default=8000, help="Game server port (default: 8000)") + + # ── replay ───────────────────────────────────────────────────── + replay_parser = subparsers.add_parser("replay", help="Manage and watch game replays") + replay_sub = replay_parser.add_subparsers(dest="replay_command") + + watch_parser = replay_sub.add_parser("watch", help="Watch a replay in your browser (via VNC)") + watch_parser.add_argument("file", nargs="?", default=None, help="Replay file (local path or container path; default: latest)") + watch_parser.add_argument("--port", type=int, default=6080, help="noVNC port (default: 6080)") + watch_parser.add_argument( + "--resolution", default=None, + help="Replay viewer resolution WxH (default: 1280x960)", + ) + watch_parser.add_argument( + "--render", dest="render_mode", choices=["auto", "gpu", "cpu"], default=None, + help="Render backend: auto tries GPU then CPU (default: auto)", + ) + watch_parser.add_argument( + "--vnc-quality", type=int, default=None, + help="VNC quality 0-9, higher = sharper (default: 8)", + ) + watch_parser.add_argument( + "--vnc-compression", type=int, default=None, + help="VNC compression 0-9, higher = smaller (default: 4)", + ) + watch_parser.add_argument( + "--cpus", type=int, default=None, + help="CPU cores for software rendering (default: 4, 0 = all available).", + ) + + replay_sub.add_parser("list", help="List available replays") + replay_sub.add_parser("copy", help="Copy replays from Docker to ~/.openra-rl/replays/") + replay_sub.add_parser("stop", help="Stop the replay viewer") + + # ── bench ───────────────────────────────────────────────────────── + bench_parser = subparsers.add_parser("bench", help="Benchmark leaderboard tools") + bench_sub = bench_parser.add_subparsers(dest="bench_command") + + bench_submit_parser = bench_sub.add_parser("submit", help="Upload game result JSON to the leaderboard") + bench_submit_parser.add_argument("json_file", type=str, help="Path to bench export JSON file") + bench_submit_parser.add_argument("--agent-name", default=None, help="Override agent name") + bench_submit_parser.add_argument("--agent-type", default=None, help="Override agent type (Scripted/LLM/RL)") + bench_submit_parser.add_argument("--agent-url", default=None, help="GitHub/project URL") + bench_submit_parser.add_argument("--replay", default=None, help="Path to .orarep replay file") + bench_submit_parser.add_argument( + "--bench-url", default=None, + help="Bench leaderboard URL (default: https://openra-rl-openra-bench.hf.space)", + ) + + # ── doctor ────────────────────────────────────────────────────── + subparsers.add_parser("doctor", help="Check system prerequisites") + + # ── version ───────────────────────────────────────────────────── + subparsers.add_parser("version", help="Print version") + + args = parser.parse_args() + + # Handle --version at top level + if args.version: + from openra_env.cli.commands import cmd_version + cmd_version() + return + + if args.command is None: + parser.print_help() + sys.exit(0) + + # Dispatch + from openra_env.cli import commands + + if args.command == "play": + commands.cmd_play( + provider=args.provider, + model=args.model, + api_key=args.api_key, + difficulty=args.difficulty, + verbose=args.verbose, + port=args.port, + server_url=args.server_url, + local=args.local, + image_version=args.image_version, + ) + elif args.command == "config": + commands.cmd_config() + elif args.command == "server": + if args.server_command == "start": + commands.cmd_server_start( + port=args.port, + difficulty=args.difficulty, + detach=args.detach, + ) + elif args.server_command == "stop": + commands.cmd_server_stop() + elif args.server_command == "status": + commands.cmd_server_status() + elif args.server_command == "logs": + commands.cmd_server_logs(follow=args.follow) + else: + server_parser.print_help() + elif args.command == "replay": + if args.replay_command == "watch": + commands.cmd_replay_watch( + file=args.file, + port=args.port, + resolution=args.resolution, + render_mode=args.render_mode, + vnc_quality=args.vnc_quality, + vnc_compression=args.vnc_compression, + cpu_cores=args.cpus, + ) + elif args.replay_command == "list": + commands.cmd_replay_list() + elif args.replay_command == "copy": + commands.cmd_replay_copy() + elif args.replay_command == "stop": + commands.cmd_replay_stop() + else: + replay_parser.print_help() + elif args.command == "mcp-server": + commands.cmd_mcp_server( + server_url=args.server_url, + port=args.port, + ) + elif args.command == "bench": + if args.bench_command == "submit": + from openra_env.bench_submit import main as bench_submit_main + # Patch sys.argv so bench_submit's argparse sees the right args + submit_argv = ["openra-rl bench submit", args.json_file] + if args.agent_name: + submit_argv += ["--agent-name", args.agent_name] + if args.agent_type: + submit_argv += ["--agent-type", args.agent_type] + if args.agent_url: + submit_argv += ["--agent-url", args.agent_url] + if args.replay: + submit_argv += ["--replay", args.replay] + if args.bench_url: + submit_argv += ["--bench-url", args.bench_url] + sys.argv = submit_argv + bench_submit_main() + else: + bench_parser.print_help() + elif args.command == "doctor": + commands.cmd_doctor() + elif args.command == "version": + commands.cmd_version() + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/openra_env/cli/wizard.py b/openra_env/cli/wizard.py new file mode 100644 index 0000000000000000000000000000000000000000..25301f1271290e2ccf35865081b7fe38dd9c9180 --- /dev/null +++ b/openra_env/cli/wizard.py @@ -0,0 +1,166 @@ +"""Interactive first-run setup wizard.""" + +from pathlib import Path +from typing import Optional + +import yaml + +from openra_env.cli.console import dim, error, header, info, success, warn + +CONFIG_DIR = Path.home() / ".openra-rl" +CONFIG_PATH = CONFIG_DIR / "config.yaml" + +# Provider presets +PROVIDERS = { + "openrouter": { + "name": "OpenRouter", + "base_url": "https://openrouter.ai/api/v1/chat/completions", + "needs_key": True, + "key_help": "Get one at https://openrouter.ai/keys", + "default_model": "qwen/qwen3-coder-next", + }, + "ollama": { + "name": "Ollama", + "base_url": "http://localhost:11434/v1/chat/completions", + "needs_key": False, + "default_model": "qwen3:32b", + }, + "lmstudio": { + "name": "LM Studio", + "base_url": "http://localhost:1234/v1/chat/completions", + "needs_key": False, + "default_model": "", + "models": [], + }, +} + + +def _prompt(question: str, default: str = "") -> str: + """Prompt user for input with optional default.""" + if default: + raw = input(f" {question} [{default}]: ").strip() + return raw or default + else: + while True: + raw = input(f" {question}: ").strip() + if raw: + return raw + error("Please enter a value.") + + +def _choose(question: str, options: list[tuple[str, str]], allow_custom: bool = False) -> str: + """Present numbered options and get user choice.""" + print(f"\n {question}") + for i, (value, label) in enumerate(options, 1): + print(f" [{i}] {label}") + if allow_custom: + print(f" [{len(options) + 1}] Enter custom value") + + max_choice = len(options) + (1 if allow_custom else 0) + while True: + raw = input(" > ").strip() + try: + idx = int(raw) + if 1 <= idx <= len(options): + return options[idx - 1][0] + if allow_custom and idx == max_choice: + return _prompt("Enter value") + except ValueError: + # Allow typing the value directly + if raw: + return raw + error(f"Please enter a number 1-{max_choice}.") + + +def has_saved_config() -> bool: + """Check if a saved config exists.""" + return CONFIG_PATH.exists() + + +def load_saved_config() -> Optional[dict]: + """Load saved config if it exists.""" + if not CONFIG_PATH.exists(): + return None + try: + with open(CONFIG_PATH, encoding="utf-8") as f: + return yaml.safe_load(f) or {} + except Exception: + return None + + +def save_config(config: dict) -> None: + """Save config to ~/.openra-rl/config.yaml.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + with open(CONFIG_PATH, "w", encoding="utf-8") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + success(f"Config saved to {CONFIG_PATH}") + + +def run_wizard() -> dict: + """Run the interactive setup wizard. Returns a config dict.""" + header("Welcome to OpenRA-RL!") + info("Let's set up your LLM provider.\n") + + # Choose provider + provider_key = _choose( + "Choose provider:", + [ + ("openrouter", "OpenRouter (cloud — Claude, GPT, Qwen, Mistral, etc.)"), + ("ollama", "Ollama (local, free)"), + ("lmstudio", "LM Studio (local, free)"), + ], + ) + + provider = PROVIDERS.get(provider_key, PROVIDERS["openrouter"]) + config: dict = {"provider": provider_key, "llm": {"base_url": provider["base_url"]}} + + # API key (if needed) + if provider.get("needs_key"): + print() + api_key = _prompt(f"Enter your {provider['name']} API key ({provider.get('key_help', '')})") + config["llm"]["api_key"] = api_key + + # Model selection + if provider.get("models"): + model = _choose( + "Choose a model:", + [(m, label) for m, label in provider["models"]], + allow_custom=True, + ) + else: + model = _prompt("Enter model ID", default=provider.get("default_model", "")) + + config["llm"]["model"] = model + + # Ollama: warn about context window + if provider_key == "ollama": + print() + warn("Tip: If you see truncation errors, increase the context window:") + dim(f" ollama create {model}-32k --from {model} --parameter num_ctx 32768") + + print() + save_config(config) + dim("Run `openra-rl config` to change these settings later.\n") + + return config + + +def merge_cli_into_config( + config: dict, + provider: Optional[str] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, +) -> dict: + """Apply CLI flag overrides onto a config dict.""" + if provider and provider in PROVIDERS: + p = PROVIDERS[provider] + config.setdefault("llm", {})["base_url"] = p["base_url"] + config["provider"] = provider + + if model: + config.setdefault("llm", {})["model"] = model + + if api_key: + config.setdefault("llm", {})["api_key"] = api_key + + return config diff --git a/openra_env/client.py b/openra_env/client.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e80b3b925791ad7187d578e8c0a07f37d7641e --- /dev/null +++ b/openra_env/client.py @@ -0,0 +1,113 @@ +"""OpenRA-RL environment client. + +Provides the EnvClient subclass for connecting to the OpenRA-RL +environment server over WebSocket. +""" + +import os +from typing import Any, Dict + +from openenv.core.client_types import StepResult +from openenv.core.env_client import EnvClient +from websockets.asyncio.client import connect as ws_connect + +from openra_env.models import ( + BuildingInfoModel, + EconomyInfo, + MapInfoModel, + MilitaryInfo, + OpenRAAction, + OpenRAObservation, + OpenRAState, + ProductionInfoModel, + UnitInfoModel, +) + + +class OpenRAEnv(EnvClient[OpenRAAction, OpenRAObservation, OpenRAState]): + """WebSocket client for the OpenRA-RL environment. + + Usage: + async with OpenRAEnv(base_url="http://localhost:8000") as env: + result = await env.reset() + while not result.done: + action = OpenRAAction(commands=[...]) + result = await env.step(action) + """ + + async def connect(self) -> "OpenRAEnv": + """Connect with ping keepalive disabled. + + OpenRA operations (especially reset) can take 60-120+ seconds + with software rendering. The default websockets ping_interval=20s + would kill the connection before the server responds. + """ + if self._ws is not None: + return self + + ws_url_lower = self._ws_url.lower() + is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower + + old_no_proxy = os.environ.get("NO_PROXY") + if is_localhost: + current_no_proxy = old_no_proxy or "" + if "localhost" not in current_no_proxy.lower(): + os.environ["NO_PROXY"] = ( + f"{current_no_proxy},localhost,127.0.0.1" + if current_no_proxy + else "localhost,127.0.0.1" + ) + + try: + self._ws = await ws_connect( + self._ws_url, + open_timeout=self._connect_timeout, + max_size=self._max_message_size, + ping_interval=None, + ) + except Exception as e: + raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e + finally: + if is_localhost: + if old_no_proxy is None: + os.environ.pop("NO_PROXY", None) + else: + os.environ["NO_PROXY"] = old_no_proxy + + return self + + def _step_payload(self, action: OpenRAAction) -> Dict[str, Any]: + """Convert action to JSON for WebSocket transport.""" + return action.model_dump() + + def _parse_result(self, data: Dict[str, Any]) -> StepResult[OpenRAObservation]: + """Parse server response into StepResult.""" + obs_data = data.get("observation", data) + + observation = OpenRAObservation( + tick=obs_data.get("tick", 0), + economy=EconomyInfo(**obs_data.get("economy", {})), + military=MilitaryInfo(**obs_data.get("military", {})), + units=[UnitInfoModel(**u) for u in obs_data.get("units", [])], + buildings=[BuildingInfoModel(**b) for b in obs_data.get("buildings", [])], + production=[ProductionInfoModel(**p) for p in obs_data.get("production", [])], + visible_enemies=[UnitInfoModel(**u) for u in obs_data.get("visible_enemies", [])], + visible_enemy_buildings=[BuildingInfoModel(**b) for b in obs_data.get("visible_enemy_buildings", [])], + map_info=MapInfoModel(**obs_data.get("map_info", {})), + available_production=obs_data.get("available_production", []), + done=obs_data.get("done", False), + reward=obs_data.get("reward"), + result=obs_data.get("result", ""), + spatial_map=obs_data.get("spatial_map", ""), + spatial_channels=obs_data.get("spatial_channels", 0), + ) + + return StepResult( + observation=observation, + reward=data.get("reward", obs_data.get("reward")), + done=data.get("done", obs_data.get("done", False)), + ) + + def _parse_state(self, data: Dict[str, Any]) -> OpenRAState: + """Parse state response into OpenRAState.""" + return OpenRAState(**data) diff --git a/openra_env/config.py b/openra_env/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4ecc2d21dad7872879842d8cbee17de80bc826dc --- /dev/null +++ b/openra_env/config.py @@ -0,0 +1,535 @@ +"""Unified configuration for OpenRA-RL. + +Provides a single YAML-based configuration system with Pydantic validation. +Supports multiple override layers: + env vars > constructor overrides > config file > built-in defaults + +Usage: + from openra_env.config import load_config + config = load_config() # auto-find config.yaml + config = load_config("path/to/config.yaml") # explicit path + config = load_config(game={"mod": "cnc"}) # with overrides +""" + +import os +from pathlib import Path +from typing import Optional + +import yaml +from pydantic import BaseModel, Field, model_validator + + +# ── Pydantic Config Models ──────────────────────────────────────────── + + +class GameConfig(BaseModel): + openra_path: str = "/opt/openra" + mod: str = "ra" + map_name: str = "singles.oramap" + grpc_port: int = 9999 + headless: bool = True + record_replays: bool = False + seed: Optional[int] = None + max_ticks: int = 0 # 0 = unlimited + max_wall_time_s: int = 0 # 0 = unlimited + + +class OpponentConfig(BaseModel): + # bot_type: difficulty tiers (beginner/easy/medium/hard/brutal) + # or raw OpenRA play styles (rush/normal/turtle/naval) + # ai_slot: player slot for AI; set to "" to disable enemy spawning + bot_type: str = "easy" + ai_slot: str = "Multi0" + + +class PlanningConfig(BaseModel): + enabled: bool = True + max_turns: int = 10 + max_time_s: float = 60.0 + + +class RewardConfig(BaseModel): + survival: float = 0.001 + economic_efficiency: float = 0.01 + aggression: float = 0.1 + defense: float = 0.05 + victory: float = 1.0 + defeat: float = -1.0 + + +class RewardVectorConfig(BaseModel): + """Configuration for the multi-dimensional reward vector. + + When enabled, each step returns an 8-dimensional reward vector + (combat, economy, infrastructure, intelligence, composition, + tempo, disruption, outcome) alongside the scalar reward. + """ + + enabled: bool = True # 8-dimensional skill signal (combat, economy, etc.) + weights: dict[str, float] = Field(default_factory=lambda: { + "combat": 0.30, + "economy": 0.15, + "infrastructure": 0.10, + "intelligence": 0.10, + "composition": 0.10, + "tempo": 0.10, + "disruption": 0.15, + "outcome": 1.00, + }) + + +class ToolCategoriesConfig(BaseModel): + read: bool = True + knowledge: bool = True + bulk_knowledge: bool = True + planning: bool = True + game_control: bool = True + movement: bool = True + production: bool = True + building_actions: bool = True + placement: bool = True + unit_groups: bool = True + compound: bool = True + utility: bool = True + terrain: bool = True + + +class ToolsConfig(BaseModel): + categories: ToolCategoriesConfig = Field(default_factory=ToolCategoriesConfig) + disabled: list[str] = Field(default_factory=list) + + +class AlertsConfig(BaseModel): + under_attack: bool = True + damaged_building: bool = True + low_power: bool = True + idle_funds: bool = True + ore_full: bool = True + idle_production: bool = True + production_stalled: bool = True + building_ready: bool = True + stance_warning: bool = True + idle_army: bool = True + no_defenses: bool = True + no_scouting: bool = True + loss_tracking: bool = True + minimap: bool = True # Show ASCII minimap in turn briefing + max_alerts: int = 0 # 0 = unlimited; set >0 to cap alerts per turn + + +class LLMConfig(BaseModel): + base_url: str = "https://openrouter.ai/api/v1/chat/completions" + api_key: str = "" + model: str = "qwen/qwen3-coder-next" + max_tokens: int = 1500 + temperature: Optional[float] = None + top_p: Optional[float] = None + keep_last_messages: int = 40 + compression_strategy: str = "sliding_window" # "sliding_window" or "none" + compression_trigger: int = 0 # 0 = keep_last_messages * 2 + max_retries: int = 4 + retry_backoff_s: int = 10 + request_timeout_s: float = 120.0 + reasoning_effort: Optional[str] = None # "none", "low", "medium", "high" + extra_headers: dict[str, str] = Field( + default_factory=lambda: { + "HTTP-Referer": "https://github.com/openra-rl", + "X-Title": "OpenRA-RL Agent", + } + ) + + +class AgentConfig(BaseModel): + server_url: str = "http://localhost:8000" + max_turns: int = 0 # 0 = unlimited + max_time_s: int = 1800 + verbose: bool = False + log_file: str = "" + agent_name: str = "" # Display name on leaderboard; empty = model name + agent_type: str = "" # Scripted/LLM/RL; empty = auto-detect + agent_url: str = "" # GitHub/project URL shown on leaderboard + bench_upload: bool = True # Auto-upload results to bench after each game + bench_url: str = "https://openra-rl-openra-bench.hf.space" + system_prompt: str = "" # deprecated — use prompts.system_prompt + system_prompt_file: str = "" # deprecated — use prompts.system_prompt_file + + +class AlertPromptsConfig(BaseModel): + """Templates for in-game alert messages. + + All templates use Python str.format() placeholders (e.g. {balance}). + """ + + under_attack: str = "UNDER ATTACK: enemy {type} id={id} near base" + under_attack_mass: str = "UNDER ATTACK: {count} enemies near base ({breakdown})" + damaged: str = "DAMAGED: {type} id={id} at {hp} HP" + low_power: str = "LOW POWER: {balance} — production runs at 1/3 speed" + power_tight: str = "POWER TIGHT: {balance} surplus — next building may cause low power" + idle_funds: str = "IDLE FUNDS: ${funds} available, {harvesters} harvester(s)" + ore_full: str = "ORE FULL: {ore}/{cap} storage — income is being lost" + idle_production: str = "IDLE PRODUCTION: no active production queue" + stalled: str = "STALLED: {item}@{progress} — $0 funds, production paused" + building_stuck: str = "BUILDING STUCK: {building} — auto-placement failing" + ready_to_place: str = "READY TO PLACE: {building} — completed, awaiting placement" + stance: str = "STANCE: {count} combat unit(s) on ReturnFire (only fire when fired upon)" + idle_army: str = "IDLE ARMY: {count} combat units idle" + no_defenses: str = "NO DEFENSES: no defense structures built" + no_scouting: str = ( + "NO SCOUTING: enemy not found — {explored} of map explored, " + "{idle} idle combat units available" + ) + + +class CompressionConfig(BaseModel): + """Controls what context is preserved in history compression summaries.""" + include_strategy: bool = True # Preserve planning strategy + include_military: bool = True # Include kill/death counts + include_production: bool = True # Track what was produced + + +class PromptsConfig(BaseModel): + """All LLM-facing text, configurable for customization. + + Templates use Python str.format() placeholders. Override individual + fields in config.yaml, or point prompts_file to a YAML with all prompts. + """ + + # ── System prompt ──────────────────────────────────────────────── + system_prompt: str = "" # inline override (highest priority) + system_prompt_file: str = "" # path to .txt file override + prompts_file: str = "" # path to YAML with all prompts below + + # ── Planning phase ─────────────────────────────────────────────── + # Variables: {max_turns}, {map_name}, {map_width}, {map_height}, + # {base_x}, {base_y}, {enemy_x}, {enemy_y}, {faction}, {side}, + # {opponent_summary}, {planning_nudge} + planning_prompt: str = ( + "## PRE-GAME PLANNING PHASE\n" + "You have {max_turns} turns to plan.\n\n" + "### Map Intel\n" + "Map: {map_name} ({map_width}x{map_height})\n" + "Your base: ({base_x}, {base_y})\n" + "Enemy estimated: ({enemy_x}, {enemy_y})\n" + "Your faction: {faction} ({side})\n\n" + "### Opponent Intelligence\n{opponent_summary}\n\n" + "{planning_nudge}" + ) + planning_nudge: str = "Call end_planning_phase(strategy='...') when ready to start." + planning_instructions: str = ( + "Planning phase active. Available tools: get_faction_briefing " + "(all unit/building stats), get_map_analysis (terrain/resources), " + "get_opponent_intel (enemy profile), batch_lookup (multi-item queries). " + "Call end_planning_phase(strategy=...) to begin gameplay." + ) + planning_complete: str = "Planning complete. Game is now live." + + # ── Game start ─────────────────────────────────────────────────── + # Variables: {strategy_section}, {briefing}, {barracks_type}, {mcv_note} + game_start: str = ( + "Game started!{strategy_section}\n\n{briefing}\n\n" + "Your barracks type is '{barracks_type}'.{mcv_note}" + ) + + # ── Agent nudges ───────────────────────────────────────────────── + no_tool_nudge: str = "No tool was called. A tool call is required each turn." + continue_nudge: str = "The game is still in progress." + compression_suffix: str = "Game continues from current state." + sanitize_bridge: str = "Acknowledged. Continuing." + + # ── Tool warnings ──────────────────────────────────────────────── + # Variables: {building}, {drain}, {balance} + power_warning: str = ( + "POWER WARNING: {building} drains {drain} power. " + "Balance will be {balance}." + ) + # Variables: {available}, {item}, {cost} + insufficient_funds: str = ( + "Insufficient funds: ${available} available, " + "{item} costs ${cost}." + ) + + # ── Placement feedback ─────────────────────────────────────────── + placement_success: str = "AUTO-PLACED: {building}" + placement_failed: str = "PLACEMENT FAILED: {building} — {reason}. Auto-cancelling." + placement_water: str = "WATER BUILDING: {building} requires water tiles for placement." + + # ── Build confirmations ─────────────────────────────────────────── + # Variables: {building}, {cost}, {ticks}, {seconds} + build_queued: str = ( + "'{building}' (${cost}) queued, auto-places on completion. " + "~{ticks} ticks (~{seconds}s)." + ) + build_structure_queued: str = ( + "'{building}' (${cost}) queued. ~{ticks} ticks (~{seconds}s) to complete." + ) + # Variables: {count}, {unit}, {cost}, {ticks_each}, {ticks_total}, {seconds_total} + build_unit_queued: str = ( + "{count}x '{unit}' (${cost} each) queued. " + "~{ticks_each} ticks per unit, ~{ticks_total} ticks (~{seconds_total}s) total." + ) + + # ── Build guards ────────────────────────────────────────────────── + # Variables: {building} + build_already_pending: str = "'{building}' is already queued and pending auto-placement." + place_auto_managed: str = ( + "'{building}' is queued via build_and_place — placement is automatic." + ) + + # ── Movement feedback ──────────────────────────────────────────── + # Variables: {ticks}, {seconds} + move_eta: str = "Units moving. Slowest arrives in ~{ticks} ticks (~{seconds}s)." + + # ── Compression ────────────────────────────────────────────────── + compression: CompressionConfig = Field(default_factory=CompressionConfig) + + # ── Alerts ─────────────────────────────────────────────────────── + alerts: AlertPromptsConfig = Field(default_factory=AlertPromptsConfig) + + +class OpenRARLConfig(BaseModel): + """Root configuration for the OpenRA-RL system.""" + + game: GameConfig = Field(default_factory=GameConfig) + opponent: OpponentConfig = Field(default_factory=OpponentConfig) + planning: PlanningConfig = Field(default_factory=PlanningConfig) + reward: RewardConfig = Field(default_factory=RewardConfig) + reward_vector: RewardVectorConfig = Field(default_factory=RewardVectorConfig) + tools: ToolsConfig = Field(default_factory=ToolsConfig) + alerts: AlertsConfig = Field(default_factory=AlertsConfig) + llm: LLMConfig = Field(default_factory=LLMConfig) + agent: AgentConfig = Field(default_factory=AgentConfig) + prompts: PromptsConfig = Field(default_factory=PromptsConfig) + + @model_validator(mode="after") + def sync_planning_tools(self) -> "OpenRARLConfig": + """Auto-disable planning tools when planning is disabled.""" + if not self.planning.enabled: + self.tools.categories.planning = False + return self + + @model_validator(mode="after") + def migrate_system_prompt(self) -> "OpenRARLConfig": + """Backward compat: copy agent.system_prompt* to prompts.* if prompts.* empty.""" + if not self.prompts.system_prompt and self.agent.system_prompt: + self.prompts.system_prompt = self.agent.system_prompt + if not self.prompts.system_prompt_file and self.agent.system_prompt_file: + self.prompts.system_prompt_file = self.agent.system_prompt_file + return self + + +# ── Tool Category Mapping ───────────────────────────────────────────── + +TOOL_CATEGORIES: dict[str, str] = { + # Read + "get_game_state": "read", + "get_economy": "read", + "get_units": "read", + "get_buildings": "read", + "get_enemies": "read", + "get_production": "read", + "get_map_info": "read", + "get_exploration_status": "read", + # Knowledge + "lookup_unit": "knowledge", + "lookup_building": "knowledge", + "lookup_tech_tree": "knowledge", + "lookup_faction": "knowledge", + # Bulk Knowledge + "get_faction_briefing": "bulk_knowledge", + "get_map_analysis": "bulk_knowledge", + "batch_lookup": "bulk_knowledge", + # Planning + "get_opponent_intel": "planning", + "start_planning_phase": "planning", + "end_planning_phase": "planning", + "get_planning_status": "planning", + # Game Control + "advance": "game_control", + # Movement + "move_units": "movement", + "attack_move": "movement", + "attack_target": "movement", + "stop_units": "movement", + # Production + "build_unit": "production", + "build_structure": "production", + "build_and_place": "production", + # Building/Unit Actions + "place_building": "building_actions", + "cancel_production": "building_actions", + "deploy_unit": "building_actions", + "sell_building": "building_actions", + "repair_building": "building_actions", + "set_rally_point": "building_actions", + "guard_target": "building_actions", + "set_stance": "building_actions", + "harvest": "building_actions", + "power_down": "building_actions", + "set_primary": "building_actions", + # Placement + "get_valid_placements": "placement", + # Unit Groups + "assign_group": "unit_groups", + "add_to_group": "unit_groups", + "get_groups": "unit_groups", + "command_group": "unit_groups", + # Compound + "batch": "compound", + "plan": "compound", + # Utility + "get_replay_path": "utility", + "surrender": "utility", + # Terrain + "get_terrain_at": "terrain", +} + + +# ── Env Var Mapping ─────────────────────────────────────────────────── + +# Ordered so that more-specific vars (LLM_*) overwrite less-specific (OPENROUTER_*) +_ENV_VAR_MAP: list[tuple[str, str]] = [ + # game + ("OPENRA_PATH", "game.openra_path"), + ("RECORD_REPLAYS", "game.record_replays"), + # opponent + ("BOT_TYPE", "opponent.bot_type"), + ("AI_SLOT", "opponent.ai_slot"), + # planning + ("PLANNING_ENABLED", "planning.enabled"), + ("PLANNING_MAX_TURNS", "planning.max_turns"), + ("PLANNING_MAX_TIME", "planning.max_time_s"), + # llm — legacy OpenRouter names first, then generic LLM_ names (override) + ("OPENROUTER_API_KEY", "llm.api_key"), + ("OPENROUTER_MODEL", "llm.model"), + ("LLM_BASE_URL", "llm.base_url"), + ("LLM_API_KEY", "llm.api_key"), + ("LLM_MODEL", "llm.model"), + # agent + ("OPENRA_URL", "agent.server_url"), + ("MAX_TIME", "agent.max_time_s"), + ("LLM_AGENT_LOG", "agent.log_file"), + ("AGENT_NAME", "agent.agent_name"), + ("AGENT_TYPE", "agent.agent_type"), + ("AGENT_URL", "agent.agent_url"), + ("BENCH_UPLOAD", "agent.bench_upload"), + ("BENCH_URL", "agent.bench_url"), + ("SYSTEM_PROMPT_FILE", "agent.system_prompt_file"), + # prompts + ("SYSTEM_PROMPT_FILE", "prompts.system_prompt_file"), # also maps to prompts.* + ("PROMPTS_FILE", "prompts.prompts_file"), +] + + +# ── Helper Functions ────────────────────────────────────────────────── + + +def _deep_merge(base: dict, override: dict) -> None: + """Recursively merge *override* into *base* in place.""" + for key, value in override.items(): + if key in base and isinstance(base[key], dict) and isinstance(value, dict): + _deep_merge(base[key], value) + else: + base[key] = value + + +def _set_nested(d: dict, path: str, value: object) -> None: + """Set a value in a nested dict via dotted path (e.g. ``'game.mod'``).""" + keys = path.split(".") + for key in keys[:-1]: + d = d.setdefault(key, {}) + d[keys[-1]] = value + + +def _coerce_value(value: str) -> object: + """Coerce a string env-var value to bool / int / float / str.""" + lower = value.lower() + if lower in ("true", "1", "yes"): + return True + if lower in ("false", "0", "no"): + return False + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + return value + + +def should_register_tool(tool_name: str, tools_config: ToolsConfig) -> bool: + """Return True if *tool_name* should be registered based on config.""" + if tool_name in tools_config.disabled: + return False + category = TOOL_CATEGORIES.get(tool_name) + if category is not None: + return getattr(tools_config.categories, category, True) + return True # unknown tools default to enabled + + +# ── Config Loading ──────────────────────────────────────────────────── + + +def load_config( + config_path: Optional[str] = None, + cli_overrides: Optional[dict] = None, + **overrides: object, +) -> OpenRARLConfig: + """Load configuration with precedence: CLI > env vars > overrides > file > defaults. + + Parameters + ---------- + config_path: + Explicit path to a YAML config file. When ``None``, searches for + ``config.yaml`` in the current working directory and the project root. + cli_overrides: + Dict of overrides from explicit CLI flags. Applied last (highest + priority), beating even environment variables. Use this for values + the user typed on the command line. + **overrides: + Keyword arguments that are deep-merged on top of the file values. + Keys should be top-level section names (e.g. ``game={...}``). + """ + config_dict: dict = {} + + # 1. Load YAML file + resolved_path = _resolve_config_path(config_path) + if resolved_path is not None: + with open(resolved_path, encoding="utf-8") as f: + file_dict = yaml.safe_load(f) or {} + _deep_merge(config_dict, file_dict) + + # 2. Apply programmatic overrides (e.g. constructor args) + if overrides: + _deep_merge(config_dict, overrides) + + # 3. Apply environment variable overrides + for env_var, dotted_path in _ENV_VAR_MAP: + value = os.environ.get(env_var) + if value is not None: + _set_nested(config_dict, dotted_path, _coerce_value(value)) + + # 4. Apply CLI overrides (highest priority — explicit user intent) + if cli_overrides: + _deep_merge(config_dict, cli_overrides) + + # 5. Validate and return + return OpenRARLConfig(**config_dict) + + +def _resolve_config_path(config_path: Optional[str]) -> Optional[str]: + """Find the config file to load, or None if none exists.""" + if config_path is not None: + p = Path(config_path) + return str(p) if p.exists() else None + + # Auto-discover: CWD first, then project root + candidates = [ + Path.cwd() / "config.yaml", + Path(__file__).resolve().parent.parent / "config.yaml", + ] + for candidate in candidates: + if candidate.exists(): + return str(candidate) + return None diff --git a/openra_env/game_data.py b/openra_env/game_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6e781232a6c1ecc9eb7b1286d9161c6885a88f21 --- /dev/null +++ b/openra_env/game_data.py @@ -0,0 +1,984 @@ +"""Static Red Alert mod data for game knowledge tools. + +Provides unit stats, building stats, tech tree, and faction information +extracted from OpenRA Red Alert mod rules. This gives an LLM agent the same +reference knowledge a human player would have from experience. +""" + +from typing import Optional + + +# ─── Unit Data ──────────────────────────────────────────────────────────────── + +RA_UNITS: dict[str, dict] = { + # Infantry + "e1": { + "name": "Rifle Infantry", + "category": "infantry", + "cost": 100, + "hp": 5000, + "speed": 56, + "armor": "none", + "side": "both", + "prerequisites": ["barr|tent"], + "description": "Basic infantry unit. Cheap and fast to produce.", + }, + "e2": { + "name": "Grenadier", + "category": "infantry", + "cost": 150, + "hp": 5000, + "speed": 56, + "armor": "none", + "side": "both", + "prerequisites": ["barr|tent"], + "description": "Anti-structure infantry. Grenades deal area damage.", + }, + "e3": { + "name": "Rocket Soldier", + "category": "infantry", + "cost": 300, + "hp": 4500, + "speed": 56, + "armor": "none", + "side": "both", + "prerequisites": ["barr|tent"], + "description": "Anti-armor and anti-air infantry.", + }, + "e4": { + "name": "Flamethrower", + "category": "infantry", + "cost": 300, + "hp": 4000, + "speed": 56, + "armor": "none", + "side": "soviet", + "prerequisites": ["barr", "ftur"], + "description": "Short-range anti-infantry/structure. Soviet only.", + }, + "e6": { + "name": "Engineer", + "category": "infantry", + "cost": 400, + "hp": 4000, + "speed": 56, + "armor": "none", + "side": "both", + "prerequisites": ["barr|tent"], + "description": "Captures enemy buildings. Cannot attack.", + }, + "e7": { + "name": "Tanya", + "category": "infantry", + "cost": 1800, + "hp": 10000, + "speed": 68, + "armor": "none", + "side": "allied", + "prerequisites": ["tent", "atek"], + "build_limit": 1, + "description": "Elite commando. Destroys buildings with C4, kills infantry instantly. Allied only.", + }, + "medi": { + "name": "Medic", + "category": "infantry", + "cost": 200, + "hp": 6000, + "speed": 49, + "armor": "none", + "side": "allied", + "prerequisites": ["tent"], + "description": "Heals nearby infantry. Cannot attack.", + }, + "mech": { + "name": "Mechanic", + "category": "infantry", + "cost": 500, + "hp": 8000, + "speed": 49, + "armor": "none", + "side": "allied", + "prerequisites": ["tent", "fix"], + "description": "Repairs nearby vehicles. Cannot attack.", + }, + "spy": { + "name": "Spy", + "category": "infantry", + "cost": 500, + "hp": 2500, + "speed": 56, + "armor": "none", + "side": "allied", + "prerequisites": ["tent", "dome"], + "description": "Disguises as enemy infantry. Infiltrates buildings for bonuses.", + }, + "thf": { + "name": "Thief", + "category": "infantry", + "cost": 500, + "hp": 5000, + "speed": 68, + "armor": "none", + "side": "allied", + "prerequisites": ["tent", "dome"], + "description": "Steals credits from enemy refineries.", + }, + "shok": { + "name": "Shock Trooper", + "category": "infantry", + "cost": 350, + "hp": 5000, + "speed": 49, + "armor": "none", + "side": "soviet", + "prerequisites": ["barr", "stek", "tsla"], + "description": "Tesla infantry. High damage vs all targets. Soviet only.", + }, + "dog": { + "name": "Attack Dog", + "category": "infantry", + "cost": 200, + "hp": 2000, + "speed": 99, + "armor": "none", + "side": "soviet", + "prerequisites": ["kenn"], + "description": "Fast anti-infantry unit. Kills spies. Soviet only.", + }, + + # Vehicles + "1tnk": { + "name": "Light Tank", + "category": "vehicle", + "cost": 700, + "hp": 23000, + "speed": 113, + "armor": "heavy", + "side": "allied", + "prerequisites": ["weap"], + "description": "Fast medium tank. Good all-around. Allied only.", + }, + "2tnk": { + "name": "Medium Tank", + "category": "vehicle", + "cost": 850, + "hp": 30000, + "speed": 72, + "armor": "heavy", + "side": "allied", + "prerequisites": ["weap", "fix"], + "description": "Main battle tank. Balanced stats. Allied only. Requires Repair Facility.", + }, + "3tnk": { + "name": "Heavy Tank", + "category": "vehicle", + "cost": 1150, + "hp": 46000, + "speed": 64, + "armor": "heavy", + "side": "soviet", + "prerequisites": ["weap", "fix"], + "description": "Powerful main battle tank. Dual cannons. Soviet only. Requires Repair Facility.", + }, + "4tnk": { + "name": "Mammoth Tank", + "category": "vehicle", + "cost": 2000, + "hp": 60000, + "speed": 43, + "armor": "heavy", + "side": "soviet", + "prerequisites": ["weap", "fix", "stek"], + "description": "Heaviest tank. Dual cannons + missiles. Self-healing. Soviet only.", + }, + "v2rl": { + "name": "V2 Rocket Launcher", + "category": "vehicle", + "cost": 900, + "hp": 15000, + "speed": 72, + "armor": "light", + "side": "soviet", + "prerequisites": ["weap", "dome"], + "description": "Long-range artillery. High damage, inaccurate. Soviet only.", + }, + "jeep": { + "name": "Ranger", + "category": "vehicle", + "cost": 500, + "hp": 15000, + "speed": 164, + "armor": "light", + "side": "allied", + "prerequisites": ["weap"], + "description": "Fast scout vehicle with machine gun. Allied only.", + }, + "apc": { + "name": "APC", + "category": "vehicle", + "cost": 850, + "hp": 20000, + "speed": 128, + "armor": "heavy", + "side": "soviet", + "prerequisites": ["weap"], + "description": "Armored troop transport. Carries 5 infantry. Soviet only.", + }, + "arty": { + "name": "Artillery", + "category": "vehicle", + "cost": 850, + "hp": 7500, + "speed": 54, + "armor": "light", + "side": "allied", + "prerequisites": ["weap", "dome"], + "description": "Long-range siege weapon. Allied only.", + }, + "harv": { + "name": "Ore Truck", + "category": "vehicle", + "cost": 1100, + "hp": 60000, + "speed": 72, + "armor": "heavy", + "side": "both", + "prerequisites": ["proc"], + "description": "Harvests ore and delivers to refinery. Free with refinery.", + }, + "mcv": { + "name": "MCV", + "category": "vehicle", + "cost": 2000, + "hp": 60000, + "speed": 60, + "armor": "light", + "side": "both", + "prerequisites": ["weap", "fix"], + "description": "Deploys into Construction Yard. Mobile base.", + }, + "ftrk": { + "name": "Flak Truck", + "category": "vehicle", + "cost": 600, + "hp": 15000, + "speed": 113, + "armor": "light", + "side": "soviet", + "prerequisites": ["weap"], + "description": "Mobile anti-air unit. Soviet only.", + }, + "mnly": { + "name": "Minelayer", + "category": "vehicle", + "cost": 800, + "hp": 30000, + "speed": 113, + "armor": "heavy", + "side": "both", + "prerequisites": ["weap", "fix"], + "description": "Lays anti-tank mines.", + }, + "ttnk": { + "name": "Tesla Tank", + "category": "vehicle", + "cost": 1350, + "hp": 30000, + "speed": 92, + "armor": "light", + "side": "soviet", + "prerequisites": ["weap", "stek", "tsla"], + "description": "Tesla weapon on tracks. Effective vs all targets. Soviet only.", + }, + "ctnk": { + "name": "Chrono Tank", + "category": "vehicle", + "cost": 1350, + "hp": 20000, + "speed": 86, + "armor": "light", + "side": "allied", + "prerequisites": ["weap", "atek"], + "description": "Teleporting tank. Hit and run tactics. Allied only.", + }, + "stnk": { + "name": "Phase Transport", + "category": "vehicle", + "cost": 1000, + "hp": 11000, + "speed": 128, + "armor": "light", + "side": "allied", + "prerequisites": ["weap", "atek"], + "description": "Cloaked APC. Invisible when not firing. Allied only.", + }, + "qtnk": { + "name": "MAD Tank", + "category": "vehicle", + "cost": 2000, + "hp": 22000, + "speed": 46, + "armor": "heavy", + "side": "soviet", + "prerequisites": ["weap", "stek"], + "description": "Deploys seismic charge, destroying self and nearby vehicles. Soviet only.", + }, + "dtrk": { + "name": "Demolition Truck", + "category": "vehicle", + "cost": 2500, + "hp": 11000, + "speed": 113, + "armor": "light", + "side": "soviet", + "prerequisites": ["weap", "stek"], + "description": "Suicide vehicle. Massive area nuclear explosion on death. Soviet only.", + }, + "mgg": { + "name": "Mobile Gap Generator", + "category": "vehicle", + "cost": 1000, + "hp": 11000, + "speed": 72, + "armor": "heavy", + "side": "allied", + "prerequisites": ["weap", "atek"], + "description": "Creates mobile shroud area. Allied only.", + }, + "mrj": { + "name": "Mobile Radar Jammer", + "category": "vehicle", + "cost": 1000, + "hp": 11000, + "speed": 68, + "armor": "heavy", + "side": "allied", + "prerequisites": ["weap", "atek"], + "description": "Jams enemy radar in area. Allied only.", + }, + "truk": { + "name": "Supply Truck", + "category": "vehicle", + "cost": 500, + "hp": 11000, + "speed": 113, + "armor": "light", + "side": "both", + "prerequisites": ["weap"], + "description": "Delivers cash when reaching allied structures.", + }, + + # Aircraft + "heli": { + "name": "Longbow", + "category": "aircraft", + "cost": 2000, + "hp": 12000, + "speed": 149, + "armor": "light", + "side": "allied", + "prerequisites": ["hpad"], + "description": "Anti-armor helicopter with missiles. Allied only.", + }, + "hind": { + "name": "Hind", + "category": "aircraft", + "cost": 1500, + "hp": 12000, + "speed": 112, + "armor": "light", + "side": "soviet", + "prerequisites": ["afld"], + "description": "Anti-ground attack helicopter. Soviet only.", + }, + "mh60": { + "name": "Black Hawk", + "category": "aircraft", + "cost": 1500, + "hp": 12000, + "speed": 112, + "armor": "light", + "side": "allied", + "prerequisites": ["hpad"], + "description": "Transport/attack helicopter. Allied only.", + }, + "tran": { + "name": "Chinook", + "category": "aircraft", + "cost": 900, + "hp": 14000, + "speed": 128, + "armor": "light", + "side": "both", + "prerequisites": ["hpad|afld"], + "description": "Transport helicopter. Carries 5 infantry.", + }, + "yak": { + "name": "Yak", + "category": "aircraft", + "cost": 1350, + "hp": 6000, + "speed": 178, + "armor": "light", + "side": "soviet", + "prerequisites": ["afld"], + "description": "Fast anti-infantry attack plane. Soviet only.", + }, + "mig": { + "name": "MiG", + "category": "aircraft", + "cost": 2000, + "hp": 8000, + "speed": 223, + "armor": "light", + "side": "soviet", + "prerequisites": ["afld", "stek"], + "description": "Anti-structure/armor attack plane with missiles. Soviet only.", + }, + + # Ships + "ss": { + "name": "Submarine", + "category": "ship", + "cost": 950, + "hp": 25000, + "speed": 78, + "armor": "light", + "side": "soviet", + "prerequisites": ["spen"], + "description": "Invisible anti-ship unit. Soviet only.", + }, + "dd": { + "name": "Destroyer", + "category": "ship", + "cost": 1000, + "hp": 40000, + "speed": 92, + "armor": "heavy", + "side": "allied", + "prerequisites": ["syrd", "dome"], + "description": "Multi-role warship. Anti-sub, anti-air, anti-surface. Allied only.", + }, + "ca": { + "name": "Cruiser", + "category": "ship", + "cost": 2400, + "hp": 80000, + "speed": 44, + "armor": "heavy", + "side": "allied", + "prerequisites": ["syrd", "atek"], + "description": "Heavy bombardment ship. Long range. Allied only.", + }, + "pt": { + "name": "Gunboat", + "category": "ship", + "cost": 500, + "hp": 20000, + "speed": 142, + "armor": "heavy", + "side": "both", + "prerequisites": ["syrd|spen"], + "description": "Fast patrol boat.", + }, + "lst": { + "name": "Transport", + "category": "ship", + "cost": 500, + "hp": 40000, + "speed": 115, + "armor": "heavy", + "side": "both", + "prerequisites": ["syrd|spen"], + "description": "Naval transport. Carries vehicles and infantry.", + }, + "msub": { + "name": "Missile Submarine", + "category": "ship", + "cost": 2000, + "hp": 40000, + "speed": 44, + "armor": "light", + "side": "soviet", + "prerequisites": ["spen", "stek"], + "description": "Long-range missile submarine. Soviet only.", + }, +} + + +# ─── Building Data ──────────────────────────────────────────────────────────── + +RA_BUILDINGS: dict[str, dict] = { + "fact": { + "name": "Construction Yard", + "cost": 2000, + "hp": 150000, + "power": 0, + "side": "both", + "prerequisites": [], + "produces": ["Building", "Defense"], + "description": "Primary base structure. Required to build other structures.", + }, + "powr": { + "name": "Power Plant", + "cost": 300, + "hp": 40000, + "power": 100, + "side": "both", + "prerequisites": [], + "produces": [], + "description": "Basic power supply. Most structures need power to function.", + }, + "apwr": { + "name": "Advanced Power Plant", + "cost": 500, + "hp": 70000, + "power": 200, + "side": "both", + "prerequisites": ["dome"], + "produces": [], + "description": "Double power output. Requires radar dome tech.", + }, + "barr": { + "name": "Soviet Barracks", + "cost": 500, + "hp": 60000, + "power": -20, + "side": "soviet", + "prerequisites": ["powr"], + "produces": ["Infantry"], + "description": "Soviet infantry production. Required for all Soviet infantry.", + }, + "tent": { + "name": "Allied Barracks", + "cost": 500, + "hp": 60000, + "power": -20, + "side": "allied", + "prerequisites": ["powr"], + "produces": ["Infantry"], + "description": "Allied infantry production. Required for all Allied infantry.", + }, + "proc": { + "name": "Ore Refinery", + "cost": 1400, + "hp": 90000, + "power": -30, + "side": "both", + "prerequisites": ["powr"], + "produces": [], + "description": "Processes ore into credits. Comes with a free Ore Truck.", + }, + "weap": { + "name": "War Factory", + "cost": 2000, + "hp": 150000, + "power": -30, + "side": "both", + "prerequisites": ["proc"], + "produces": ["Vehicle"], + "description": "Vehicle production facility. Required for all vehicles.", + }, + "dome": { + "name": "Radar Dome", + "cost": 1500, + "hp": 100000, + "power": -40, + "side": "both", + "prerequisites": ["proc"], + "produces": [], + "description": "Provides minimap radar. Unlocks advanced tech.", + }, + "fix": { + "name": "Service Depot", + "cost": 1200, + "hp": 80000, + "power": -30, + "side": "both", + "prerequisites": ["weap"], + "produces": [], + "description": "Repairs vehicles. Unlocks MCV and Minelayer.", + }, + "atek": { + "name": "Allied Tech Center", + "cost": 1500, + "hp": 60000, + "power": -200, + "side": "allied", + "prerequisites": ["dome", "weap"], + "produces": [], + "description": "Unlocks advanced Allied units. GPS satellite.", + }, + "stek": { + "name": "Soviet Tech Center", + "cost": 1500, + "hp": 80000, + "power": -100, + "side": "soviet", + "prerequisites": ["dome", "weap"], + "produces": [], + "description": "Unlocks advanced Soviet units.", + }, + "hpad": { + "name": "Helipad", + "cost": 500, + "hp": 80000, + "power": -10, + "side": "allied", + "prerequisites": ["dome"], + "produces": ["Aircraft"], + "description": "Allied aircraft production. Rearming pad.", + }, + "afld": { + "name": "Airfield", + "cost": 500, + "hp": 100000, + "power": -20, + "side": "soviet", + "prerequisites": ["dome"], + "produces": ["Aircraft"], + "description": "Soviet aircraft production. Rearming strip.", + }, + "spen": { + "name": "Sub Pen", + "cost": 800, + "hp": 100000, + "power": -20, + "side": "soviet", + "prerequisites": ["powr"], + "produces": ["Ship"], + "terrain": "water", + "description": "Soviet naval production. Repairs ships. REQUIRES WATER — cannot build on land maps.", + }, + "syrd": { + "name": "Naval Yard", + "cost": 1000, + "hp": 100000, + "power": -20, + "side": "allied", + "prerequisites": ["powr"], + "produces": ["Ship"], + "terrain": "water", + "description": "Allied naval production. Repairs ships. REQUIRES WATER — cannot build on land maps.", + }, + "silo": { + "name": "Ore Silo", + "cost": 150, + "hp": 30000, + "power": -10, + "side": "both", + "prerequisites": ["proc"], + "produces": [], + "description": "Additional ore storage capacity.", + }, + "kenn": { + "name": "Kennel", + "cost": 200, + "hp": 30000, + "power": -10, + "side": "soviet", + "prerequisites": ["powr"], + "produces": ["Infantry"], + "description": "Produces attack dogs. Soviet only.", + }, + + # Defenses + "pbox": { + "name": "Pillbox", + "cost": 600, + "hp": 40000, + "power": 0, + "side": "allied", + "prerequisites": ["tent"], + "produces": [], + "description": "Anti-infantry defense turret. Allied only.", + }, + "hbox": { + "name": "Camo Pillbox", + "cost": 750, + "hp": 40000, + "power": 0, + "side": "allied", + "prerequisites": ["tent"], + "produces": [], + "description": "Hidden anti-infantry defense. Allied only.", + }, + "gun": { + "name": "Turret", + "cost": 800, + "hp": 40000, + "power": -20, + "side": "allied", + "prerequisites": ["weap"], + "produces": [], + "description": "Anti-armor defense turret. Allied only.", + }, + "ftur": { + "name": "Flame Tower", + "cost": 600, + "hp": 40000, + "power": -20, + "side": "soviet", + "prerequisites": ["barr"], + "produces": [], + "description": "Short-range anti-infantry defense. Soviet only.", + }, + "tsla": { + "name": "Tesla Coil", + "cost": 1200, + "hp": 40000, + "power": -75, + "side": "soviet", + "prerequisites": ["weap"], + "produces": [], + "description": "Powerful anti-ground defense. High power cost. Soviet only.", + }, + "agun": { + "name": "AA Gun", + "cost": 800, + "hp": 40000, + "power": -50, + "side": "allied", + "prerequisites": ["dome"], + "produces": [], + "description": "Anti-air defense turret. Allied only.", + }, + "sam": { + "name": "SAM Site", + "cost": 700, + "hp": 40000, + "power": -20, + "side": "soviet", + "prerequisites": ["dome"], + "produces": [], + "description": "Anti-air missile defense. Soviet only.", + }, + "gap": { + "name": "Gap Generator", + "cost": 800, + "hp": 50000, + "power": -60, + "side": "allied", + "prerequisites": ["atek"], + "produces": [], + "description": "Creates shroud area over your base. Allied only.", + }, + + # Superweapons + "iron": { + "name": "Iron Curtain", + "cost": 2000, + "hp": 100000, + "power": -200, + "side": "soviet", + "prerequisites": ["stek"], + "produces": [], + "build_limit": 1, + "description": "Superweapon: Makes one unit/building invulnerable temporarily.", + }, + "pdox": { + "name": "Chronosphere", + "cost": 1500, + "hp": 100000, + "power": -200, + "side": "allied", + "prerequisites": ["atek"], + "produces": [], + "build_limit": 1, + "description": "Superweapon: Teleports units across the map.", + }, + "mslo": { + "name": "Missile Silo", + "cost": 2500, + "hp": 100000, + "power": -150, + "side": "soviet", + "prerequisites": ["stek"], + "produces": [], + "build_limit": 1, + "description": "Superweapon: Launches nuclear missile at target location.", + }, +} + + +# ─── Tech Tree ──────────────────────────────────────────────────────────────── + +RA_TECH_TREE: dict[str, list[str]] = { + "soviet": [ + "powr", # Power Plant (base) + "barr", # Barracks → infantry (requires powr) + "kenn", # Kennel → dogs (requires powr) + "proc", # Ore Refinery (requires powr) + "weap", # War Factory (requires proc) + "spen", # Sub Pen (requires powr, needs water) + "dome", # Radar Dome (requires proc) + "fix", # Service Depot (requires weap) + "afld", # Airfield (requires dome) + "stek", # Tech Center (requires dome + weap) + "tsla", # Tesla Coil (requires weap) + "sam", # SAM Site (requires dome) + "ftur", # Flame Tower (requires barr) + "iron", # Iron Curtain (requires stek) + "mslo", # Missile Silo (requires stek) + ], + "allied": [ + "powr", # Power Plant (base) + "tent", # Barracks → infantry (requires powr) + "proc", # Ore Refinery (requires powr) + "weap", # War Factory (requires proc) + "syrd", # Naval Yard (requires powr, needs water) + "dome", # Radar Dome (requires proc) + "fix", # Service Depot (requires weap) + "hpad", # Helipad (requires dome) + "atek", # Tech Center (requires dome + weap) + "gun", # Turret (requires weap) + "pbox", # Pillbox (requires tent) + "agun", # AA Gun (requires dome) + "gap", # Gap Generator (requires atek) + "pdox", # Chronosphere (requires atek) + ], +} + + +# ─── Faction Data ───────────────────────────────────────────────────────────── + +RA_FACTIONS: dict[str, dict] = { + "england": { + "side": "allied", + "display_name": "England", + "unique_units": [], + "description": "Standard Allied faction.", + }, + "france": { + "side": "allied", + "display_name": "France", + "unique_units": ["stnk"], + "description": "Allied faction with Phase Transport (cloaked APC).", + }, + "germany": { + "side": "allied", + "display_name": "Germany", + "unique_units": ["ctnk"], + "description": "Allied faction with Chrono Tank (teleporting tank).", + }, + "russia": { + "side": "soviet", + "display_name": "Russia", + "unique_units": ["ttnk"], + "description": "Soviet faction with Tesla Tank.", + }, + "ukraine": { + "side": "soviet", + "display_name": "Ukraine", + "unique_units": ["dtrk"], + "description": "Soviet faction with Demolition Truck (nuclear suicide vehicle).", + }, +} + + +# ─── Query Functions ────────────────────────────────────────────────────────── + + +def get_unit_stats(unit_type: str) -> Optional[dict]: + """Get stats for a unit type. Returns None if not found.""" + return RA_UNITS.get(unit_type.lower()) + + +def get_building_stats(building_type: str) -> Optional[dict]: + """Get stats for a building type. Returns None if not found.""" + return RA_BUILDINGS.get(building_type.lower()) + + +def get_tech_tree(faction: Optional[str] = None) -> dict: + """Get the tech tree build order. + + Args: + faction: Faction name (e.g., 'russia') or side ('allied', 'soviet'). + If None, returns both sides. + """ + if faction is None: + return RA_TECH_TREE + + # Map faction to side + side = faction.lower() + if side in RA_FACTIONS: + side = RA_FACTIONS[side]["side"] + + if side in RA_TECH_TREE: + return {side: RA_TECH_TREE[side]} + + return {} + + +def get_faction_info(faction: str) -> Optional[dict]: + """Get faction info including available units and buildings.""" + faction = faction.lower() + info = RA_FACTIONS.get(faction) + if info is None: + return None + + side = info["side"] + + # Collect units available to this faction + available_units = [] + for unit_type, data in RA_UNITS.items(): + unit_side = data.get("side", "") + if unit_side == "both" or unit_side == side: + available_units.append(unit_type) + + # Add faction-unique units + for u in info.get("unique_units", []): + if u not in available_units and u in RA_UNITS: + available_units.append(u) + + # Collect buildings + available_buildings = [] + for bldg_type, data in RA_BUILDINGS.items(): + bldg_side = data.get("side", "") + if bldg_side == "both" or bldg_side == side: + available_buildings.append(bldg_type) + + return { + **info, + "faction": faction, + "available_units": sorted(available_units), + "available_buildings": sorted(available_buildings), + } + + +def get_all_unit_types() -> list[str]: + """Get all available unit type names.""" + return sorted(RA_UNITS.keys()) + + +def get_all_building_types() -> list[str]: + """Get all available building type names.""" + return sorted(RA_BUILDINGS.keys()) + + +def get_all_units_for_side(side: str) -> dict[str, dict]: + """Get all units available to a side ('allied' or 'soviet') with full stats. + + Returns dict keyed by unit type name, each value is the full stats dict. + Includes units with side='both' plus units specific to the given side. + """ + side = side.lower() + return { + utype: dict(data) + for utype, data in RA_UNITS.items() + if data.get("side") in (side, "both") + } + + +def get_all_buildings_for_side(side: str) -> dict[str, dict]: + """Get all buildings available to a side ('allied' or 'soviet') with full stats. + + Returns dict keyed by building type name, each value is the full stats dict. + Includes buildings with side='both' plus buildings specific to the given side. + """ + side = side.lower() + return { + btype: dict(data) + for btype, data in RA_BUILDINGS.items() + if data.get("side") in (side, "both") + } diff --git a/openra_env/generated/__init__.py b/openra_env/generated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/openra_env/generated/rl_bridge_pb2.py b/openra_env/generated/rl_bridge_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..27ef5b0f8c1a8726a8e6d4a81d58bae6912e7c73 --- /dev/null +++ b/openra_env/generated/rl_bridge_pb2.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: rl_bridge.proto +# Protobuf Python Version: 6.31.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 31, + 1, + '', + 'rl_bridge.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0frl_bridge.proto\x12\topenra.rl\"\x97\x04\n\x0fGameObservation\x12\x0c\n\x04tick\x18\x01 \x01(\x05\x12\x12\n\nepisode_id\x18\x02 \x01(\t\x12%\n\x07\x65\x63onomy\x18\x03 \x01(\x0b\x32\x14.openra.rl.RlEconomy\x12\'\n\x08military\x18\x04 \x01(\x0b\x32\x15.openra.rl.RlMilitary\x12$\n\x05units\x18\x05 \x03(\x0b\x32\x15.openra.rl.RlUnitInfo\x12,\n\tbuildings\x18\x06 \x03(\x0b\x32\x19.openra.rl.RlBuildingInfo\x12/\n\nproduction\x18\x07 \x03(\x0b\x32\x1b.openra.rl.RlProductionInfo\x12.\n\x0fvisible_enemies\x18\x08 \x03(\x0b\x32\x15.openra.rl.RlUnitInfo\x12&\n\x08map_info\x18\t \x01(\x0b\x32\x14.openra.rl.RlMapInfo\x12\x13\n\x0bspatial_map\x18\n \x01(\x0c\x12\x18\n\x10spatial_channels\x18\x0b \x01(\x05\x12\x0c\n\x04\x64one\x18\x0c \x01(\x08\x12\x0e\n\x06reward\x18\r \x01(\x02\x12\x0e\n\x06result\x18\x0e \x01(\t\x12\x1c\n\x14\x61vailable_production\x18\x0f \x03(\t\x12:\n\x17visible_enemy_buildings\x18\x10 \x03(\x0b\x32\x19.openra.rl.RlBuildingInfo\"\x89\x01\n\tRlEconomy\x12\x0c\n\x04\x63\x61sh\x18\x01 \x01(\x05\x12\x0b\n\x03ore\x18\x02 \x01(\x05\x12\x16\n\x0epower_provided\x18\x03 \x01(\x05\x12\x15\n\rpower_drained\x18\x04 \x01(\x05\x12\x19\n\x11resource_capacity\x18\x05 \x01(\x05\x12\x17\n\x0fharvester_count\x18\x06 \x01(\x05\"\xff\x01\n\nRlMilitary\x12\x14\n\x0cunits_killed\x18\x01 \x01(\x05\x12\x12\n\nunits_lost\x18\x02 \x01(\x05\x12\x18\n\x10\x62uildings_killed\x18\x03 \x01(\x05\x12\x16\n\x0e\x62uildings_lost\x18\x04 \x01(\x05\x12\x12\n\narmy_value\x18\x05 \x01(\x05\x12\x19\n\x11\x61\x63tive_unit_count\x18\x06 \x01(\x05\x12\x12\n\nkills_cost\x18\x07 \x01(\x05\x12\x13\n\x0b\x64\x65\x61ths_cost\x18\x08 \x01(\x05\x12\x14\n\x0c\x61ssets_value\x18\t \x01(\x05\x12\x12\n\nexperience\x18\n \x01(\x05\x12\x13\n\x0border_count\x18\x0b \x01(\x05\"\xe7\x02\n\nRlUnitInfo\x12\x10\n\x08\x61\x63tor_id\x18\x01 \x01(\r\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05pos_x\x18\x03 \x01(\x05\x12\r\n\x05pos_y\x18\x04 \x01(\x05\x12\x0e\n\x06\x63\x65ll_x\x18\x05 \x01(\x05\x12\x0e\n\x06\x63\x65ll_y\x18\x06 \x01(\x05\x12\x12\n\nhp_percent\x18\x07 \x01(\x02\x12\x0f\n\x07is_idle\x18\x08 \x01(\x08\x12\x18\n\x10\x63urrent_activity\x18\t \x01(\t\x12\r\n\x05owner\x18\n \x01(\t\x12\x0c\n\x04\x61mmo\x18\x0b \x01(\x05\x12\x12\n\ncan_attack\x18\x0c \x01(\x08\x12\x0e\n\x06\x66\x61\x63ing\x18\r \x01(\x05\x12\x18\n\x10\x65xperience_level\x18\x0e \x01(\x05\x12\x0e\n\x06stance\x18\x0f \x01(\x05\x12\r\n\x05speed\x18\x10 \x01(\x05\x12\x14\n\x0c\x61ttack_range\x18\x11 \x01(\x05\x12\x17\n\x0fpassenger_count\x18\x12 \x01(\x05\x12\x13\n\x0bis_building\x18\x13 \x01(\x08\"\xe7\x02\n\x0eRlBuildingInfo\x12\x10\n\x08\x61\x63tor_id\x18\x01 \x01(\r\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05pos_x\x18\x03 \x01(\x05\x12\r\n\x05pos_y\x18\x04 \x01(\x05\x12\x12\n\nhp_percent\x18\x05 \x01(\x02\x12\r\n\x05owner\x18\x06 \x01(\t\x12\x14\n\x0cis_producing\x18\x07 \x01(\x08\x12\x1b\n\x13production_progress\x18\x08 \x01(\x02\x12\x16\n\x0eproducing_item\x18\t \x01(\t\x12\x12\n\nis_powered\x18\n \x01(\x08\x12\x14\n\x0cis_repairing\x18\x0b \x01(\x08\x12\x12\n\nsell_value\x18\x0c \x01(\x05\x12\x0f\n\x07rally_x\x18\r \x01(\x05\x12\x0f\n\x07rally_y\x18\x0e \x01(\x05\x12\x14\n\x0cpower_amount\x18\x0f \x01(\x05\x12\x13\n\x0b\x63\x61n_produce\x18\x10 \x03(\t\x12\x0e\n\x06\x63\x65ll_x\x18\x11 \x01(\x05\x12\x0e\n\x06\x63\x65ll_y\x18\x12 \x01(\x05\"\x87\x01\n\x10RlProductionInfo\x12\x12\n\nqueue_type\x18\x01 \x01(\t\x12\x0c\n\x04item\x18\x02 \x01(\t\x12\x10\n\x08progress\x18\x03 \x01(\x02\x12\x17\n\x0fremaining_ticks\x18\x04 \x01(\x05\x12\x16\n\x0eremaining_cost\x18\x05 \x01(\x05\x12\x0e\n\x06paused\x18\x06 \x01(\x08\"<\n\tRlMapInfo\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x10\n\x08map_name\x18\x03 \x01(\t\"3\n\x0b\x41gentAction\x12$\n\x08\x63ommands\x18\x01 \x03(\x0b\x32\x12.openra.rl.Command\"\xa2\x01\n\x07\x43ommand\x12%\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x15.openra.rl.ActionType\x12\x10\n\x08\x61\x63tor_id\x18\x02 \x01(\r\x12\x17\n\x0ftarget_actor_id\x18\x03 \x01(\r\x12\x10\n\x08target_x\x18\x04 \x01(\x05\x12\x10\n\x08target_y\x18\x05 \x01(\x05\x12\x11\n\titem_type\x18\x06 \x01(\t\x12\x0e\n\x06queued\x18\x07 \x01(\x08\"\x91\x01\n\tGameState\x12\x12\n\nepisode_id\x18\x01 \x01(\t\x12\x0c\n\x04tick\x18\x02 \x01(\x05\x12\r\n\x05phase\x18\x03 \x01(\t\x12\x0e\n\x06winner\x18\x04 \x01(\t\x12\x14\n\x0cplayer_count\x18\x05 \x01(\x05\x12\x16\n\x0eplayer_faction\x18\x06 \x01(\t\x12\x15\n\renemy_faction\x18\x07 \x01(\t\"\x0e\n\x0cStateRequest*\xb9\x02\n\nActionType\x12\t\n\x05NO_OP\x10\x00\x12\x08\n\x04MOVE\x10\x01\x12\x0f\n\x0b\x41TTACK_MOVE\x10\x02\x12\n\n\x06\x41TTACK\x10\x03\x12\x08\n\x04STOP\x10\x04\x12\x0b\n\x07HARVEST\x10\x05\x12\t\n\x05\x42UILD\x10\x06\x12\t\n\x05TRAIN\x10\x07\x12\n\n\x06\x44\x45PLOY\x10\x08\x12\x08\n\x04SELL\x10\t\x12\n\n\x06REPAIR\x10\n\x12\x12\n\x0ePLACE_BUILDING\x10\x0b\x12\x15\n\x11\x43\x41NCEL_PRODUCTION\x10\x0c\x12\x13\n\x0fSET_RALLY_POINT\x10\r\x12\t\n\x05GUARD\x10\x0e\x12\x0e\n\nSET_STANCE\x10\x0f\x12\x13\n\x0f\x45NTER_TRANSPORT\x10\x10\x12\n\n\x06UNLOAD\x10\x11\x12\x0e\n\nPOWER_DOWN\x10\x12\x12\x0f\n\x0bSET_PRIMARY\x10\x13\x12\r\n\tSURRENDER\x10\x14\x32\x8c\x01\n\x08RLBridge\x12\x45\n\x0bGameSession\x12\x16.openra.rl.AgentAction\x1a\x1a.openra.rl.GameObservation(\x01\x30\x01\x12\x39\n\x08GetState\x12\x17.openra.rl.StateRequest\x1a\x14.openra.rl.GameStateB\x18\xaa\x02\x15OpenRA.Mods.Common.RLb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rl_bridge_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\252\002\025OpenRA.Mods.Common.RL' + _globals['_ACTIONTYPE']._serialized_start=2273 + _globals['_ACTIONTYPE']._serialized_end=2586 + _globals['_GAMEOBSERVATION']._serialized_start=31 + _globals['_GAMEOBSERVATION']._serialized_end=566 + _globals['_RLECONOMY']._serialized_start=569 + _globals['_RLECONOMY']._serialized_end=706 + _globals['_RLMILITARY']._serialized_start=709 + _globals['_RLMILITARY']._serialized_end=964 + _globals['_RLUNITINFO']._serialized_start=967 + _globals['_RLUNITINFO']._serialized_end=1326 + _globals['_RLBUILDINGINFO']._serialized_start=1329 + _globals['_RLBUILDINGINFO']._serialized_end=1688 + _globals['_RLPRODUCTIONINFO']._serialized_start=1691 + _globals['_RLPRODUCTIONINFO']._serialized_end=1826 + _globals['_RLMAPINFO']._serialized_start=1828 + _globals['_RLMAPINFO']._serialized_end=1888 + _globals['_AGENTACTION']._serialized_start=1890 + _globals['_AGENTACTION']._serialized_end=1941 + _globals['_COMMAND']._serialized_start=1944 + _globals['_COMMAND']._serialized_end=2106 + _globals['_GAMESTATE']._serialized_start=2109 + _globals['_GAMESTATE']._serialized_end=2254 + _globals['_STATEREQUEST']._serialized_start=2256 + _globals['_STATEREQUEST']._serialized_end=2270 + _globals['_RLBRIDGE']._serialized_start=2589 + _globals['_RLBRIDGE']._serialized_end=2729 +# @@protoc_insertion_point(module_scope) diff --git a/openra_env/generated/rl_bridge_pb2_grpc.py b/openra_env/generated/rl_bridge_pb2_grpc.py new file mode 100644 index 0000000000000000000000000000000000000000..30acedd33bf526b4fe8231edc5c2940423c91a6a --- /dev/null +++ b/openra_env/generated/rl_bridge_pb2_grpc.py @@ -0,0 +1,148 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from openra_env.generated import rl_bridge_pb2 as rl__bridge__pb2 + +GRPC_GENERATED_VERSION = '1.75.1' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + ' but the generated code in rl_bridge_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class RLBridgeStub(object): + """The RL Bridge service allows an external agent to interact with OpenRA + via bidirectional streaming (lock-step) or unary state queries. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GameSession = channel.stream_stream( + '/openra.rl.RLBridge/GameSession', + request_serializer=rl__bridge__pb2.AgentAction.SerializeToString, + response_deserializer=rl__bridge__pb2.GameObservation.FromString, + _registered_method=True) + self.GetState = channel.unary_unary( + '/openra.rl.RLBridge/GetState', + request_serializer=rl__bridge__pb2.StateRequest.SerializeToString, + response_deserializer=rl__bridge__pb2.GameState.FromString, + _registered_method=True) + + +class RLBridgeServicer(object): + """The RL Bridge service allows an external agent to interact with OpenRA + via bidirectional streaming (lock-step) or unary state queries. + """ + + def GameSession(self, request_iterator, context): + """Bidirectional streaming: game sends observations, agent sends actions. + Each observation waits for an action before advancing to the next tick. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetState(self, request, context): + """Unary: query current game state on demand. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_RLBridgeServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GameSession': grpc.stream_stream_rpc_method_handler( + servicer.GameSession, + request_deserializer=rl__bridge__pb2.AgentAction.FromString, + response_serializer=rl__bridge__pb2.GameObservation.SerializeToString, + ), + 'GetState': grpc.unary_unary_rpc_method_handler( + servicer.GetState, + request_deserializer=rl__bridge__pb2.StateRequest.FromString, + response_serializer=rl__bridge__pb2.GameState.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'openra.rl.RLBridge', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('openra.rl.RLBridge', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class RLBridge(object): + """The RL Bridge service allows an external agent to interact with OpenRA + via bidirectional streaming (lock-step) or unary state queries. + """ + + @staticmethod + def GameSession(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream( + request_iterator, + target, + '/openra.rl.RLBridge/GameSession', + rl__bridge__pb2.AgentAction.SerializeToString, + rl__bridge__pb2.GameObservation.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetState(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/openra.rl.RLBridge/GetState', + rl__bridge__pb2.StateRequest.SerializeToString, + rl__bridge__pb2.GameState.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/openra_env/mcp_server.py b/openra_env/mcp_server.py new file mode 100644 index 0000000000000000000000000000000000000000..ae7d7123e77ff5cac315cd40d7a2583f022c4c1a --- /dev/null +++ b/openra_env/mcp_server.py @@ -0,0 +1,454 @@ +"""Standard MCP server for OpenRA-RL (stdio transport). + +Exposes all game tools over the MCP protocol using FastMCP. +Connects to the game server WebSocket and proxies tool calls. + +Usage: + openra-rl mcp-server + openra-rl mcp-server --server-url http://localhost:8000 + +Works with OpenClaw, Claude Desktop, and any MCP client. +""" + +import json +import logging +from typing import Any, Optional + +from mcp.server.fastmcp import FastMCP + +logger = logging.getLogger("openra-rl-mcp") + +# Lazy-initialized shared state +_client = None +_server_url = "http://localhost:8000" +_game_started = False + +mcp = FastMCP( + "openra-rl", + instructions="Play Command & Conquer: Red Alert via AI tool calls", +) + + +async def _get_client(): + """Get or create the WebSocket client connection.""" + global _client + if _client is not None: + return _client + from openra_env.mcp_ws_client import OpenRAMCPClient + _client = OpenRAMCPClient(base_url=_server_url, message_timeout_s=300.0) + await _client.connect() + return _client + + +async def _ensure_game() -> None: + """Ensure game server is running and a game is started.""" + global _game_started + if _game_started: + return + + # Check if server is healthy + import urllib.request + import urllib.error + + try: + req = urllib.request.urlopen(f"{_server_url}/health", timeout=3) + if req.status == 200: + client = await _get_client() + await client.reset() + _game_started = True + return + except (urllib.error.URLError, OSError): + pass + + # Try starting Docker container + try: + from openra_env.cli.docker_manager import ( + check_docker, is_running, start_server, wait_for_health, + ) + if not is_running(): + if not check_docker(): + raise RuntimeError( + "Docker is not available. Start the game server manually: " + "docker run -p 8000:8000 ghcr.io/yxc20089/openra-rl:latest" + ) + port = int(_server_url.split(":")[-1].split("/")[0]) if ":" in _server_url else 8000 + start_server(port=port) + wait_for_health(port=port) + except ImportError: + raise RuntimeError( + f"Game server not reachable at {_server_url}. " + "Start it manually: docker run -p 8000:8000 ghcr.io/yxc20089/openra-rl:latest" + ) + + client = await _get_client() + await client.reset() + _game_started = True + + +async def _call(tool_name: str, **kwargs) -> Any: + """Call a game tool and return the result.""" + await _ensure_game() + client = await _get_client() + return await client.call_tool(tool_name, **kwargs) + + +def _format(result: Any) -> str: + """Format a tool result as a string.""" + if isinstance(result, str): + return result + return json.dumps(result, indent=2, default=str) + + +# ── Game Lifecycle ───────────────────────────────────────────────── + +@mcp.tool() +async def start_game(difficulty: str = "normal") -> str: + """Start a new Red Alert game. Returns initial game state.""" + global _game_started + _game_started = False + await _ensure_game() + state = await _call("get_game_state") + return _format(state) + + +@mcp.tool() +async def get_game_state() -> str: + """Get current game state: economy, units, buildings, enemies, production.""" + return _format(await _call("get_game_state")) + + +@mcp.tool() +async def advance(ticks: int = 50) -> str: + """Advance the game by N ticks (~25 ticks = 1 second). + Production, movement, combat, and auto-placement all require game time. + Also triggers auto-placement of buildings queued via build_and_place(). + Typical build times: power plant ~300 ticks, barracks ~500, war factory ~750.""" + return _format(await _call("advance", ticks=ticks)) + + +# ── Economy & Info ───────────────────────────────────────────────── + +@mcp.tool() +async def get_economy() -> str: + """Get economy info: cash, ore, power, harvesters.""" + return _format(await _call("get_economy")) + + +@mcp.tool() +async def get_units() -> str: + """Get list of your units with positions, health, type.""" + return _format(await _call("get_units")) + + +@mcp.tool() +async def get_buildings() -> str: + """Get list of your buildings with positions, production, power.""" + return _format(await _call("get_buildings")) + + +@mcp.tool() +async def get_enemies() -> str: + """Get visible enemy units and buildings.""" + return _format(await _call("get_enemies")) + + +@mcp.tool() +async def get_production() -> str: + """Get current production queue and available builds.""" + return _format(await _call("get_production")) + + +@mcp.tool() +async def get_map_info() -> str: + """Get map dimensions, name, and metadata.""" + return _format(await _call("get_map_info")) + + +@mcp.tool() +async def get_exploration_status() -> str: + """Get fog-of-war data: explored %, quadrants, enemy found.""" + return _format(await _call("get_exploration_status")) + + +# ── Knowledge ────────────────────────────────────────────────────── + +@mcp.tool() +async def lookup_unit(unit_type: str) -> str: + """Look up stats for a unit type (e.g. 'e1', '3tnk').""" + return _format(await _call("lookup_unit", unit_type=unit_type)) + + +@mcp.tool() +async def lookup_building(building_type: str) -> str: + """Look up stats for a building type (e.g. 'powr', 'weap').""" + return _format(await _call("lookup_building", building_type=building_type)) + + +@mcp.tool() +async def lookup_tech_tree(faction: str = "soviet") -> str: + """Get full tech tree and build order for a faction ('allied' or 'soviet').""" + return _format(await _call("lookup_tech_tree", faction=faction)) + + +@mcp.tool() +async def lookup_faction(faction: str) -> str: + """Get all available units and buildings for a faction.""" + return _format(await _call("lookup_faction", faction=faction)) + + +@mcp.tool() +async def get_faction_briefing() -> str: + """Get ALL units and buildings for your faction with full stats. Best for planning.""" + return _format(await _call("get_faction_briefing")) + + +@mcp.tool() +async def get_map_analysis() -> str: + """Get strategic map analysis: resources, terrain, chokepoints, quadrants.""" + return _format(await _call("get_map_analysis")) + + +@mcp.tool() +async def batch_lookup(queries: list[dict]) -> str: + """Batch multiple lookups. Example: [{"type":"unit","name":"3tnk"}, {"type":"building","name":"weap"}]""" + return _format(await _call("batch_lookup", queries=queries)) + + +# ── Planning ─────────────────────────────────────────────────────── + +@mcp.tool() +async def get_opponent_intel() -> str: + """Get intelligence on the AI opponent: difficulty, tendencies, counters.""" + return _format(await _call("get_opponent_intel")) + + +@mcp.tool() +async def start_planning_phase() -> str: + """Start pre-game planning phase with map intel and opponent report.""" + return _format(await _call("start_planning_phase")) + + +@mcp.tool() +async def end_planning_phase(strategy: str = "") -> str: + """End planning phase with your strategy. Begins gameplay.""" + return _format(await _call("end_planning_phase", strategy=strategy)) + + +@mcp.tool() +async def get_planning_status() -> str: + """Check if planning phase is active and remaining turns.""" + return _format(await _call("get_planning_status")) + + +# ── Movement ─────────────────────────────────────────────────────── + +@mcp.tool() +async def move_units(unit_ids: str, target_x: int, target_y: int, queued: bool = False) -> str: + """Move units to a position. unit_ids: comma-separated IDs, 'all_combat', 'type:e1', etc.""" + return _format(await _call("move_units", unit_ids=unit_ids, target_x=target_x, target_y=target_y, queued=queued)) + + +@mcp.tool() +async def attack_move(unit_ids: str, target_x: int, target_y: int, queued: bool = False) -> str: + """Move units, engaging enemies en route. Best for advancing your army.""" + return _format(await _call("attack_move", unit_ids=unit_ids, target_x=target_x, target_y=target_y, queued=queued)) + + +@mcp.tool() +async def attack_target(unit_ids: str, target_actor_id: int, queued: bool = False) -> str: + """Order units to attack a specific enemy by actor ID.""" + return _format(await _call("attack_target", unit_ids=unit_ids, target_actor_id=target_actor_id, queued=queued)) + + +@mcp.tool() +async def stop_units(unit_ids: str) -> str: + """Stop units from moving or attacking.""" + return _format(await _call("stop_units", unit_ids=unit_ids)) + + +# ── Production ───────────────────────────────────────────────────── + +@mcp.tool() +async def build_unit(unit_type: str, count: int = 1) -> str: + """Train units. Requires the right production building (barracks, war factory).""" + return _format(await _call("build_unit", unit_type=unit_type, count=count)) + + +@mcp.tool() +async def build_structure(building_type: str) -> str: + """Start constructing a building (manual placement workflow). + Call advance(ticks) to let construction finish, then place_building() to place it. + Prefer build_and_place() which handles placement automatically.""" + return _format(await _call("build_structure", building_type=building_type)) + + +@mcp.tool() +async def build_and_place(building_type: str, cell_x: int = 0, cell_y: int = 0) -> str: + """Build a structure and auto-place it when construction finishes. + Call advance(ticks) after this to let construction complete — placement is automatic. + Do NOT call place_building() on buildings queued this way.""" + return _format(await _call("build_and_place", building_type=building_type, cell_x=cell_x, cell_y=cell_y)) + + +# ── Building/Unit Actions ───────────────────────────────────────── + +@mcp.tool() +async def place_building(building_type: str, cell_x: int = 0, cell_y: int = 0) -> str: + """Place a completed building on the map (only for build_structure workflow). + Do NOT use on buildings queued via build_and_place() — those auto-place via advance(). + Cell coordinates are optional — engine auto-finds position if omitted.""" + return _format(await _call("place_building", building_type=building_type, cell_x=cell_x, cell_y=cell_y)) + + +@mcp.tool() +async def cancel_production(item_type: str) -> str: + """Cancel production of a unit or building type.""" + return _format(await _call("cancel_production", item_type=item_type)) + + +@mcp.tool() +async def deploy_unit(unit_id: int) -> str: + """Deploy a unit (e.g. MCV → Construction Yard).""" + return _format(await _call("deploy_unit", unit_id=unit_id)) + + +@mcp.tool() +async def sell_building(building_id: int) -> str: + """Sell a building for partial refund.""" + return _format(await _call("sell_building", building_id=building_id)) + + +@mcp.tool() +async def repair_building(building_id: int) -> str: + """Toggle repair on a building.""" + return _format(await _call("repair_building", building_id=building_id)) + + +@mcp.tool() +async def set_rally_point(building_id: int, cell_x: int, cell_y: int) -> str: + """Set rally point for a production building. New units go here automatically.""" + return _format(await _call("set_rally_point", building_id=building_id, cell_x=cell_x, cell_y=cell_y)) + + +@mcp.tool() +async def guard_target(unit_ids: str, target_actor_id: int, queued: bool = False) -> str: + """Order units to guard a specific actor.""" + return _format(await _call("guard_target", unit_ids=unit_ids, target_actor_id=target_actor_id, queued=queued)) + + +@mcp.tool() +async def set_stance(unit_ids: str, stance: str) -> str: + """Set unit stance: 'holdfire', 'returnfire', 'defend', 'attackanything'.""" + return _format(await _call("set_stance", unit_ids=unit_ids, stance=stance)) + + +@mcp.tool() +async def harvest(unit_id: int, cell_x: int = 0, cell_y: int = 0) -> str: + """Send a harvester to harvest at a location.""" + return _format(await _call("harvest", unit_id=unit_id, cell_x=cell_x, cell_y=cell_y)) + + +@mcp.tool() +async def power_down(building_id: int) -> str: + """Toggle power on a building to save electricity.""" + return _format(await _call("power_down", building_id=building_id)) + + +@mcp.tool() +async def set_primary(building_id: int) -> str: + """Set a building as the primary production facility.""" + return _format(await _call("set_primary", building_id=building_id)) + + +# ── Placement ────────────────────────────────────────────────────── + +@mcp.tool() +async def get_valid_placements(building_type: str, max_results: int = 8) -> str: + """Get valid placement locations for a building type.""" + return _format(await _call("get_valid_placements", building_type=building_type, max_results=max_results)) + + +# ── Unit Groups ──────────────────────────────────────────────────── + +@mcp.tool() +async def assign_group(group_name: str, unit_ids: list[int]) -> str: + """Create a named group of units.""" + return _format(await _call("assign_group", group_name=group_name, unit_ids=unit_ids)) + + +@mcp.tool() +async def add_to_group(group_name: str, unit_ids: list[int]) -> str: + """Add units to an existing group.""" + return _format(await _call("add_to_group", group_name=group_name, unit_ids=unit_ids)) + + +@mcp.tool() +async def get_groups() -> str: + """List all unit groups and their members.""" + return _format(await _call("get_groups")) + + +@mcp.tool() +async def command_group( + group_name: str, + command_type: str, + target_x: int = 0, + target_y: int = 0, + target_actor_id: int = 0, + queued: bool = False, +) -> str: + """Issue a command to a unit group. command_type: move, attack_move, attack, stop, guard.""" + kwargs = dict( + group_name=group_name, command_type=command_type, + target_x=target_x, target_y=target_y, + target_actor_id=target_actor_id, queued=queued, + ) + return _format(await _call("command_group", **kwargs)) + + +# ── Compound ─────────────────────────────────────────────────────── + +@mcp.tool() +async def batch(actions: list[dict]) -> str: + """Execute multiple actions simultaneously in one tick. Does NOT advance game time. + Cannot contain advance() or query tools. Example: [{"tool":"build_unit","unit_type":"e1"}]""" + return _format(await _call("batch", actions=actions)) + + +@mcp.tool() +async def plan(steps: list[dict]) -> str: + """Execute steps sequentially with state refresh between each. + Does NOT advance game time between steps — use advance() standalone for that.""" + return _format(await _call("plan", steps=steps)) + + +# ── Utility ──────────────────────────────────────────────────────── + +@mcp.tool() +async def get_replay_path() -> str: + """Get the path to the current game's replay file.""" + return _format(await _call("get_replay_path")) + + +@mcp.tool() +async def surrender() -> str: + """Surrender the current game.""" + return _format(await _call("surrender")) + + +# ── Terrain ──────────────────────────────────────────────────────── + +@mcp.tool() +async def get_terrain_at(cell_x: int, cell_y: int) -> str: + """Get terrain type at a specific cell.""" + return _format(await _call("get_terrain_at", cell_x=cell_x, cell_y=cell_y)) + + +# ── Entry Point ──────────────────────────────────────────────────── + +def main(server_url: Optional[str] = None) -> None: + """Run the MCP stdio server.""" + global _server_url + if server_url: + _server_url = server_url + mcp.run(transport="stdio") diff --git a/openra_env/mcp_ws_client.py b/openra_env/mcp_ws_client.py new file mode 100644 index 0000000000000000000000000000000000000000..fe144d0562314f6afe9284d087335720e19dc57f --- /dev/null +++ b/openra_env/mcp_ws_client.py @@ -0,0 +1,231 @@ +"""WebSocket MCP client for OpenRA-RL. + +Talks to the OpenEnv server's /ws endpoint using the correct message +protocol for MCP tool calls: + - {"type": "reset"} → reset environment + - {"type": "mcp", "data": {...}} → JSON-RPC MCP call (tools/list, tools/call) + - {"type": "step", "data": {...}} → Gym-style step (OpenRAAction) + +MCPToolClient from OpenEnv sends ListToolsAction via "step" which the +server tries to parse as OpenRAAction and fails. This client uses the +correct "mcp" message type instead. +""" + +import asyncio +import json +import os +from dataclasses import dataclass +from typing import Any, Optional + +from websockets.asyncio.client import connect as ws_connect + + +@dataclass +class Tool: + """MCP tool descriptor.""" + name: str + description: str + input_schema: dict + + +class OpenRAMCPClient: + """Async WebSocket client for OpenRA-RL with MCP tool support. + + Usage: + async with OpenRAMCPClient("http://localhost:8000") as client: + await client.reset() + tools = await client.list_tools() + result = await client.call_tool("get_game_state") + result = await client.call_tool("build_structure", building_type="powr") + """ + + def __init__( + self, + base_url: str = "http://localhost:8000", + message_timeout_s: float = 300.0, + ): + # Convert HTTP URL to WebSocket URL + ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") + ws_url = ws_url.rstrip("/") + self._ws_url = f"{ws_url}/ws" + self._timeout = message_timeout_s + self._ws = None + self._rpc_id = 0 + self._tools_cache: Optional[list[Tool]] = None + + async def connect(self) -> "OpenRAMCPClient": + """Connect to the WebSocket endpoint.""" + if self._ws is not None: + return self + + # Handle proxy bypass for localhost + ws_lower = self._ws_url.lower() + is_localhost = "localhost" in ws_lower or "127.0.0.1" in ws_lower + old_no_proxy = os.environ.get("NO_PROXY") + + if is_localhost: + current = old_no_proxy or "" + if "localhost" not in current.lower(): + os.environ["NO_PROXY"] = ( + f"{current},localhost,127.0.0.1" if current else "localhost,127.0.0.1" + ) + + try: + self._ws = await ws_connect( + self._ws_url, + open_timeout=30.0, + max_size=50 * 1024 * 1024, # 50 MB + ping_interval=None, + ) + except (asyncio.TimeoutError, OSError, ConnectionRefusedError) as e: + raise RuntimeError( + f"Could not connect to game server at {self._ws_url}: {e}\n" + f" Is the server running? Try: openra-rl server start" + ) from e + finally: + if is_localhost: + if old_no_proxy is None: + os.environ.pop("NO_PROXY", None) + else: + os.environ["NO_PROXY"] = old_no_proxy + + return self + + async def close(self): + """Close the WebSocket connection.""" + if self._ws: + try: + await self._ws.close() + except Exception: + pass + self._ws = None + + async def __aenter__(self) -> "OpenRAMCPClient": + return await self.connect() + + async def __aexit__(self, *args): + await self.close() + + async def _send_recv(self, message: dict) -> dict: + """Send a message and wait for response.""" + if self._ws is None: + raise RuntimeError("Not connected. Call connect() first.") + + await self._ws.send(json.dumps(message)) + raw = await asyncio.wait_for(self._ws.recv(), timeout=self._timeout) + return json.loads(raw) + + # ── Environment Control ─────────────────────────────────────── + + async def reset(self, **kwargs) -> dict: + """Reset the environment and start a new game.""" + response = await self._send_recv({"type": "reset", "data": kwargs}) + if response.get("type") == "error": + raise RuntimeError(f"Reset failed: {response.get('data', {}).get('message', '?')}") + return response.get("data", {}) + + # ── MCP Tool Operations ─────────────────────────────────────── + + async def list_tools(self, use_cache: bool = True) -> list[Tool]: + """List available MCP tools.""" + if use_cache and self._tools_cache is not None: + return self._tools_cache + + self._rpc_id += 1 + rpc_request = { + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": self._rpc_id, + } + + response = await self._send_recv({"type": "mcp", "data": rpc_request}) + rpc_response = response.get("data", {}) + + if "error" in rpc_response: + raise RuntimeError(f"tools/list failed: {rpc_response['error']}") + + tools_data = rpc_response.get("result", {}).get("tools", []) + self._tools_cache = [ + Tool( + name=t.get("name", ""), + description=t.get("description", ""), + input_schema=t.get("inputSchema", t.get("input_schema", {})), + ) + for t in tools_data + ] + return self._tools_cache + + async def call_tool(self, name: str, **kwargs) -> Any: + """Call an MCP tool by name with keyword arguments.""" + self._rpc_id += 1 + rpc_request = { + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": name, "arguments": kwargs}, + "id": self._rpc_id, + } + + response = await self._send_recv({"type": "mcp", "data": rpc_request}) + rpc_response = response.get("data", {}) + + if "error" in rpc_response: + error = rpc_response["error"] + raise RuntimeError(f"Tool '{name}' failed: {error.get('message', error)}") + + result = rpc_response.get("result") + return self._unwrap_mcp_result(result) + + @staticmethod + def _unwrap_mcp_result(result: Any) -> Any: + """Unwrap FastMCP tool result to plain Python data. + + FastMCP wraps results as: + { + "content": [{"type": "text", "text": "..."}], + "structured_content": {"result": }, + "data": , + "is_error": false + } + + Priority: structured_content.result > data > content text > raw result + """ + if not isinstance(result, dict): + return result + + # data field is correct for dicts, buggy ([{}]) for lists. + # structured_content.result is correct for lists, empty string for dicts. + # Strategy: use data if it's a non-empty dict, else structured_content.result, + # else fall back to content text parsing. + data = result.get("data") + if isinstance(data, dict) and data: + return data + + sc = result.get("structured_content") + if isinstance(sc, dict): + sc_result = sc.get("result") + if sc_result is not None and sc_result != "": + return sc_result + + # data for empty lists (both data=[] and sc.result=[]) + if isinstance(data, list) and data != [{}]: + return data + + # Fallback: parse content text items + content = result.get("content") + if isinstance(content, list) and content: + texts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text = item.get("text", "") + try: + texts.append(json.loads(text)) + except (json.JSONDecodeError, TypeError): + texts.append(text) + else: + texts.append(item) + if len(texts) == 1: + return texts[0] + return texts + + return result diff --git a/openra_env/models.py b/openra_env/models.py new file mode 100644 index 0000000000000000000000000000000000000000..3733b149037cfd637f1deae74262613879570ef7 --- /dev/null +++ b/openra_env/models.py @@ -0,0 +1,222 @@ +"""Pydantic models for the OpenRA-RL environment. + +Defines the Action, Observation, and State types used across +the OpenEnv client-server boundary. +""" + +from enum import Enum +from typing import Dict, List, Optional + +from pydantic import Field + +from openenv.core.env_server.types import Action, Observation, State + + +# ─── Action Types ───────────────────────────────────────────────────────────── + + +class ActionType(str, Enum): + """Available command types matching the protobuf ActionType enum.""" + + NO_OP = "no_op" + MOVE = "move" + ATTACK_MOVE = "attack_move" + ATTACK = "attack" + STOP = "stop" + HARVEST = "harvest" + BUILD = "build" + TRAIN = "train" + DEPLOY = "deploy" + SELL = "sell" + REPAIR = "repair" + PLACE_BUILDING = "place_building" + CANCEL_PRODUCTION = "cancel_production" + SET_RALLY_POINT = "set_rally_point" + GUARD = "guard" + SET_STANCE = "set_stance" + ENTER_TRANSPORT = "enter_transport" + UNLOAD = "unload" + POWER_DOWN = "power_down" + SET_PRIMARY = "set_primary" + SURRENDER = "surrender" + + +class CommandModel(Action): + """A single command to issue to the game engine.""" + + action: ActionType = Field(..., description="Type of command to execute") + actor_id: int = Field(default=0, description="Subject actor ID (for unit commands)") + target_actor_id: int = Field(default=0, description="Target actor ID (for attack, etc.)") + target_x: int = Field(default=0, description="Target cell X coordinate") + target_y: int = Field(default=0, description="Target cell Y coordinate") + item_type: str = Field(default="", description="Actor type for build/train commands") + queued: bool = Field(default=False, description="Queue after current activity vs interrupt") + + +class OpenRAAction(Action): + """Action sent from the agent to the OpenRA environment. + + Contains a list of commands to execute in a single game step. + Multiple commands can be issued per step (e.g., move unit A and build unit B). + """ + + commands: List[CommandModel] = Field( + default_factory=list, description="List of commands to execute this step" + ) + + +# ─── Observation Types ──────────────────────────────────────────────────────── + + +class EconomyInfo(Action): + """Player economic state.""" + + cash: int = Field(default=0, description="Available cash") + ore: int = Field(default=0, description="Raw ore in silos") + power_provided: int = Field(default=0, description="Total power generation") + power_drained: int = Field(default=0, description="Total power consumption") + resource_capacity: int = Field(default=0, description="Maximum resource storage") + harvester_count: int = Field(default=0, description="Number of active harvesters") + + +class MilitaryInfo(Action): + """Player military statistics.""" + + units_killed: int = Field(default=0, description="Enemy units destroyed") + units_lost: int = Field(default=0, description="Own units lost") + buildings_killed: int = Field(default=0, description="Enemy buildings destroyed") + buildings_lost: int = Field(default=0, description="Own buildings lost") + army_value: int = Field(default=0, description="Total value of active army") + active_unit_count: int = Field(default=0, description="Number of active units") + kills_cost: int = Field(default=0, description="Total cost of enemy units/buildings killed") + deaths_cost: int = Field(default=0, description="Total cost of own units/buildings lost") + assets_value: int = Field(default=0, description="Total value of all assets (units + buildings)") + experience: int = Field(default=0, description="Player experience points") + order_count: int = Field(default=0, description="Total orders issued") + + +class UnitInfoModel(Action): + """Information about a single unit.""" + + actor_id: int = Field(..., description="Unique actor ID") + type: str = Field(..., description="Actor type (e.g., 'e1', '1tnk', 'harv')") + pos_x: int = Field(default=0, description="World position X") + pos_y: int = Field(default=0, description="World position Y") + cell_x: int = Field(default=0, description="Cell position X") + cell_y: int = Field(default=0, description="Cell position Y") + hp_percent: float = Field(default=1.0, description="Health percentage 0.0-1.0") + is_idle: bool = Field(default=True, description="Whether the unit is idle") + current_activity: str = Field(default="", description="Current activity name") + owner: str = Field(default="", description="Owner player internal name") + can_attack: bool = Field(default=False, description="Whether the unit can attack") + + # Sprint 4: enriched unit data + facing: int = Field(default=0, description="WAngle 0-1023 direction unit faces") + experience_level: int = Field(default=0, description="Veterancy level (0=none)") + stance: int = Field(default=0, description="0=HoldFire, 1=ReturnFire, 2=Defend, 3=AttackAnything") + speed: int = Field(default=0, description="Base movement speed") + attack_range: int = Field(default=0, description="Max attack range in WDist units") + passenger_count: int = Field(default=-1, description="Cargo count (0 if transport empty, -1 if N/A)") + is_building: bool = Field(default=False, description="False for units, helps distinguish in visible_enemies") + + +class BuildingInfoModel(Action): + """Information about a single building.""" + + actor_id: int = Field(..., description="Unique actor ID") + type: str = Field(..., description="Actor type (e.g., 'powr', 'barr', 'weap')") + pos_x: int = Field(default=0, description="World position X") + pos_y: int = Field(default=0, description="World position Y") + hp_percent: float = Field(default=1.0, description="Health percentage 0.0-1.0") + owner: str = Field(default="", description="Owner player internal name") + is_producing: bool = Field(default=False, description="Whether actively producing") + production_progress: float = Field(default=0.0, description="Production progress 0.0-1.0") + producing_item: str = Field(default="", description="Item currently being produced") + is_powered: bool = Field(default=True, description="Whether powered") + + # Sprint 4: enriched building data + is_repairing: bool = Field(default=False, description="Actively being repaired") + sell_value: int = Field(default=0, description="Refund amount if sold") + rally_x: int = Field(default=-1, description="Rally point cell X (-1 if none)") + rally_y: int = Field(default=-1, description="Rally point cell Y (-1 if none)") + power_amount: int = Field(default=0, description="Power provided (+) or consumed (-)") + can_produce: List[str] = Field(default_factory=list, description="Items this building can produce") + cell_x: int = Field(default=0, description="Cell position X") + cell_y: int = Field(default=0, description="Cell position Y") + + +class ProductionInfoModel(Action): + """Information about a production queue entry.""" + + queue_type: str = Field(..., description="Queue type: Building, Infantry, Vehicle, Aircraft") + item: str = Field(..., description="Actor type being produced") + progress: float = Field(default=0.0, description="Progress 0.0-1.0") + remaining_ticks: int = Field(default=0, description="Ticks until completion") + remaining_cost: int = Field(default=0, description="Remaining cost") + paused: bool = Field(default=False, description="Whether production is paused") + + +class MapInfoModel(Action): + """Basic map information.""" + + width: int = Field(default=0, description="Map width in cells") + height: int = Field(default=0, description="Map height in cells") + map_name: str = Field(default="", description="Map display name") + + +class OpenRAObservation(Observation): + """Observation returned from the OpenRA environment each step. + + Contains structured game state data matching the protobuf GameObservation. + """ + + tick: int = Field(default=0, description="Current game tick") + economy: EconomyInfo = Field(default_factory=EconomyInfo, description="Economic state") + military: MilitaryInfo = Field(default_factory=MilitaryInfo, description="Military statistics") + units: List[UnitInfoModel] = Field(default_factory=list, description="Own units") + buildings: List[BuildingInfoModel] = Field(default_factory=list, description="Own buildings") + production: List[ProductionInfoModel] = Field(default_factory=list, description="Active production queues") + visible_enemies: List[UnitInfoModel] = Field(default_factory=list, description="Visible enemy units") + visible_enemy_buildings: List[BuildingInfoModel] = Field( + default_factory=list, description="Visible enemy buildings" + ) + map_info: MapInfoModel = Field(default_factory=MapInfoModel, description="Map metadata") + available_production: List[str] = Field( + default_factory=list, description="Actor types available for production" + ) + result: str = Field(default="", description="Game result: 'win', 'lose', 'draw', or ''") + + # Spatial map tensor (base64-encoded float32 array for JSON transport) + spatial_map: str = Field(default="", description="Base64-encoded spatial tensor: H×W×C float32 array") + spatial_channels: int = Field(default=0, description="Number of spatial channels") + + # Multi-dimensional reward vector (when reward_vector.enabled=True) + reward_vector: Optional[Dict[str, float]] = Field( + default=None, + description="8-dimensional reward: combat, economy, infrastructure, intelligence, composition, tempo, disruption, outcome", + ) + + # Inherited from Observation: + # done: bool = False + # reward: float | None = None + # metadata: Dict[str, Any] = {} + + +# ─── State ──────────────────────────────────────────────────────────────────── + + +class OpenRAState(State): + """Environment state tracking episode metadata. + + Extends the base State with OpenRA-specific fields. + """ + + game_tick: int = Field(default=0, description="Current game tick") + map_name: str = Field(default="", description="Active map name") + opponent_type: str = Field(default="bot_normal", description="Opponent type: bot_easy, bot_normal, bot_hard") + planning_strategy: str = Field(default="", description="Agent's pre-game strategy if planning was used") + planning_turns_used: int = Field(default=0, description="Number of planning turns used") + + # Inherited from State: + # episode_id: Optional[str] = None + # step_count: int = 0 diff --git a/openra_env/opponent_intel.py b/openra_env/opponent_intel.py new file mode 100644 index 0000000000000000000000000000000000000000..6e65c6425a2d5136f61adb14f65aee62d289a87d --- /dev/null +++ b/openra_env/opponent_intel.py @@ -0,0 +1,263 @@ +"""Hardcoded opponent intelligence profiles for OpenRA AI bots. + +Provides scouting reports and behavioral profiles based on the AI difficulty +level. These are static assessments based on observed AI behavior patterns. +""" + +from typing import Optional + + +# ── Opponent Profiles ────────────────────────────────────────────────────── + +AI_PROFILES: dict[str, dict] = { + "beginner": { + "difficulty": "Beginner", + "display_name": "Beginner AI", + "aggressiveness": "minimal", + "expansion_tendency": "none", + "unit_diversity": "very_low", + "build_order_quality": "very_poor", + "estimated_win_rate_vs_new_player": 0.10, + "typical_first_attack_tick": 150000, + "behavioral_traits": [ + "Almost never attacks — first attack after 100+ minutes", + "Builds only basic infantry (rifle soldiers, grenadiers)", + "No vehicles, no aircraft, no navy", + "Tiny squads of 3-5 units that pose almost no threat", + "Stays at starting base, never expands", + "Extremely slow economy — one refinery, one harvester", + "Does not repair damaged buildings", + "Very slow construction speed — 8x slower than normal AI", + "Does not use superweapons or advanced tech", + "Barely defends base — minimal turrets placed very late", + ], + "recommended_counters": [ + "Any military force will win — even 3-4 infantry can overwhelm", + "Take your time building economy and army — no rush needed", + "Good difficulty for learning basic game mechanics", + "Practice build orders without pressure", + ], + "typical_army_composition": { + "infantry": 1.0, + "vehicles": 0.0, + "aircraft": 0.0, + "ships": 0.0, + }, + "recent_match_history": [ + {"result": "loss", "duration_ticks": 8000, "score": 400}, + {"result": "loss", "duration_ticks": 6000, "score": 300}, + {"result": "loss", "duration_ticks": 10000, "score": 600}, + ], + }, + "easy": { + "difficulty": "Easy", + "display_name": "Easy AI", + "aggressiveness": "low", + "expansion_tendency": "very_low", + "unit_diversity": "low", + "build_order_quality": "poor", + "estimated_win_rate_vs_new_player": 0.25, + "typical_first_attack_tick": 80000, + "behavioral_traits": [ + "Passive — first attack after ~50 minutes of game time", + "Builds basic infantry and some light vehicles (light tanks, APCs)", + "No aircraft, no navy, no advanced tech", + "Small attack squads of 8-12 units", + "Rarely expands beyond starting base", + "Slow economy — 1-2 refineries with 2-4 harvesters", + "Repairs buildings slowly (5x slower than normal)", + "Moderate construction speed — 3x slower than normal AI", + "Limited unit caps — cannot mass large armies", + "Defenses delayed but eventually builds pillboxes and turrets", + ], + "recommended_counters": [ + "Build a small army of 10-15 units and attack before their defenses solidify", + "Any combined arms force (infantry + tanks) will overwhelm them", + "Economy is their weakness — denying resources cripples them further", + "No need to rush — focus on good build order first", + ], + "typical_army_composition": { + "infantry": 0.6, + "vehicles": 0.4, + "aircraft": 0.0, + "ships": 0.0, + }, + "recent_match_history": [ + {"result": "loss", "duration_ticks": 5000, "score": 800}, + {"result": "loss", "duration_ticks": 7000, "score": 1200}, + {"result": "win", "duration_ticks": 15000, "score": 2500}, + ], + }, + "medium": { + "difficulty": "Medium", + "display_name": "Medium AI", + "aggressiveness": "moderate", + "expansion_tendency": "moderate", + "unit_diversity": "moderate", + "build_order_quality": "decent", + "estimated_win_rate_vs_new_player": 0.50, + "typical_first_attack_tick": 5000, + "behavioral_traits": [ + "Moderately aggressive — sends first attack around tick 5000 (~3 minutes)", + "Builds a balanced ground force (infantry, tanks, artillery)", + "No aircraft or naval units — ground-focused only", + "Medium-sized attack squads of 20-35 units", + "Will expand to a second base if resources allow", + "Decent economy — 2-3 refineries with up to 6 harvesters", + "Repairs buildings at normal speed", + "Slightly slower construction than Hard/Brutal AI", + "Builds advanced tech eventually (tech centers delayed ~8 minutes)", + "Uses superweapons if available but slowly", + "Limited production capacity — fewer factories than Hard AI", + ], + "recommended_counters": [ + "Build early defenses — first attack comes around tick 5000", + "Scout by tick 2000 to identify expansion attempts", + "Match their economy with 2+ refineries minimum", + "Combined arms with anti-armor focus works well", + "Their lack of air power means you can skip AA early", + "Deny expansion to keep resource advantage", + ], + "typical_army_composition": { + "infantry": 0.35, + "vehicles": 0.65, + "aircraft": 0.0, + "ships": 0.0, + }, + "recent_match_history": [ + {"result": "win", "duration_ticks": 7000, "score": 3200}, + {"result": "loss", "duration_ticks": 9000, "score": 3800}, + {"result": "win", "duration_ticks": 8000, "score": 4200}, + {"result": "loss", "duration_ticks": 10000, "score": 3500}, + ], + }, + "normal": { + "difficulty": "Normal", + "display_name": "Normal AI", + "aggressiveness": "high", + "expansion_tendency": "high", + "unit_diversity": "high", + "build_order_quality": "good", + "estimated_win_rate_vs_new_player": 0.65, + "typical_first_attack_tick": 1500, + "behavioral_traits": [ + "Very aggressive — sends attack waves frequently starting around tick 1500", + "Masters all different unit types (infantry, tanks, aircraft, ships)", + "Eager to open a second base near your position or mid-way on the map", + "Strong economy — builds 2-3 refineries with multiple harvesters", + "Rebuilds destroyed buildings quickly and adapts composition", + "Will target your harvesters and exposed, undefended buildings", + "Uses combined arms effectively (infantry + vehicles + air strikes)", + "Scouts your base early and adjusts strategy based on what you build", + ], + "recommended_counters": [ + "Build early defenses (turrets) at base entrance — first attack comes ~tick 1500", + "Scout early (by tick 500) to find and deny expansion attempts", + "Send a small raiding force to destroy their second base before it's established", + "Maintain power surplus at all times — their attacks exploit brownouts", + "Build anti-air (SAM/AA Gun) by mid-game to counter their aircraft", + "Match their economy: build 2+ refineries minimum to keep up", + "Don't turtle — they will out-expand and out-resource you", + ], + "typical_army_composition": { + "infantry": 0.30, + "vehicles": 0.45, + "aircraft": 0.15, + "ships": 0.10, + }, + "recent_match_history": [ + {"result": "win", "duration_ticks": 8000, "score": 5200}, + {"result": "win", "duration_ticks": 6500, "score": 4800}, + {"result": "loss", "duration_ticks": 10000, "score": 6100}, + {"result": "win", "duration_ticks": 7200, "score": 5500}, + {"result": "loss", "duration_ticks": 9000, "score": 4000}, + ], + }, + "hard": { + "difficulty": "Hard", + "display_name": "Hard AI", + "aggressiveness": "very_high", + "expansion_tendency": "very_high", + "unit_diversity": "very_high", + "build_order_quality": "optimal", + "estimated_win_rate_vs_new_player": 0.85, + "typical_first_attack_tick": 1000, + "behavioral_traits": [ + "Extremely aggressive — attacks within first 1000 ticks with combined forces", + "Optimal build orders — wastes no time or resources, perfect macro", + "Expands aggressively with multiple bases across the map", + "Uses superweapons if tech allows (nuclear missile, iron curtain)", + "Coordinates multi-front attacks simultaneously from different angles", + "Excellent at resource denial — prioritizes harvesters and refineries", + "Rapid tech progression to advanced units (Mammoth tanks, MiGs)", + "Will cheat slightly on resource gathering speed", + ], + "recommended_counters": [ + "MUST build defenses immediately — turrets before second refinery", + "Scout by tick 300 — their expansion is very fast", + "Deny expansions aggressively or you'll be completely out-resourced", + "Build multiple production buildings for faster unit output", + "Never let power go negative — they will exploit it ruthlessly", + "Mix anti-air into every attack group — they will use aircraft", + "Prepare for superweapons by mid-game — keep army spread out", + ], + "typical_army_composition": { + "infantry": 0.20, + "vehicles": 0.45, + "aircraft": 0.25, + "ships": 0.10, + }, + "recent_match_history": [ + {"result": "win", "duration_ticks": 5000, "score": 7200}, + {"result": "win", "duration_ticks": 4500, "score": 6800}, + {"result": "win", "duration_ticks": 6000, "score": 8100}, + {"result": "loss", "duration_ticks": 12000, "score": 9500}, + {"result": "win", "duration_ticks": 5500, "score": 7500}, + ], + }, +} + + +def get_opponent_profile(difficulty: str) -> Optional[dict]: + """Get the opponent intelligence profile for a given AI difficulty. + + Args: + difficulty: One of "beginner", "easy", "medium", "normal", "hard". + Also accepts "bot_" prefix (strips it). + + Returns: + Profile dict or None if not found. + """ + clean = difficulty.lower().replace("bot_", "") + return AI_PROFILES.get(clean) + + +def get_opponent_summary(difficulty: str) -> str: + """Get a human-readable scouting report for LLM consumption.""" + profile = get_opponent_profile(difficulty) + if profile is None: + return f"Unknown AI difficulty: {difficulty}" + + traits = "\n".join(f" - {t}" for t in profile["behavioral_traits"]) + counters = "\n".join(f" - {c}" for c in profile["recommended_counters"]) + + wins = sum(1 for m in profile["recent_match_history"] if m["result"] == "win") + total = len(profile["recent_match_history"]) + avg_score = sum(m["score"] for m in profile["recent_match_history"]) // total + + army = profile["typical_army_composition"] + army_str = ", ".join(f"{k}: {v:.0%}" for k, v in army.items() if v > 0) + + return ( + f"## Opponent Scouting Report: {profile['display_name']}\n" + f"Aggressiveness: {profile['aggressiveness']}\n" + f"Expansion tendency: {profile['expansion_tendency']}\n" + f"Unit diversity: {profile['unit_diversity']}\n" + f"Build order quality: {profile['build_order_quality']}\n" + f"Estimated first attack: ~tick {profile['typical_first_attack_tick']}\n" + f"Win rate vs new players: {profile['estimated_win_rate_vs_new_player']:.0%}\n" + f"Recent record: {wins}W-{total - wins}L (avg score: {avg_score})\n" + f"Typical army mix: {army_str}\n" + f"\nBehavioral traits:\n{traits}\n" + f"\nRecommended counters:\n{counters}" + ) diff --git a/openra_env/prompts/__init__.py b/openra_env/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2656ad04825fcf0bc06b080ffac581cb00c52df --- /dev/null +++ b/openra_env/prompts/__init__.py @@ -0,0 +1,32 @@ +"""Default prompts and prompt loading for OpenRA-RL agents.""" + +from pathlib import Path + +import yaml + +_PROMPTS_DIR = Path(__file__).parent + + +def load_default_prompt() -> str: + """Load the default system prompt shipped with the package.""" + return (_PROMPTS_DIR / "default.txt").read_text(encoding="utf-8").strip() + + +def load_default_prompts_yaml() -> dict: + """Load the default prompts YAML shipped with the package.""" + path = _PROMPTS_DIR / "default_prompts.yaml" + with open(path, encoding="utf-8") as f: + return yaml.safe_load(f) or {} + + +def load_prompts_file(prompts_file: str) -> dict: + """Load a custom prompts YAML file. + + Returns a dict suitable for merging into PromptsConfig fields. + Raises FileNotFoundError if the file doesn't exist. + """ + p = Path(prompts_file).expanduser() + if not p.is_file(): + raise FileNotFoundError(f"prompts_file not found: {p}") + with open(p, encoding="utf-8") as f: + return yaml.safe_load(f) or {} diff --git a/openra_env/prompts/default.txt b/openra_env/prompts/default.txt new file mode 100644 index 0000000000000000000000000000000000000000..27fceff076d085baae36400c0ffe8beab2080615 --- /dev/null +++ b/openra_env/prompts/default.txt @@ -0,0 +1,67 @@ +You are playing Command & Conquer: Red Alert as one faction against an AI opponent. + +## Interface + +The game runs in real time at ~25 ticks/sec. You interact through tool calls. Between your turns, a TURN BRIEFING shows current state: economy, units, buildings, enemies, production queue, and available builds. A STRATEGIC BRIEFING at game start provides map size, base position, enemy spawn estimate, faction, tech tree, and unit/building stats. + +## Economy + +Funds = cash + ore. Harvesters collect ore from patches and deliver it to refineries. Each ore refinery (proc, $2000) comes with one free harvester. Silos ($150) add +1500 ore storage capacity. Construction costs are paid incrementally — if funds reach $0, all production pauses until income resumes. + +## Power + +Power plants (powr, $300) provide +100 power each. When power demand exceeds supply, all production runs at 1/3 speed. Each building has a power drain listed in its stats. + +## Production + +- Barracks (tent/barr, $400): trains infantry (e1 rifleman $100, e2 grenadier $160, e3 rocket soldier $300, e6 engineer $500) +- War factory (weap, $2000): builds vehicles (light tank 1tnk $600, medium tank 2tnk $800, heavy tank 3tnk $950, APC $800, harvester $1400) +- Multiple production buildings of the same type increase build speed + +## Tech Tree + +Buildings have prerequisites. Standard unlock chain: powr → barracks → proc → weap. A war factory requires an ore refinery. The "Can build:" line in each briefing shows what is currently available. + +## Combat + +Vehicles are much stronger than infantry. A heavy tank (3tnk, $950) defeats ~10 riflemen (e1, $100 each). Rocket soldiers (e3) are effective against vehicles and aircraft. Engineers (e6) can capture enemy buildings. + +## Defense Structures + +Defense turrets are stationary and engage enemies automatically. + +Allied: +- Pillbox (pbox, $400, no power drain, needs barracks) — anti-infantry +- Camo Pillbox (hbox, $600, no power drain, needs barracks) — hidden anti-infantry +- Gun Turret (gun, $600, -20 power, needs war factory) — anti-armor + +Soviet: +- Flame Tower (ftur, $600, -20 power, needs barracks) — anti-infantry +- Tesla Coil (tsla, $1500, -75 power, needs war factory) — high damage, all targets + +Anti-air: +- SAM Site (sam, Soviet, $750, needs radar dome) +- AA Gun (agun, Allied, $600, needs radar dome) +- Rocket soldiers (e3) provide mobile anti-air + +## Map & Fog of War + +The map is partially hidden by fog of war. Units reveal terrain as they move. Exploration status (overall %, per-quadrant) is available via tools. Enemy positions are only visible when within your units' sight range. + +## Rally Points + +Production buildings can have rally points set. Newly produced units automatically move to the rally point location. + +## Planning Phase + +If planning is enabled, a planning phase occurs before gameplay. It provides map metadata, faction info, and an opponent intelligence report with behavioral tendencies and difficulty assessment. Use bulk knowledge tools (get_faction_briefing, get_map_analysis) to gather information efficiently. + +## Briefing Format + +Each turn briefing includes: +- Funds (cash + ore), power balance, harvester count +- Your units with IDs, types, and positions +- Your buildings with IDs, types, and positions +- Visible enemies with IDs, types, and positions +- Current production queue and available builds +- ALERTS for events needing attention (attacks, low power, idle production, etc.) \ No newline at end of file diff --git a/openra_env/prompts/default_prompts.yaml b/openra_env/prompts/default_prompts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..21eb701c18136ee5846d52201d8c85fa80fcf6b5 --- /dev/null +++ b/openra_env/prompts/default_prompts.yaml @@ -0,0 +1,144 @@ +# OpenRA-RL Default Prompts +# ========================= +# All LLM-facing text used by the agent. Copy this file and customize +# to change what the model sees. Point to your copy via: +# prompts: +# prompts_file: "path/to/my_prompts.yaml" +# in config.yaml, or set PROMPTS_FILE environment variable. +# +# Templates use Python str.format() placeholders: {variable_name} +# Available variables are documented in comments above each template. + +# ── Planning Phase ─────────────────────────────────────────────────── + +# Variables: {max_turns}, {map_name}, {map_width}, {map_height}, +# {base_x}, {base_y}, {enemy_x}, {enemy_y}, {faction}, {side}, +# {opponent_summary}, {planning_nudge} +planning_prompt: | + ## PRE-GAME PLANNING PHASE + You have {max_turns} turns to plan. + + ### Map Intel + Map: {map_name} ({map_width}x{map_height}) + Your base: ({base_x}, {base_y}) + Enemy estimated: ({enemy_x}, {enemy_y}) + Your faction: {faction} ({side}) + + ### Opponent Intelligence + {opponent_summary} + + {planning_nudge} + +planning_nudge: "Call end_planning_phase(strategy='...') when ready to start." + +planning_instructions: >- + Planning phase active. Available tools: get_faction_briefing + (all unit/building stats), get_map_analysis (terrain/resources), + get_opponent_intel (enemy profile), batch_lookup (multi-item queries). + Call end_planning_phase(strategy=...) to begin gameplay. + +planning_complete: "Planning complete. Game is now live." + +# ── Game Start ─────────────────────────────────────────────────────── + +# Variables: {strategy_section}, {briefing}, {barracks_type}, {mcv_note} +game_start: "Game started!{strategy_section}\n\n{briefing}\n\nYour barracks type is '{barracks_type}'.{mcv_note}" + +# ── Agent Nudges ───────────────────────────────────────────────────── + +no_tool_nudge: "No tool was called. A tool call is required each turn." +continue_nudge: "The game is still in progress." +compression_suffix: "Game continues from current state." +sanitize_bridge: "Acknowledged. Continuing." + +# ── Tool Warnings ──────────────────────────────────────────────────── + +# Variables: {building}, {drain}, {balance} +power_warning: "POWER WARNING: {building} drains {drain} power. Balance will be {balance}." + +# Variables: {available}, {item}, {cost} +insufficient_funds: "Insufficient funds: ${available} available, {item} costs ${cost}." + +# ── Placement Feedback ─────────────────────────────────────────────── + +# Variables: {building} +placement_success: "AUTO-PLACED: {building}" + +# Variables: {building}, {reason} +placement_failed: "PLACEMENT FAILED: {building} — {reason}. Auto-cancelling." + +# Variables: {building} +placement_water: "WATER BUILDING: {building} requires water tiles for placement." + +# ── Build Confirmations ───────────────────────────────────────── +# Factual confirmation notes returned by build tools. + +# Variables: {building}, {cost}, {ticks}, {seconds} +build_queued: "'{building}' (${cost}) queued, auto-places on completion. ~{ticks} ticks (~{seconds}s)." + +# Variables: {building}, {cost}, {ticks}, {seconds} +build_structure_queued: "'{building}' (${cost}) queued. ~{ticks} ticks (~{seconds}s) to complete." + +# Variables: {count}, {unit}, {cost}, {ticks_each}, {ticks_total}, {seconds_total} +build_unit_queued: "{count}x '{unit}' (${cost} each) queued. ~{ticks_each} ticks per unit, ~{ticks_total} ticks (~{seconds_total}s) total." + +# ── Build Guards ──────────────────────────────────────────────── + +# Variables: {building} +build_already_pending: "'{building}' is already queued and pending auto-placement." + +# Variables: {building} +place_auto_managed: "'{building}' is queued via build_and_place — placement is automatic." + +# ── Movement Feedback ─────────────────────────────────────────────── + +# Variables: {ticks}, {seconds} +move_eta: "Units moving. Slowest arrives in ~{ticks} ticks (~{seconds}s)." + +# ── Alerts ─────────────────────────────────────────────────────────── +# These appear in the TURN BRIEFING under ALERTS. +# Each alert type can be enabled/disabled separately in the alerts config. + +alerts: + # Variables: {type}, {id} + under_attack: "UNDER ATTACK: enemy {type} id={id} near base" + + # Variables: {count}, {breakdown} + under_attack_mass: "UNDER ATTACK: {count} enemies near base ({breakdown})" + + # Variables: {type}, {id}, {hp} + damaged: "DAMAGED: {type} id={id} at {hp} HP" + + # Variables: {balance} + low_power: "LOW POWER: {balance} — production runs at 1/3 speed" + + # Variables: {balance} + power_tight: "POWER TIGHT: {balance} surplus — next building may cause low power" + + # Variables: {funds}, {harvesters} + idle_funds: "IDLE FUNDS: ${funds} available, {harvesters} harvester(s)" + + # Variables: {ore}, {cap} + ore_full: "ORE FULL: {ore}/{cap} storage — income is being lost" + + idle_production: "IDLE PRODUCTION: no active production queue" + + # Variables: {item}, {progress} + stalled: "STALLED: {item}@{progress} — $0 funds, production paused" + + # Variables: {building} + building_stuck: "BUILDING STUCK: {building} — auto-placement failing" + + # Variables: {building} + ready_to_place: "READY TO PLACE: {building} — completed, awaiting placement" + + # Variables: {count} + stance: "STANCE: {count} combat unit(s) on ReturnFire (only fire when fired upon)" + + # Variables: {count} + idle_army: "IDLE ARMY: {count} combat units idle" + + no_defenses: "NO DEFENSES: no defense structures built" + + # Variables: {explored}, {idle} + no_scouting: "NO SCOUTING: enemy not found — {explored} of map explored, {idle} idle combat units available" diff --git a/openra_env/reward.py b/openra_env/reward.py new file mode 100644 index 0000000000000000000000000000000000000000..e8039a48999f83f2d6d1b02c8eb50b20e8a9fa80 --- /dev/null +++ b/openra_env/reward.py @@ -0,0 +1,160 @@ +"""Reward computation for OpenRA-RL. + +Two reward systems: + +1. **Scalar reward** (OpenRARewardFunction) — Legacy 6-component shaped reward. + Used when reward_vector.enabled=False (default). + +2. **Reward vector** (RewardVectorComputer from openra-rl-util) — 7+1 dimensional + skill-based signal. Enabled via reward_vector.enabled=True in config. + Can be collapsed to scalar via configurable weights. +""" + +from dataclasses import dataclass +from typing import Optional + +from openra_rl_util.reward_vector import RewardVector, RewardVectorComputer + + +@dataclass +class RewardWeights: + """Configurable weights for each reward component.""" + + survival: float = 0.001 # Per-tick survival bonus + economic_efficiency: float = 0.01 # Reward for cash/power changes + aggression: float = 0.1 # Reward for killing enemy units + defense: float = 0.05 # Penalty for losing units + victory: float = 1.0 # Terminal reward for winning + defeat: float = -1.0 # Terminal penalty for losing + + +@dataclass +class RewardState: + """Tracks previous observation values for delta computation.""" + + prev_cash: int = 0 + prev_army_value: int = 0 + prev_units_killed: int = 0 + prev_units_lost: int = 0 + prev_buildings_killed: int = 0 + prev_buildings_lost: int = 0 + + +class OpenRARewardFunction: + """Computes shaped rewards from OpenRA game observations. + + Supports two modes: + - Scalar: weighted sum of 6 simple components (default) + - Vector: 8-dimensional reward via RewardVectorComputer (when enabled) + + The vector mode provides richer training signal for RL, decomposing + reward into combat, economy, infrastructure, intelligence, composition, + tempo, disruption, and outcome dimensions. + """ + + def __init__( + self, + weights: Optional[RewardWeights] = None, + vector_enabled: bool = False, + vector_weights: Optional[dict[str, float]] = None, + ): + self.weights = weights or RewardWeights() + self._state = RewardState() + + # Reward vector mode + self.vector_enabled = vector_enabled + self._vector_computer = RewardVectorComputer() if vector_enabled else None + self._vector_weights = vector_weights + + def reset(self) -> None: + """Reset tracking state for a new episode.""" + self._state = RewardState() + if self._vector_computer is not None: + self._vector_computer.reset() + + def compute(self, obs_dict: dict) -> float: + """Compute scalar reward from an observation dictionary. + + Args: + obs_dict: Observation data with economy, military, done, result fields. + + Returns: + Scalar reward value. + """ + reward = 0.0 + + economy = obs_dict.get("economy", {}) + military = obs_dict.get("military", {}) + done = obs_dict.get("done", False) + result = obs_dict.get("result", "") + + # Survival reward + reward += self.weights.survival + + # Economic efficiency (delta cash) + cash = economy.get("cash", 0) + cash_delta = cash - self._state.prev_cash + if cash_delta > 0: + reward += self.weights.economic_efficiency * (cash_delta / 1000.0) + + # Aggression (enemy kills) + units_killed = military.get("units_killed", 0) + buildings_killed = military.get("buildings_killed", 0) + kills_delta = (units_killed - self._state.prev_units_killed) + ( + buildings_killed - self._state.prev_buildings_killed + ) + reward += self.weights.aggression * kills_delta + + # Defense (own losses) + units_lost = military.get("units_lost", 0) + buildings_lost = military.get("buildings_lost", 0) + losses_delta = (units_lost - self._state.prev_units_lost) + ( + buildings_lost - self._state.prev_buildings_lost + ) + reward -= self.weights.defense * losses_delta + + # Terminal rewards + if done: + if result == "win": + reward += self.weights.victory + elif result == "lose": + reward += self.weights.defeat + + # Update tracking state + self._state.prev_cash = cash + self._state.prev_units_killed = units_killed + self._state.prev_units_lost = units_lost + self._state.prev_buildings_killed = buildings_killed + self._state.prev_buildings_lost = buildings_lost + self._state.prev_army_value = military.get("army_value", 0) + + return reward + + def compute_vector(self, obs_dict: dict) -> Optional[RewardVector]: + """Compute multi-dimensional reward vector. + + Returns None if vector mode is not enabled. + + Args: + obs_dict: Full observation dictionary. + + Returns: + RewardVector with 8 dimensions, or None if disabled. + """ + if self._vector_computer is None: + return None + return self._vector_computer.compute(obs_dict) + + def compute_all(self, obs_dict: dict) -> tuple[float, Optional[dict[str, float]]]: + """Compute both scalar reward and optional reward vector dict. + + Convenience method for the environment step() to get both signals. + + Returns: + (scalar_reward, reward_vector_dict_or_None) + """ + scalar = self.compute(obs_dict) + vector = self.compute_vector(obs_dict) + if vector is not None: + return scalar, vector.as_dict() + return scalar, None diff --git a/openra_env/server/__init__.py b/openra_env/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/openra_env/server/app.py b/openra_env/server/app.py new file mode 100644 index 0000000000000000000000000000000000000000..08f43b8b414d0b1650924d7f6aa4e37084f73444 --- /dev/null +++ b/openra_env/server/app.py @@ -0,0 +1,1161 @@ +"""FastAPI application for the OpenRA-RL environment. + +Creates the OpenEnv-compatible server using create_app(). +""" + +import asyncio +import json +import os +import time + +from fastapi import Query +from fastapi.responses import HTMLResponse, StreamingResponse +from openenv.core.env_server import create_app + +from openra_env.models import OpenRAAction, OpenRAObservation +from openra_env.server.openra_environment import OpenRAEnvironment + +app = create_app( + OpenRAEnvironment, + OpenRAAction, + OpenRAObservation, + env_name="openra_env", +) + + +# ── Try Agent: LLM demo endpoint ──────────────────────────────────────────── + +_TRY_MAX_TURNS = 30 +_TRY_MAX_TIME = 300 # 5 minutes + +_COMMENTARY_SYSTEM_PROMPT = ( + "You are a real-time commentator for an AI playing Command & Conquer: Red Alert. " + "Given the AI's recent actions and current game state, write 1-2 sentences " + "explaining what the AI is doing and why, in an engaging style. " + "Keep it concise and accessible to viewers who may not know RTS games well." +) +_COMMENTARY_MAX_TOKENS = 150 + + +def _sse(event_type: str, data: dict) -> str: + """Format a Server-Sent Event.""" + return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" + + +async def _generate_commentary(user_content: str, llm_config, broadcaster) -> None: + """Generate commentary in the background and broadcast it.""" + import httpx as _httpx + + try: + headers = dict(llm_config.extra_headers) + if llm_config.api_key: + headers["Authorization"] = f"Bearer {llm_config.api_key}" + + payload = { + "model": llm_config.model, + "messages": [ + {"role": "system", "content": _COMMENTARY_SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ], + "max_tokens": 400, + "reasoning": {"effort": "low"}, + "temperature": 0.6, + "top_p": 0.95, + } + + async with _httpx.AsyncClient() as client: + resp = await client.post( + llm_config.base_url, + headers=headers, + json=payload, + timeout=llm_config.request_timeout_s, + ) + + if resp.status_code != 200: + return + data = resp.json() + msg = data["choices"][0]["message"] + text = msg.get("content") or "" + if not text: + # Reasoning models may put output in 'reasoning' if content is empty + text = msg.get("reasoning") or "" + if text: + sentences = [s.strip() for s in text.replace("\n", " ").split(".") if s.strip()] + text = ". ".join(sentences[-2:]) + "." if sentences else "" + if text: + broadcaster._broadcast(_sse("commentary", {"text": text.strip()})) + except Exception: + pass # Commentary is non-essential + + +class TryGameBroadcaster: + """Manages a single game broadcast to multiple SSE subscribers.""" + + def __init__(self): + self._event_history: list[str] = [] + self._subscribers: set[asyncio.Queue] = set() + self._game_running: bool = False + self._game_task: asyncio.Task | None = None + self._opponent: str = "" + self._start_lock = asyncio.Lock() + + @property + def game_running(self) -> bool: + return self._game_running + + def subscribe(self) -> asyncio.Queue: + queue: asyncio.Queue = asyncio.Queue() + self._subscribers.add(queue) + return queue + + def unsubscribe(self, queue: asyncio.Queue) -> None: + self._subscribers.discard(queue) + + def _broadcast(self, event: str) -> None: + self._event_history.append(event) + for q in self._subscribers: + q.put_nowait(event) + + async def replay_to(self, queue: asyncio.Queue) -> None: + for event in list(self._event_history): + await queue.put(event) + + async def start_game(self, opponent: str) -> None: + async with self._start_lock: + if self._game_running: + return + self._event_history.clear() + self._opponent = opponent + self._game_running = True + self._game_task = asyncio.create_task(self._run_game(opponent)) + + async def _run_game(self, opponent: str) -> None: + try: + async for event in _run_try_agent(opponent): + self._broadcast(event) + finally: + self._game_running = False + sentinel = _sse("_stream_end", {}) + for q in self._subscribers: + q.put_nowait(sentinel) + + +_broadcaster = TryGameBroadcaster() + + +async def _run_try_agent(opponent: str): + """Run LLM agent for one demo game, yielding SSE events.""" + from openra_env.agent import ( + SYSTEM_PROMPT, + chat_completion, + compose_pregame_briefing, + compress_history, + format_state_briefing, + mcp_tools_to_openai, + ) + from openra_env.config import LLMConfig + from openra_env.mcp_ws_client import OpenRAMCPClient + + api_key = os.environ.get("OPENROUTER_API_KEY", "") + if not api_key: + yield _sse("error_event", {"message": "Server not configured for demo play (no API key)."}) + return + + llm_config = LLMConfig( + api_key=api_key, + model="stepfun/step-3.5-flash", + base_url="https://openrouter.ai/api/v1/chat/completions", + max_tokens=1500, + temperature=1.0, + top_p=0.95, + reasoning_effort="low", + extra_headers={ + "HTTP-Referer": "https://openra-rl.dev", + "X-Title": "OpenRA-RL Try Agent", + }, + ) + commentary_config = LLMConfig( + api_key=api_key, + model="stepfun/step-3.5-flash", + base_url="https://openrouter.ai/api/v1/chat/completions", + max_tokens=_COMMENTARY_MAX_TOKENS, + request_timeout_s=15.0, + extra_headers={ + "HTTP-Referer": "https://openra-rl.dev", + "X-Title": "OpenRA-RL Commentary", + }, + ) + + # Configure opponent difficulty for the next game + os.environ["BOT_TYPE"] = opponent.lower() + + yield _sse("status", {"message": f"Launching game vs {opponent} AI..."}) + + try: + async with OpenRAMCPClient( + base_url="http://localhost:8000", message_timeout_s=300.0 + ) as env: + yield _sse("status", {"message": "Resetting environment..."}) + await env.reset() + + # Discover tools + mcp_tools = await env.list_tools() + openai_tools = mcp_tools_to_openai(mcp_tools) + + # Start + end planning to trigger session start (unpauses game) + yield _sse("status", {"message": "Starting game session..."}) + await env.call_tool("start_planning_phase") + await env.call_tool("end_planning_phase", strategy="Demo game - aggressive rush") + yield _sse("status", {"message": f"Game started. {len(mcp_tools)} tools available."}) + + # Initialize conversation + messages = [{"role": "system", "content": SYSTEM_PROMPT}] + + # Get initial state and compose briefing + state = await env.call_tool("get_game_state") + briefing = compose_pregame_briefing(state) + + messages.append({ + "role": "user", + "content": ( + f"Game started!\n\n{briefing}\n\n" + f"## Current State\n```json\n{json.dumps(state, indent=2)}\n```\n\n" + f"ACT NOW! Deploy your MCV immediately, then start building power plant + barracks. " + f"Expand fast — every idle second costs you. Use plan() to chain: " + f"deploy MCV → build power plant → build barracks → build refinery. " + f"Then focus on economy (3+ refineries) and defense turrets toward the enemy." + ), + }) + + yield _sse("game_state", { + "tick": state.get("tick", 0), + "units": state.get("own_units", 0), + "buildings": state.get("own_buildings", 0), + "cash": state.get("economy", {}).get("cash", 0), + }) + + total_tool_calls = 0 + total_api_calls = 0 + start_time = time.time() + game_done = False + consecutive_errors = 0 + + for turn in range(1, _TRY_MAX_TURNS + 1): + elapsed = time.time() - start_time + if elapsed >= _TRY_MAX_TIME: + yield _sse("status", {"message": f"Time limit reached ({_TRY_MAX_TIME}s)."}) + break + + # Compress history to stay within context limits + messages = compress_history(messages, keep_last=40) + + # Inject state briefing (skip first turn — initial state already sent) + if total_api_calls > 0: + try: + briefing_state = await env.call_tool("get_game_state") + brief = format_state_briefing(briefing_state) + if brief: + messages.append({"role": "user", "content": brief}) + if isinstance(briefing_state, dict) and briefing_state.get("done"): + game_done = True + yield _sse("done", { + "result": briefing_state.get("result", "?"), + "tick": briefing_state.get("tick", 0), + }) + break + except Exception: + pass + + # Call LLM + try: + response = await chat_completion(messages, openai_tools, llm_config) + except Exception as e: + yield _sse("error_event", {"message": f"LLM error: {e}"}) + break + + total_api_calls += 1 + choice = response["choices"][0] + assistant_msg = choice["message"] + messages.append(assistant_msg) + + # Emit LLM reasoning + if assistant_msg.get("content"): + yield _sse("llm", {"content": assistant_msg["content"][:500]}) + + yield _sse("turn", { + "turn": turn, + "api_calls": total_api_calls, + "elapsed": round(elapsed), + }) + + # Handle tool calls + tool_calls = assistant_msg.get("tool_calls", []) + if not tool_calls: + messages.append({ + "role": "user", + "content": "Please use the game tools to take action.", + }) + continue + + for tc in tool_calls: + fn_name = tc["function"]["name"] + try: + fn_args = json.loads(tc["function"].get("arguments", "{}")) + except (json.JSONDecodeError, TypeError): + fn_args = {} + + total_tool_calls += 1 + + args_str = json.dumps(fn_args) + if len(args_str) > 120: + args_str = args_str[:120] + "..." + yield _sse("tool_call", {"name": fn_name, "args": args_str}) + + try: + result = await env.call_tool(fn_name, **fn_args) + consecutive_errors = 0 + except Exception as e: + result = {"error": str(e)} + + # Detect game crash + if isinstance(result, dict) and "connection lost" in str( + result.get("error", "") + ).lower(): + consecutive_errors += 1 + if consecutive_errors >= 3: + yield _sse("error_event", {"message": "Game connection lost."}) + game_done = True + + result_str = ( + json.dumps(result) if not isinstance(result, str) else result + ) + messages.append({ + "role": "tool", + "tool_call_id": tc["id"], + "content": result_str, + }) + + # Check game over + if isinstance(result, dict): + if result.get("done"): + game_done = True + yield _sse("done", { + "result": result.get("result", "?"), + "tick": result.get("tick", 0), + }) + elif "tick" in result and "economy" in result: + yield _sse("game_state", { + "tick": result.get("tick", 0), + "units": result.get("own_units", 0), + "buildings": result.get("own_buildings", 0), + "cash": result.get("economy", {}).get("cash", 0), + }) + + # Fire-and-forget async commentary (doesn't block game loop) + if tool_calls and not game_done: + action_summaries = [] + for tc in tool_calls: + fn = tc["function"]["name"] + try: + fa = json.loads(tc["function"].get("arguments", "{}")) + except (json.JSONDecodeError, TypeError): + fa = {} + action_summaries.append(f"{fn}({json.dumps(fa)})") + + commentary_user = ( + f"Turn {turn} actions:\n" + + "\n".join(f"- {a}" for a in action_summaries[:8]) + ) + asyncio.create_task(_generate_commentary( + commentary_user, commentary_config, _broadcaster, + )) + + if game_done: + break + + if choice.get("finish_reason") == "stop" and not tool_calls: + messages.append({ + "role": "user", + "content": "Continue playing. Use game tools to check state and take actions.", + }) + + # Surrender if game didn't end naturally + if not game_done: + try: + await env.call_tool("surrender") + except Exception: + pass + + # Emit final scorecard + try: + final = await env.call_tool("get_game_state") + mil = final.get("military", {}) + eco = final.get("economy", {}) + yield _sse("final", { + "result": final.get("result", "ongoing"), + "tick": final.get("tick", 0), + "turns": total_api_calls, + "tool_calls": total_tool_calls, + "elapsed": round(time.time() - start_time), + "kills_cost": mil.get("kills_cost", 0), + "deaths_cost": mil.get("deaths_cost", 0), + "units_killed": mil.get("units_killed", 0), + "units_lost": mil.get("units_lost", 0), + "cash": eco.get("cash", 0), + "units": final.get("own_units", 0), + "buildings": final.get("own_buildings", 0), + }) + except Exception: + pass + + except Exception as e: + yield _sse("error_event", {"message": str(e)}) + + +@app.get("/try-agent") +async def try_agent( + opponent: str = Query("Normal", pattern="^(Easy|Normal|Hard)$"), +): + """SSE stream of an LLM agent playing Red Alert. + + Multiple viewers can watch simultaneously. The first request starts + a new game; subsequent requests join as spectators of the ongoing game. + """ + queue = _broadcaster.subscribe() + + if _broadcaster.game_running: + await queue.put(_sse("status", {"message": "Joining ongoing game as spectator..."})) + await _broadcaster.replay_to(queue) + else: + await _broadcaster.start_game(opponent) + + async def stream(): + try: + while True: + event = await asyncio.wait_for(queue.get(), timeout=360) + if '"_stream_end"' in event: + break + yield event + except asyncio.TimeoutError: + pass + finally: + _broadcaster.unsubscribe(queue) + + return StreamingResponse( + stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +LANDING_PAGE = """\ + + + + + +OpenRA-RL — OpenEnv Environment + + + + + +
+ + + + + +
+
+ + SYSTEM OVERRIDE ACTIVE +
+

OPENRA-RL

+
+ OpenEnv environment for training AI agents to play + Red Alert through the OpenRA engine. + Connect via WebSocket or HTTP, send actions, observe the battlefield. +
+ +
+ + +
+

Endpoints

+
+
+

API DOCS

+

Interactive Swagger UI with all REST and WebSocket endpoints.

+ /docs → +
+
+

HEALTH CHECK

+

Server status and readiness probe for monitoring.

+ /health → +
+
+

ENV SCHEMA

+

JSON schemas for actions, observations, and game state.

+ /schema → +
+
+
+ + +
+
+
+
+

Connect to Environment

+

+ Use the Python client to connect, reset the environment, + and step through the game loop. Works with both local + Docker and this HuggingFace-hosted server. +

+ + + API REFERENCE + +
+
+
+
+
+
+ terminal +
+
+
$ pip install openra-rl
+
+from openra_env.client import OpenRAEnv
+from openra_env.models import OpenRAAction
+
+url = "https://openra-rl-openra-rl.hf.space"
+
+async with OpenRAEnv(url) as env:
+    obs = await env.reset()
+    while not obs.done:
+        action = your_agent(obs)
+        obs = await env.step(action)
+
+
+
+
+
+ + + + +""" + + +@app.get("/", response_class=HTMLResponse) +async def root(): + """Landing page for the HuggingFace Space.""" + return LANDING_PAGE + + +# ── Try Page: Watch AI Play ────────────────────────────────────────────────── + +TRY_PAGE = """\ + + + + + +Try — Watch AI Play Red Alert + + + + + +
+ + + +
+
+

Watch AI Play

+

A pre-configured LLM agent plays Red Alert against the built-in AI. No setup needed.

+
+ +
+ + +
+ +
Waiting to start...\n
+ +
+

Scorecard

+
+
+
+ +
© 2025 OpenRA-RL Contributors — Home
+ + + +""" + + +@app.get("/try-status") +async def try_status(): + """Check if a game is currently running.""" + return { + "game_running": _broadcaster.game_running, + "opponent": _broadcaster._opponent if _broadcaster.game_running else "", + } + + +@app.get("/try", response_class=HTMLResponse) +async def try_page(): + """Interactive page to watch an LLM agent play Red Alert.""" + return TRY_PAGE + + +def main(): + import uvicorn + + uvicorn.run( + app, + host="0.0.0.0", + port=8000, + ws_ping_interval=None, + ws_ping_timeout=None, + ) + + +if __name__ == "__main__": + main() diff --git a/openra_env/server/bridge_client.py b/openra_env/server/bridge_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9da4eb0673e97eab0c1c5584606aeaa341673564 --- /dev/null +++ b/openra_env/server/bridge_client.py @@ -0,0 +1,421 @@ +"""gRPC bridge client for communicating with the OpenRA ExternalBotBridge. + +This client connects to the gRPC server running inside the OpenRA process +and handles bidirectional streaming of observations and actions. + +Protocol: + - Bidirectional streaming RPC (GameSession): game sends observations, agent sends actions + - Unary RPC (GetState): query current game state on demand + - Real-time: game runs at normal speed, observations stream continuously, + actions are sent whenever the agent is ready +""" + +import asyncio +import base64 +import logging +from typing import AsyncIterator, Optional + +import grpc + +from openra_env.generated import rl_bridge_pb2, rl_bridge_pb2_grpc + +logger = logging.getLogger(__name__) + + +class BridgeClient: + """Async gRPC client for the OpenRA RL Bridge. + + Uses bidirectional streaming: the game sends observations continuously + at its natural tick rate, and the agent sends actions when ready. + A background reader task keeps the latest observation cached. + """ + + def __init__(self, host: str = "localhost", port: int = 9999, timeout_s: float = 30.0): + self.host = host + self.port = port + self.timeout_s = timeout_s + self._channel: Optional[grpc.aio.Channel] = None + self._stub: Optional[rl_bridge_pb2_grpc.RLBridgeStub] = None + self._session_call = None + self._action_queue: asyncio.Queue[rl_bridge_pb2.AgentAction] = asyncio.Queue() + self._connected = False + + # Background observation reader state + self._latest_obs: Optional[rl_bridge_pb2.GameObservation] = None + self._obs_event: asyncio.Event = asyncio.Event() + self._obs_tick: int = 0 + self._obs_reader_task: Optional[asyncio.Task] = None + + async def connect(self) -> None: + """Establish gRPC channel.""" + target = f"{self.host}:{self.port}" + self._channel = grpc.aio.insecure_channel( + target, + options=[ + ("grpc.max_receive_message_length", 64 * 1024 * 1024), + ("grpc.max_send_message_length", 16 * 1024 * 1024), + ("grpc.keepalive_time_ms", 10000), + ("grpc.keepalive_timeout_ms", 5000), + ], + ) + self._stub = rl_bridge_pb2_grpc.RLBridgeStub(self._channel) + self._connected = True + logger.info(f"Connected to OpenRA bridge at {target}") + + async def wait_for_ready(self, max_retries: int = 30, retry_interval: float = 1.0) -> bool: + """Wait for the gRPC server to become available.""" + for attempt in range(max_retries): + try: + await self.connect() + state = await self.get_state() + logger.info(f"Bridge ready after {attempt + 1} attempts, phase={state.phase}") + return True + except grpc.aio.AioRpcError as e: + if attempt < max_retries - 1: + logger.debug(f"Bridge not ready (attempt {attempt + 1}): {e.code()}") + await asyncio.sleep(retry_interval) + else: + logger.error(f"Bridge failed to become ready after {max_retries} attempts") + return False + except Exception as e: + if attempt < max_retries - 1: + logger.debug(f"Connection attempt {attempt + 1} failed: {e}") + await asyncio.sleep(retry_interval) + else: + return False + return False + + @property + def session_started(self) -> bool: + """Whether the streaming session has been started.""" + return self._session_call is not None + + async def start_session(self) -> rl_bridge_pb2.GameObservation: + """Start a bidirectional streaming session and return the first observation. + + The game sends observations continuously; a background reader task + keeps the latest observation cached. Actions are sent via step(). + + Idempotent: if the session is already started, returns the latest observation. + """ + if self._session_call is not None: + # Already started — return latest cached observation + return self._latest_obs + + if not self._connected: + await self.connect() + + self._action_queue = asyncio.Queue() + self._session_call = self._stub.GameSession(self._action_request_iterator()) + + first_obs = await self._session_call.read() + if first_obs is None: + raise ConnectionError("Bridge stream closed before sending initial observation") + + # Initialize observation state and start background reader + self._latest_obs = first_obs + self._obs_tick = first_obs.tick + self._obs_event = asyncio.Event() + self._obs_event.set() + self._obs_reader_task = asyncio.create_task(self._bg_obs_reader()) + + logger.info(f"Session started, initial tick={first_obs.tick}") + return first_obs + + async def _action_request_iterator(self) -> AsyncIterator[rl_bridge_pb2.AgentAction]: + """Yield actions from the queue as the gRPC stream requests them.""" + while True: + action = await self._action_queue.get() + yield action + + async def _bg_obs_reader(self): + """Background task: continuously read observations from the gRPC stream. + + Updates _latest_obs and signals _obs_event each time a new + observation arrives. The game sends observations at its natural + tick rate regardless of agent actions. + """ + try: + while True: + obs = await self._session_call.read() + if obs is None: + logger.info("gRPC observation stream ended") + break + self._latest_obs = obs + self._obs_tick = obs.tick + self._obs_event.set() + if obs.done: + logger.info(f"Game over at tick {obs.tick}: {obs.result}") + break + except grpc.aio.AioRpcError as e: + logger.error(f"Background observation reader error: {e.code()}") + except asyncio.CancelledError: + logger.debug("Background observation reader cancelled") + + def _check_reader_alive(self): + """Raise if the background observation reader has exited (game died).""" + if self._obs_reader_task is not None and self._obs_reader_task.done(): + exc = self._obs_reader_task.exception() + if exc: + raise ConnectionError(f"Game connection lost: {exc}") from exc + raise ConnectionError("Game connection lost (observation stream ended)") + + async def step(self, action: rl_bridge_pb2.AgentAction) -> rl_bridge_pb2.GameObservation: + """Send an action and wait for the next observation. + + The action is queued immediately. Then we wait for an observation + with a tick newer than the current one (confirming the game has + processed at least one more tick since the action was sent). + """ + if self._session_call is None: + raise RuntimeError("Session not started. Call start_session() first.") + + current_tick = self._obs_tick + await self._action_queue.put(action) + + # Wait for an observation newer than when we sent the action + while self._obs_tick <= current_tick: + self._check_reader_alive() + self._obs_event.clear() + await asyncio.wait_for(self._obs_event.wait(), timeout=self.timeout_s) + + return self._latest_obs + + async def wait_ticks(self, n: int) -> rl_bridge_pb2.GameObservation: + """Wait for approximately N game ticks to pass. + + The game runs at its natural speed (~25 ticks/sec at default). + Returns the observation at or after the target tick. + """ + target_tick = self._obs_tick + n + while self._obs_tick < target_tick: + self._check_reader_alive() + self._obs_event.clear() + await asyncio.wait_for(self._obs_event.wait(), timeout=self.timeout_s) + if self._latest_obs and self._latest_obs.done: + break + return self._latest_obs + + async def observe(self) -> Optional[rl_bridge_pb2.GameObservation]: + """Return the latest cached observation without sending any action.""" + return self._latest_obs + + async def get_state(self) -> rl_bridge_pb2.GameState: + """Query current game state via unary RPC.""" + if not self._connected or self._stub is None: + raise RuntimeError("Not connected. Call connect() first.") + request = rl_bridge_pb2.StateRequest() + return await self._stub.GetState(request, timeout=self.timeout_s) + + async def close(self) -> None: + """Close the gRPC channel and clean up.""" + # Cancel background observation reader + if self._obs_reader_task is not None: + self._obs_reader_task.cancel() + try: + await self._obs_reader_task + except asyncio.CancelledError: + pass + self._obs_reader_task = None + + if self._session_call is not None: + self._session_call.cancel() + self._session_call = None + + if self._channel is not None: + await self._channel.close() + self._channel = None + + self._stub = None + self._connected = False + self._latest_obs = None + logger.info("Bridge connection closed") + + @property + def is_connected(self) -> bool: + return self._connected + + +def observation_to_dict(obs: rl_bridge_pb2.GameObservation) -> dict: + """Convert a protobuf GameObservation to a plain dict for the OpenEnv layer.""" + return { + "tick": obs.tick, + "economy": { + "cash": obs.economy.cash, + "ore": obs.economy.ore, + "power_provided": obs.economy.power_provided, + "power_drained": obs.economy.power_drained, + "resource_capacity": obs.economy.resource_capacity, + "harvester_count": obs.economy.harvester_count, + }, + "military": { + "units_killed": obs.military.units_killed, + "units_lost": obs.military.units_lost, + "buildings_killed": obs.military.buildings_killed, + "buildings_lost": obs.military.buildings_lost, + "army_value": obs.military.army_value, + "active_unit_count": obs.military.active_unit_count, + "kills_cost": obs.military.kills_cost, + "deaths_cost": obs.military.deaths_cost, + "assets_value": obs.military.assets_value, + "experience": obs.military.experience, + "order_count": obs.military.order_count, + }, + "units": [ + { + "actor_id": u.actor_id, + "type": u.type, + "pos_x": u.pos_x, + "pos_y": u.pos_y, + "cell_x": u.cell_x, + "cell_y": u.cell_y, + "hp_percent": u.hp_percent, + "is_idle": u.is_idle, + "current_activity": u.current_activity, + "owner": u.owner, + "can_attack": u.can_attack, + "facing": u.facing, + "experience_level": u.experience_level, + "stance": u.stance, + "speed": u.speed, + "attack_range": u.attack_range, + "passenger_count": u.passenger_count, + "is_building": u.is_building, + } + for u in obs.units + ], + "buildings": [ + { + "actor_id": b.actor_id, + "type": b.type, + "pos_x": b.pos_x, + "pos_y": b.pos_y, + "hp_percent": b.hp_percent, + "owner": b.owner, + "is_producing": b.is_producing, + "production_progress": b.production_progress, + "producing_item": b.producing_item, + "is_powered": b.is_powered, + "is_repairing": b.is_repairing, + "sell_value": b.sell_value, + "rally_x": b.rally_x, + "rally_y": b.rally_y, + "power_amount": b.power_amount, + "can_produce": list(b.can_produce), + "cell_x": b.cell_x, + "cell_y": b.cell_y, + } + for b in obs.buildings + ], + "production": [ + { + "queue_type": p.queue_type, + "item": p.item, + "progress": p.progress, + "remaining_ticks": p.remaining_ticks, + "remaining_cost": p.remaining_cost, + "paused": p.paused, + } + for p in obs.production + ], + "visible_enemies": [ + { + "actor_id": u.actor_id, + "type": u.type, + "pos_x": u.pos_x, + "pos_y": u.pos_y, + "cell_x": u.cell_x, + "cell_y": u.cell_y, + "hp_percent": u.hp_percent, + "is_idle": u.is_idle, + "current_activity": u.current_activity, + "owner": u.owner, + "can_attack": u.can_attack, + "facing": u.facing, + "experience_level": u.experience_level, + "stance": u.stance, + "speed": u.speed, + "attack_range": u.attack_range, + "passenger_count": u.passenger_count, + "is_building": u.is_building, + } + for u in obs.visible_enemies + ], + "visible_enemy_buildings": [ + { + "actor_id": b.actor_id, + "type": b.type, + "pos_x": b.pos_x, + "pos_y": b.pos_y, + "hp_percent": b.hp_percent, + "owner": b.owner, + "is_producing": b.is_producing, + "production_progress": b.production_progress, + "producing_item": b.producing_item, + "is_powered": b.is_powered, + "is_repairing": b.is_repairing, + "sell_value": b.sell_value, + "rally_x": b.rally_x, + "rally_y": b.rally_y, + "power_amount": b.power_amount, + "can_produce": list(b.can_produce), + "cell_x": b.cell_x, + "cell_y": b.cell_y, + } + for b in obs.visible_enemy_buildings + ], + "map_info": { + "width": obs.map_info.width, + "height": obs.map_info.height, + "map_name": obs.map_info.map_name, + }, + "available_production": list(obs.available_production), + "done": obs.done, + "reward": obs.reward, + "result": obs.result, + "spatial_map": base64.b64encode(bytes(obs.spatial_map)).decode("ascii"), + "spatial_channels": obs.spatial_channels, + } + + +def commands_to_proto(commands: list[dict]) -> rl_bridge_pb2.AgentAction: + """Convert a list of command dicts to a protobuf AgentAction.""" + action_type_map = { + "no_op": rl_bridge_pb2.NO_OP, + "move": rl_bridge_pb2.MOVE, + "attack_move": rl_bridge_pb2.ATTACK_MOVE, + "attack": rl_bridge_pb2.ATTACK, + "stop": rl_bridge_pb2.STOP, + "harvest": rl_bridge_pb2.HARVEST, + "build": rl_bridge_pb2.BUILD, + "train": rl_bridge_pb2.TRAIN, + "deploy": rl_bridge_pb2.DEPLOY, + "sell": rl_bridge_pb2.SELL, + "repair": rl_bridge_pb2.REPAIR, + "place_building": rl_bridge_pb2.PLACE_BUILDING, + "cancel_production": rl_bridge_pb2.CANCEL_PRODUCTION, + "set_rally_point": rl_bridge_pb2.SET_RALLY_POINT, + "guard": rl_bridge_pb2.GUARD, + "set_stance": rl_bridge_pb2.SET_STANCE, + "enter_transport": rl_bridge_pb2.ENTER_TRANSPORT, + "unload": rl_bridge_pb2.UNLOAD, + "power_down": rl_bridge_pb2.POWER_DOWN, + "set_primary": rl_bridge_pb2.SET_PRIMARY, + "surrender": rl_bridge_pb2.SURRENDER, + } + + proto_commands = [] + for cmd in commands: + action_str = cmd.get("action", "no_op") + proto_cmd = rl_bridge_pb2.Command( + action=action_type_map.get(action_str, rl_bridge_pb2.NO_OP), + actor_id=cmd.get("actor_id", 0), + target_actor_id=cmd.get("target_actor_id", 0), + target_x=cmd.get("target_x", 0), + target_y=cmd.get("target_y", 0), + item_type=cmd.get("item_type", ""), + queued=cmd.get("queued", False), + ) + proto_commands.append(proto_cmd) + + return rl_bridge_pb2.AgentAction(commands=proto_commands) diff --git a/openra_env/server/openra_environment.py b/openra_env/server/openra_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..991b145b14c57fd0ed0c5cb70288d21dd23ba40f --- /dev/null +++ b/openra_env/server/openra_environment.py @@ -0,0 +1,2993 @@ +"""OpenRA Environment server implementing the OpenEnv MCPEnvironment interface. + +This is the core environment that manages OpenRA game instances, +translates between the OpenEnv API and the gRPC bridge protocol, +computes rewards, and exposes MCP tools for LLM agents. +""" + +import asyncio +import logging +import os +import sys +import threading +import time +import uuid +from pathlib import Path +from typing import Any, Optional + +from fastmcp import FastMCP +from openenv.core.env_server.mcp_environment import MCPEnvironment +from openenv.core.env_server.types import Action, Observation + +from openra_env.game_data import ( + get_all_building_types, + get_all_buildings_for_side, + get_all_unit_types, + get_all_units_for_side, + get_building_stats, + get_faction_info, + get_tech_tree, + get_unit_stats, +) +from openra_env.opponent_intel import get_opponent_profile, get_opponent_summary +from openra_env.models import ( + ActionType, + BuildingInfoModel, + CommandModel, + EconomyInfo, + MapInfoModel, + MilitaryInfo, + OpenRAAction, + OpenRAObservation, + OpenRAState, + ProductionInfoModel, + UnitInfoModel, +) +from openra_env.config import OpenRARLConfig, load_config, should_register_tool +from openra_env.reward import OpenRARewardFunction, RewardWeights +from openra_env.server.bridge_client import BridgeClient, commands_to_proto, observation_to_dict +from openra_env.server.openra_process import OpenRAConfig, OpenRAProcessManager + +logger = logging.getLogger(__name__) + + +def _estimate_build_ticks(cost: int) -> int: + """Estimate build time in ticks from item cost. + + OpenRA formula: cost * BuildDurationModifier / 100. + Buildable.BuildDurationModifier defaults to 60, + ProductionQueue.BuildDurationModifier defaults to 100. + """ + return cost * 60 // 100 + + +def _estimate_move_ticks(speed: int, from_x: int, from_y: int, to_x: int, to_y: int) -> int: + """Estimate movement time in ticks from unit speed and Manhattan distance. + + OpenRA MobileInfo.Speed is WDist/tick. 1 cell = 1024 WDist. + """ + if speed <= 0: + return 0 + dist = abs(to_x - from_x) + abs(to_y - from_y) + return dist * 1024 // speed + + +def _render_minimap(obs: dict, max_cols: int = 28) -> str: + """Render a compact ASCII minimap from spatial tensor + unit/building positions. + + Downsamples the map to ~max_cols columns. Each cell shows the highest-priority + feature in its area using these characters (highest priority first): + ! = enemy unit, X = enemy building, @ = own unit, B = own building, + $ = resources (ore/gems), ~ = water/impassable, . = explored land, + # = unexplored (shroud) + """ + import base64 + import struct + from math import ceil + + map_info = obs.get("map_info", {}) + w = map_info.get("width", 0) + h = map_info.get("height", 0) + channels = obs.get("spatial_channels", 0) + spatial = obs.get("spatial_map", "") + + if w == 0 or h == 0 or channels == 0 or not spatial: + return "" + + try: + raw = base64.b64decode(spatial) + except Exception: + return "" + + scale = max(1, ceil(w / max_cols)) + grid_w = ceil(w / scale) + grid_h = ceil(h / scale) + + # Initialize grid from spatial tensor + grid = [] + for gy in range(grid_h): + row = [] + for gx in range(grid_w): + # Sample center of each cell block + sx = gx * scale + scale // 2 + sy = gy * scale + scale // 2 + sx = min(sx, w - 1) + sy = min(sy, h - 1) + base_idx = (sy * w + sx) * channels + + try: + fog = struct.unpack_from("f", raw, (base_idx + 4) * 4)[0] + except struct.error: + row.append("#") + continue + + if fog <= 0.25: + row.append("#") + else: + try: + passability = struct.unpack_from("f", raw, (base_idx + 3) * 4)[0] + except struct.error: + passability = 1.0 + if passability < 0.5: + row.append("~") + else: + try: + resources = struct.unpack_from("f", raw, (base_idx + 2) * 4)[0] + except struct.error: + resources = 0.0 + row.append("$" if resources > 0 else ".") + grid.append(row) + + # Overlay unit and building positions (priority: ! > X > @ > B) + _PRIORITY = {"B": 1, "@": 2, "X": 3, "!": 4} + + def _overlay(items, char): + for item in items: + cx = item.get("cell_x", -1) + cy = item.get("cell_y", -1) + if cx < 0 or cy < 0: + continue + gx = cx // scale + gy = cy // scale + if 0 <= gx < grid_w and 0 <= gy < grid_h: + cur = grid[gy][gx] + if _PRIORITY.get(char, 0) >= _PRIORITY.get(cur, 0): + grid[gy][gx] = char + + _overlay(obs.get("buildings", []), "B") + _overlay(obs.get("units", []), "@") + _overlay(obs.get("visible_enemy_buildings", []), "X") + _overlay(obs.get("visible_enemies", []), "!") + + lines = ["".join(row) for row in grid] + header = f"Map ({grid_w}x{grid_h}, 1cell={scale}x{scale}):" + legend = ( + "YOUR: B=building @=unit | ENEMY: X=building !=unit | " + "terrain: .=land ~=water $=ore #=unexplored" + ) + return header + "\n" + "\n".join(lines) + "\n" + legend + + +class OpenRAEnvironment(MCPEnvironment): + """OpenRA RL Environment with MCP tool support. + + Manages OpenRA game instances and provides both: + - Gymnasium-style API (reset/step/state) for RL training + - MCP tools for LLM agent interaction (via ListToolsAction/CallToolAction) + + Each reset() launches a new OpenRA subprocess with the ExternalBotBridge + trait enabled, connects via gRPC, and returns the initial observation. + """ + + SUPPORTS_CONCURRENT_SESSIONS = True + + def __init__( + self, + openra_path: Optional[str] = None, + mod: str = "ra", + map_name: str = "singles.oramap", + grpc_port: int = 9999, + bot_type: str = "normal", + ai_slot: str = "Multi0", + reward_weights: Optional[RewardWeights] = None, + record_replays: bool = False, + planning_enabled: bool = True, + planning_max_turns: int = 10, + planning_max_time_s: float = 60.0, + config: Optional[OpenRARLConfig] = None, + ): + # ── Load unified config ────────────────────────────────────── + if config is not None: + self._app_config = config + else: + # Build config from constructor args; env vars applied inside load_config() + overrides: dict = {} + if openra_path is not None: + overrides.setdefault("game", {})["openra_path"] = openra_path + if mod != "ra": + overrides.setdefault("game", {})["mod"] = mod + if map_name != "singles.oramap": + overrides.setdefault("game", {})["map_name"] = map_name + if grpc_port != 9999: + overrides.setdefault("game", {})["grpc_port"] = grpc_port + if record_replays: + overrides.setdefault("game", {})["record_replays"] = True + if bot_type != "normal": + overrides.setdefault("opponent", {})["bot_type"] = bot_type + if ai_slot != "Multi0": + overrides.setdefault("opponent", {})["ai_slot"] = ai_slot + if not planning_enabled: + overrides.setdefault("planning", {})["enabled"] = False + if planning_max_turns != 10: + overrides.setdefault("planning", {})["max_turns"] = planning_max_turns + if planning_max_time_s != 60.0: + overrides.setdefault("planning", {})["max_time_s"] = planning_max_time_s + if reward_weights is not None: + overrides["reward"] = { + "survival": reward_weights.survival, + "economic_efficiency": reward_weights.economic_efficiency, + "aggression": reward_weights.aggression, + "defense": reward_weights.defense, + "victory": reward_weights.victory, + "defeat": reward_weights.defeat, + } + self._app_config = load_config(**overrides) + + cfg = self._app_config + + # Create MCP server and register tools (uses config for filtering) + mcp = FastMCP("openra") + self._register_tools(mcp) + super().__init__(mcp) + + self._config = OpenRAConfig( + openra_path=cfg.game.openra_path, + mod=cfg.game.mod, + map_name=cfg.game.map_name, + grpc_port=cfg.game.grpc_port, + bot_type=cfg.opponent.bot_type, + ai_slot=cfg.opponent.ai_slot, + record_replays=cfg.game.record_replays, + headless=cfg.game.headless, + seed=cfg.game.seed, + ) + self._process = OpenRAProcessManager(self._config) + self._bridge = BridgeClient(port=cfg.game.grpc_port) + rw = RewardWeights( + survival=cfg.reward.survival, + economic_efficiency=cfg.reward.economic_efficiency, + aggression=cfg.reward.aggression, + defense=cfg.reward.defense, + victory=cfg.reward.victory, + defeat=cfg.reward.defeat, + ) + self._reward_fn = OpenRARewardFunction( + weights=rw, + vector_enabled=cfg.reward_vector.enabled, + vector_weights=cfg.reward_vector.weights if cfg.reward_vector.enabled else None, + ) + self._state = OpenRAState() + self._last_obs: Optional[dict] = None + self._unit_groups: dict[str, list[int]] = {} # named groups of unit IDs + self._pending_placements: dict[str, dict] = {} # building_type → {cell_x, cell_y} + self._move_targets: dict[int, tuple[int, int]] = {} # unit_id → (target_x, target_y) + self._attempted_placements: dict[str, int] = {} # building_type → attempt_count (for failure detection) + self._placement_results: list[str] = [] # alerts from auto-placement attempts + self._player_faction: str = "" + self._enemy_faction: str = "" + self._last_production_progress: dict[str, float] = {} # item → progress for stall detection + self._prev_buildings: dict[int, str] = {} # actor_id → type for loss detection + self._prev_unit_ids: dict[int, str] = {} # actor_id → type for loss detection + self._enemy_ever_seen: bool = False # suppress NO SCOUTING after first contact + self._accumulated_reward_vector: dict[str, float] = {} # running sum of reward vector + + # Planning phase configuration (from unified config) + self._planning_enabled = cfg.planning.enabled + self._planning_max_turns = cfg.planning.max_turns + self._planning_max_time_s = cfg.planning.max_time_s + self._planning_active = False + self._planning_start_time: float = 0.0 + self._planning_turns_used: int = 0 + self._planning_strategy: str = "" + + # Persistent event loop for async gRPC bridge operations. + # Runs in a background thread so it doesn't conflict with + # FastAPI/uvicorn's event loop when MCP tools call it. + self._loop = asyncio.new_event_loop() + self._loop_thread = threading.Thread( + target=self._loop.run_forever, daemon=True, name="openra-bridge-loop" + ) + self._loop_thread.start() + + def _register_tools(self, mcp: FastMCP) -> None: + """Register MCP tools for LLM agent interaction (filtered by config).""" + env = self + tools_cfg = self._app_config.tools + + def configurable_tool(fn): + """Conditionally register *fn* as an MCP tool based on config.""" + if should_register_tool(fn.__name__, tools_cfg): + return mcp.tool()(fn) + return fn + + # ── Read Tools (return from cached observation) ────────────────── + + @configurable_tool + def get_game_state() -> dict: + """Get a full summary of the current game state including economy, + military stats, unit counts, building counts, enemy visibility, and alerts.""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return {"error": "No observation available. Call advance() or reset first."} + + eco = obs["economy"] + power_balance = eco["power_provided"] - eco["power_drained"] + + # Compute alerts (each type gated by config) + alerts = [] + acfg = env._app_config.alerts + pcfg = env._app_config.prompts.alerts + + # Under attack: enemy units near our buildings + if acfg.under_attack: + attackers = [] # enemies near base buildings + for enemy in obs["visible_enemies"]: + for bldg in obs["buildings"]: + dx = abs(enemy.get("cell_x", 0) - bldg.get("cell_x", 0)) + dy = abs(enemy.get("cell_y", 0) - bldg.get("cell_y", 0)) + if dx + dy < 12: + attackers.append(enemy) + break + if len(attackers) <= 3: + for enemy in attackers: + alerts.append((1, pcfg.under_attack.format( + type=enemy["type"], id=enemy["actor_id"]))) + elif attackers: + from collections import Counter + type_counts = Counter(e["type"] for e in attackers) + breakdown = ", ".join(f"{cnt}x {t}" for t, cnt in type_counts.most_common()) + alerts.append((1, pcfg.under_attack_mass.format( + count=len(attackers), breakdown=breakdown))) + + # Damaged buildings + if acfg.damaged_building: + for bldg in obs["buildings"]: + if bldg["hp_percent"] < 0.5: + alerts.append((5, pcfg.damaged.format( + type=bldg["type"], id=bldg["actor_id"], + hp=f"{bldg['hp_percent']:.0%}"))) + + # Power crisis + if acfg.low_power: + if power_balance < 0: + alerts.append((2, pcfg.low_power.format( + balance=f"{power_balance:+d}"))) + elif 0 <= power_balance < 30: + building_power = any( + p["item"] in ("powr", "apwr") + for p in obs["production"] + ) + if not building_power: + alerts.append((6, pcfg.power_tight.format( + balance=f"{power_balance:+d}"))) + + # Idle funds with few harvesters + if acfg.idle_funds: + total_funds = eco["cash"] + eco.get("ore", 0) + if total_funds > 2000 and eco["harvester_count"] < 2: + alerts.append((6, pcfg.idle_funds.format( + funds=total_funds, harvesters=eco["harvester_count"]))) + + # Ore storage near capacity — income is being wasted + if acfg.ore_full: + ore = eco.get("ore", 0) + res_cap = eco.get("resource_capacity", 0) + if res_cap > 0 and ore >= res_cap * 0.9: + alerts.append((4, pcfg.ore_full.format(ore=ore, cap=res_cap))) + + # Nothing being produced + if acfg.idle_production: + if not obs["production"] and len(obs["buildings"]) >= 3: + alerts.append((4, pcfg.idle_production)) + + # Production stalled due to $0 funds + if acfg.production_stalled: + total_funds = eco["cash"] + eco.get("ore", 0) + current_progress = {p["item"]: p["progress"] for p in obs["production"] + if p["progress"] < 0.99} + last_progress = getattr(env, "_last_production_progress", {}) + if total_funds == 0 and current_progress: + for item, prog in current_progress.items(): + if item in last_progress and abs(prog - last_progress[item]) < 0.01: + alerts.append((2, pcfg.stalled.format( + item=item, progress=f"{prog:.0%}"))) + break # one alert is enough + env._last_production_progress = current_progress + + # Building ready to place or stuck in auto-placement + if acfg.building_ready: + pending = getattr(env, "_pending_placements", {}) + attempted = getattr(env, "_attempted_placements", {}) + for p in obs["production"]: + if p["queue_type"] in env._PLACEABLE_QUEUE_TYPES and p["progress"] >= 0.99: + btype = p["item"] + if btype in attempted: + alerts.append((3, pcfg.building_stuck.format(building=btype))) + elif btype not in pending: + alerts.append((3, pcfg.ready_to_place.format(building=btype))) + + # Auto-placement results from build_and_place (always shown — these are action feedback) + placement_results = getattr(env, "_placement_results", []) + if placement_results: + alerts.extend((3, msg) for msg in placement_results) + placement_results.clear() + + # Combat units on default ReturnFire stance + if acfg.stance_warning: + returnfire_count = sum( + 1 for u in obs["units"] + if u.get("can_attack") and u.get("stance", 1) == 1 + ) + if returnfire_count > 0: + alerts.append((7, pcfg.stance.format(count=returnfire_count))) + + # Idle army + if acfg.idle_army: + idle_combat = [u for u in obs["units"] if u.get("can_attack") and u.get("is_idle")] + if len(idle_combat) >= 4: + alerts.append((7, pcfg.idle_army.format(count=len(idle_combat)))) + + # No defenses + if acfg.no_defenses: + _DEFENSE_BUILDINGS = {"gun", "ftur", "tsla", "sam", "agun", "pbox", "hbox"} + building_types = {b["type"] for b in obs["buildings"]} + if len(obs["buildings"]) >= 4 and not (building_types & _DEFENSE_BUILDINGS): + alerts.append((7, pcfg.no_defenses)) + + # Track enemy discovery history + if obs.get("visible_enemies") or obs.get("visible_enemy_buildings"): + env._enemy_ever_seen = True + + # No scouting — factual alert (suppress after first contact) + if acfg.no_scouting: + if obs["tick"] > 750 and not obs["visible_enemies"] and not obs.get("visible_enemy_buildings"): + if not getattr(env, "_enemy_ever_seen", False): + idle_combat = sum(1 for u in obs["units"] if u.get("can_attack") and u.get("is_idle")) + # Compute exploration % from spatial data if available + _expl_pct = "?" + _spatial = obs.get("spatial_map", "") + _mi = obs.get("map_info", {}) + _sw, _sh, _sc = _mi.get("width", 0), _mi.get("height", 0), obs.get("spatial_channels", 0) + if _spatial and _sw > 0 and _sc > 0: + import base64 as _b64 + import struct as _st + try: + _raw = _b64.b64decode(_spatial) + _explored = sum( + 1 for _i in range(_sw * _sh) + if _st.unpack_from("f", _raw, (_i * _sc + 4) * 4)[0] > 0.25 + ) + _expl_pct = f"{round(100 * _explored / (_sw * _sh), 1)}%" + except Exception: + pass + alerts.append((7, pcfg.no_scouting.format( + explored=_expl_pct, idle=idle_combat))) + + # Sort alerts by priority and apply cap + alerts.sort(key=lambda x: x[0]) + max_a = env._app_config.alerts.max_alerts + if max_a > 0 and len(alerts) > max_a: + alerts = alerts[:max_a] + alert_texts = [text for _, text in alerts] + + # Compact summaries with actor IDs for planning + units_summary = [] + for u in obs["units"]: + uid = u["actor_id"] + entry = {"id": uid, "type": u["type"], "idle": u["is_idle"], + "can_attack": u["can_attack"], "stance": u["stance"], + "cell_x": u["cell_x"], "cell_y": u["cell_y"], + "activity": u["current_activity"]} + # Clear stale move targets for idle units + move_targets = getattr(env, "_move_targets", {}) + if u["is_idle"] and uid in move_targets: + del move_targets[uid] + # Attach tracked destination + if uid in move_targets: + tx, ty = move_targets[uid] + entry["target_x"] = tx + entry["target_y"] = ty + units_summary.append(entry) + buildings_summary = [ + {"id": b["actor_id"], "type": b["type"], + "cell_x": b["cell_x"], "cell_y": b["cell_y"]} + for b in obs["buildings"] + ] + enemy_summary = [ + {"id": e["actor_id"], "type": e["type"], + "cell_x": e["cell_x"], "cell_y": e["cell_y"]} + for e in obs["visible_enemies"] + ] + enemy_buildings_summary = [ + {"id": b["actor_id"], "type": b["type"], + "cell_x": b["cell_x"], "cell_y": b["cell_y"]} + for b in obs.get("visible_enemy_buildings", []) + ] + + # Render minimap (gated by config) + minimap = "" + if env._app_config.alerts.minimap: + minimap = _render_minimap(obs) + + # Compute exploration % from spatial tensor channel 4 (fog) + explored_pct = 0.0 + _map = obs.get("map_info", {}) + _w, _h = _map.get("width", 0), _map.get("height", 0) + _sp = obs.get("spatial_map", "") + _ch = obs.get("spatial_channels", 0) + if _sp and _ch > 0 and _w > 0 and _h > 0: + import base64 as _b64 + import struct as _st + try: + _raw = _b64.b64decode(_sp) + _explored = sum( + 1 for _i in range(_w * _h) + if _st.unpack_from("f", _raw, (_i * _ch + 4) * 4)[0] > 0.25 + ) + explored_pct = round(100 * _explored / (_w * _h), 1) + except Exception: + pass + + result = { + "tick": obs["tick"], + "done": obs["done"], + "result": obs.get("result", ""), + "faction": getattr(env, "_player_faction", ""), + "economy": obs["economy"], + "power_balance": power_balance, + "military": obs["military"], + "own_units": len(obs["units"]), + "own_buildings": len(obs["buildings"]), + "building_types": list(set(b["type"] for b in obs["buildings"])), + "visible_enemy_units": len(obs["visible_enemies"]), + "visible_enemy_buildings": len(obs.get("visible_enemy_buildings", [])), + "production_queues": len(obs["production"]), + "production_items": [ + f"{p['item']}@{p['progress']:.0%}(~{p.get('remaining_ticks', 0)} ticks)" + for p in obs["production"] + ], + "available_production": obs.get("available_production", []), + "units_summary": units_summary, + "buildings_summary": buildings_summary, + "enemy_summary": enemy_summary, + "enemy_buildings_summary": enemy_buildings_summary, + "minimap": minimap, + "explored_percent": explored_pct, + "reward_vector": dict(getattr(env, "_accumulated_reward_vector", {})), + "alerts": alert_texts, + "map": obs["map_info"], + } + + # Include planning phase context + if env._planning_active: + result["planning_active"] = True + result["planning_turns_remaining"] = max( + 0, env._planning_max_turns - env._planning_turns_used + ) + if env._planning_strategy: + result["planning_strategy"] = env._planning_strategy + + return result + + @configurable_tool + def get_economy() -> dict: + """Get current economic state: cash, ore, power, harvesters.""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return {"error": "No observation available."} + return obs["economy"] + + @configurable_tool + def get_units() -> list[dict]: + """Get list of own units with id, type, position, hp, activity, stance.""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return [] + return [ + { + "actor_id": u["actor_id"], + "type": u["type"], + "cell_x": u["cell_x"], + "cell_y": u["cell_y"], + "hp_percent": round(u["hp_percent"], 2), + "is_idle": u["is_idle"], + "current_activity": u["current_activity"], + "can_attack": u["can_attack"], + "stance": u["stance"], + "attack_range": u["attack_range"], + } + for u in obs["units"] + ] + + @configurable_tool + def get_buildings() -> list[dict]: + """Get list of own buildings with id, type, position, hp, production status, power.""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return [] + return [ + { + "actor_id": b["actor_id"], + "type": b["type"], + "cell_x": b["cell_x"], + "cell_y": b["cell_y"], + "hp_percent": round(b["hp_percent"], 2), + "is_producing": b["is_producing"], + "producing_item": b["producing_item"], + "production_progress": round(b["production_progress"], 2), + "is_powered": b["is_powered"], + "is_repairing": b["is_repairing"], + "power_amount": b["power_amount"], + "rally_x": b["rally_x"], + "rally_y": b["rally_y"], + "can_produce": b["can_produce"], + } + for b in obs["buildings"] + ] + + @configurable_tool + def get_enemies() -> dict: + """Get visible enemy units and buildings.""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return {"units": [], "buildings": []} + return { + "units": [ + { + "actor_id": u["actor_id"], + "type": u["type"], + "cell_x": u["cell_x"], + "cell_y": u["cell_y"], + "hp_percent": round(u["hp_percent"], 2), + "owner": u["owner"], + "can_attack": u["can_attack"], + } + for u in obs["visible_enemies"] + ], + "buildings": [ + { + "actor_id": b["actor_id"], + "type": b["type"], + "cell_x": b["cell_x"], + "cell_y": b["cell_y"], + "hp_percent": round(b["hp_percent"], 2), + "owner": b["owner"], + } + for b in obs.get("visible_enemy_buildings", []) + ], + } + + @configurable_tool + def get_production() -> dict: + """Get production queue items and available buildable types.""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return {"queue": [], "available": []} + return { + "queue": [ + { + "queue_type": p["queue_type"], + "item": p["item"], + "progress": round(p["progress"], 2), + "remaining_ticks": p["remaining_ticks"], + "paused": p["paused"], + } + for p in obs["production"] + ], + "available": obs.get("available_production", []), + } + + @configurable_tool + def get_map_info() -> dict: + """Get map dimensions and name.""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return {"error": "No observation available."} + return obs["map_info"] + + @configurable_tool + def get_exploration_status() -> dict: + """Get fog-of-war exploration status: overall explored %, per-quadrant + explored %, enemy visibility, base position, idle combat/infantry counts. + Use to understand how much of the map has been revealed.""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return {"error": "No observation available. Call reset first."} + + map_info = obs.get("map_info", {}) + w = map_info.get("width", 0) + h = map_info.get("height", 0) + channels = obs.get("spatial_channels", 0) + spatial = obs.get("spatial_map", "") + + # Base position + buildings = obs.get("buildings", []) + units = obs.get("units", []) + all_pos = ( + [(b["cell_x"], b["cell_y"]) for b in buildings] + + [(u["cell_x"], u["cell_y"]) for u in units] + ) + if all_pos: + base_x = sum(p[0] for p in all_pos) // len(all_pos) + base_y = sum(p[1] for p in all_pos) // len(all_pos) + else: + base_x, base_y = w // 2, h // 2 + + # Count idle combat / infantry + idle_combat = sum(1 for u in units if u.get("can_attack") and u.get("is_idle")) + from openra_env.game_data import RA_UNITS + infantry_types = {k for k, v in RA_UNITS.items() if v.get("category") == "infantry"} + idle_infantry = sum( + 1 for u in units + if u.get("is_idle") and u["type"] in infantry_types + ) + + result = { + "map_width": w, + "map_height": h, + "base_position": {"x": base_x, "y": base_y}, + "enemy_found": bool(getattr(env, "_enemy_ever_seen", False)), + "enemy_currently_visible": len(obs.get("visible_enemies", [])) + len(obs.get("visible_enemy_buildings", [])), + "idle_combat_count": idle_combat, + "idle_infantry_count": idle_infantry, + } + + if not spatial or w == 0 or channels == 0: + result["explored_percent"] = 0.0 + result["unexplored_percent"] = 100.0 + result["quadrant_exploration"] = {} + return result + + import base64 + import struct + + try: + raw = base64.b64decode(spatial) + except Exception: + result["explored_percent"] = 0.0 + result["unexplored_percent"] = 100.0 + result["quadrant_exploration"] = {} + return result + + total_cells = w * h + explored_count = 0 + half_w = w // 2 + half_h = h // 2 + quad_explored = {"NW": 0, "NE": 0, "SW": 0, "SE": 0} + quad_total = {"NW": 0, "NE": 0, "SW": 0, "SE": 0} + + for y in range(h): + for x in range(w): + base_idx = (y * w + x) * channels + try: + fog = struct.unpack_from("f", raw, (base_idx + 4) * 4)[0] + except struct.error: + continue + quad = ("N" if y < half_h else "S") + ("W" if x < half_w else "E") + quad_total[quad] += 1 + if fog > 0.25: + explored_count += 1 + quad_explored[quad] += 1 + + explored_pct = round(100 * explored_count / max(total_cells, 1), 1) + result["explored_percent"] = explored_pct + result["unexplored_percent"] = round(100 - explored_pct, 1) + + # Per-quadrant exploration with label + quad_exploration = {} + for quad in ["NW", "NE", "SW", "SE"]: + total = max(quad_total[quad], 1) + label = "" + if quad == ("N" if base_y < half_h else "S") + ("W" if base_x < half_w else "E"): + label = "your base area" + quad_exploration[quad] = { + "explored_percent": round(100 * quad_explored[quad] / total, 1), + "label": label, + } + result["quadrant_exploration"] = quad_exploration + + return result + + @configurable_tool + def get_terrain_at(cell_x: int, cell_y: int) -> dict: + """Check terrain at a map cell. Returns passability and whether it's + water. Useful before placing buildings (spen/syrd need water).""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return {"error": "No observation available"} + + spatial = obs.get("spatial_map", "") + map_info = obs.get("map_info", {}) + w = map_info.get("width", 0) + h = map_info.get("height", 0) + channels = obs.get("spatial_channels", 0) + + if not spatial or w == 0 or channels == 0: + return {"error": "No spatial map data available"} + if cell_x < 0 or cell_x >= w or cell_y < 0 or cell_y >= h: + return {"error": f"Out of bounds: ({cell_x},{cell_y}), map is {w}x{h}"} + + import base64 + import struct + try: + raw = base64.b64decode(spatial) + # Row-major channels-last: index = (y * w + x) * channels + ch + base_idx = (cell_y * w + cell_x) * channels + terrain_idx = struct.unpack_from("f", raw, base_idx * 4)[0] + passable = struct.unpack_from("f", raw, (base_idx + 3) * 4)[0] + is_passable = passable > 0.5 + tidx = int(terrain_idx) + if is_passable: + note = "Passable terrain." + elif tidx in (7, 8): + note = "Water — impassable to land units. spen/syrd require water." + else: + note = "Impassable terrain (cliff or obstacle)." + return { + "cell_x": cell_x, + "cell_y": cell_y, + "terrain_index": tidx, + "passable": is_passable, + "note": note, + } + except Exception as e: + return {"error": f"Failed to decode terrain: {e}"} + + # ── Game Knowledge Tools (static mod data) ─────────────────────── + + @configurable_tool + def lookup_unit(unit_type: str) -> dict: + """Look up stats for a unit type (e.g., 'e1', '3tnk', 'mig'). + Returns cost, HP, speed, armor, prerequisites, and description.""" + result = get_unit_stats(unit_type) + if result is None: + all_types = get_all_unit_types() + return {"error": f"Unknown unit type '{unit_type}'", "available_types": all_types} + return result + + @configurable_tool + def lookup_building(building_type: str) -> dict: + """Look up stats for a building type (e.g., 'powr', 'weap', 'stek'). + Returns cost, HP, power, prerequisites, and description.""" + result = get_building_stats(building_type) + if result is None: + all_types = get_all_building_types() + return {"error": f"Unknown building type '{building_type}'", "available_types": all_types} + return result + + @configurable_tool + def lookup_tech_tree(faction: str = "soviet") -> dict: + """Get the tech tree / build order for a faction or side. + Accepts faction names ('russia', 'england') or sides ('allied', 'soviet').""" + return get_tech_tree(faction) + + @configurable_tool + def lookup_faction(faction: str) -> dict: + """Get faction info including all available units and buildings. + Faction names: 'england', 'france', 'germany', 'russia', 'ukraine'.""" + result = get_faction_info(faction) + if result is None: + return {"error": f"Unknown faction '{faction}'", "factions": ["england", "france", "germany", "russia", "ukraine"]} + return result + + # ── Bulk Knowledge Tools (rich context in one call) ───────────── + + @configurable_tool + def get_faction_briefing() -> dict: + """Get complete briefing for your faction: all available units with + full stats, all available buildings with full stats, tech tree, and + faction info. One call gives you everything about your faction's + military capabilities. Ideal for planning phase.""" + faction = env._player_faction + if not faction: + # Infer from observation + env._refresh_obs() + obs = env._last_obs + if obs: + avail = obs.get("available_production", []) + bldg_types = [b["type"] for b in obs.get("buildings", [])] + if "tent" in avail or "tent" in bldg_types: + faction = "england" + else: + faction = "russia" + + faction_info = get_faction_info(faction) + if faction_info is None: + return {"error": f"Could not determine faction (got '{faction}')"} + + side = faction_info["side"] + units = get_all_units_for_side(side) + buildings = get_all_buildings_for_side(side) + tech_tree = get_tech_tree(side) + + return { + "faction": faction, + "side": side, + "description": faction_info.get("description", ""), + "unique_units": faction_info.get("unique_units", []), + "tech_tree": tech_tree.get(side, []), + "units": units, + "buildings": buildings, + } + + @configurable_tool + def get_map_analysis() -> dict: + """Analyze the map and produce a strategic summary: resource patch + locations, water presence, passability overview, quadrant breakdown, + and key terrain features. Essential for planning.""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return {"error": "No observation available. Call reset first."} + + map_info = obs.get("map_info", {}) + w = map_info.get("width", 0) + h = map_info.get("height", 0) + channels = obs.get("spatial_channels", 0) + spatial = obs.get("spatial_map", "") + + # Base/enemy position + buildings = obs.get("buildings", []) + units = obs.get("units", []) + all_pos = ( + [(b["cell_x"], b["cell_y"]) for b in buildings] + + [(u["cell_x"], u["cell_y"]) for u in units] + ) + if all_pos: + base_x = sum(p[0] for p in all_pos) // len(all_pos) + base_y = sum(p[1] for p in all_pos) // len(all_pos) + else: + base_x, base_y = w // 2, h // 2 + enemy_x = max(2, min(w - 2, w - base_x)) + enemy_y = max(2, min(h - 2, h - base_y)) + distance = abs(enemy_x - base_x) + abs(enemy_y - base_y) + + result = { + "map_name": map_info.get("map_name", "?"), + "width": w, + "height": h, + "base_position": {"x": base_x, "y": base_y}, + "enemy_estimated_position": {"x": enemy_x, "y": enemy_y}, + "base_to_enemy_distance": distance, + } + + if not spatial or w == 0 or channels == 0: + result["note"] = "No spatial data available for detailed analysis" + return result + + import base64 + import struct + + try: + raw = base64.b64decode(spatial) + except Exception: + result["note"] = "Failed to decode spatial data" + return result + + total_cells = w * h + passable_count = 0 + water_count = 0 + explored_count = 0 + visible_count = 0 + resource_cells = [] + half_w = w // 2 + half_h = h // 2 + quad_stats = { + "NW": {"passable": 0, "total": 0, "resources": 0, "explored": 0}, + "NE": {"passable": 0, "total": 0, "resources": 0, "explored": 0}, + "SW": {"passable": 0, "total": 0, "resources": 0, "explored": 0}, + "SE": {"passable": 0, "total": 0, "resources": 0, "explored": 0}, + } + + for y in range(h): + for x in range(w): + base_idx = (y * w + x) * channels + try: + passable = struct.unpack_from("f", raw, (base_idx + 3) * 4)[0] + resource = struct.unpack_from("f", raw, (base_idx + 2) * 4)[0] + fog = struct.unpack_from("f", raw, (base_idx + 4) * 4)[0] + except struct.error: + continue + + # Determine quadrant + quad = ("N" if y < half_h else "S") + ("W" if x < half_w else "E") + quad_stats[quad]["total"] += 1 + + if passable > 0.5: + passable_count += 1 + quad_stats[quad]["passable"] += 1 + else: + water_count += 1 + + if resource > 0: + resource_cells.append((x, y, resource)) + quad_stats[quad]["resources"] += 1 + + # Fog of war: 0.0=shroud, 0.5=fog(explored), 1.0=visible + if fog > 0.25: + explored_count += 1 + quad_stats[quad]["explored"] += 1 + if fog > 0.75: + visible_count += 1 + + # Exploration + explored_pct = round(100 * explored_count / max(total_cells, 1), 1) + visible_pct = round(100 * visible_count / max(total_cells, 1), 1) + result["exploration"] = { + "explored_percent": explored_pct, + "unexplored_percent": round(100 - explored_pct, 1), + "visible_percent": visible_pct, + } + + # Passability + passable_ratio = passable_count / max(total_cells, 1) + result["passable_ratio"] = round(passable_ratio, 2) + result["has_water"] = water_count > (total_cells * 0.02) + + # Map type + water_ratio = water_count / max(total_cells, 1) + if water_ratio > 0.4: + map_type = "island/naval" + elif water_ratio > 0.1: + map_type = "mixed terrain" + elif passable_ratio > 0.8: + map_type = "open land" + else: + map_type = "confined terrain" + result["map_type"] = map_type + + # Cluster resource cells into patches using simple grid-based grouping + patches = [] + if resource_cells: + visited = set() + for rx, ry, rd in resource_cells: + if (rx, ry) in visited: + continue + # BFS to find connected resource cells (8-connectivity, radius 3) + cluster = [] + queue = [(rx, ry)] + visited.add((rx, ry)) + while queue: + cx, cy = queue.pop(0) + # Find density for this cell + cell_density = 0 + for rcx, rcy, rcd in resource_cells: + if rcx == cx and rcy == cy: + cell_density = rcd + break + cluster.append((cx, cy, cell_density)) + for dx in range(-3, 4): + for dy in range(-3, 4): + nx, ny = cx + dx, cy + dy + if (nx, ny) not in visited: + for rcx, rcy, rcd in resource_cells: + if rcx == nx and rcy == ny: + visited.add((nx, ny)) + queue.append((nx, ny)) + break + + if len(cluster) >= 2: # Only report meaningful patches + cx = sum(c[0] for c in cluster) // len(cluster) + cy = sum(c[1] for c in cluster) // len(cluster) + total_density = sum(c[2] for c in cluster) + near_base = abs(cx - base_x) + abs(cy - base_y) < distance // 3 + patches.append({ + "center_x": cx, + "center_y": cy, + "cells": len(cluster), + "total_density": round(total_density, 1), + "near_base": near_base, + }) + + # Sort patches: nearest to base first + patches.sort(key=lambda p: abs(p["center_x"] - base_x) + abs(p["center_y"] - base_y)) + result["resource_patches"] = patches[:10] # Cap at 10 + + # Quadrant summary + quadrant_summary = {} + for quad, stats in quad_stats.items(): + total = max(stats["total"], 1) + note = "" + if quad == ("N" if base_y < half_h else "S") + ("W" if base_x < half_w else "E"): + note = "your base area" + elif quad == ("N" if enemy_y < half_h else "S") + ("W" if enemy_x < half_w else "E"): + note = "enemy base area" + quadrant_summary[quad] = { + "passable_ratio": round(stats["passable"] / total, 2), + "resource_cells": stats["resources"], + "explored_percent": round(100 * stats["explored"] / total, 1), + "note": note, + } + result["quadrant_summary"] = quadrant_summary + + # Strategic notes + notes = [] + if result["has_water"]: + notes.append("Naval buildings possible — water detected on map") + else: + notes.append("Land-only map — skip naval buildings (spen/syrd)") + if patches: + nearest = patches[0] + dist_to_ore = abs(nearest["center_x"] - base_x) + abs(nearest["center_y"] - base_y) + notes.append(f"Nearest ore patch at ({nearest['center_x']},{nearest['center_y']}), " + f"{dist_to_ore} cells from base, {nearest['cells']} resource cells") + if distance < 40: + notes.append("Short distance to enemy — expect early aggression, prioritize defense") + elif distance > 100: + notes.append("Long distance to enemy — time to build economy before attacking") + result["strategic_notes"] = notes + + return result + + @configurable_tool + def batch_lookup(queries: list[dict]) -> dict: + """Look up multiple units, buildings, factions, or tech trees in one call. + Each query: {"type": "unit"|"building"|"faction"|"tech_tree", "name": "..."} + Example: [{"type":"unit","name":"3tnk"}, {"type":"building","name":"weap"}] + Returns all results at once — efficient for researching multiple items.""" + results = [] + for q in queries: + qtype = q.get("type", "").lower() + name = q.get("name", "") + if qtype == "unit": + data = get_unit_stats(name) + if data is None: + results.append({"error": f"Unknown unit '{name}'", "query": q}) + else: + results.append({"type": "unit", "name": name, **data}) + elif qtype == "building": + data = get_building_stats(name) + if data is None: + results.append({"error": f"Unknown building '{name}'", "query": q}) + else: + results.append({"type": "building", "name": name, **data}) + elif qtype == "faction": + data = get_faction_info(name) + if data is None: + results.append({"error": f"Unknown faction '{name}'", "query": q}) + else: + results.append({"type": "faction", **data}) + elif qtype == "tech_tree": + data = get_tech_tree(name) + results.append({"type": "tech_tree", "name": name, **data}) + else: + results.append({"error": f"Unknown query type '{qtype}'", "query": q}) + return {"results": results, "count": len(results)} + + # ── Planning Phase Tools ──────────────────────────────────────── + + @configurable_tool + def get_opponent_intel() -> dict: + """Get intelligence report on the opponent AI. Returns behavioral + profile, win rate, typical strategy, recommended counters, and + recent match history. Use this during planning to prepare your strategy.""" + difficulty = env._config.bot_type # "easy", "normal", "hard" + profile = get_opponent_profile(difficulty) + if profile is None: + return { + "difficulty": difficulty, + "note": "No detailed profile available for this opponent type.", + } + result = dict(profile) + result["your_faction"] = env._player_faction + result["enemy_faction"] = env._enemy_faction + return result + + @configurable_tool + def start_planning_phase() -> dict: + """Begin the pre-game planning phase. Returns map metadata, faction info, + opponent intelligence, tech tree, and available units/buildings. + + Available knowledge tools during planning: get_faction_briefing, + get_map_analysis, get_opponent_intel, batch_lookup, lookup_unit, + lookup_building, lookup_tech_tree, lookup_faction. + + Planning has a turn limit and time limit. If exceeded, planning ends + automatically. Call end_planning_phase(strategy=...) to finish.""" + if not env._planning_enabled: + return { + "planning_enabled": False, + "message": "Planning phase is disabled. Proceed directly to gameplay.", + } + + if env._planning_active: + return { + "error": "Planning phase already active.", + "turns_used": env._planning_turns_used, + "turns_remaining": max(0, env._planning_max_turns - env._planning_turns_used), + } + + env._planning_active = True + env._planning_start_time = time.time() + env._planning_turns_used = 0 + env._planning_strategy = "" + + # Gather initial game metadata + env._refresh_obs() + obs = env._last_obs or {} + + map_info = obs.get("map_info", {}) + buildings = obs.get("buildings", []) + units = obs.get("units", []) + + # Base position + all_positions = ( + [(b["cell_x"], b["cell_y"]) for b in buildings] + + [(u["cell_x"], u["cell_y"]) for u in units] + ) + if all_positions: + base_x = sum(p[0] for p in all_positions) // len(all_positions) + base_y = sum(p[1] for p in all_positions) // len(all_positions) + else: + base_x = map_info.get("width", 128) // 2 + base_y = map_info.get("height", 128) // 2 + + # Enemy spawn estimate (opposite side of map) + map_w = map_info.get("width", 128) + map_h = map_info.get("height", 128) + enemy_x = max(2, min(map_w - 2, map_w - base_x)) + enemy_y = max(2, min(map_h - 2, map_h - base_y)) + + # Faction and tech tree + faction = env._player_faction + enemy_faction = env._enemy_faction + faction_info = get_faction_info(faction) if faction else None + side = faction_info["side"] if faction_info else "unknown" + tech_tree = get_tech_tree(side) if side != "unknown" else {} + + # Opponent intel + opponent_profile = get_opponent_profile(env._config.bot_type) + opponent_summary = get_opponent_summary(env._config.bot_type) + + # Key units/buildings with full stats (top 8 by cost) + key_units = {} + key_buildings = {} + if side != "unknown": + all_units = get_all_units_for_side(side) + sorted_units = sorted(all_units.items(), key=lambda x: x[1].get("cost", 0), reverse=True) + for utype, udata in sorted_units[:8]: + key_units[utype] = udata + all_bldgs = get_all_buildings_for_side(side) + sorted_bldgs = sorted(all_bldgs.items(), key=lambda x: x[1].get("cost", 0), reverse=True) + for btype, bdata in sorted_bldgs[:8]: + key_buildings[btype] = bdata + + return { + "planning_active": True, + "max_turns": env._planning_max_turns, + "max_time_seconds": env._planning_max_time_s, + "map": map_info, + "base_position": {"x": base_x, "y": base_y}, + "enemy_estimated_position": {"x": enemy_x, "y": enemy_y}, + "your_faction": faction, + "your_side": side, + "enemy_faction": enemy_faction, + "tech_tree": tech_tree, + "available_units": faction_info.get("available_units", []) if faction_info else [], + "available_buildings": faction_info.get("available_buildings", []) if faction_info else [], + "key_units": key_units, + "key_buildings": key_buildings, + "starting_units": [ + {"type": u["type"], "id": u["actor_id"], "cell_x": u["cell_x"], "cell_y": u["cell_y"]} + for u in units + ], + "starting_buildings": [ + {"type": b["type"], "id": b["actor_id"], "cell_x": b["cell_x"], "cell_y": b["cell_y"]} + for b in buildings + ], + "opponent_intel": opponent_profile, + "opponent_summary": opponent_summary, + "reward_dimensions": { + "combat": "Cost-weighted damage exchange — kill expensive enemies, protect your own units", + "economy": "Economic growth — build refineries, expand harvester fleet, deny enemy economy", + "infrastructure": "Base building — unlock new building types, keep production active, maintain power", + "intelligence": "Scouting & discovery — explore the map, spot enemy units and buildings", + "composition": "Army mix quality — build units that counter the enemy's army composition", + "tempo": "Action efficiency — keep units active, issue orders, avoid idle armies", + "disruption": "Strategic sabotage — destroy enemy power plants, production, and tech buildings", + "outcome": "Win (+1.0) or lose (-1.0)", + }, + "instructions": env._app_config.prompts.planning_instructions, + } + + @configurable_tool + def end_planning_phase(strategy: str = "") -> dict: + """End the planning phase and transition to gameplay. + + Args: + strategy: Your formulated strategy as a text summary. This will be + available as context during gameplay. + + Returns game state summary to begin gameplay.""" + if not env._planning_active: + return { + "error": "No planning phase active.", + "planning_enabled": env._planning_enabled, + } + + elapsed = time.time() - env._planning_start_time + env._planning_active = False + env._planning_strategy = strategy.strip() if strategy else "" + env._planning_turns_used = env._planning_turns_used + env._state.planning_strategy = env._planning_strategy + env._state.planning_turns_used = env._planning_turns_used + + # Start the streaming session NOW — this unpauses the game. + # The game has been paused since reset, so tick should still be 0. + try: + loop = getattr(env, '_loop', None) + bridge = getattr(env, '_bridge', None) + if loop and bridge: + asyncio.run_coroutine_threadsafe( + env._ensure_session_started(), loop + ).result(timeout=30) + except (AttributeError, RuntimeError) as e: + logger.debug(f"Could not start session in end_planning_phase: {e}") + + # Get current game state to hand off + env._refresh_obs() + obs = env._last_obs or {} + + return { + "planning_complete": True, + "planning_duration_seconds": round(elapsed, 1), + "planning_turns_used": env._planning_turns_used, + "strategy_recorded": bool(env._planning_strategy), + "strategy": env._planning_strategy, + "tick": obs.get("tick", 0), + "economy": obs.get("economy", {}), + "own_units": len(obs.get("units", [])), + "own_buildings": len(obs.get("buildings", [])), + "message": env._app_config.prompts.planning_complete, + } + + @configurable_tool + def get_planning_status() -> dict: + """Check the status of the planning phase — turns used, time remaining.""" + if not env._planning_enabled: + return {"planning_enabled": False} + if not env._planning_active: + return { + "planning_active": False, + "strategy": env._planning_strategy or "(none)", + } + + elapsed = time.time() - env._planning_start_time + return { + "planning_active": True, + "turns_used": env._planning_turns_used, + "turns_remaining": max(0, env._planning_max_turns - env._planning_turns_used), + "time_elapsed_seconds": round(elapsed, 1), + "time_remaining_seconds": round(max(0, env._planning_max_time_s - elapsed), 1), + } + + # ── Action Tools (advance game state) ──────────────────────────── + + @configurable_tool + def advance(ticks: int = 1) -> dict: + """Advance the game by N ticks (max 500 per call, ~25 ticks = 1 second). + Production, movement, combat, and building auto-placement all require + game time to progress — nothing happens without calling advance(). + Also triggers auto-placement of buildings queued via build_and_place(). + Typical build times: power plant ~300 ticks, barracks ~500 ticks, + war factory ~750 ticks. Returns updated game summary.""" + requested = ticks + ticks = max(1, min(ticks, 500)) # clamp to [1, 500] + try: + future = asyncio.run_coroutine_threadsafe( + env._bridge.wait_ticks(ticks), env._loop + ) + proto_obs = future.result(timeout=300) + obs_dict = observation_to_dict(proto_obs) + env._last_obs = obs_dict + except Exception: + # Connection lost — check if game ended while waiting + env._refresh_obs() + obs_dict = env._last_obs + if obs_dict is None or not obs_dict.get("done"): + raise + + env._state.game_tick = obs_dict["tick"] + # Accumulate reward vector if enabled + try: + _, reward_vec = env._reward_fn.compute_all(obs_dict) + if reward_vec: + for k, v in reward_vec.items(): + env._accumulated_reward_vector[k] = env._accumulated_reward_vector.get(k, 0.0) + v + except Exception: + pass + # Track losses and trigger auto-placement + if env._app_config.alerts.loss_tracking: + env._update_loss_tracking() + env._process_pending_placements() + result = { + "tick": obs_dict["tick"], + "done": obs_dict["done"], + "result": obs_dict.get("result", ""), + "economy": obs_dict["economy"], + "own_units": len(obs_dict["units"]), + "own_buildings": len(obs_dict["buildings"]), + "visible_enemies": len(obs_dict["visible_enemies"]), + } + if requested > 500: + result["note"] = f"Clamped from {requested} to 500 ticks (max per call)." + return result + + @configurable_tool + def move_units(unit_ids: str, target_x: int, target_y: int, queued: bool = False) -> dict: + """Move units to a map cell position. Units pathfind automatically. + unit_ids: comma-separated IDs, "all_combat", "all_idle", "type:e1", "all_infantry", "all_vehicles", or a group name.""" + env._refresh_obs() + resolved = env._resolve_unit_ids(unit_ids, env._last_obs or {}) + if not resolved: + return {"error": "No matching units found"} + commands = [ + CommandModel(action=ActionType.MOVE, actor_id=uid, target_x=target_x, target_y=target_y, queued=queued) + for uid in resolved + ] + result = env._execute_commands(commands) + targets = getattr(env, "_move_targets", None) + if targets is not None: + for uid in resolved: + targets[uid] = (target_x, target_y) + return env._add_unit_feedback(result, resolved, target_x=target_x, target_y=target_y) + + @configurable_tool + def attack_move(unit_ids: str, target_x: int, target_y: int, queued: bool = False) -> dict: + """Move units toward a cell, attacking enemies encountered along the way. + unit_ids: comma-separated IDs, "all_combat", "all_idle", "type:e1", "all_infantry", "all_vehicles", or a group name.""" + env._refresh_obs() + resolved = env._resolve_unit_ids(unit_ids, env._last_obs or {}) + if not resolved: + return {"error": "No matching units found"} + commands = [ + CommandModel(action=ActionType.ATTACK_MOVE, actor_id=uid, target_x=target_x, target_y=target_y, queued=queued) + for uid in resolved + ] + result = env._execute_commands(commands) + targets = getattr(env, "_move_targets", None) + if targets is not None: + for uid in resolved: + targets[uid] = (target_x, target_y) + return env._add_unit_feedback(result, resolved, target_x=target_x, target_y=target_y) + + @configurable_tool + def attack_target(unit_ids: str, target_actor_id: int, queued: bool = False) -> dict: + """Order units to attack a specific enemy actor by ID. + unit_ids: comma-separated IDs, "all_combat", "all_idle", "type:e1", "all_infantry", "all_vehicles", or a group name.""" + env._refresh_obs() + resolved = env._resolve_unit_ids(unit_ids, env._last_obs or {}) + if not resolved: + return {"error": "No matching units found"} + commands = [ + CommandModel(action=ActionType.ATTACK, actor_id=uid, target_actor_id=target_actor_id, queued=queued) + for uid in resolved + ] + result = env._execute_commands(commands) + return env._add_unit_feedback(result, resolved) + + @configurable_tool + def stop_units(unit_ids: str) -> dict: + """Stop units from their current activity. + unit_ids: comma-separated IDs, "all_combat", "all_idle", "type:e1", "all_infantry", "all_vehicles", or a group name.""" + env._refresh_obs() + resolved = env._resolve_unit_ids(unit_ids, env._last_obs or {}) + if not resolved: + return {"error": "No matching units found"} + commands = [CommandModel(action=ActionType.STOP, actor_id=uid) for uid in resolved] + result = env._execute_commands(commands) + return env._add_unit_feedback(result, resolved) + + @configurable_tool + def build_unit(unit_type: str, count: int = 1) -> dict: + """Start training units (infantry, vehicle, aircraft, ship). + The unit_type is the internal name (e.g., 'e1', '3tnk', 'mig'). + Use count > 1 to queue multiple of the same type.""" + # Validate against available production + env._refresh_obs() + if env._last_obs: + available = env._last_obs.get("available_production", []) + if not available: + # No production buildings — nothing can be built + return { + "error": "No production buildings available. Build a Construction Yard (deploy MCV), Barracks, War Factory, etc. first.", + "available_units": [], + } + if unit_type not in available: + all_buildings = get_all_building_types() + avail_units = [u for u in available if u not in all_buildings] + diag = env._diagnose_unavailable(unit_type) + result = { + "error": diag["reason"], + "available_units": avail_units, + } + if "missing_prerequisites" in diag: + result["missing_prerequisites"] = diag["missing_prerequisites"] + return result + # Check funds + eco = env._last_obs.get("economy", {}) + total_funds = eco.get("cash", 0) + eco.get("ore", 0) + unit_stats = get_unit_stats(unit_type) + unit_cost = unit_stats["cost"] if unit_stats else 0 + if unit_cost > 0 and total_funds < unit_cost: + return { + "error": env._app_config.prompts.insufficient_funds.format( + available=total_funds, item=unit_type, cost=unit_cost), + } + count = max(1, min(count, 10)) + commands = [CommandModel(action=ActionType.TRAIN, item_type=unit_type) + for _ in range(count)] + result = env._execute_commands(commands) + # Factual build confirmation with estimated ticks + unit_stats = get_unit_stats(unit_type) + unit_cost = unit_stats["cost"] if unit_stats else 0 + if unit_cost > 0: + ticks_each = _estimate_build_ticks(unit_cost) + ticks_total = ticks_each * count + result["note"] = env._app_config.prompts.build_unit_queued.format( + count=count, unit=unit_type, cost=unit_cost, + ticks_each=ticks_each, ticks_total=ticks_total, + seconds_total=round(ticks_total / 25, 1)) + return result + + @configurable_tool + def build_structure(building_type: str) -> dict: + """Start constructing a building (manual placement workflow). + After calling this, call advance(ticks) to let construction finish, + then call place_building() to place it on the map. + Prefer build_and_place() which handles placement automatically. + building_type: internal name (e.g., 'powr', 'barr', 'weap').""" + # Reject if same building already in production queue + env._refresh_obs() + if env._last_obs: + available = env._last_obs.get("available_production", []) + if not available: + return { + "error": "No Construction Yard (fact) — requires MCV deployment to build.", + "available_buildings": [], + } + if building_type not in available: + all_buildings = get_all_building_types() + avail_bldgs = [b for b in available if b in all_buildings] + diag = env._diagnose_unavailable(building_type) + result = { + "error": diag["reason"], + "available_buildings": avail_bldgs, + } + if "missing_prerequisites" in diag: + result["missing_prerequisites"] = diag["missing_prerequisites"] + return result + if building_type in env._pending_placements: + return {"note": env._app_config.prompts.build_already_pending.format( + building=building_type)} + already = any( + p["queue_type"] in env._PLACEABLE_QUEUE_TYPES and p["item"] == building_type + for p in env._last_obs.get("production", []) + ) + if already: + return {"error": f"'{building_type}' is already in the production queue."} + commands = [CommandModel(action=ActionType.BUILD, item_type=building_type)] + result = env._execute_commands(commands) + # Factual build confirmation with estimated ticks + stats = get_building_stats(building_type) + bld_cost = stats["cost"] if stats else 0 + if bld_cost > 0: + ticks = _estimate_build_ticks(bld_cost) + result["note"] = env._app_config.prompts.build_structure_queued.format( + building=building_type, cost=bld_cost, + ticks=ticks, seconds=round(ticks / 25, 1)) + # Proactive power warning + if stats and env._last_obs: + power_drain = stats.get("power", 0) + if power_drain < 0: + eco = env._last_obs.get("economy", {}) + current_balance = eco.get("power_provided", 0) - eco.get("power_drained", 0) + if current_balance + power_drain < 0: + new_balance = current_balance + power_drain + result["warning"] = env._app_config.prompts.power_warning.format( + building=building_type, drain=abs(power_drain), + balance=f"{new_balance:+d}") + return result + + @configurable_tool + def build_and_place(building_type: str, cell_x: int = 0, cell_y: int = 0) -> dict: + """Build a structure and auto-place it when construction finishes. + After calling this, you must call advance(ticks) to let construction + complete — the building auto-places once done. Do NOT call + place_building() on buildings queued this way — placement is automatic. + Coordinates are optional — the engine finds a valid position near + your base if omitted. Returns updated game summary.""" + # Validate and reject duplicates + env._refresh_obs() + if env._last_obs: + available = env._last_obs.get("available_production", []) + if not available: + return { + "error": "No Construction Yard (fact) — requires MCV deployment to build.", + "available_buildings": [], + } + if building_type not in available: + all_buildings = get_all_building_types() + avail_bldgs = [b for b in available if b in all_buildings] + diag = env._diagnose_unavailable(building_type) + result = { + "error": diag["reason"], + "available_buildings": avail_bldgs, + } + if "missing_prerequisites" in diag: + result["missing_prerequisites"] = diag["missing_prerequisites"] + return result + if building_type in env._pending_placements: + return {"note": env._app_config.prompts.build_already_pending.format( + building=building_type)} + already = any( + p["queue_type"] in env._PLACEABLE_QUEUE_TYPES and p["item"] == building_type + for p in env._last_obs.get("production", []) + ) + if already: + return {"error": f"'{building_type}' is already in the production queue."} + commands = [CommandModel(action=ActionType.BUILD, item_type=building_type)] + result = env._execute_commands(commands) + env._pending_placements[building_type] = {"cell_x": cell_x, "cell_y": cell_y} + # Factual build confirmation with estimated ticks + stats = get_building_stats(building_type) + bld_cost = stats["cost"] if stats else 0 + if bld_cost > 0: + ticks = _estimate_build_ticks(bld_cost) + result["note"] = env._app_config.prompts.build_queued.format( + building=building_type, cost=bld_cost, + ticks=ticks, seconds=round(ticks / 25, 1)) + # Proactive power warning + if stats and env._last_obs: + power_drain = stats.get("power", 0) + if power_drain < 0: + eco = env._last_obs.get("economy", {}) + current_balance = eco.get("power_provided", 0) - eco.get("power_drained", 0) + if current_balance + power_drain < 0: + new_balance = current_balance + power_drain + result["warning"] = env._app_config.prompts.power_warning.format( + building=building_type, drain=abs(power_drain), + balance=f"{new_balance:+d}") + return result + + @configurable_tool + def place_building(building_type: str, cell_x: int = 0, cell_y: int = 0) -> dict: + """Place a completed building on the map (only for build_structure workflow). + The building must be at 100% in the production queue or this will error. + Do NOT use this on buildings queued via build_and_place() — those + auto-place via advance(). Cell coordinates are optional — the engine + auto-finds a valid position near your base if omitted.""" + # Guard: building queued via build_and_place auto-places + if building_type in env._pending_placements: + return { + "note": env._app_config.prompts.place_auto_managed.format( + building=building_type), + "tick": env._last_obs.get("tick", 0) if env._last_obs else 0, + } + # Check building is ready in production queue + env._refresh_obs() + pre_obs = env._last_obs + if pre_obs: + ready = any( + p["queue_type"] in env._PLACEABLE_QUEUE_TYPES and p["item"] == building_type and p["progress"] >= 0.99 + for p in pre_obs["production"] + ) + if not ready: + return { + "error": f"'{building_type}' not ready to place — not at 100% in production queue.", + "tick": pre_obs.get("tick", 0), + } + + env._pending_placements.pop(building_type, None) + commands = [CommandModel(action=ActionType.PLACE_BUILDING, + item_type=building_type, target_x=cell_x, target_y=cell_y)] + return env._execute_commands(commands) + + @configurable_tool + def cancel_production(item_type: str) -> dict: + """Cancel production of an item currently in a production queue.""" + env._refresh_obs() + obs = env._last_obs or {} + queue = obs.get("production", []) + in_queue = any(p.get("type", "").lower() == item_type.lower() for p in queue) + if not in_queue: + queued_items = [p.get("type", "") for p in queue] + return {"error": f"'{item_type}' is not in the production queue.", "current_queue": queued_items} + commands = [CommandModel(action=ActionType.CANCEL_PRODUCTION, item_type=item_type)] + return env._execute_commands(commands) + + @configurable_tool + def deploy_unit(unit_id: int) -> dict: + """Deploy a unit (e.g., MCV → Construction Yard).""" + env._refresh_obs() + obs = env._last_obs or {} + units = obs.get("units", []) + if not any(u.get("actor_id") == unit_id for u in units): + return {"error": f"Unit {unit_id} not found. It may have been destroyed.", + "your_units": [{"id": u["actor_id"], "type": u["type"]} for u in units[:20]]} + commands = [CommandModel(action=ActionType.DEPLOY, actor_id=unit_id)] + return env._execute_commands(commands) + + @configurable_tool + def sell_building(building_id: int) -> dict: + """Sell a building for partial refund.""" + env._refresh_obs() + obs = env._last_obs or {} + buildings = obs.get("buildings", []) + if not any(b.get("actor_id") == building_id for b in buildings): + return {"error": f"Building {building_id} not found. It may have been destroyed or sold.", + "your_buildings": [{"id": b["actor_id"], "type": b["type"]} for b in buildings[:20]]} + commands = [CommandModel(action=ActionType.SELL, actor_id=building_id)] + return env._execute_commands(commands) + + @configurable_tool + def repair_building(building_id: int) -> dict: + """Toggle repair on a building. Costs credits over time.""" + env._refresh_obs() + obs = env._last_obs or {} + buildings = obs.get("buildings", []) + if not any(b.get("actor_id") == building_id for b in buildings): + return {"error": f"Building {building_id} not found. It may have been destroyed or sold.", + "your_buildings": [{"id": b["actor_id"], "type": b["type"]} for b in buildings[:20]]} + commands = [CommandModel(action=ActionType.REPAIR, actor_id=building_id)] + return env._execute_commands(commands) + + @configurable_tool + def set_rally_point(building_id: int, cell_x: int, cell_y: int) -> dict: + """Set rally point for a production building. Newly produced units + will move to this location.""" + env._refresh_obs() + obs = env._last_obs or {} + buildings = obs.get("buildings", []) + if not any(b.get("actor_id") == building_id for b in buildings): + return {"error": f"Building {building_id} not found. It may have been destroyed or sold.", + "your_buildings": [{"id": b["actor_id"], "type": b["type"]} for b in buildings[:20]]} + commands = [CommandModel(action=ActionType.SET_RALLY_POINT, actor_id=building_id, target_x=cell_x, target_y=cell_y)] + return env._execute_commands(commands) + + @configurable_tool + def guard_target(unit_ids: str, target_actor_id: int, queued: bool = False) -> dict: + """Order units to guard another actor, following and protecting it. + unit_ids: comma-separated IDs, "all_combat", "all_idle", "type:e1", "all_infantry", "all_vehicles", or a group name.""" + env._refresh_obs() + resolved = env._resolve_unit_ids(unit_ids, env._last_obs or {}) + if not resolved: + return {"error": "No matching units found"} + commands = [ + CommandModel(action=ActionType.GUARD, actor_id=uid, target_actor_id=target_actor_id, queued=queued) + for uid in resolved + ] + result = env._execute_commands(commands) + return env._add_unit_feedback(result, resolved) + + @configurable_tool + def set_stance(unit_ids: str, stance: str) -> dict: + """Set combat stance for units. + Stances: 'hold_fire' (0), 'return_fire' (1), 'defend' (2), 'attack_anything' (3). + unit_ids: comma-separated IDs, "all_combat", "all_idle", "type:e1", "all_infantry", "all_vehicles", or a group name.""" + env._refresh_obs() + resolved = env._resolve_unit_ids(unit_ids, env._last_obs or {}) + if not resolved: + return {"error": "No matching units found"} + stance_map = {"hold_fire": 0, "return_fire": 1, "defend": 2, "attack_anything": 3} + stance_val = stance_map.get(stance.lower(), 3) + commands = [ + CommandModel(action=ActionType.SET_STANCE, actor_id=uid, target_x=stance_val) + for uid in resolved + ] + result = env._execute_commands(commands) + return env._add_unit_feedback(result, resolved) + + @configurable_tool + def harvest(unit_id: int, cell_x: int = 0, cell_y: int = 0) -> dict: + """Send a harvester to collect ore. If cell_x/cell_y are provided, + harvest at that location. Otherwise, auto-harvest nearest ore.""" + env._refresh_obs() + obs = env._last_obs or {} + units = obs.get("units", []) + if not any(u.get("actor_id") == unit_id for u in units): + return {"error": f"Unit {unit_id} not found. It may have been destroyed.", + "your_units": [{"id": u["actor_id"], "type": u["type"]} for u in units[:20]]} + commands = [CommandModel(action=ActionType.HARVEST, actor_id=unit_id, target_x=cell_x, target_y=cell_y)] + return env._execute_commands(commands) + + @configurable_tool + def power_down(building_id: int) -> dict: + """Toggle power on/off for a building. Reduces power consumption + but disables the building's function.""" + env._refresh_obs() + obs = env._last_obs or {} + buildings = obs.get("buildings", []) + if not any(b.get("actor_id") == building_id for b in buildings): + return {"error": f"Building {building_id} not found. It may have been destroyed or sold.", + "your_buildings": [{"id": b["actor_id"], "type": b["type"]} for b in buildings[:20]]} + commands = [CommandModel(action=ActionType.POWER_DOWN, actor_id=building_id)] + return env._execute_commands(commands) + + @configurable_tool + def set_primary(building_id: int) -> dict: + """Set a production building as the primary producer. New units will + exit from this building.""" + env._refresh_obs() + obs = env._last_obs or {} + buildings = obs.get("buildings", []) + if not any(b.get("actor_id") == building_id for b in buildings): + return {"error": f"Building {building_id} not found. It may have been destroyed or sold.", + "your_buildings": [{"id": b["actor_id"], "type": b["type"]} for b in buildings[:20]]} + commands = [CommandModel(action=ActionType.SET_PRIMARY, actor_id=building_id)] + return env._execute_commands(commands) + + # ── Placement Helper ──────────────────────────────────────────── + + @configurable_tool + def get_valid_placements(building_type: str, max_results: int = 8) -> dict: + """Get suggested placement positions for a building near your base. + Returns positions sorted by distance from Construction Yard. + Use the first position with place_building(). If it fails, try the next.""" + env._refresh_obs() + obs = env._last_obs + if obs is None: + return {"error": "No observation available"} + + candidates = env._find_placement_candidates(building_type, obs) + if not candidates: + return {"error": "No Construction Yard found — deploy MCV first"} + + bw, bh = env._FOOTPRINTS.get(building_type, (2, 2)) + # Find CY position for response + cy_pos = {"cell_x": 0, "cell_y": 0} + for b in obs.get("buildings", []): + if b["type"] == "fact": + cy_pos = {"cell_x": b["cell_x"], "cell_y": b["cell_y"]} + break + + suggestions = candidates[:max(1, min(max_results, 15))] + + return { + "building_type": building_type, + "size": f"{bw}x{bh}", + "cy_position": cy_pos, + "suggestions": suggestions, + } + + # ── Unit Group Tools ──────────────────────────────────────────── + + @configurable_tool + def assign_group(group_name: str, unit_ids: list[int]) -> dict: + """Assign units to a named group (like Ctrl+1 in-game). + Groups persist across turns. Use group names in other commands. + Example: assign_group("attackers", [155, 160, 170])""" + env._unit_groups[group_name] = list(unit_ids) + return {"group": group_name, "unit_count": len(unit_ids), "unit_ids": unit_ids} + + @configurable_tool + def add_to_group(group_name: str, unit_ids: list[int]) -> dict: + """Add units to an existing group (like Shift+Ctrl+1).""" + existing = env._unit_groups.get(group_name, []) + for uid in unit_ids: + if uid not in existing: + existing.append(uid) + env._unit_groups[group_name] = existing + return {"group": group_name, "unit_count": len(existing), "unit_ids": existing} + + @configurable_tool + def get_groups() -> dict: + """List all unit groups and their members.""" + # Prune dead units from groups + env._refresh_obs() + alive_ids = set() + if env._last_obs: + alive_ids = {u["actor_id"] for u in env._last_obs.get("units", [])} + result = {} + for name, ids in env._unit_groups.items(): + alive = [uid for uid in ids if uid in alive_ids] + env._unit_groups[name] = alive # auto-prune dead units + if alive: + result[name] = alive + return result + + @configurable_tool + def command_group( + group_name: str, + command: str, + target_x: int = 0, + target_y: int = 0, + target_actor_id: int = 0, + stance: str = "attack_anything", + ) -> dict: + """Send a command to all units in a named group. + command: "attack_move", "move_units", "attack_target", "set_stance", "stop_units" + For attack_move/move_units: provide target_x, target_y + For attack_target: provide target_actor_id + For set_stance: provide stance name""" + ids = env._unit_groups.get(group_name, []) + if not ids: + return {"error": f"Group '{group_name}' not found or empty"} + + # Prune dead units + env._refresh_obs() + if env._last_obs: + alive_ids = {u["actor_id"] for u in env._last_obs.get("units", [])} + ids = [uid for uid in ids if uid in alive_ids] + env._unit_groups[group_name] = ids + if not ids: + return {"error": f"All units in group '{group_name}' are dead"} + + if command == "attack_move": + commands = [CommandModel(action=ActionType.ATTACK_MOVE, actor_id=uid, + target_x=target_x, target_y=target_y) for uid in ids] + elif command == "move_units": + commands = [CommandModel(action=ActionType.MOVE, actor_id=uid, + target_x=target_x, target_y=target_y) for uid in ids] + elif command == "attack_target": + commands = [CommandModel(action=ActionType.ATTACK, actor_id=uid, + target_actor_id=target_actor_id) for uid in ids] + elif command == "set_stance": + stance_map = {"hold_fire": 0, "return_fire": 1, "defend": 2, "attack_anything": 3} + stance_val = stance_map.get(stance.lower(), 3) + commands = [CommandModel(action=ActionType.SET_STANCE, actor_id=uid, + target_x=stance_val) for uid in ids] + elif command == "stop_units": + commands = [CommandModel(action=ActionType.STOP, actor_id=uid) for uid in ids] + else: + return {"error": f"Unknown group command '{command}'"} + + result = env._execute_commands(commands) + result["group"] = group_name + result["units_commanded"] = len(ids) + return env._add_unit_feedback(result, ids) + + @configurable_tool + def get_replay_path() -> dict: + """Get the path to the most recent replay file from this session.""" + replay_dir = env._get_replay_dir() + if not replay_dir.exists(): + return {"error": "No replay directory found"} + replays = sorted(replay_dir.rglob("*.orarep"), key=lambda p: p.stat().st_mtime, reverse=True) + if not replays: + return {"error": "No replay files found"} + return {"path": str(replays[0]), "size_bytes": replays[0].stat().st_size} + + @configurable_tool + def surrender() -> dict: + """Surrender / resign the current game. Ends the game as a loss. + Use this when you want to concede the match.""" + commands = [CommandModel(action=ActionType.SURRENDER)] + result = env._execute_commands(commands) + if not result.get("done"): + # Game may need a tick to process surrender + try: + adv_result = advance(ticks=50) + if adv_result.get("done"): + return adv_result + except Exception: + pass + return result + + @configurable_tool + def batch(actions: list[dict]) -> dict: + """Send multiple commands that all execute concurrently (same game tick). + + Actions use same format as individual tools: + {"tool": "build_unit", "unit_type": "e1", "count": 3} + {"tool": "attack_move", "unit_ids": "155,160", "target_x": 50, "target_y": 30} + {"tool": "set_stance", "unit_ids": "all_combat", "stance": "attack_anything"} + {"tool": "deploy_unit", "unit_id": 120} + + Unit selectors for unit_ids: + "all_combat" — all own combat units + "all_idle" — all idle combat units + "155,160" — comma-separated actor IDs + group name — a previously assigned group + + All commands are sent in a single call. The game resolves any + conflicts by its own logic. + + Note: batch does NOT advance game time. Cannot contain advance(), + get_game_state, or other query/flow-control tools — those are + silently skipped. Use advance() as a separate call after batch. + + Returns: game state summary after commands are processed. + """ + env._refresh_obs() + obs = env._last_obs + if obs is None: + return {"error": "No observation available"} + if obs.get("done"): + return {"error": "Game is over", "done": True, "result": obs.get("result", "")} + + # Tools that cannot be batched (flow-control or read-only) + _BATCH_UNSUPPORTED = { + "advance", "get_game_state", "get_units", "get_buildings", + "get_terrain_at", "get_map_analysis", "get_valid_placements", + "lookup_unit", "lookup_building", "get_replay", + "surrender", "plan", "batch", + } + + all_commands = [] + action_names = [] + for action in actions: + tool = action.get("tool", "?") + if tool in _BATCH_UNSUPPORTED: + action_names.append(f"{tool}:SKIPPED (use standalone)") + continue + cmds = env._action_to_commands(action, obs) + if cmds: + all_commands.extend(cmds) + action_names.append(tool) + else: + action_names.append(f"{tool}:FAILED") + + if not all_commands: + return {"error": "No valid commands generated", "actions": action_names} + + try: + result = env._execute_commands(all_commands) + result["actions"] = action_names + return result + except Exception as e: + return {"error": f"Command execution failed: {e}"} + + @configurable_tool + def plan(steps: list[dict]) -> dict: + """Execute steps sequentially. Each step's commands are sent, then + the observation is refreshed before the next step. Use conditions + to gate steps on game state. + + Each step is a dict with: + actions: list of action dicts to execute + condition: optional — only execute if condition is met, else skip + + Conditions: "enemies_visible", "no_enemies_visible", "under_attack", + "building_ready", "funds_above:2000", "funds_below:500" + (funds = cash + ore; "cash_above"/"cash_below" also work) + + Note: plan does NOT advance game time between steps. It only + refreshes observations. Use advance() as a standalone call to + let construction and movement complete. + + Example — deploy then build: + [ + {"actions": [{"tool": "deploy_unit", "unit_id": 120}]}, + {"actions": [{"tool": "build_structure", "building_type": "powr"}]}, + {"condition": "building_ready", + "actions": [{"tool": "place_building", "building_type": "powr"}]} + ] + + Returns: game state summary + execution log. + """ + execution_log = [] + start_tick = env._last_obs["tick"] if env._last_obs else 0 + + for i, step in enumerate(steps): + step_num = i + 1 + env._refresh_obs() + obs = env._last_obs + if obs is None: + execution_log.append(f"Step {step_num}: ERROR no observation") + break + if obs.get("done"): + execution_log.append(f"Step {step_num}: game over") + break + + condition = step.get("condition") + if condition and not env._check_plan_condition(condition, obs): + execution_log.append(f"Step {step_num}: SKIPPED ({condition} = false)") + continue + + actions = step.get("actions", []) + all_commands = [] + action_names = [] + for action in actions: + cmds = env._action_to_commands(action, obs) + all_commands.extend(cmds) + action_names.append(action.get("tool", "?")) + + if all_commands: + try: + result = env._execute_commands(all_commands) + if result.get("done"): + execution_log.append( + f"Step {step_num}: {', '.join(action_names)} -> game over" + ) + break + except Exception as e: + execution_log.append( + f"Step {step_num}: {', '.join(action_names)} -> ERROR {e}" + ) + break + + execution_log.append(f"Step {step_num}: {', '.join(action_names)} OK") + + env._refresh_obs() + obs = env._last_obs or {} + end_tick = obs.get("tick", start_tick) + executed = sum(1 for e in execution_log if "OK" in e) + skipped = sum(1 for e in execution_log if "SKIPPED" in e) + + return { + "steps_total": len(steps), + "steps_executed": executed, + "steps_skipped": skipped, + "tick": end_tick, + "done": obs.get("done", False), + "result": obs.get("result", ""), + "economy": obs.get("economy", {}), + "own_units": len(obs.get("units", [])), + "own_buildings": len(obs.get("buildings", [])), + "visible_enemies": len(obs.get("visible_enemies", [])), + "execution_log": execution_log, + } + + # ── Internal helpers ───────────────────────────────────────────────── + + def _check_plan_condition(self, condition: str, obs: dict) -> bool: + """Evaluate a plan condition against current observation.""" + if condition == "enemies_visible": + return len(obs.get("visible_enemies", [])) > 0 + elif condition == "no_enemies_visible": + return len(obs.get("visible_enemies", [])) == 0 + elif condition == "under_attack": + for enemy in obs.get("visible_enemies", []): + for bldg in obs.get("buildings", []): + dx = abs(enemy.get("cell_x", 0) - bldg.get("cell_x", 0)) + dy = abs(enemy.get("cell_y", 0) - bldg.get("cell_y", 0)) + if dx + dy < 12: + return True + return False + elif condition == "building_ready": + return any( + p["queue_type"] in self._PLACEABLE_QUEUE_TYPES and p["progress"] >= 0.99 + for p in obs.get("production", []) + ) + elif condition.startswith("cash_above:") or condition.startswith("funds_above:"): + threshold = int(condition.split(":")[1]) + eco = obs.get("economy", {}) + return eco.get("cash", 0) + eco.get("ore", 0) > threshold + elif condition.startswith("cash_below:") or condition.startswith("funds_below:"): + threshold = int(condition.split(":")[1]) + eco = obs.get("economy", {}) + return eco.get("cash", 0) + eco.get("ore", 0) < threshold + return True # unknown condition → proceed + + # Category selectors for _resolve_unit_ids + _UNIT_CATEGORY_SELECTORS = { + "all_infantry": "infantry", + "all_vehicles": "vehicle", + "all_aircraft": "aircraft", + "all_ships": "ship", + } + + def _resolve_unit_ids(self, selector, obs: dict) -> list[int]: + """Resolve unit selectors to actual actor IDs. + + Accepts: + - list of ints: [145, 146] — validated against living units + - "all_combat": all units that can attack + - "all_idle": idle units that can attack + - "type:e1": all units of a specific type + - "all_infantry", "all_vehicles", "all_aircraft", "all_ships": by category + - group name: units in a named group + - comma-separated IDs: "145,146,148" + - stringified list: "[145, 146]" + """ + living_ids = {u["actor_id"] for u in obs.get("units", [])} + + if isinstance(selector, list): + return self._filter_living(selector, living_ids) + if not isinstance(selector, str): + return [] + selector = selector.strip() + if selector == "all_combat": + return [u["actor_id"] for u in obs.get("units", []) if u.get("can_attack")] + if selector == "all_idle": + return [u["actor_id"] for u in obs.get("units", []) + if u.get("can_attack") and u.get("is_idle")] + # Type selector: "type:e1" + if selector.startswith("type:"): + target_type = selector[5:].strip() + return [u["actor_id"] for u in obs.get("units", []) if u["type"] == target_type] + # Category selectors: "all_infantry", "all_vehicles", etc. + if selector in self._UNIT_CATEGORY_SELECTORS: + from openra_env.game_data import RA_UNITS + target_cat = self._UNIT_CATEGORY_SELECTORS[selector] + cat_types = {k for k, v in RA_UNITS.items() if v.get("category") == target_cat} + return [u["actor_id"] for u in obs.get("units", []) if u["type"] in cat_types] + # Check named groups + if selector in self._unit_groups: + group_ids = list(self._unit_groups[selector]) + return self._filter_living(group_ids, living_ids) + # Parse string-encoded lists: "[145, 146]" or "145,146,148" + cleaned = selector.strip("[] ") + if cleaned: + try: + parsed = [int(x.strip()) for x in cleaned.split(",") if x.strip()] + return self._filter_living(parsed, living_ids) + except ValueError: + pass + return [] + + def _filter_living(self, unit_ids: list[int], living_ids: set[int]) -> list[int]: + """Filter unit IDs to only those that are alive, warn about dead ones.""" + alive = [uid for uid in unit_ids if uid in living_ids] + dead = [uid for uid in unit_ids if uid not in living_ids] + if dead: + self._placement_results.append( + f"DEAD UNITS: IDs {dead} not found — units were destroyed or invalid" + ) + return alive + + def _add_unit_feedback(self, result: dict, commanded_ids: list[int], + target_x: int | None = None, target_y: int | None = None) -> dict: + """Append commanded_units feedback to a tool result. + + Looks up the commanded unit IDs in the latest observation and adds + their current position and activity so the agent can verify commands + were received and see where units are. When target coordinates are + provided, also computes an estimated arrival time per unit. + """ + if not self._last_obs or not commanded_ids: + return result + units_by_id = {u["actor_id"]: u for u in self._last_obs.get("units", [])} + entries = [] + for uid in commanded_ids: + if uid not in units_by_id: + continue + u = units_by_id[uid] + entry = { + "id": uid, + "type": u["type"], + "cell_x": u["cell_x"], + "cell_y": u["cell_y"], + "activity": u.get("current_activity", "Unknown"), + } + if target_x is not None and target_y is not None and u.get("speed", 0) > 0: + eta = _estimate_move_ticks(u["speed"], u["cell_x"], u["cell_y"], target_x, target_y) + entry["eta_ticks"] = eta + entry["eta_seconds"] = round(eta / 25, 1) + entries.append(entry) + result["commanded_units"] = entries + # Group-level ETA note for movement commands + if target_x is not None and entries: + etas = [e["eta_ticks"] for e in entries if "eta_ticks" in e] + if etas: + slowest = max(etas) + result["note"] = self._app_config.prompts.move_eta.format( + ticks=slowest, seconds=round(slowest / 25, 1)) + return result + + def _diagnose_unavailable(self, item_type: str) -> dict: + """Diagnose why a unit/building is unavailable for production. + + Returns a dict with 'reason' and optionally 'missing_prerequisites'. + """ + stats = get_unit_stats(item_type) or get_building_stats(item_type) + if not stats: + return {"reason": f"'{item_type}' is not a known unit or building type."} + + # Buildings require a Construction Yard to produce + if get_building_stats(item_type) and self._last_obs: + owned_types = {b["type"] for b in self._last_obs.get("buildings", [])} + if "fact" not in owned_types: + return {"reason": "No Construction Yard (fact) — requires MCV deployment to build."} + + prereqs = stats.get("prerequisites", []) + if not prereqs: + return {"reason": f"'{item_type}' is not available. Check your faction."} + + # Check which prerequisite buildings we're missing + owned_types = set() + if self._last_obs: + owned_types = {b["type"] for b in self._last_obs.get("buildings", [])} + + missing = [] + for prereq in prereqs: + # Handle "barr|tent" style alternatives + alternatives = prereq.split("|") + if not any(alt in owned_types for alt in alternatives): + missing.append(prereq) + + if missing: + missing_str = ", ".join(missing) + return { + "reason": f"'{item_type}' unavailable — requires {missing_str} which you don't have.", + "missing_prerequisites": missing, + } + + # Prereqs are met but still unavailable — likely faction mismatch + side = stats.get("side", "") + if side: + return {"reason": f"'{item_type}' is not available for your faction (it's {side}-only)."} + + return {"reason": f"'{item_type}' is not available. Check your faction and tech tree."} + + def _action_to_commands(self, action: dict, obs: dict) -> list[CommandModel]: + """Convert a plan action dict to a list of CommandModel objects.""" + tool = action.get("tool", "") + unit_ids = self._resolve_unit_ids(action.get("unit_ids", []), obs) + queued = action.get("queued", False) + + if tool == "build_unit": + unit_type = action.get("unit_type", "") + available = obs.get("available_production", []) + if not available or unit_type not in available: + return [] # unavailable — batch() will mark as FAILED + count = max(1, action.get("count", 1)) + return [CommandModel(action=ActionType.TRAIN, item_type=unit_type) + for _ in range(count)] + elif tool == "build_structure": + btype = action["building_type"] + available = obs.get("available_production", []) + if not available or btype not in available: + return [] + return [CommandModel(action=ActionType.BUILD, item_type=btype)] + elif tool == "build_and_place": + btype = action["building_type"] + available = obs.get("available_production", []) + if not available or btype not in available: + return [] + self._pending_placements[btype] = { + "cell_x": action.get("cell_x", 0), "cell_y": action.get("cell_y", 0) + } + return [CommandModel(action=ActionType.BUILD, item_type=btype)] + elif tool == "place_building": + return [CommandModel(action=ActionType.PLACE_BUILDING, + item_type=action["building_type"], + target_x=action.get("cell_x", 0), target_y=action.get("cell_y", 0))] + elif tool == "attack_move": + return [CommandModel(action=ActionType.ATTACK_MOVE, actor_id=uid, + target_x=action["target_x"], target_y=action["target_y"], + queued=queued) + for uid in unit_ids] + elif tool == "move_units": + return [CommandModel(action=ActionType.MOVE, actor_id=uid, + target_x=action["target_x"], target_y=action["target_y"], + queued=queued) + for uid in unit_ids] + elif tool == "attack_target": + return [CommandModel(action=ActionType.ATTACK, actor_id=uid, + target_actor_id=action["target_actor_id"], queued=queued) + for uid in unit_ids] + elif tool == "set_stance": + stance_map = {"hold_fire": 0, "return_fire": 1, "defend": 2, "attack_anything": 3} + stance_val = stance_map.get(action.get("stance", "attack_anything").lower(), 3) + return [CommandModel(action=ActionType.SET_STANCE, actor_id=uid, target_x=stance_val) + for uid in unit_ids] + elif tool == "deploy_unit": + uid = action["unit_id"] + if not any(u.get("actor_id") == uid for u in obs.get("units", [])): + return [] + return [CommandModel(action=ActionType.DEPLOY, actor_id=uid)] + elif tool == "set_rally_point": + bid = action["building_id"] + if not any(b.get("actor_id") == bid for b in obs.get("buildings", [])): + return [] + return [CommandModel(action=ActionType.SET_RALLY_POINT, + actor_id=bid, + target_x=action["cell_x"], target_y=action["cell_y"])] + elif tool == "repair_building": + bid = action["building_id"] + if not any(b.get("actor_id") == bid for b in obs.get("buildings", [])): + return [] + return [CommandModel(action=ActionType.REPAIR, actor_id=bid)] + elif tool == "stop_units": + return [CommandModel(action=ActionType.STOP, actor_id=uid) for uid in unit_ids] + elif tool == "harvest": + uid = action["unit_id"] + if not any(u.get("actor_id") == uid for u in obs.get("units", [])): + return [] + return [CommandModel(action=ActionType.HARVEST, actor_id=uid, + target_x=action.get("cell_x", 0), + target_y=action.get("cell_y", 0))] + elif tool == "cancel_production": + item = action["item_type"] + queue = obs.get("production", []) + if not any(p.get("type", "").lower() == item.lower() for p in queue): + return [] + return [CommandModel(action=ActionType.CANCEL_PRODUCTION, item_type=item)] + elif tool == "surrender": + return [CommandModel(action=ActionType.SURRENDER)] + else: + return [] + + def _build_initial_obs_from_state(self, game_state) -> dict: + """Build a minimal observation dict from the unary GameState RPC. + + Used during reset so we have an _last_obs cache for planning tools + WITHOUT starting the streaming session (which unpauses the game). + """ + return { + "tick": game_state.tick if game_state else 0, + "economy": { + "cash": 0, "ore": 0, "power_provided": 0, "power_drained": 0, + "resource_capacity": 0, "harvester_count": 0, + }, + "military": { + "units_killed": 0, "units_lost": 0, "buildings_killed": 0, + "buildings_lost": 0, "army_value": 0, "active_unit_count": 0, + "kills_cost": 0, "deaths_cost": 0, "assets_value": 0, + "experience": 0, "order_count": 0, + }, + "units": [], + "buildings": [], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": { + "width": self._config.map_width if hasattr(self._config, 'map_width') else 128, + "height": self._config.map_height if hasattr(self._config, 'map_height') else 128, + "map_name": self._config.map_name or "", + }, + "available_production": [], + "done": False, + "reward": 0.0, + "result": "", + "spatial_map": "", + "spatial_channels": 0, + } + + async def _ensure_session_started(self) -> None: + """Start the streaming session (unpauses game) if not already started. + + Called lazily from end_planning_phase() or the first game action, + so the game stays paused during the planning phase and while the + LLM processes its first prompt. + """ + if not self._bridge.session_started: + proto_obs = await self._bridge.start_session() + obs_dict = observation_to_dict(proto_obs) + self._last_obs = obs_dict + self._state.game_tick = obs_dict["tick"] + logger.info("Streaming session started — game unpaused") + + def _refresh_obs(self) -> None: + """Update _last_obs from the bridge's background observation reader. + + In real-time mode, the game runs continuously. This fetches the + latest cached observation so read tools return fresh state. + """ + if not getattr(self, '_bridge', None) or not self._bridge.session_started: + return # Session not started yet; use cached _last_obs from reset + try: + future = asyncio.run_coroutine_threadsafe( + self._bridge.observe(), self._loop + ) + proto_obs = future.result(timeout=5) + if proto_obs is not None: + self._last_obs = observation_to_dict(proto_obs) + self._state.game_tick = self._last_obs["tick"] + except Exception: + pass # Keep existing _last_obs if refresh fails + self._process_pending_placements() + + # Naval buildings that require water tiles + _WATER_BUILDINGS = {"spen", "syrd"} + + # Defense structures — placement biased toward enemy direction + _DEFENSE_BUILDINGS = {"gun", "ftur", "tsla", "agun", "pbox", "hbox", "sam", "gap"} + + # Queue types that produce placeable structures (Building + Defense) + _PLACEABLE_QUEUE_TYPES = {"Building", "Defense"} + + # Building footprint sizes (width x height in cells) from RA rules + _FOOTPRINTS = { + "fact": (3, 4), "proc": (3, 4), + "powr": (2, 3), "apwr": (3, 3), + "barr": (2, 3), "tent": (2, 3), + "weap": (3, 3), "fix": (3, 3), "stek": (3, 3), + "dome": (2, 3), "hpad": (2, 3), "afld": (3, 2), + "iron": (2, 2), "pdox": (2, 2), + "mslo": (2, 1), "sam": (2, 1), + "spen": (3, 3), "syrd": (3, 3), "atek": (2, 3), + # Defense buildings (most are 1x1 from ^Defense base) + "gun": (1, 1), "ftur": (1, 1), "tsla": (1, 1), + "agun": (1, 1), "pbox": (1, 1), "hbox": (1, 1), "gap": (1, 1), + } + + def _find_placement_candidates(self, building_type: str, obs: dict) -> list[dict]: + """Find valid placement positions for a building near the Construction Yard. + + Searches the full build radius around the CY, avoiding occupied cells. + For defense buildings, biases placement toward the enemy direction. + Returns candidates sorted by score (best first). + """ + buildings = obs.get("buildings", []) + + # Find Construction Yard + cy = None + for b in buildings: + if b["type"] == "fact": + cy = b + break + if cy is None: + return [] + + cx, cy_y = cy["cell_x"], cy["cell_y"] + bw, bh = self._FOOTPRINTS.get(building_type, (2, 2)) + + # Mark occupied cells from all existing buildings (with 1-cell padding) + occupied = set() + for b in buildings: + fw, fh = self._FOOTPRINTS.get(b["type"], (2, 2)) + bx, by = b["cell_x"], b["cell_y"] + for dx in range(-1, fw + 1): + for dy in range(-1, fh + 1): + occupied.add((bx + dx, by + dy)) + + # Map bounds + map_info = obs.get("map_info", {}) + map_w = map_info.get("width", 128) + map_h = map_info.get("height", 128) + + # Enemy direction for defense placement bias + is_defense = building_type in self._DEFENSE_BUILDINGS + enemy_dx, enemy_dy = 0, 0 + if is_defense: + # Use visible enemies if any, otherwise estimate from map opposite corner + enemies = obs.get("visible_enemies", []) + obs.get("visible_enemy_buildings", []) + if enemies: + avg_ex = sum(e.get("cell_x", e.get("pos_x", 0) // 1024) for e in enemies) // len(enemies) + avg_ey = sum(e.get("cell_y", e.get("pos_y", 0) // 1024) for e in enemies) // len(enemies) + enemy_dx = avg_ex - cx + enemy_dy = avg_ey - cy_y + else: + # Estimate: enemy at opposite corner + enemy_dx = (map_w - cx) - cx # positive if enemy is to the right + enemy_dy = (map_h - cy_y) - cy_y + + # Normalize to unit-ish direction + mag = max(abs(enemy_dx), abs(enemy_dy), 1) + enemy_dx /= mag + enemy_dy /= mag + + # Generate candidates + candidates = [] + max_radius = 15 # Full RA build radius + for dx in range(-max_radius, max_radius + 1): + for dy in range(-max_radius, max_radius + 1): + px, py = cx + dx, cy_y + dy + dist = abs(dx) + abs(dy) + if dist < 2 or dist > max_radius: + continue + + # Check map bounds + if px < 0 or py < 0 or px + bw > map_w or py + bh > map_h: + continue + + # Check no overlap with occupied cells + overlap = False + for ox in range(bw): + for oy in range(bh): + if (px + ox, py + oy) in occupied: + overlap = True + break + if overlap: + break + + if not overlap: + candidates.append({"cell_x": px, "cell_y": py, "distance": dist}) + + if is_defense and (enemy_dx or enemy_dy): + # Defense: prefer positions toward the enemy, close-ish to CY (3-7 cells out) + def defense_score(c): + dx = c["cell_x"] - cx + dy = c["cell_y"] - cy_y + # Dot product with enemy direction (higher = more toward enemy) + toward_enemy = dx * enemy_dx + dy * enemy_dy + # Ideal distance: 3-7 cells from CY for defense perimeter + dist_penalty = abs(c["distance"] - 5) + return (-toward_enemy, dist_penalty) + candidates.sort(key=defense_score) + else: + candidates.sort(key=lambda c: c["distance"]) + + return candidates + + _MAX_PLACEMENT_ATTEMPTS = 20 + + def _update_loss_tracking(self) -> None: + """Compare current buildings/units against previous snapshot, emit loss alerts.""" + obs = self._last_obs + if obs is None: + return + + # Current state + cur_buildings = {b["actor_id"]: b["type"] for b in obs.get("buildings", [])} + cur_units = {u["actor_id"]: u["type"] for u in obs.get("units", [])} + + # Building losses + if self._prev_buildings: + for actor_id, btype in self._prev_buildings.items(): + if actor_id not in cur_buildings: + self._placement_results.append(f"DESTROYED: {btype}") + + # Unit losses (summarized by type) + _DEPLOY_MAP = {"mcv": "fact"} # unit type → building type it deploys into + if self._prev_unit_ids: + lost_ids = set(self._prev_unit_ids) - set(cur_units) + # Filter out deployments (MCV → Construction Yard) + new_btypes = set(cur_buildings.values()) - set((self._prev_buildings or {}).values()) + for uid in list(lost_ids): + utype = self._prev_unit_ids[uid] + if utype in _DEPLOY_MAP and _DEPLOY_MAP[utype] in new_btypes: + lost_ids.discard(uid) + # Filter out husk decay (wreckage disappearing, not a real loss) + lost_ids = {uid for uid in lost_ids + if not self._prev_unit_ids[uid].endswith(".husk")} + if lost_ids: + from collections import Counter + lost_types = Counter(self._prev_unit_ids[uid] for uid in lost_ids) + breakdown = ", ".join(f"{cnt}x {t}" for t, cnt in lost_types.most_common()) + self._placement_results.append( + f"UNITS LOST: {len(lost_ids)} destroyed ({breakdown})" + ) + + # Update snapshots + self._prev_buildings = cur_buildings + self._prev_unit_ids = cur_units + + def _process_pending_placements(self) -> None: + """Auto-place buildings that finished construction via build_and_place. + + Uses smart placement: finds valid positions in the full build radius + around the Construction Yard, tries them in order. Cancels after + _MAX_PLACEMENT_ATTEMPTS failures to avoid blocking the queue. + """ + if not self._last_obs: + return + production = self._last_obs.get("production", []) + attempted = getattr(self, "_attempted_placements", {}) + + # Phase 2: Check previously attempted placements for failure + failed = [] + for btype, attempt_idx in list(attempted.items()): + still_in_queue = any( + p["queue_type"] in self._PLACEABLE_QUEUE_TYPES and p["item"] == btype and p["progress"] >= 0.99 + for p in production + ) + if still_in_queue: + # Building is still in queue → last placement failed, try next candidate + candidates = self._find_placement_candidates(btype, self._last_obs) + if attempt_idx < min(len(candidates), self._MAX_PLACEMENT_ATTEMPTS): + # Try the next candidate position + pos = candidates[attempt_idx] + try: + commands = [CommandModel( + action=ActionType.PLACE_BUILDING, + item_type=btype, + target_x=pos["cell_x"], + target_y=pos["cell_y"], + )] + self._execute_commands(commands) + attempted[btype] = attempt_idx + 1 + except Exception: + attempted[btype] = attempt_idx + 1 + else: + # Exhausted all candidates — report failure and cancel + if btype in self._WATER_BUILDINGS: + reason = f"{btype} requires water tiles — must be placed on water, not land" + else: + reason = f"no valid position found (tried {attempt_idx} spots)" + self._placement_results.append( + self._app_config.prompts.placement_failed.format( + building=btype, reason=reason) + ) + try: + cancel_cmd = [CommandModel(action=ActionType.CANCEL_PRODUCTION, item_type=btype)] + self._execute_commands(cancel_cmd) + except Exception: + pass + failed.append(btype) + else: + # Building no longer in queue → placement succeeded + self._placement_results.append( + self._app_config.prompts.placement_success.format(building=btype)) + failed.append(btype) + + for btype in failed: + attempted.pop(btype, None) + self._pending_placements.pop(btype, None) + + # Phase 1: Send placement commands for newly ready buildings + if not getattr(self, "_pending_placements", None): + return + for btype, coords in list(self._pending_placements.items()): + if btype in attempted: + continue # already being tracked + ready = any( + p["queue_type"] in self._PLACEABLE_QUEUE_TYPES and p["item"] == btype and p["progress"] >= 0.99 + for p in production + ) + if not ready: + continue + + # Water buildings can't auto-place on land — warn and skip + if btype in self._WATER_BUILDINGS: + self._placement_results.append( + self._app_config.prompts.placement_water.format(building=btype) + ) + self._pending_placements.pop(btype, None) + continue + + # Find best placement position using full CY radius search + candidates = self._find_placement_candidates(btype, self._last_obs) + + # Use user-specified coords if provided and valid, otherwise use best candidate + cx, cy = coords["cell_x"], coords["cell_y"] + if cx == 0 and cy == 0 and candidates: + cx, cy = candidates[0]["cell_x"], candidates[0]["cell_y"] + + try: + commands = [CommandModel( + action=ActionType.PLACE_BUILDING, + item_type=btype, + target_x=cx, + target_y=cy, + )] + self._execute_commands(commands) + attempted[btype] = 1 # start tracking from candidate index 1 + except Exception: + self._placement_results.append( + self._app_config.prompts.placement_failed.format( + building=btype, reason="command error") + ) + self._pending_placements.pop(btype, None) + + def _execute_commands(self, commands: list[CommandModel]) -> dict: + """Send commands, step the game, update cache, return summary.""" + action = OpenRAAction(commands=commands) + try: + future = asyncio.run_coroutine_threadsafe( + self._async_step_internal(action), self._loop + ) + obs_dict = future.result(timeout=300) + self._last_obs = obs_dict + except Exception: + # Connection lost — check if game ended while we weren't looking + self._refresh_obs() + obs_dict = self._last_obs + if obs_dict is None: + raise + if not obs_dict.get("done"): + raise + + # Track losses and trigger auto-placement + if self._app_config.alerts.loss_tracking: + self._update_loss_tracking() + self._process_pending_placements() + + return { + "tick": obs_dict["tick"], + "done": obs_dict["done"], + "result": obs_dict.get("result", ""), + "economy": obs_dict["economy"], + "own_units": len(obs_dict["units"]), + "own_buildings": len(obs_dict["buildings"]), + "visible_enemies": len(obs_dict["visible_enemies"]), + "production": [f"{p['item']}@{p['progress']:.0%}" for p in obs_dict["production"]], + } + + async def _async_step_internal(self, action: OpenRAAction) -> dict: + """Core step logic: send action via gRPC, receive observation dict.""" + await self._ensure_session_started() # Start session on first action if not already started + self._state.step_count += 1 + + cmd_dicts = [cmd.model_dump() for cmd in action.commands] + proto_action = commands_to_proto(cmd_dicts) + + proto_obs = await self._bridge.step(proto_action) + obs_dict = observation_to_dict(proto_obs) + + self._state.game_tick = obs_dict["tick"] + return obs_dict + + def _get_replay_dir(self) -> Path: + """Get the OpenRA replays directory for current mod. + + OpenRA stores replays at {SupportDir}/Replays/{mod}/{version}/. + On macOS: ~/Library/Application Support/OpenRA/ + On Linux: ~/.config/openra/ (modern) or ~/.openra/ (legacy) + Also checks {EngineDir}/Support/ (local override). + """ + candidates = [] + + # Local Support dir (takes priority if it exists) + engine_support = Path(self._config.openra_path) / "Support" + if engine_support.exists(): + candidates.append(engine_support / "Replays" / self._config.mod) + + if sys.platform == "darwin": + candidates.append(Path.home() / "Library/Application Support/OpenRA/Replays" / self._config.mod) + else: + # Modern path (XDG_CONFIG_HOME or ~/.config/openra) + xdg = os.environ.get("XDG_CONFIG_HOME", str(Path.home() / ".config")) + candidates.append(Path(xdg) / "openra/Replays" / self._config.mod) + # Legacy path + candidates.append(Path.home() / ".openra/Replays" / self._config.mod) + + for base in candidates: + if base.exists(): + return base + + # Fallback: return first candidate (will be created if needed) + return candidates[0] + + # ── OpenEnv Interface ──────────────────────────────────────────────── + + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> OpenRAObservation: + """Reset the environment for a new episode.""" + future = asyncio.run_coroutine_threadsafe( + self._async_reset(seed, episode_id, **kwargs), self._loop + ) + return future.result(timeout=300) + + async def _async_reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> OpenRAObservation: + # Clean up previous episode + await self._bridge.close() + self._process.kill() + + # Initialize new episode state + ep_id = episode_id or str(uuid.uuid4()) + self._state = OpenRAState( + episode_id=ep_id, + step_count=0, + game_tick=0, + map_name=self._config.map_name, + opponent_type=f"bot_{self._config.bot_type}", + ) + self._reward_fn.reset() + self._accumulated_reward_vector.clear() + self._last_obs = None + self._unit_groups.clear() + self._pending_placements.clear() + self._move_targets.clear() + self._attempted_placements.clear() + self._placement_results.clear() + self._planning_active = False + self._planning_start_time = 0.0 + self._planning_turns_used = 0 + self._planning_strategy = "" + self._enemy_ever_seen = False + self._prev_buildings = {} + self._prev_unit_ids = {} + + # Update seed if provided + if seed is not None: + self._config.seed = seed + + # Launch OpenRA + logger.info(f"Launching OpenRA: map={self._config.map_name}, mod={self._config.mod}") + self._process.launch() + logger.info(f"OpenRA process launched (PID={self._process.pid})") + + # Wait for gRPC server to be ready + logger.info("Waiting for gRPC bridge to become ready...") + ready = await self._bridge.wait_for_ready(max_retries=120, retry_interval=2.0) + if not ready: + alive = self._process.is_alive() + logger.error(f"Bridge failed to start. Process alive={alive}") + raise RuntimeError("OpenRA gRPC bridge failed to start") + + # Get faction info from GameState (unary RPC — game stays paused) + game_state = None + try: + game_state = await self._bridge.get_state() + self._player_faction = game_state.player_faction or "" + self._enemy_faction = game_state.enemy_faction or "" + except Exception: + self._player_faction = "" + self._enemy_faction = "" + + # Build a minimal initial observation WITHOUT starting the streaming + # session. The game stays paused until _ensure_session_started() is + # called (from end_planning_phase or the first game action). + self._last_obs = self._build_initial_obs_from_state(game_state) + + # Compute initial reward (should be 0) + reward, reward_vec = self._reward_fn.compute_all(self._last_obs) + + return self._build_observation(self._last_obs, reward, reward_vec) + + def _step_impl( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: + """Handle non-MCP actions (OpenRAAction for backward compat).""" + if isinstance(action, OpenRAAction): + future = asyncio.run_coroutine_threadsafe( + self._async_step_internal(action), self._loop + ) + obs_dict = future.result(timeout=300) + self._last_obs = obs_dict + reward, reward_vec = self._reward_fn.compute_all(obs_dict) + if reward_vec: + for k, v in reward_vec.items(): + self._accumulated_reward_vector[k] = self._accumulated_reward_vector.get(k, 0.0) + v + return self._build_observation(obs_dict, reward, reward_vec) + + return Observation( + done=False, + reward=0.0, + metadata={"error": f"Unknown action type: {type(action).__name__}. Use MCP tools or OpenRAAction."}, + ) + + @property + def state(self) -> OpenRAState: + return self._state + + def _build_observation( + self, obs_dict: dict, reward: float, reward_vec: dict | None = None, + ) -> OpenRAObservation: + """Convert a raw observation dict to an OpenRAObservation model.""" + return OpenRAObservation( + tick=obs_dict["tick"], + economy=EconomyInfo(**obs_dict["economy"]), + military=MilitaryInfo(**obs_dict["military"]), + units=[UnitInfoModel(**u) for u in obs_dict["units"]], + buildings=[BuildingInfoModel(**b) for b in obs_dict["buildings"]], + production=[ProductionInfoModel(**p) for p in obs_dict["production"]], + visible_enemies=[UnitInfoModel(**u) for u in obs_dict["visible_enemies"]], + visible_enemy_buildings=[BuildingInfoModel(**b) for b in obs_dict.get("visible_enemy_buildings", [])], + map_info=MapInfoModel(**obs_dict["map_info"]), + available_production=obs_dict.get("available_production", []), + done=obs_dict["done"], + reward=reward, + result=obs_dict.get("result", ""), + spatial_map=obs_dict.get("spatial_map", ""), + spatial_channels=obs_dict.get("spatial_channels", 0), + reward_vector=reward_vec, + ) + + def close(self) -> None: + """Clean up resources.""" + try: + future = asyncio.run_coroutine_threadsafe( + self._bridge.close(), self._loop + ) + future.result(timeout=10) + except Exception: + pass + self._process.kill() + try: + self._loop.call_soon_threadsafe(self._loop.stop) + self._loop_thread.join(timeout=5) + self._loop.close() + except Exception: + pass + + def __del__(self): + try: + self.close() + except Exception: + pass diff --git a/openra_env/server/openra_process.py b/openra_env/server/openra_process.py new file mode 100644 index 0000000000000000000000000000000000000000..43afd8922c82c278d0aff326609d431b7880a1b6 --- /dev/null +++ b/openra_env/server/openra_process.py @@ -0,0 +1,232 @@ +"""OpenRA subprocess manager. + +Handles launching, monitoring, and terminating OpenRA game instances +for RL training episodes. +""" + +import logging +import os +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + +# Default path to the OpenRA installation +DEFAULT_OPENRA_PATH = os.environ.get("OPENRA_PATH", "/opt/openra") + +# Map user-friendly difficulty names to actual OpenRA bot type strings. +# Users can set either the friendly name or the raw OpenRA name. +# Difficulty tiers: beginner < easy < medium < hard < brutal +# Play styles (raw pass-through): rush, normal, turtle, naval +BOT_TYPE_MAP: dict[str, str] = { + "beginner": "beginner", + "easy": "easy", + "medium": "medium", + "hard": "normal", + "brutal": "rush", +} + + + +@dataclass +class OpenRAConfig: + """Configuration for launching an OpenRA game instance.""" + + openra_path: str = DEFAULT_OPENRA_PATH + mod: str = "ra" + map_name: str = "singles.oramap" + grpc_port: int = 9999 + bot_name: str = "Normal AI" + bot_type: str = "normal" + rl_slot: str = "Multi1" + ai_slot: str = "Multi0" + seed: Optional[int] = None + headless: bool = True # Use Null renderer (no GPU needed) + record_replays: bool = False # Enable .orarep replay recording + extra_args: dict = field(default_factory=dict) + + +class OpenRAProcessManager: + """Manages an OpenRA game subprocess for RL training. + + Each episode starts a new OpenRA process with the ExternalBotBridge + trait enabled. The process communicates with the Python environment + via gRPC on the configured port. + """ + + def __init__(self, config: Optional[OpenRAConfig] = None): + self.config = config or OpenRAConfig() + self._process: Optional[subprocess.Popen] = None + self._stdout_log: list[str] = [] + self._stderr_log: list[str] = [] + + def launch(self) -> int: + """Launch a new OpenRA game instance. + + Returns the PID of the launched process. + """ + if self._process is not None and self._process.poll() is None: + logger.warning("Killing existing OpenRA process before launching new one") + self.kill() + + cmd = self._build_command() + logger.info(f"Launching OpenRA: {' '.join(cmd)}") + + env = os.environ.copy() + env.setdefault("DOTNET_ROLL_FORWARD", "LatestMajor") + + self._process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=self.config.openra_path, + env=env, + ) + logger.info(f"OpenRA launched with PID {self._process.pid}") + return self._process.pid + + def _build_command(self) -> list[str]: + """Build the command line for launching OpenRA. + + Uses the game client (OpenRA.dll) with Launch.Map and Launch.Bots + to auto-start a local game with the RL bot and optional AI opponent. + """ + openra_path = Path(self.config.openra_path) + + # Find the game client executable (OpenRA.dll, not OpenRA.Server.dll) + exe = None + for search_dir in [openra_path, openra_path / "bin"]: + game_dll = search_dir / "OpenRA.dll" + if game_dll.exists(): + exe = ["dotnet", str(game_dll)] + break + + if exe is None: + # Fallback: look for the RL launch script + launch_script = openra_path / "launch-rl.sh" + if launch_script.exists(): + exe = ["bash", str(launch_script)] + else: + raise FileNotFoundError( + f"Could not find OpenRA game client in {openra_path}. " + "Expected OpenRA.dll in root or bin/, or launch-rl.sh" + ) + + # Build bots configuration: slot:bottype,slot:bottype + bots = f"{self.config.rl_slot}:rl-agent" + if self.config.ai_slot: + # Map difficulty tiers to OpenRA bot types + actual_type = BOT_TYPE_MAP.get(self.config.bot_type, self.config.bot_type) + bots += f",{self.config.ai_slot}:{actual_type}" + + args = [ + *exe, + f"Engine.EngineDir={self.config.openra_path}", + f"Game.Mod={self.config.mod}", + f"Launch.Map={self.config.map_name}", + f"Launch.Bots={bots}", + ] + + # Use Null renderer for headless operation (no GPU/OpenGL needed) + if self.config.headless: + args.append("Game.Platform=Null") + + if self.config.record_replays: + args.append("Server.RecordReplays=True") + + for key, value in self.config.extra_args.items(): + args.append(f"{key}={value}") + + return [a for a in args if a] + + def is_alive(self) -> bool: + """Check if the OpenRA process is still running.""" + if self._process is None: + return False + return self._process.poll() is None + + def kill(self, timeout: float = 5.0) -> Optional[int]: + """Terminate the OpenRA process. + + Returns the exit code, or None if the process had to be force-killed. + """ + if self._process is None: + return None + + pid = self._process.pid + + # Try graceful termination first + try: + self._process.terminate() + try: + exit_code = self._process.wait(timeout=timeout) + logger.info(f"OpenRA process {pid} terminated gracefully (exit code {exit_code})") + return exit_code + except subprocess.TimeoutExpired: + pass + except ProcessLookupError: + self._process = None + return None + + # Force kill + try: + self._process.kill() + self._process.wait(timeout=2.0) + logger.warning(f"OpenRA process {pid} force-killed") + except (ProcessLookupError, subprocess.TimeoutExpired): + pass + + self._process = None + return None + + def get_stdout(self) -> str: + """Read available stdout from the process.""" + if self._process is None or self._process.stdout is None: + return "" + try: + # Non-blocking read + import select + + if select.select([self._process.stdout], [], [], 0.0)[0]: + data = self._process.stdout.read(4096) + if data: + text = data.decode("utf-8", errors="replace") + self._stdout_log.append(text) + return text + except Exception: + pass + return "" + + def get_stderr(self) -> str: + """Read available stderr from the process.""" + if self._process is None or self._process.stderr is None: + return "" + try: + import select + + if select.select([self._process.stderr], [], [], 0.0)[0]: + data = self._process.stderr.read(4096) + if data: + text = data.decode("utf-8", errors="replace") + self._stderr_log.append(text) + return text + except Exception: + pass + return "" + + @property + def pid(self) -> Optional[int]: + """Get the PID of the running process.""" + if self._process is None: + return None + return self._process.pid + + def __del__(self): + """Ensure cleanup on garbage collection.""" + if self._process is not None and self._process.poll() is None: + try: + self._process.kill() + except Exception: + pass diff --git a/openra_env/server/requirements.txt b/openra_env/server/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7ac0f90545e61d92d05bcd098169a8d10f2b7258 --- /dev/null +++ b/openra_env/server/requirements.txt @@ -0,0 +1,7 @@ +openenv-core>=0.2.0 +grpcio>=1.60.0 +grpcio-tools>=1.60.0 +protobuf>=4.25.0 +pydantic>=2.0.0 +fastapi>=0.100.0 +uvicorn>=0.20.0 diff --git a/proto/rl_bridge.proto b/proto/rl_bridge.proto new file mode 100644 index 0000000000000000000000000000000000000000..262fbb45f77d8dfd31426d9f7cda7faa8447a35c --- /dev/null +++ b/proto/rl_bridge.proto @@ -0,0 +1,188 @@ +syntax = "proto3"; + +package openra.rl; + +option csharp_namespace = "OpenRA.Mods.Common.RL"; + +// The RL Bridge service allows an external agent to interact with OpenRA +// via bidirectional streaming (lock-step) or unary state queries. +service RLBridge { + // Bidirectional streaming: game sends observations, agent sends actions. + // Each observation waits for an action before advancing to the next tick. + rpc GameSession(stream AgentAction) returns (stream GameObservation); + + // Unary: query current game state on demand. + rpc GetState(StateRequest) returns (GameState); +} + +// ─── Observations (Game → Agent) ──────────────────────────────────────────── + +message GameObservation { + int32 tick = 1; + string episode_id = 2; + + // Structured observations + RlEconomy economy = 3; + RlMilitary military = 4; + repeated RlUnitInfo units = 5; + repeated RlBuildingInfo buildings = 6; + repeated RlProductionInfo production = 7; + repeated RlUnitInfo visible_enemies = 8; + RlMapInfo map_info = 9; + + // Binary-encoded spatial tensor (terrain, unit density, fog, etc.) + // Format: flat float32 array, row-major channels-last + // Shape: map_info.height × map_info.width × spatial_channels + bytes spatial_map = 10; + int32 spatial_channels = 11; + + // Episode signals + bool done = 12; + float reward = 13; + string result = 14; // "win", "lose", "draw", "" + + // Available actions context + repeated string available_production = 15; + + // Visible enemy buildings (separate from visible_enemies which only has units) + repeated RlBuildingInfo visible_enemy_buildings = 16; +} + +message RlEconomy { + int32 cash = 1; + int32 ore = 2; + int32 power_provided = 3; + int32 power_drained = 4; + int32 resource_capacity = 5; + int32 harvester_count = 6; +} + +message RlMilitary { + int32 units_killed = 1; + int32 units_lost = 2; + int32 buildings_killed = 3; + int32 buildings_lost = 4; + int32 army_value = 5; + int32 active_unit_count = 6; + int32 kills_cost = 7; // Total cost value of enemy units/buildings killed + int32 deaths_cost = 8; // Total cost value of own units/buildings lost + int32 assets_value = 9; // Total value of all assets (units + buildings) + int32 experience = 10; // Player experience points + int32 order_count = 11; // Total orders issued +} + +message RlUnitInfo { + uint32 actor_id = 1; + string type = 2; // e.g. "e1", "1tnk", "harv" + int32 pos_x = 3; // WPos X + int32 pos_y = 4; // WPos Y + int32 cell_x = 5; // CPos X (grid) + int32 cell_y = 6; // CPos Y (grid) + float hp_percent = 7; // 0.0 - 1.0 + bool is_idle = 8; + string current_activity = 9; + string owner = 10; // Player internal name + int32 ammo = 11; // -1 if not applicable + bool can_attack = 12; + + // Sprint 4: enriched unit data + int32 facing = 13; // WAngle 0-1023 (direction unit faces) + int32 experience_level = 14; // Veterancy level (0 = none) + int32 stance = 15; // 0=HoldFire, 1=ReturnFire, 2=Defend, 3=AttackAnything + int32 speed = 16; // Current movement speed (with modifiers) + int32 attack_range = 17; // Max attack range in WDist units + int32 passenger_count = 18; // Cargo count (0 if not transport, -1 if N/A) + bool is_building = 19; // false for units (helps distinguish in visible_enemies) +} + +message RlBuildingInfo { + uint32 actor_id = 1; + string type = 2; // e.g. "powr", "barr", "weap" + int32 pos_x = 3; + int32 pos_y = 4; + float hp_percent = 5; + string owner = 6; + bool is_producing = 7; + float production_progress = 8; // 0.0 - 1.0 if producing + string producing_item = 9; + bool is_powered = 10; + + // Sprint 4: enriched building data + bool is_repairing = 11; // Actively being repaired + int32 sell_value = 12; // Refund amount if sold + int32 rally_x = 13; // Rally point cell X (-1 if none) + int32 rally_y = 14; // Rally point cell Y (-1 if none) + int32 power_amount = 15; // Power provided (positive) or consumed (negative) + repeated string can_produce = 16; // Items this building can produce + int32 cell_x = 17; // Cell position X + int32 cell_y = 18; // Cell position Y +} + +message RlProductionInfo { + string queue_type = 1; // "Building", "Infantry", "Vehicle", "Aircraft" + string item = 2; // Actor type being produced + float progress = 3; // 0.0 - 1.0 + int32 remaining_ticks = 4; + int32 remaining_cost = 5; + bool paused = 6; +} + +message RlMapInfo { + int32 width = 1; // Map width in cells + int32 height = 2; // Map height in cells + string map_name = 3; +} + +// ─── Actions (Agent → Game) ────────────────────────────────────────────────── + +message AgentAction { + repeated Command commands = 1; +} + +message Command { + ActionType action = 1; + uint32 actor_id = 2; // Subject actor (for unit commands) + uint32 target_actor_id = 3; // Target actor (for attack, enter, etc.) + int32 target_x = 4; // Target CPos X (for move, deploy, etc.) + int32 target_y = 5; // Target CPos Y + string item_type = 6; // For build/train: actor type to produce + bool queued = 7; // Queue after current activity vs interrupt +} + +enum ActionType { + NO_OP = 0; + MOVE = 1; + ATTACK_MOVE = 2; + ATTACK = 3; + STOP = 4; + HARVEST = 5; + BUILD = 6; // Start production of a building + TRAIN = 7; // Start production of a unit + DEPLOY = 8; // Deploy MCV, unpack + SELL = 9; // Sell building + REPAIR = 10; // Repair building + PLACE_BUILDING = 11; // Place a completed building at target location + CANCEL_PRODUCTION = 12; + SET_RALLY_POINT = 13; + GUARD = 14; // Guard another actor (target_actor_id) + SET_STANCE = 15; // Set unit stance (target_x: 0=HoldFire, 1=ReturnFire, 2=Defend, 3=AttackAnything) + ENTER_TRANSPORT = 16; // Load into transport (target_actor_id) + UNLOAD = 17; // Unload all passengers at current location + POWER_DOWN = 18; // Toggle building power on/off (actor_id = building) + SET_PRIMARY = 19; // Set primary production building (actor_id = building) + SURRENDER = 20; // Surrender / resign — ends the game as a loss +} + +// ─── State Query ───────────────────────────────────────────────────────────── + +message GameState { + string episode_id = 1; + int32 tick = 2; + string phase = 3; // "waiting", "playing", "game_over" + string winner = 4; // Player internal name, empty if ongoing + int32 player_count = 5; + string player_faction = 6; // Player faction internal name (e.g. "england", "ukraine") + string enemy_faction = 7; // Opponent faction internal name (if known) +} + +message StateRequest {} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..05c2d00e0fe3b03f05bf366124a843284cdc5289 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,53 @@ +[project] +name = "openra-rl" +version = "0.4.0" +description = "Play Red Alert with AI agents — LLMs, scripted bots, or RL" +readme = "README.md" +license = {text = "GPL-3.0"} +requires-python = ">=3.10" +dependencies = [ + "openenv-core>=0.2.0", + "openra-rl-util>=0.1.0", + "grpcio>=1.60.0", + "grpcio-tools>=1.60.0", + "protobuf>=4.25.0", + "pydantic>=2.0.0", + "fastapi>=0.100.0", + "uvicorn>=0.20.0", + "pyyaml>=6.0", + "mcp>=1.2.0", + "httpx>=0.24.0", + "python-dotenv>=1.0.0", + "websockets>=12.0", +] + +[project.scripts] +openra-rl = "openra_env.cli.main:main" +openra-rl-mcp = "openra_env.mcp_server:main" + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-asyncio>=0.21", + "ruff>=0.1.0", +] +training = [ + "trl>=0.7.0", + "torch>=2.0.0", + "transformers>=4.30.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["openra_env"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" + +[tool.ruff] +line-length = 120 +target-version = "py310" diff --git a/scripts/test_integration.py b/scripts/test_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..247f4c9eea68c35cc9cc74e57b47e39d992124d3 --- /dev/null +++ b/scripts/test_integration.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +"""End-to-end integration test for OpenRA-RL. + +Tests the full reset → step × N → done cycle against a live OpenRA instance. + +Prerequisites: + - OpenRA built with ExternalBotBridge trait + - OPENRA_PATH environment variable pointing to OpenRA installation + - .NET runtime installed + +Usage: + $ python scripts/test_integration.py + $ OPENRA_PATH=/path/to/openra python scripts/test_integration.py + $ python scripts/test_integration.py --steps 50 --port 9999 +""" + +import argparse +import asyncio +import logging +import os +import sys +import time + +# Add project root to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from openra_env.models import ActionType, CommandModel, OpenRAAction, OpenRAObservation +from openra_env.reward import OpenRARewardFunction +from openra_env.server.bridge_client import BridgeClient, commands_to_proto, observation_to_dict +from openra_env.server.openra_process import OpenRAConfig, OpenRAProcessManager + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(name)s: %(message)s", +) +logger = logging.getLogger("integration-test") + + +async def test_bridge_connection(port: int) -> bool: + """Test 1: Can we connect to the gRPC bridge?""" + logger.info("--- Test 1: Bridge Connection ---") + bridge = BridgeClient(port=port) + try: + ready = await bridge.wait_for_ready(max_retries=30, retry_interval=1.0) + if ready: + logger.info("PASS: Bridge connection established") + state = await bridge.get_state() + logger.info(f" Game phase: {state.phase}, tick: {state.tick}") + return True + else: + logger.error("FAIL: Bridge not ready after 30 attempts") + return False + finally: + await bridge.close() + + +async def test_session_start(port: int) -> bool: + """Test 2: Can we start a streaming session and get an initial observation?""" + logger.info("--- Test 2: Session Start ---") + bridge = BridgeClient(port=port) + try: + await bridge.connect() + obs = await bridge.start_session() + obs_dict = observation_to_dict(obs) + + logger.info(f" Initial tick: {obs_dict['tick']}") + logger.info(f" Economy cash: {obs_dict['economy']['cash']}") + logger.info(f" Units: {len(obs_dict['units'])}") + logger.info(f" Buildings: {len(obs_dict['buildings'])}") + logger.info(f" Map: {obs_dict['map_info']['width']}x{obs_dict['map_info']['height']}") + logger.info("PASS: Session started, initial observation received") + return True + except Exception as e: + logger.error(f"FAIL: Session start failed: {e}") + return False + finally: + await bridge.close() + + +async def test_step_cycle(port: int, num_steps: int) -> bool: + """Test 3: Can we run a full step cycle (send actions, receive observations)?""" + logger.info(f"--- Test 3: Step Cycle ({num_steps} steps) ---") + bridge = BridgeClient(port=port) + reward_fn = OpenRARewardFunction() + + try: + await bridge.connect() + obs = await bridge.start_session() + obs_dict = observation_to_dict(obs) + reward_fn.reset() + + total_reward = 0.0 + game_done = False + + for step in range(num_steps): + # Build a simple action: no-op or move a random unit + commands = [] + if obs_dict["units"]: + # Move the first idle unit to a nearby cell + for unit in obs_dict["units"]: + if unit["is_idle"]: + commands.append({ + "action": "move", + "actor_id": unit["actor_id"], + "target_x": unit["cell_x"] + 1, + "target_y": unit["cell_y"], + }) + break + + if not commands: + commands.append({"action": "no_op"}) + + proto_action = commands_to_proto(commands) + obs = await bridge.step(proto_action) + obs_dict = observation_to_dict(obs) + + reward = reward_fn.compute(obs_dict) + total_reward += reward + + if step % 10 == 0 or obs_dict["done"]: + logger.info( + f" Step {step}: tick={obs_dict['tick']}, " + f"cash={obs_dict['economy']['cash']}, " + f"units={len(obs_dict['units'])}, " + f"enemies={len(obs_dict['visible_enemies'])}, " + f"reward={reward:.4f}" + ) + + if obs_dict["done"]: + game_done = True + logger.info(f" Game ended: result={obs_dict['result']}") + break + + logger.info(f" Total reward after {step + 1} steps: {total_reward:.4f}") + if game_done: + logger.info("PASS: Full game episode completed") + else: + logger.info(f"PASS: {num_steps} steps executed successfully (game still running)") + return True + + except Exception as e: + logger.error(f"FAIL: Step cycle failed: {e}") + import traceback + traceback.print_exc() + return False + finally: + await bridge.close() + + +async def test_observation_model_parsing(port: int) -> bool: + """Test 4: Can observation dicts be parsed into Pydantic models?""" + logger.info("--- Test 4: Observation Model Parsing ---") + bridge = BridgeClient(port=port) + + try: + await bridge.connect() + obs = await bridge.start_session() + obs_dict = observation_to_dict(obs) + + from openra_env.server.openra_environment import OpenRAEnvironment + + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._reward_fn = OpenRARewardFunction() + parsed = env._build_observation(obs_dict, 0.0) + + assert isinstance(parsed, OpenRAObservation) + assert parsed.tick == obs_dict["tick"] + assert parsed.economy.cash == obs_dict["economy"]["cash"] + assert len(parsed.units) == len(obs_dict["units"]) + + logger.info(f" Parsed observation at tick {parsed.tick}") + logger.info(f" Economy: cash={parsed.economy.cash}, power={parsed.economy.power_provided}") + logger.info(f" Military: kills={parsed.military.units_killed}, losses={parsed.military.units_lost}") + logger.info("PASS: Observation correctly parsed into Pydantic model") + return True + + except Exception as e: + logger.error(f"FAIL: Model parsing failed: {e}") + import traceback + traceback.print_exc() + return False + finally: + await bridge.close() + + +def main(): + parser = argparse.ArgumentParser(description="OpenRA-RL Integration Test") + parser.add_argument("--openra-path", default=os.environ.get("OPENRA_PATH", "/opt/openra"), + help="Path to OpenRA installation") + parser.add_argument("--port", type=int, default=9999, help="gRPC port") + parser.add_argument("--steps", type=int, default=30, help="Number of steps to run") + parser.add_argument("--skip-launch", action="store_true", + help="Skip launching OpenRA (connect to existing instance)") + parser.add_argument("--mod", default="ra", help="Game mod to use") + parser.add_argument("--map", default="", help="Map to use") + args = parser.parse_args() + + process = None + results = {} + + try: + # Launch OpenRA if not skipping + if not args.skip_launch: + logger.info("=== Launching OpenRA ===") + config = OpenRAConfig( + openra_path=args.openra_path, + mod=args.mod, + map_name=args.map, + grpc_port=args.port, + ) + process = OpenRAProcessManager(config) + pid = process.launch() + logger.info(f"OpenRA launched with PID {pid}") + time.sleep(2) # Brief wait for process startup + else: + logger.info("=== Skipping OpenRA launch (--skip-launch) ===") + + # Run tests + loop = asyncio.new_event_loop() + try: + results["bridge_connection"] = loop.run_until_complete( + test_bridge_connection(args.port) + ) + + if results["bridge_connection"]: + results["session_start"] = loop.run_until_complete( + test_session_start(args.port) + ) + + results["observation_parsing"] = loop.run_until_complete( + test_observation_model_parsing(args.port) + ) + + results["step_cycle"] = loop.run_until_complete( + test_step_cycle(args.port, args.steps) + ) + finally: + loop.close() + + finally: + # Clean up + if process is not None: + logger.info("=== Shutting down OpenRA ===") + process.kill() + + # Summary + print("\n" + "=" * 50) + print("Integration Test Results") + print("=" * 50) + all_passed = True + for name, passed in results.items(): + status = "PASS" if passed else "FAIL" + print(f" {name}: {status}") + if not passed: + all_passed = False + + if not results: + print(" No tests were run!") + all_passed = False + + print("=" * 50) + if all_passed: + print("All tests PASSED") + sys.exit(0) + else: + print("Some tests FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/skill/SKILL.md b/skill/SKILL.md new file mode 100644 index 0000000000000000000000000000000000000000..c3a16b1f29ebed414073a3131990e83055a3c485 --- /dev/null +++ b/skill/SKILL.md @@ -0,0 +1,316 @@ +--- +name: openra-rl +description: Play Command & Conquer Red Alert RTS — build bases, train armies, and defeat AI opponents using 48 MCP tools. +version: 1.1.0 +metadata: + openclaw: + emoji: "🎮" + homepage: https://github.com/yxc20089/OpenRA-RL + requires: + bins: + - docker + env: [] + install: + - kind: uv + package: openra-rl + bins: [openra-rl] + os: ["macos", "linux"] +--- + +# OpenRA-RL: Play Command & Conquer Red Alert + +You are an AI agent playing **Command & Conquer: Red Alert**, a classic real-time strategy (RTS) game. You control one faction (Allied or Soviet) and must build a base, gather resources, train an army, and destroy the enemy. + +The game runs in a Docker container. You interact through MCP tools that let you observe the battlefield, issue orders, and advance game time. + +## Quick Start + +### 1. Install + +```bash +pip install openra-rl +``` + +### 2. Start the game server + +```bash +openra-rl server start +``` + +This pulls the Docker image and starts the game server on port 8000. Verify with `openra-rl server status`. + +### 3. Configure MCP + +Add to your OpenClaw config (`~/.openclaw/openclaw.json`): + +```json +{ + "mcpServers": { + "openra-rl": { + "command": "openra-rl", + "args": ["mcp-server"] + } + } +} +``` + +### 4. Play + +Tell your agent: *"Start a game of Red Alert and try to win."* + +The agent will use the MCP tools listed below to observe and command. + +--- + +## How the Game Works + +- **Real-time**: The game runs continuously at ~25 ticks/second. Call `advance(ticks)` to let time pass. +- **Fog of war**: You can only see areas near your units/buildings. Scout to find the enemy. +- **Resources**: Harvest ore to earn credits. Credits buy buildings and units. +- **Power**: Buildings need power. Build Power Plants (`powr`) to stay powered. Low power slows production. +- **Tech tree**: Advanced buildings require prerequisites (e.g., War Factory needs Ore Refinery). + +--- + +## MCP Tools Reference + +### Observation (read the battlefield) + +| Tool | Purpose | +|------|---------| +| `get_game_state` | Full snapshot: economy, units, buildings, enemies, production, military stats | +| `get_economy` | Cash, ore, power balance, harvester count | +| `get_units` | Your units with position, health, type, stance, speed, attack range | +| `get_buildings` | Your buildings with production queues, power, can_produce list | +| `get_enemies` | Visible enemy units and buildings (fog-of-war limited) | +| `get_production` | Current build queue + what you can build right now | +| `get_map_info` | Map name, dimensions | +| `get_exploration_status` | % explored, quadrant breakdown, whether enemy base found | + +### Knowledge (learn the game) + +| Tool | Purpose | +|------|---------| +| `lookup_unit(unit_type)` | Stats for a unit (e.g., `lookup_unit("e1")` → Rifle Infantry) | +| `lookup_building(building_type)` | Stats for a building (e.g., `lookup_building("weap")` → War Factory) | +| `lookup_tech_tree(faction)` | Full build order for `"allied"` or `"soviet"` | +| `lookup_faction(faction)` | All units and buildings for a faction | +| `get_faction_briefing()` | Comprehensive stats dump for YOUR faction | +| `get_map_analysis()` | Resource patches, water, terrain, strategic notes | +| `batch_lookup(queries)` | Multiple lookups in one call | + +### Game Control + +| Tool | Purpose | +|------|---------| +| `advance(ticks)` | **Critical** — advances the game by N ticks. Nothing happens without this. Use 25 ticks ≈ 1 second, 250 ticks ≈ 10 seconds. | + +### Movement & Combat + +| Tool | Purpose | +|------|---------| +| `move_units(unit_ids, target_x, target_y)` | Move units to a position | +| `attack_move(unit_ids, target_x, target_y)` | Move and engage enemies along the way | +| `attack_target(unit_ids, target_actor_id)` | Focus-fire a specific enemy | +| `stop_units(unit_ids)` | Halt movement and attacks | +| `guard_target(unit_ids, target_actor_id)` | Guard a unit or building | +| `set_stance(unit_ids, stance)` | Set to `"holdfire"`, `"returnfire"`, `"defend"`, or `"attackanything"` | +| `harvest(unit_id, cell_x, cell_y)` | Send harvester to ore field | + +### Production + +| Tool | Purpose | +|------|---------| +| `build_unit(unit_type, count)` | Train units (e.g., `build_unit("e1", 5)` → 5 Rifle Infantry) | +| `build_structure(building_type)` | Start constructing a building (needs manual placement) | +| `build_and_place(building_type, cell_x, cell_y)` | Build + auto-place when done (preferred) | +| `place_building(building_type, cell_x, cell_y)` | Place a completed building | +| `cancel_production(item_type)` | Cancel queued production | +| `get_valid_placements(building_type)` | Get valid locations to place a building | + +### Building Management + +| Tool | Purpose | +|------|---------| +| `deploy_unit(unit_id)` | Deploy MCV into Construction Yard | +| `sell_building(building_id)` | Sell for partial refund | +| `repair_building(building_id)` | Toggle auto-repair | +| `set_rally_point(building_id, cell_x, cell_y)` | New units go here | +| `power_down(building_id)` | Toggle power to save electricity | +| `set_primary(building_id)` | Set as primary production building | + +### Unit Groups + +| Tool | Purpose | +|------|---------| +| `assign_group(group_name, unit_ids)` | Create a named group | +| `add_to_group(group_name, unit_ids)` | Add units to existing group | +| `get_groups()` | List all groups | +| `command_group(group_name, command_type, ...)` | Command entire group | + +### Compound Actions + +| Tool | Purpose | +|------|---------| +| `batch(actions)` | Execute multiple actions in ONE tick (no time advance) | +| `plan(steps)` | Execute steps sequentially with state refresh between each | + +### Utility + +| Tool | Purpose | +|------|---------| +| `surrender()` | Give up the current game | +| `get_replay_path()` | Path to the replay file | +| `get_terrain_at(cell_x, cell_y)` | Terrain type at a cell | + +### Planning Phase (optional) + +| Tool | Purpose | +|------|---------| +| `start_planning_phase()` | Begin pre-game strategy planning | +| `get_opponent_intel()` | AI opponent profile and counters | +| `end_planning_phase(strategy)` | Commit strategy and start playing | +| `get_planning_status()` | Check planning state | + +--- + +## How to Play (Strategy Guide) + +### Step 1: Deploy your MCV + +At game start you have a Mobile Construction Vehicle (MCV). Deploy it to create your Construction Yard: + +``` +1. Call get_units() to find your MCV (type "mcv") +2. Call deploy_unit(mcv_actor_id) +3. Call advance(50) to let it deploy +``` + +### Step 2: Build your base + +Follow this build order: + +| Order | Building | Type Code | Cost | Why | +|-------|----------|-----------|------|-----| +| 1 | Power Plant | `powr` | $300 | Powers everything | +| 2 | Barracks | `tent` (Allied) or `barr` (Soviet) | $300 | Infantry production | +| 3 | Ore Refinery | `proc` | $2000 | Income + free harvester | +| 4 | War Factory | `weap` | $2000 | Vehicle production (requires Refinery) | +| 5 | More Power | `powr` | $300 | Keep power positive | + +Use `build_and_place()` — it auto-places when construction finishes: + +``` +1. Call get_valid_placements("powr") to find a good spot +2. Call build_and_place("powr", cell_x, cell_y) +3. Call advance(250) to let it build (~10 seconds) +4. Check get_production() to confirm completion +5. Repeat for next building +``` + +**Important**: Your faction may be Allied OR Soviet. Check `get_game_state()` → `faction` field. Barracks type depends on faction. + +### Step 3: Train your army + +``` +1. Call build_unit("e1", 5) for 5 Rifle Infantry ($100 each) +2. Call advance(100) to let them train +3. Once War Factory is ready: build_unit("3tnk", 3) for Medium Tanks ($800 each) +4. Set rally point near base exit: set_rally_point(barracks_id, x, y) +``` + +**Key units by faction:** + +| Unit | Code | Cost | Role | +|------|------|------|------| +| Rifle Infantry | `e1` | $100 | Cheap, fast | +| Rocket Soldier | `e3` | $300 | Anti-armor | +| Medium Tank | `3tnk` | $800 | Main battle tank | +| Heavy Tank | `4tnk` | $950 | Soviet heavy armor | +| Light Tank | `1tnk` | $700 | Fast flanker | +| Artillery | `arty` | $600 | Long range | +| V2 Launcher | `v2rl` | $700 | Soviet long range | + +### Step 4: Scout the map + +Send a cheap unit to explore: + +``` +1. Train one Rifle Infantry +2. Call attack_move([unit_id], far_x, far_y) toward unexplored areas +3. Call advance(500) to let it travel +4. Call get_enemies() to see what you've found +``` + +### Step 5: Attack the enemy + +Once you have 8-10 combat units: + +``` +1. Call get_enemies() to find enemy buildings +2. Call attack_move(all_unit_ids, enemy_base_x, enemy_base_y) +3. Call advance(100), check get_game_state() for battle progress +4. If enemies visible: attack_target(unit_ids, enemy_id) to focus fire +5. Keep producing reinforcements while attacking +``` + +### Step 6: Macro (ongoing economy) + +Throughout the game: +- Keep power positive (build Power Plants when needed) +- Keep producing units — never let production idle +- Build additional Ore Refineries for more income +- Replace lost harvesters + +--- + +## Game Loop Pattern + +A good agent loop looks like this: + +``` +1. get_game_state() → read the situation +2. Decide what to do based on: + - Economy: enough cash? Power positive? + - Production: anything building? Queue empty? + - Military: under attack? Ready to attack? + - Exploration: enemy found yet? +3. Issue orders (build, move, attack) +4. advance(50-250) → let time pass +5. Repeat until game is won or lost +``` + +Check `get_game_state()` → `done` field. When true, `result` will be `"win"` or `"loss"`. + +--- + +## Tips + +- **Always call `advance()`** after issuing orders. Orders don't execute until game time passes. +- **Use `batch()`** to issue multiple orders in one tick (e.g., build + move + set rally). +- **Check `available_production`** before building — it lists what you CAN build right now. +- **Don't let production idle** — keep queuing units. Idle production wastes time. +- **Build near your Construction Yard** — buildings must be placed adjacent to existing structures. +- **Power matters** — if power goes negative, production slows to a crawl. +- **Use `attack_move`** instead of `move` when heading toward enemies — units will engage threats. +- **A completed building blocks the queue** until placed. Always use `build_and_place()` to avoid this. + +--- + +## Troubleshooting + +| Problem | Solution | +|---------|----------| +| Server not running | `openra-rl server start` (needs Docker) | +| Can't build anything | Deploy MCV first with `deploy_unit()` | +| Building won't place | Use `get_valid_placements()` for valid spots | +| No money | Build Ore Refinery (`proc`) for harvesters | +| Production slow | Check power with `get_economy()` — build Power Plants | +| Can't find enemy | Scout with `attack_move` to unexplored quadrants | + +## Links + +- **GitHub**: https://github.com/yxc20089/OpenRA-RL +- **PyPI**: https://pypi.org/project/openra-rl/ +- **Leaderboard**: https://huggingface.co/spaces/yxc20089/OpenRA-Bench +- **Discord**: https://discord.gg/openra-rl diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..892ed41076cf70bd995bc04df5b741f2018bbfd9 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,151 @@ +"""Shared test utilities for FastMCP tool access across versions. + +FastMCP 3.0 (released 2026-02-19) changed internal APIs — the +``_tool_manager._tools`` dict no longer exists. This module provides +version-agnostic helpers **and** a pytest autouse fixture that patches +``mcp._tool_manager._tools`` back in so existing tests work unmodified. +""" + +import types +import pytest + + +# ── Version-agnostic tool access helpers ────────────────────────────────────── + + +def _find_tools_dict(mcp) -> dict | None: + """Probe FastMCP internals to locate the canonical tool registry. + + Returns the raw dict[str, ToolObj] if found, else ``None``. + """ + # FastMCP 2.x path + if hasattr(mcp, "_tool_manager"): + tm = mcp._tool_manager + if hasattr(tm, "_tools") and isinstance(tm._tools, dict): + return tm._tools + if hasattr(tm, "tools") and isinstance(tm.tools, dict): + return tm.tools + + # FastMCP 3.x: tools stored directly on mcp + if hasattr(mcp, "_tools") and isinstance(mcp._tools, dict): + return mcp._tools + + return None + + +def _extract_fn(tool_obj): + """Extract the underlying callable from a Tool wrapper object.""" + if hasattr(tool_obj, "fn"): + return tool_obj.fn + if callable(tool_obj): + return tool_obj + return None + + +def get_tool_fn(mcp, name): + """Get a tool's callable function from a FastMCP instance by name. + + Supports FastMCP 2.x and 3.x. Returns the raw function so it can + be called directly in tests. + """ + tools = _find_tools_dict(mcp) + if tools is not None: + tool = tools.get(name) + if tool is not None: + return _extract_fn(tool) + return None + + +def get_tool_names(mcp) -> set: + """Return the set of registered tool names.""" + tools = _find_tools_dict(mcp) + return set(tools.keys()) if tools else set() + + +def get_tool_count(mcp) -> int: + """Return the number of registered tools.""" + return len(get_tool_names(mcp)) + + +class ToolWrapper: + """Compatibility wrapper matching FastMCP 2.x Tool interface.""" + + def __init__(self, fn): + self.fn = fn + + +def get_tool_obj(mcp, name): + """Get a tool as an object with a ``.fn`` attribute (FastMCP 2.x compat).""" + fn = get_tool_fn(mcp, name) + return ToolWrapper(fn) if fn is not None else None + + +def get_tools_dict(mcp) -> dict: + """Return dict mapping tool names → ToolWrapper objects. + + Drop-in replacement for ``mcp._tool_manager._tools``. + """ + names = get_tool_names(mcp) + result = {} + for name in names: + fn = get_tool_fn(mcp, name) + if fn is not None: + result[name] = ToolWrapper(fn) + return result + + +# ── Autouse fixtures ────────────────────────────────────────────────────────── + +# Monkey-patch FastMCP so that mcp._tool_manager._tools works on 3.x +# This is done via a module-level patch applied when conftest is imported. + + +def _patch_fastmcp(): + """Ensure FastMCP instances expose ``_tool_manager._tools`` on 3.x.""" + try: + from fastmcp import FastMCP + except ImportError: + return # fastmcp not installed — nothing to patch + + original_tool = getattr(FastMCP, "tool", None) + if original_tool is None: + return + + # Check if _tool_manager._tools already works (FastMCP 2.x) + test_mcp = FastMCP("__patch_test__") + if hasattr(test_mcp, "_tool_manager") and hasattr(test_mcp._tool_manager, "_tools"): + if isinstance(test_mcp._tool_manager._tools, dict): + return # Already compatible, no patch needed + + # FastMCP 3.x: We need to create a compatibility shim. + # Override the tool() method to also store tools in a compat dict. + _compat_registry = {} # Will be shared per-mcp instance via __dict__ + + original_init = FastMCP.__init__ + + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + # Add compat _tool_manager._tools + if not hasattr(self, "_tool_manager"): + self._tool_manager = types.SimpleNamespace() + if not hasattr(self._tool_manager, "_tools"): + self._tool_manager._tools = {} + + def patched_tool(self, *args, **kwargs): + original_decorator = original_tool(self, *args, **kwargs) + + def wrapper(fn): + result = original_decorator(fn) + # Also register in our compat dict + if hasattr(self, "_tool_manager") and hasattr(self._tool_manager, "_tools"): + self._tool_manager._tools[fn.__name__] = ToolWrapper(fn) + return result + + return wrapper + + FastMCP.__init__ = patched_init + FastMCP.tool = patched_tool + + +# Apply patch at import time (before any tests run) +_patch_fastmcp() diff --git a/tests/test_bridge.py b/tests/test_bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..4edcd1b96cec4e4d5eda42ffba7b305d4412d8c3 --- /dev/null +++ b/tests/test_bridge.py @@ -0,0 +1,278 @@ +"""Tests for bridge client helper functions. + +Tests observation_to_dict and commands_to_proto conversion functions +using mock protobuf objects. +""" + +import pytest + +from openra_env.server.bridge_client import commands_to_proto, observation_to_dict +from openra_env.generated import rl_bridge_pb2 + + +class TestCommandsToProto: + def test_no_op(self): + result = commands_to_proto([{"action": "no_op"}]) + assert len(result.commands) == 1 + assert result.commands[0].action == rl_bridge_pb2.NO_OP + + def test_move_command(self): + result = commands_to_proto([ + {"action": "move", "actor_id": 42, "target_x": 100, "target_y": 200} + ]) + cmd = result.commands[0] + assert cmd.action == rl_bridge_pb2.MOVE + assert cmd.actor_id == 42 + assert cmd.target_x == 100 + assert cmd.target_y == 200 + + def test_attack_command(self): + result = commands_to_proto([ + {"action": "attack", "actor_id": 10, "target_actor_id": 99} + ]) + cmd = result.commands[0] + assert cmd.action == rl_bridge_pb2.ATTACK + assert cmd.actor_id == 10 + assert cmd.target_actor_id == 99 + + def test_build_command(self): + result = commands_to_proto([ + {"action": "build", "item_type": "powr"} + ]) + cmd = result.commands[0] + assert cmd.action == rl_bridge_pb2.BUILD + assert cmd.item_type == "powr" + + def test_queued_flag(self): + result = commands_to_proto([ + {"action": "move", "actor_id": 1, "target_x": 10, "target_y": 20, "queued": True} + ]) + assert result.commands[0].queued is True + + def test_multiple_commands(self): + result = commands_to_proto([ + {"action": "move", "actor_id": 1, "target_x": 10, "target_y": 20}, + {"action": "attack", "actor_id": 2, "target_actor_id": 50}, + {"action": "build", "item_type": "barr"}, + ]) + assert len(result.commands) == 3 + assert result.commands[0].action == rl_bridge_pb2.MOVE + assert result.commands[1].action == rl_bridge_pb2.ATTACK + assert result.commands[2].action == rl_bridge_pb2.BUILD + + def test_unknown_action_defaults_to_noop(self): + result = commands_to_proto([{"action": "invalid_action"}]) + assert result.commands[0].action == rl_bridge_pb2.NO_OP + + def test_missing_action_defaults_to_noop(self): + result = commands_to_proto([{}]) + assert result.commands[0].action == rl_bridge_pb2.NO_OP + + def test_all_action_types(self): + action_types = [ + ("no_op", rl_bridge_pb2.NO_OP), + ("move", rl_bridge_pb2.MOVE), + ("attack_move", rl_bridge_pb2.ATTACK_MOVE), + ("attack", rl_bridge_pb2.ATTACK), + ("stop", rl_bridge_pb2.STOP), + ("harvest", rl_bridge_pb2.HARVEST), + ("build", rl_bridge_pb2.BUILD), + ("train", rl_bridge_pb2.TRAIN), + ("deploy", rl_bridge_pb2.DEPLOY), + ("sell", rl_bridge_pb2.SELL), + ("repair", rl_bridge_pb2.REPAIR), + ("place_building", rl_bridge_pb2.PLACE_BUILDING), + ("cancel_production", rl_bridge_pb2.CANCEL_PRODUCTION), + ("set_rally_point", rl_bridge_pb2.SET_RALLY_POINT), + ] + for action_str, expected_enum in action_types: + result = commands_to_proto([{"action": action_str}]) + assert result.commands[0].action == expected_enum, f"Failed for {action_str}" + + def test_empty_list(self): + result = commands_to_proto([]) + assert len(result.commands) == 0 + + def test_default_values_for_missing_fields(self): + result = commands_to_proto([{"action": "move"}]) + cmd = result.commands[0] + assert cmd.actor_id == 0 + assert cmd.target_actor_id == 0 + assert cmd.target_x == 0 + assert cmd.target_y == 0 + assert cmd.item_type == "" + assert cmd.queued is False + + +class TestObservationToDict: + def _make_observation(self, **kwargs): + """Create a protobuf GameObservation with given fields.""" + obs = rl_bridge_pb2.GameObservation() + obs.tick = kwargs.get("tick", 0) + obs.done = kwargs.get("done", False) + obs.result = kwargs.get("result", "") + obs.reward = kwargs.get("reward", 0.0) + + if "economy" in kwargs: + eco = kwargs["economy"] + obs.economy.cash = eco.get("cash", 0) + obs.economy.ore = eco.get("ore", 0) + obs.economy.power_provided = eco.get("power_provided", 0) + obs.economy.power_drained = eco.get("power_drained", 0) + obs.economy.resource_capacity = eco.get("resource_capacity", 0) + obs.economy.harvester_count = eco.get("harvester_count", 0) + + if "military" in kwargs: + mil = kwargs["military"] + obs.military.units_killed = mil.get("units_killed", 0) + obs.military.units_lost = mil.get("units_lost", 0) + obs.military.buildings_killed = mil.get("buildings_killed", 0) + obs.military.buildings_lost = mil.get("buildings_lost", 0) + obs.military.army_value = mil.get("army_value", 0) + obs.military.active_unit_count = mil.get("active_unit_count", 0) + + if "map_info" in kwargs: + mi = kwargs["map_info"] + obs.map_info.width = mi.get("width", 0) + obs.map_info.height = mi.get("height", 0) + obs.map_info.map_name = mi.get("map_name", "") + + for u in kwargs.get("units", []): + unit = obs.units.add() + unit.actor_id = u.get("actor_id", 0) + unit.type = u.get("type", "") + unit.pos_x = u.get("pos_x", 0) + unit.pos_y = u.get("pos_y", 0) + unit.cell_x = u.get("cell_x", 0) + unit.cell_y = u.get("cell_y", 0) + unit.hp_percent = u.get("hp_percent", 1.0) + unit.is_idle = u.get("is_idle", True) + unit.current_activity = u.get("current_activity", "") + unit.owner = u.get("owner", "") + unit.can_attack = u.get("can_attack", False) + + for b in kwargs.get("buildings", []): + bldg = obs.buildings.add() + bldg.actor_id = b.get("actor_id", 0) + bldg.type = b.get("type", "") + bldg.pos_x = b.get("pos_x", 0) + bldg.pos_y = b.get("pos_y", 0) + bldg.hp_percent = b.get("hp_percent", 1.0) + bldg.owner = b.get("owner", "") + bldg.is_producing = b.get("is_producing", False) + bldg.production_progress = b.get("production_progress", 0.0) + bldg.producing_item = b.get("producing_item", "") + bldg.is_powered = b.get("is_powered", True) + + for p in kwargs.get("production", []): + prod = obs.production.add() + prod.queue_type = p.get("queue_type", "") + prod.item = p.get("item", "") + prod.progress = p.get("progress", 0.0) + prod.remaining_ticks = p.get("remaining_ticks", 0) + prod.remaining_cost = p.get("remaining_cost", 0) + prod.paused = p.get("paused", False) + + for ap in kwargs.get("available_production", []): + obs.available_production.append(ap) + + return obs + + def test_basic_fields(self): + obs = self._make_observation(tick=42, done=True, result="win", reward=1.5) + d = observation_to_dict(obs) + assert d["tick"] == 42 + assert d["done"] is True + assert d["result"] == "win" + assert d["reward"] == 1.5 + + def test_economy(self): + obs = self._make_observation( + economy={"cash": 5000, "power_provided": 100, "power_drained": 60, "harvester_count": 2} + ) + d = observation_to_dict(obs) + assert d["economy"]["cash"] == 5000 + assert d["economy"]["power_provided"] == 100 + assert d["economy"]["power_drained"] == 60 + assert d["economy"]["harvester_count"] == 2 + + def test_military(self): + obs = self._make_observation( + military={"units_killed": 3, "units_lost": 1, "army_value": 5000} + ) + d = observation_to_dict(obs) + assert d["military"]["units_killed"] == 3 + assert d["military"]["units_lost"] == 1 + assert d["military"]["army_value"] == 5000 + + def test_units(self): + obs = self._make_observation( + units=[ + {"actor_id": 1, "type": "e1", "pos_x": 100, "pos_y": 200, "hp_percent": 0.75, "can_attack": True}, + {"actor_id": 2, "type": "1tnk", "is_idle": False, "current_activity": "Move"}, + ] + ) + d = observation_to_dict(obs) + assert len(d["units"]) == 2 + assert d["units"][0]["actor_id"] == 1 + assert d["units"][0]["type"] == "e1" + assert d["units"][0]["hp_percent"] == pytest.approx(0.75) + assert d["units"][0]["can_attack"] is True + assert d["units"][1]["is_idle"] is False + assert d["units"][1]["current_activity"] == "Move" + + def test_buildings(self): + obs = self._make_observation( + buildings=[ + {"actor_id": 10, "type": "powr", "is_powered": True}, + {"actor_id": 20, "type": "barr", "is_producing": True, "producing_item": "e1"}, + ] + ) + d = observation_to_dict(obs) + assert len(d["buildings"]) == 2 + assert d["buildings"][0]["type"] == "powr" + assert d["buildings"][1]["is_producing"] is True + assert d["buildings"][1]["producing_item"] == "e1" + + def test_production(self): + obs = self._make_observation( + production=[{"queue_type": "Infantry", "item": "e1", "progress": 0.5, "remaining_ticks": 100}] + ) + d = observation_to_dict(obs) + assert len(d["production"]) == 1 + assert d["production"][0]["queue_type"] == "Infantry" + assert d["production"][0]["progress"] == pytest.approx(0.5) + + def test_visible_enemies(self): + obs = self._make_observation() + enemy = obs.visible_enemies.add() + enemy.actor_id = 99 + enemy.type = "2tnk" + enemy.owner = "Enemy" + d = observation_to_dict(obs) + assert len(d["visible_enemies"]) == 1 + assert d["visible_enemies"][0]["actor_id"] == 99 + assert d["visible_enemies"][0]["owner"] == "Enemy" + + def test_map_info(self): + obs = self._make_observation(map_info={"width": 128, "height": 128, "map_name": "Test Map"}) + d = observation_to_dict(obs) + assert d["map_info"]["width"] == 128 + assert d["map_info"]["height"] == 128 + assert d["map_info"]["map_name"] == "Test Map" + + def test_available_production(self): + obs = self._make_observation(available_production=["e1", "e3", "1tnk"]) + d = observation_to_dict(obs) + assert d["available_production"] == ["e1", "e3", "1tnk"] + + def test_empty_observation(self): + obs = self._make_observation() + d = observation_to_dict(obs) + assert d["tick"] == 0 + assert d["units"] == [] + assert d["buildings"] == [] + assert d["production"] == [] + assert d["visible_enemies"] == [] + assert d["done"] is False + assert d["result"] == "" diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..921f32ad77dcbbe1beeeb5f771afe3275fb0b0ab --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,549 @@ +"""Tests for the openra-rl CLI package.""" + +import os +import subprocess +import sys +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import yaml + + +# ── Console ───────────────────────────────────────────────────────── + +class TestConsole: + def test_info(self, capsys): + from openra_env.cli.console import info + info("hello") + assert "hello" in capsys.readouterr().out + + def test_success(self, capsys): + from openra_env.cli.console import success + success("done") + assert "done" in capsys.readouterr().out + + def test_error(self, capsys): + from openra_env.cli.console import error + error("fail") + assert "fail" in capsys.readouterr().err + + def test_warn(self, capsys): + from openra_env.cli.console import warn + warn("caution") + assert "caution" in capsys.readouterr().out + + def test_step(self, capsys): + from openra_env.cli.console import step + step("pulling...") + assert "pulling..." in capsys.readouterr().out + + def test_header(self, capsys): + from openra_env.cli.console import header + header("Title") + assert "Title" in capsys.readouterr().out + + def test_dim(self, capsys): + from openra_env.cli.console import dim + dim("faint text") + assert "faint text" in capsys.readouterr().out + + +# ── Docker Manager ────────────────────────────────────────────────── + +class TestDockerManager: + @patch("openra_env.cli.docker_manager.shutil.which", return_value=None) + def test_check_docker_not_installed(self, mock_which): + from openra_env.cli.docker_manager import check_docker + assert check_docker() is False + + @patch("openra_env.cli.docker_manager.shutil.which", return_value="/usr/bin/docker") + @patch("openra_env.cli.docker_manager._run") + def test_check_docker_daemon_not_running(self, mock_run, mock_which): + mock_run.return_value = MagicMock(returncode=1) + from openra_env.cli.docker_manager import check_docker + assert check_docker() is False + + @patch("openra_env.cli.docker_manager.shutil.which", return_value="/usr/bin/docker") + @patch("openra_env.cli.docker_manager._run") + def test_check_docker_ok(self, mock_run, mock_which): + mock_run.return_value = MagicMock(returncode=0) + from openra_env.cli.docker_manager import check_docker + assert check_docker() is True + + @patch("openra_env.cli.docker_manager._run") + def test_is_running_false(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="") + from openra_env.cli.docker_manager import is_running + assert is_running() is False + + @patch("openra_env.cli.docker_manager._run") + def test_is_running_true(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="openra-rl-server\n") + from openra_env.cli.docker_manager import is_running + assert is_running() is True + + @patch("openra_env.cli.docker_manager._run") + def test_image_exists_false(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="") + from openra_env.cli.docker_manager import image_exists + assert image_exists() is False + + @patch("openra_env.cli.docker_manager._run") + def test_image_exists_true(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="abc123\n") + from openra_env.cli.docker_manager import image_exists + assert image_exists() is True + + @patch("openra_env.cli.docker_manager.is_running", return_value=True) + def test_start_server_already_running(self, mock_running): + from openra_env.cli.docker_manager import start_server + assert start_server() is True + + @patch("openra_env.cli.docker_manager.is_running", return_value=False) + def test_stop_server_not_running(self, mock_running): + from openra_env.cli.docker_manager import stop_server + assert stop_server() is True + + @patch("openra_env.cli.docker_manager.is_running", return_value=True) + @patch("openra_env.cli.docker_manager._run") + def test_stop_server_ok(self, mock_run, mock_running): + mock_run.return_value = MagicMock(returncode=0) + from openra_env.cli.docker_manager import stop_server + assert stop_server() is True + + @patch("openra_env.cli.docker_manager._run") + def test_server_status_not_running(self, mock_run): + from openra_env.cli.docker_manager import server_status, is_running + with patch("openra_env.cli.docker_manager.is_running", return_value=False): + assert server_status() is None + + @patch("openra_env.cli.docker_manager.is_running", return_value=True) + @patch("openra_env.cli.docker_manager._run") + def test_server_status_running(self, mock_run, mock_running): + mock_run.return_value = MagicMock( + returncode=0, stdout="Up 5 minutes\t0.0.0.0:8000->8000/tcp" + ) + from openra_env.cli.docker_manager import server_status + status = server_status() + assert status is not None + assert "Up" in status["status"] + + def test_image_constant(self): + from openra_env.cli.docker_manager import IMAGE + assert "ghcr.io" in IMAGE + + def test_container_name(self): + from openra_env.cli.docker_manager import CONTAINER_NAME + assert CONTAINER_NAME == "openra-rl-server" + + +# ── Replay Viewer Settings ────────────────────────────────────────── + +class TestReplayViewerSettings: + def test_defaults(self, monkeypatch): + import os as _os + from openra_env.cli.docker_manager import load_replay_viewer_settings + for key in [ + "OPENRA_RL_REPLAY_RESOLUTION", "OPENRA_RL_REPLAY_RENDER", + "OPENRA_RL_REPLAY_VNC_QUALITY", "OPENRA_RL_REPLAY_VNC_COMPRESSION", + "OPENRA_RL_REPLAY_UI_SCALE", "OPENRA_RL_REPLAY_VIEWPORT_DISTANCE", + "OPENRA_RL_REPLAY_MUTE", "OPENRA_RL_REPLAY_CPU_CORES", + ]: + monkeypatch.delenv(key, raising=False) + s = load_replay_viewer_settings() + assert s.width == 1280 + assert s.height == 960 + assert s.render_mode == "auto" + assert s.vnc_quality == 8 + assert s.vnc_compression == 4 + assert s.ui_scale == 1.0 + assert s.viewport_distance == "Medium" + assert s.mute is True + assert s.cpu_cores == 4 + + def test_env_overrides(self, monkeypatch): + from openra_env.cli.docker_manager import load_replay_viewer_settings + monkeypatch.setenv("OPENRA_RL_REPLAY_RESOLUTION", "1280x720") + monkeypatch.setenv("OPENRA_RL_REPLAY_RENDER", "cpu") + monkeypatch.setenv("OPENRA_RL_REPLAY_VNC_QUALITY", "9") + monkeypatch.setenv("OPENRA_RL_REPLAY_VNC_COMPRESSION", "2") + monkeypatch.setenv("OPENRA_RL_REPLAY_UI_SCALE", "1.0") + monkeypatch.setenv("OPENRA_RL_REPLAY_VIEWPORT_DISTANCE", "far") + monkeypatch.setenv("OPENRA_RL_REPLAY_MUTE", "false") + s = load_replay_viewer_settings() + assert s.width == 1280 + assert s.height == 720 + assert s.render_mode == "cpu" + assert s.vnc_quality == 9 + assert s.vnc_compression == 2 + assert s.ui_scale == 1.0 + assert s.viewport_distance == "Far" + assert s.mute is False + + def test_cli_overrides_take_precedence(self, monkeypatch): + from openra_env.cli.docker_manager import load_replay_viewer_settings + monkeypatch.setenv("OPENRA_RL_REPLAY_RESOLUTION", "640x480") + monkeypatch.setenv("OPENRA_RL_REPLAY_RENDER", "cpu") + s = load_replay_viewer_settings(resolution="1920x1080", render_mode="gpu") + assert s.width == 1920 + assert s.height == 1080 + assert s.render_mode == "gpu" + + def test_invalid_resolution_raises(self): + from openra_env.cli.docker_manager import load_replay_viewer_settings + with pytest.raises(ValueError, match="resolution"): + load_replay_viewer_settings(resolution="bad") + + def test_invalid_render_mode_raises(self): + from openra_env.cli.docker_manager import load_replay_viewer_settings + with pytest.raises(ValueError, match="render mode"): + load_replay_viewer_settings(render_mode="turbo") + + def test_gpu_docker_args_cpu(self): + from openra_env.cli.docker_manager import _gpu_docker_args + variants = _gpu_docker_args("cpu", cpu_cores=4) + assert len(variants) == 1 + assert "LIBGL_ALWAYS_SOFTWARE=1" in variants[0] + assert "LP_NUM_THREADS=4" in variants[0] + + def test_gpu_docker_args_cpu_custom_cores(self): + from openra_env.cli.docker_manager import _gpu_docker_args + variants = _gpu_docker_args("cpu", cpu_cores=8) + assert "LP_NUM_THREADS=8" in variants[0] + + def test_gpu_docker_args_gpu(self): + from openra_env.cli.docker_manager import _gpu_docker_args + variants = _gpu_docker_args("gpu") + assert len(variants) == 4 # NVIDIA, WSL2, AMD ROCm, DRI + assert "--gpus" in variants[0] + assert "/dev/dxg" in str(variants[1]) + assert "/dev/kfd" in str(variants[2]) + assert "/dev/dri" in str(variants[3]) + + def test_gpu_docker_args_auto(self): + from openra_env.cli.docker_manager import _gpu_docker_args + variants = _gpu_docker_args("auto") + assert len(variants) == 5 # 4 GPU + 1 CPU + # GPU variants first, CPU last + assert "--gpus" in variants[0] + assert "LIBGL_ALWAYS_SOFTWARE=1" in variants[-1] + + def test_cpu_cores_env_override(self, monkeypatch): + from openra_env.cli.docker_manager import load_replay_viewer_settings + monkeypatch.setenv("OPENRA_RL_REPLAY_CPU_CORES", "8") + s = load_replay_viewer_settings() + assert s.cpu_cores == 8 + + def test_cpu_cores_cli_override(self, monkeypatch): + from openra_env.cli.docker_manager import load_replay_viewer_settings + monkeypatch.setenv("OPENRA_RL_REPLAY_CPU_CORES", "8") + s = load_replay_viewer_settings(cpu_cores=2) + assert s.cpu_cores == 2 + + def test_cpu_cores_clamped(self): + import os as _os + from openra_env.cli.docker_manager import load_replay_viewer_settings + # 0 means "all available" + s = load_replay_viewer_settings(cpu_cores=0) + assert s.cpu_cores == (_os.cpu_count() or 4) + # Clamped to max 32 + s = load_replay_viewer_settings(cpu_cores=100) + assert s.cpu_cores == 32 + + def test_settings_env_args(self): + from openra_env.cli.docker_manager import ReplayViewerSettings, _settings_env_args + s = ReplayViewerSettings(width=1280, height=720, mute=False) + args = _settings_env_args(s) + assert "-e" in args + assert "OPENRA_RL_REPLAY_RESOLUTION=1280x720" in args + assert "OPENRA_RL_REPLAY_MUTE=False" in args + + @patch("openra_env.cli.docker_manager._run") + def test_replay_viewer_exists_false(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="") + from openra_env.cli.docker_manager import replay_viewer_exists + assert replay_viewer_exists() is False + + @patch("openra_env.cli.docker_manager._run") + def test_replay_viewer_exists_true(self, mock_run): + mock_run.return_value = MagicMock(returncode=0, stdout="openra-rl-replay\n") + from openra_env.cli.docker_manager import replay_viewer_exists + assert replay_viewer_exists() is True + + @patch("openra_env.cli.commands.docker") + def test_cmd_replay_watch_invalid_setting(self, mock_docker): + from openra_env.cli.commands import cmd_replay_watch + mock_docker.check_docker.return_value = True + mock_docker.load_replay_viewer_settings.side_effect = ValueError("bad resolution") + with pytest.raises(SystemExit) as exc_info: + cmd_replay_watch(resolution="bad") + assert exc_info.value.code == 1 + mock_docker.start_replay_viewer.assert_not_called() + + @patch("openra_env.cli.commands.cmd_replay_watch") + def test_main_replay_watch_with_flags(self, mock_watch): + from openra_env.cli.main import main + with patch("sys.argv", [ + "openra-rl", "replay", "watch", "demo.orarep", + "--port", "6090", + "--resolution", "1280x720", + "--render", "gpu", + "--vnc-quality", "9", + "--vnc-compression", "2", + "--cpus", "6", + ]): + main() + mock_watch.assert_called_once_with( + file="demo.orarep", + port=6090, + resolution="1280x720", + render_mode="gpu", + vnc_quality=9, + vnc_compression=2, + cpu_cores=6, + ) + + +# ── Wizard ────────────────────────────────────────────────────────── + +class TestWizard: + def test_config_path(self): + from openra_env.cli.wizard import CONFIG_DIR, CONFIG_PATH + assert CONFIG_DIR == Path.home() / ".openra-rl" + assert CONFIG_PATH == Path.home() / ".openra-rl" / "config.yaml" + + def test_providers_defined(self): + from openra_env.cli.wizard import PROVIDERS + assert "openrouter" in PROVIDERS + assert "ollama" in PROVIDERS + assert "lmstudio" in PROVIDERS + + def test_provider_openrouter_needs_key(self): + from openra_env.cli.wizard import PROVIDERS + assert PROVIDERS["openrouter"]["needs_key"] is True + + def test_provider_ollama_no_key(self): + from openra_env.cli.wizard import PROVIDERS + assert PROVIDERS["ollama"]["needs_key"] is False + + def test_provider_lmstudio_no_key(self): + from openra_env.cli.wizard import PROVIDERS + assert PROVIDERS["lmstudio"]["needs_key"] is False + + def test_has_saved_config_false(self, tmp_path): + from openra_env.cli import wizard + with patch.object(wizard, "CONFIG_PATH", tmp_path / "nonexistent.yaml"): + assert wizard.has_saved_config() is False + + def test_save_and_load_config(self, tmp_path): + from openra_env.cli import wizard + cfg_path = tmp_path / "config.yaml" + with patch.object(wizard, "CONFIG_PATH", cfg_path), \ + patch.object(wizard, "CONFIG_DIR", tmp_path): + wizard.save_config({"llm": {"model": "test-model"}}) + loaded = wizard.load_saved_config() + assert loaded["llm"]["model"] == "test-model" + + def test_merge_cli_into_config_provider(self): + from openra_env.cli.wizard import merge_cli_into_config + config = {"llm": {"model": "old"}} + result = merge_cli_into_config(config, provider="ollama") + assert "localhost:11434" in result["llm"]["base_url"] + assert result["provider"] == "ollama" + + def test_merge_cli_into_config_model(self): + from openra_env.cli.wizard import merge_cli_into_config + config = {} + result = merge_cli_into_config(config, model="new-model") + assert result["llm"]["model"] == "new-model" + + def test_merge_cli_into_config_api_key(self): + from openra_env.cli.wizard import merge_cli_into_config + config = {} + result = merge_cli_into_config(config, api_key="sk-test") + assert result["llm"]["api_key"] == "sk-test" + + def test_merge_cli_preserves_existing(self): + from openra_env.cli.wizard import merge_cli_into_config + config = {"llm": {"model": "existing", "base_url": "http://test"}} + result = merge_cli_into_config(config, api_key="sk-new") + assert result["llm"]["model"] == "existing" + assert result["llm"]["base_url"] == "http://test" + assert result["llm"]["api_key"] == "sk-new" + + +# ── Commands ──────────────────────────────────────────────────────── + +class TestCommands: + def test_cmd_version(self, capsys): + from openra_env.cli.commands import cmd_version + cmd_version() + out = capsys.readouterr().out + assert "openra-rl" in out + + @patch("openra_env.cli.commands.docker") + def test_cmd_server_status_not_running(self, mock_docker, capsys): + mock_docker.server_status.return_value = None + from openra_env.cli.commands import cmd_server_status + cmd_server_status() + assert "not running" in capsys.readouterr().out + + @patch("openra_env.cli.commands.docker") + def test_cmd_server_status_running(self, mock_docker, capsys): + mock_docker.server_status.return_value = { + "status": "Up 5 minutes", + "ports": "0.0.0.0:8000->8000/tcp", + } + from openra_env.cli.commands import cmd_server_status + cmd_server_status() + assert "running" in capsys.readouterr().out + + @patch("openra_env.cli.commands.docker") + def test_cmd_server_stop(self, mock_docker): + from openra_env.cli.commands import cmd_server_stop + cmd_server_stop() + mock_docker.stop_server.assert_called_once() + + @patch("openra_env.cli.commands.docker") + def test_cmd_server_logs(self, mock_docker): + from openra_env.cli.commands import cmd_server_logs + cmd_server_logs(follow=True) + mock_docker.get_logs.assert_called_once_with(follow=True) + + @patch("openra_env.cli.commands.docker.check_docker", return_value=False) + def test_cmd_play_no_docker(self, mock_check): + from openra_env.cli.commands import cmd_play + with pytest.raises(SystemExit): + cmd_play() + + +# ── Main Entry Point ─────────────────────────────────────────────── + +class TestMain: + def test_main_no_args_shows_help(self, capsys): + from openra_env.cli.main import main + with patch("sys.argv", ["openra-rl"]): + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 0 + + def test_main_version_flag(self, capsys): + from openra_env.cli.main import main + with patch("sys.argv", ["openra-rl", "--version"]): + main() + assert "openra-rl" in capsys.readouterr().out + + def test_main_version_subcommand(self, capsys): + from openra_env.cli.main import main + with patch("sys.argv", ["openra-rl", "version"]): + main() + assert "openra-rl" in capsys.readouterr().out + + @patch("openra_env.cli.commands.cmd_doctor") + def test_main_doctor(self, mock_doctor): + from openra_env.cli.main import main + with patch("sys.argv", ["openra-rl", "doctor"]): + main() + mock_doctor.assert_called_once() + + @patch("openra_env.cli.commands.cmd_config") + def test_main_config(self, mock_config): + from openra_env.cli.main import main + with patch("sys.argv", ["openra-rl", "config"]): + main() + mock_config.assert_called_once() + + @patch("openra_env.cli.commands.cmd_server_stop") + def test_main_server_stop(self, mock_stop): + from openra_env.cli.main import main + with patch("sys.argv", ["openra-rl", "server", "stop"]): + main() + mock_stop.assert_called_once() + + @patch("openra_env.cli.commands.cmd_server_status") + def test_main_server_status(self, mock_status): + from openra_env.cli.main import main + with patch("sys.argv", ["openra-rl", "server", "status"]): + main() + mock_status.assert_called_once() + + @patch("openra_env.cli.commands.cmd_server_logs") + def test_main_server_logs_follow(self, mock_logs): + from openra_env.cli.main import main + with patch("sys.argv", ["openra-rl", "server", "logs", "--follow"]): + main() + mock_logs.assert_called_once_with(follow=True) + + @patch("openra_env.cli.commands.cmd_play") + def test_main_play_with_flags(self, mock_play): + from openra_env.cli.main import main + with patch("sys.argv", [ + "openra-rl", "play", + "--provider", "ollama", + "--model", "qwen3:32b", + "--verbose", + "--port", "9000", + ]): + main() + mock_play.assert_called_once_with( + provider="ollama", + model="qwen3:32b", + api_key=None, + difficulty="normal", + verbose=True, + port=9000, + server_url=None, + local=False, + image_version=None, + ) + + @patch("openra_env.cli.commands.cmd_mcp_server") + def test_main_mcp_server(self, mock_mcp): + from openra_env.cli.main import main + with patch("sys.argv", ["openra-rl", "mcp-server", "--port", "9000"]): + main() + mock_mcp.assert_called_once_with(server_url=None, port=9000) + + +# ── MCP Server ────────────────────────────────────────────────────── + +class TestMCPServer: + def test_mcp_server_module_imports(self): + from openra_env.mcp_server import mcp + assert mcp.name == "openra-rl" + + def test_format_dict(self): + from openra_env.mcp_server import _format + result = _format({"key": "value"}) + assert '"key"' in result + assert '"value"' in result + + def test_format_string(self): + from openra_env.mcp_server import _format + assert _format("hello") == "hello" + + def test_server_url_default(self): + from openra_env.mcp_server import _server_url + assert _server_url == "http://localhost:8000" + + def test_all_tools_registered(self): + from openra_env.mcp_server import mcp + # The FastMCP instance should have tools registered + # We can check by looking at the _tool_manager + tools = mcp._tool_manager._tools if hasattr(mcp, '_tool_manager') else {} + # At minimum these core tools should exist + expected = [ + "start_game", "get_game_state", "advance", + "build_unit", "build_structure", "move_units", + "attack_move", "deploy_unit", "surrender", + ] + for name in expected: + assert name in tools, f"Tool {name} not registered" + + def test_tool_count(self): + from openra_env.mcp_server import mcp + tools = mcp._tool_manager._tools if hasattr(mcp, '_tool_manager') else {} + # We have 48 tools defined in the server + assert len(tools) >= 40, f"Expected 40+ tools, got {len(tools)}" diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1cddf2137f3ac8278ea2621f80ee398185b451fe --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,815 @@ +"""Tests for the unified configuration system.""" + +import os +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +import yaml + +from openra_env.config import ( + TOOL_CATEGORIES, + AlertPromptsConfig, + AlertsConfig, + AgentConfig, + CompressionConfig, + GameConfig, + LLMConfig, + OpenRARLConfig, + OpponentConfig, + PlanningConfig, + PromptsConfig, + RewardConfig, + RewardVectorConfig, + ToolCategoriesConfig, + ToolsConfig, + _coerce_value, + _deep_merge, + _set_nested, + load_config, + should_register_tool, +) + + +# ── Default Loading ─────────────────────────────────────────────────── + + +class TestDefaults: + def test_default_config_has_sane_values(self): + cfg = OpenRARLConfig() + assert cfg.game.mod == "ra" + assert cfg.game.grpc_port == 9999 + assert cfg.opponent.bot_type == "easy" + assert cfg.planning.enabled is True + assert cfg.reward.victory == 1.0 + assert cfg.llm.model == "qwen/qwen3-coder-next" + assert cfg.agent.max_time_s == 1800 + + def test_all_tool_categories_enabled_by_default(self): + cfg = OpenRARLConfig() + cats = cfg.tools.categories + for field in ToolCategoriesConfig.model_fields: + assert getattr(cats, field) is True, f"Category {field} should default to True" + + def test_all_alerts_enabled_by_default(self): + cfg = OpenRARLConfig() + for field in AlertsConfig.model_fields: + if field == "max_alerts": + continue # max_alerts is an int, not a bool toggle + assert getattr(cfg.alerts, field) is True, f"Alert {field} should default to True" + + def test_disabled_tools_list_empty_by_default(self): + cfg = OpenRARLConfig() + assert cfg.tools.disabled == [] + + def test_load_config_no_file_returns_defaults(self): + """load_config() with no file and no env vars should return defaults.""" + with _clean_env(): + cfg = load_config(config_path="__nonexistent__.yaml") + assert cfg.game.mod == "ra" + assert cfg.llm.base_url == "https://openrouter.ai/api/v1/chat/completions" + + +# ── YAML Loading ────────────────────────────────────────────────────── + + +class TestYAMLLoading: + def test_load_from_yaml(self): + data = {"game": {"mod": "cnc", "grpc_port": 5555}, "opponent": {"bot_type": "hard"}} + with _temp_yaml(data) as path, _clean_env(): + cfg = load_config(config_path=path) + assert cfg.game.mod == "cnc" + assert cfg.game.grpc_port == 5555 + assert cfg.opponent.bot_type == "hard" + # Unspecified fields keep defaults + assert cfg.game.map_name == "singles.oramap" + + def test_partial_yaml_merges_with_defaults(self): + data = {"reward": {"victory": 5.0}} + with _temp_yaml(data) as path, _clean_env(): + cfg = load_config(config_path=path) + assert cfg.reward.victory == 5.0 + assert cfg.reward.defeat == -1.0 # default preserved + + def test_empty_yaml_returns_defaults(self): + with _temp_yaml({}) as path, _clean_env(): + cfg = load_config(config_path=path) + assert cfg.game.mod == "ra" + + def test_llm_config_from_yaml(self): + data = { + "llm": { + "base_url": "http://localhost:11434/v1/chat/completions", + "model": "llama3.1:70b", + "api_key": "", + "extra_headers": {}, + } + } + with _temp_yaml(data) as path, _clean_env(): + cfg = load_config(config_path=path) + assert cfg.llm.base_url == "http://localhost:11434/v1/chat/completions" + assert cfg.llm.model == "llama3.1:70b" + assert cfg.llm.api_key == "" + assert cfg.llm.extra_headers == {} + + +# ── Environment Variable Precedence ────────────────────────────────── + + +class TestEnvVarPrecedence: + def test_env_var_overrides_yaml(self): + data = {"opponent": {"bot_type": "easy"}} + with _temp_yaml(data) as path: + with _clean_env(BOT_TYPE="hard"): + cfg = load_config(config_path=path) + assert cfg.opponent.bot_type == "hard" + + def test_env_var_overrides_default(self): + with _clean_env(BOT_TYPE="hard"): + cfg = load_config(config_path="__nonexistent__.yaml") + assert cfg.opponent.bot_type == "hard" + + def test_openra_path_env(self): + with _clean_env(OPENRA_PATH="/custom/openra"): + cfg = load_config(config_path="__nonexistent__.yaml") + assert cfg.game.openra_path == "/custom/openra" + + def test_planning_enabled_env(self): + with _clean_env(PLANNING_ENABLED="false"): + cfg = load_config(config_path="__nonexistent__.yaml") + assert cfg.planning.enabled is False + + def test_record_replays_env(self): + with _clean_env(RECORD_REPLAYS="yes"): + cfg = load_config(config_path="__nonexistent__.yaml") + assert cfg.game.record_replays is True + + def test_openrouter_api_key_env(self): + with _clean_env(OPENROUTER_API_KEY="sk-test-123"): + cfg = load_config(config_path="__nonexistent__.yaml") + assert cfg.llm.api_key == "sk-test-123" + + def test_llm_api_key_overrides_openrouter(self): + """LLM_API_KEY should take precedence over OPENROUTER_API_KEY.""" + with _clean_env(OPENROUTER_API_KEY="sk-old", LLM_API_KEY="sk-new"): + cfg = load_config(config_path="__nonexistent__.yaml") + assert cfg.llm.api_key == "sk-new" + + def test_llm_base_url_env(self): + with _clean_env(LLM_BASE_URL="http://localhost:1234/v1/chat/completions"): + cfg = load_config(config_path="__nonexistent__.yaml") + assert cfg.llm.base_url == "http://localhost:1234/v1/chat/completions" + + def test_llm_model_env(self): + with _clean_env(LLM_MODEL="my-local-model"): + cfg = load_config(config_path="__nonexistent__.yaml") + assert cfg.llm.model == "my-local-model" + + def test_max_time_env(self): + with _clean_env(MAX_TIME="3600"): + cfg = load_config(config_path="__nonexistent__.yaml") + assert cfg.agent.max_time_s == 3600 + + +# ── Constructor Override Precedence ─────────────────────────────────── + + +class TestOverridePrecedence: + def test_overrides_beat_yaml(self): + data = {"game": {"mod": "cnc"}} + with _temp_yaml(data) as path, _clean_env(): + cfg = load_config(config_path=path, game={"mod": "d2k"}) + assert cfg.game.mod == "d2k" + + def test_env_beats_overrides(self): + with _clean_env(BOT_TYPE="hard"): + cfg = load_config(config_path="__nonexistent__.yaml", opponent={"bot_type": "easy"}) + assert cfg.opponent.bot_type == "hard" + + def test_cli_overrides_beat_env(self): + """Explicit CLI flags should beat environment variables.""" + with _clean_env(OPENROUTER_MODEL="env-model"): + cfg = load_config( + config_path="__nonexistent__.yaml", + cli_overrides={"llm": {"model": "cli-model"}}, + ) + assert cfg.llm.model == "cli-model" + + def test_cli_overrides_beat_yaml_and_env(self): + data = {"llm": {"model": "yaml-model"}} + with _temp_yaml(data) as path, _clean_env(LLM_MODEL="env-model"): + cfg = load_config( + config_path=path, + cli_overrides={"llm": {"model": "cli-model"}}, + ) + assert cfg.llm.model == "cli-model" + + +# ── Boolean Coercion ────────────────────────────────────────────────── + + +class TestCoercion: + @pytest.mark.parametrize("val,expected", [ + ("true", True), ("True", True), ("TRUE", True), + ("1", True), ("yes", True), ("Yes", True), + ("false", False), ("False", False), ("FALSE", False), + ("0", False), ("no", False), ("No", False), + ]) + def test_bool_coercion(self, val, expected): + assert _coerce_value(val) is expected + + def test_int_coercion(self): + assert _coerce_value("42") == 42 + assert isinstance(_coerce_value("42"), int) + + def test_float_coercion(self): + assert _coerce_value("3.14") == 3.14 + assert isinstance(_coerce_value("3.14"), float) + + def test_string_passthrough(self): + assert _coerce_value("hello") == "hello" + + +# ── Deep Merge ──────────────────────────────────────────────────────── + + +class TestDeepMerge: + def test_simple_merge(self): + base = {"a": 1, "b": 2} + _deep_merge(base, {"b": 3, "c": 4}) + assert base == {"a": 1, "b": 3, "c": 4} + + def test_nested_merge(self): + base = {"game": {"mod": "ra", "port": 9999}} + _deep_merge(base, {"game": {"mod": "cnc"}}) + assert base == {"game": {"mod": "cnc", "port": 9999}} + + def test_override_replaces_non_dict(self): + base = {"a": {"b": 1}} + _deep_merge(base, {"a": "flat"}) + assert base == {"a": "flat"} + + +class TestSetNested: + def test_single_level(self): + d: dict = {} + _set_nested(d, "key", "val") + assert d == {"key": "val"} + + def test_multi_level(self): + d: dict = {} + _set_nested(d, "game.mod", "cnc") + assert d == {"game": {"mod": "cnc"}} + + def test_preserves_siblings(self): + d = {"game": {"mod": "ra", "port": 9999}} + _set_nested(d, "game.mod", "cnc") + assert d == {"game": {"mod": "cnc", "port": 9999}} + + +# ── Tool Filtering ──────────────────────────────────────────────────── + + +class TestToolFiltering: + def test_all_tools_enabled_by_default(self): + cfg = ToolsConfig() + for tool_name in TOOL_CATEGORIES: + assert should_register_tool(tool_name, cfg) is True + + def test_disable_category(self): + cfg = ToolsConfig(categories=ToolCategoriesConfig(knowledge=False)) + assert should_register_tool("lookup_unit", cfg) is False + assert should_register_tool("lookup_building", cfg) is False + assert should_register_tool("lookup_tech_tree", cfg) is False + assert should_register_tool("lookup_faction", cfg) is False + # Other categories unaffected + assert should_register_tool("advance", cfg) is True + assert should_register_tool("move_units", cfg) is True + + def test_disable_individual_tool(self): + cfg = ToolsConfig(disabled=["surrender", "sell_building"]) + assert should_register_tool("surrender", cfg) is False + assert should_register_tool("sell_building", cfg) is False + # Other utility tools still enabled + assert should_register_tool("get_replay_path", cfg) is True + + def test_disabled_list_overrides_category_enable(self): + cfg = ToolsConfig( + categories=ToolCategoriesConfig(movement=True), + disabled=["move_units"], + ) + assert should_register_tool("move_units", cfg) is False + assert should_register_tool("attack_move", cfg) is True + + def test_unknown_tool_defaults_to_enabled(self): + cfg = ToolsConfig() + assert should_register_tool("some_future_tool", cfg) is True + + def test_all_tools_have_categories(self): + """Every tool in TOOL_CATEGORIES should map to a valid category field.""" + valid_categories = set(ToolCategoriesConfig.model_fields.keys()) + for tool_name, category in TOOL_CATEGORIES.items(): + assert category in valid_categories, f"Tool {tool_name} maps to unknown category {category}" + + def test_tool_count(self): + """Verify we have all 48 tools mapped.""" + assert len(TOOL_CATEGORIES) == 48 + + +# ── Planning Sync Validator ─────────────────────────────────────────── + + +class TestPlanningSync: + def test_planning_disabled_auto_disables_planning_tools(self): + cfg = OpenRARLConfig(planning=PlanningConfig(enabled=False)) + assert cfg.tools.categories.planning is False + + def test_planning_enabled_keeps_planning_tools(self): + cfg = OpenRARLConfig(planning=PlanningConfig(enabled=True)) + assert cfg.tools.categories.planning is True + + def test_planning_disabled_via_yaml(self): + data = {"planning": {"enabled": False}} + with _temp_yaml(data) as path, _clean_env(): + cfg = load_config(config_path=path) + assert cfg.planning.enabled is False + assert cfg.tools.categories.planning is False + + +# ── LLM Config ──────────────────────────────────────────────────────── + + +class TestLLMConfig: + def test_local_model_no_key(self): + cfg = LLMConfig( + base_url="http://localhost:11434/v1/chat/completions", + api_key="", + model="llama3.1:70b", + ) + assert cfg.api_key == "" + assert "localhost" in cfg.base_url + + def test_remote_model_with_key(self): + cfg = LLMConfig(api_key="sk-test-123") + assert cfg.api_key == "sk-test-123" + + def test_extra_headers_default(self): + cfg = LLMConfig() + assert "HTTP-Referer" in cfg.extra_headers + assert "X-Title" in cfg.extra_headers + + def test_extra_headers_empty_for_local(self): + cfg = LLMConfig(extra_headers={}) + assert cfg.extra_headers == {} + + def test_temperature_default_none(self): + cfg = LLMConfig() + assert cfg.temperature is None + + def test_temperature_set(self): + cfg = LLMConfig(temperature=0.7) + assert cfg.temperature == 0.7 + + +# ── Alert Config ────────────────────────────────────────────────────── + + +class TestAlertConfig: + def test_disable_specific_alerts(self): + cfg = AlertsConfig(under_attack=False, low_power=False) + assert cfg.under_attack is False + assert cfg.low_power is False + assert cfg.damaged_building is True # others unchanged + + +# ── Backwards Compatibility ─────────────────────────────────────────── + + +class TestBackwardsCompat: + def test_load_config_with_no_args(self): + """Calling load_config() with no args should not raise.""" + with _clean_env(): + cfg = load_config(config_path="__nonexistent__.yaml") + assert isinstance(cfg, OpenRARLConfig) + + def test_reward_config_matches_reward_weights(self): + """RewardConfig fields should match the existing RewardWeights dataclass.""" + from openra_env.reward import RewardWeights + + rw = RewardWeights() + rc = RewardConfig() + assert rc.survival == rw.survival + assert rc.economic_efficiency == rw.economic_efficiency + assert rc.aggression == rw.aggression + assert rc.defense == rw.defense + assert rc.victory == rw.victory + assert rc.defeat == rw.defeat + + +class TestRewardVectorConfig: + """Test reward vector configuration.""" + + def test_enabled_by_default(self): + cfg = RewardVectorConfig() + assert cfg.enabled is True + + def test_default_weights(self): + cfg = RewardVectorConfig() + assert cfg.weights["combat"] == 0.30 + assert cfg.weights["economy"] == 0.15 + assert cfg.weights["outcome"] == 1.00 + assert len(cfg.weights) == 8 + + def test_present_in_root_config(self): + cfg = OpenRARLConfig() + assert hasattr(cfg, "reward_vector") + assert isinstance(cfg.reward_vector, RewardVectorConfig) + assert cfg.reward_vector.enabled is True + + def test_enable_via_yaml(self): + with _clean_env(): + import tempfile + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump({"reward_vector": {"enabled": True}}, f) + f.flush() + cfg = load_config(config_path=f.name) + assert cfg.reward_vector.enabled is True + + def test_custom_weights_via_yaml(self): + with _clean_env(): + import tempfile + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump({"reward_vector": {"enabled": True, "weights": {"combat": 0.5}}}, f) + f.flush() + cfg = load_config(config_path=f.name) + assert cfg.reward_vector.weights["combat"] == 0.5 + + +# ── Validation Errors ───────────────────────────────────────────────── + + +class TestValidation: + def test_invalid_grpc_port_type(self): + with pytest.raises(Exception): # Pydantic ValidationError + GameConfig(grpc_port="not_a_number") + + def test_invalid_reward_weight(self): + with pytest.raises(Exception): + RewardConfig(victory="not_a_float") + + +# ── Helpers ─────────────────────────────────────────────────────────── + +_CONFIG_ENV_VARS = [ + "OPENRA_PATH", "RECORD_REPLAYS", "BOT_TYPE", "AI_SLOT", + "PLANNING_ENABLED", "PLANNING_MAX_TURNS", "PLANNING_MAX_TIME", + "OPENROUTER_API_KEY", "OPENROUTER_MODEL", + "LLM_BASE_URL", "LLM_API_KEY", "LLM_MODEL", + "OPENRA_URL", "MAX_TIME", "LLM_AGENT_LOG", +] + + +class _clean_env: + """Context manager that temporarily clears config-related env vars and sets new ones.""" + + def __init__(self, **overrides): + self._overrides = overrides + self._saved: dict[str, str | None] = {} + + def __enter__(self): + # Save and clear all config env vars + for var in _CONFIG_ENV_VARS: + self._saved[var] = os.environ.pop(var, None) + # Set overrides + for key, val in self._overrides.items(): + os.environ[key] = str(val) + return self + + def __exit__(self, *args): + # Remove overrides + for key in self._overrides: + os.environ.pop(key, None) + # Restore saved values + for var, val in self._saved.items(): + if val is not None: + os.environ[var] = val + + +def _temp_yaml(data: dict): + """Context manager that writes *data* to a temp YAML file and yields its path.""" + import contextlib + + @contextlib.contextmanager + def _ctx(): + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(data, f) + path = f.name + try: + yield path + finally: + Path(path).unlink(missing_ok=True) + + return _ctx() + + +# ── PromptsConfig Tests ────────────────────────────────────────────── + + +class TestPromptsConfig: + """Tests for the PromptsConfig system.""" + + def test_default_prompts_have_values(self): + """PromptsConfig defaults should have non-empty values for key fields.""" + p = PromptsConfig() + assert "end_planning_phase" in p.planning_nudge + assert "tool" in p.no_tool_nudge.lower() + assert "{building}" in p.power_warning + assert "{count}" in p.alerts.idle_army + + def test_prompts_in_root_config(self): + """OpenRARLConfig should have prompts field with defaults.""" + config = OpenRARLConfig() + assert isinstance(config.prompts, PromptsConfig) + assert isinstance(config.prompts.alerts, AlertPromptsConfig) + assert config.prompts.planning_complete == "Planning complete. Game is now live." + + def test_prompts_from_yaml(self): + """Override prompts via config YAML.""" + data = { + "prompts": { + "no_tool_nudge": "Please call a tool now.", + "alerts": { + "low_power": "Power is low: {balance}", + }, + }, + } + with _temp_yaml(data) as path: + config = load_config(config_path=path) + assert config.prompts.no_tool_nudge == "Please call a tool now." + assert config.prompts.alerts.low_power == "Power is low: {balance}" + # Other fields keep defaults + assert "combat units idle" in config.prompts.alerts.idle_army + + def test_alert_template_format(self): + """Alert templates should render with .format().""" + p = AlertPromptsConfig() + result = p.low_power.format(balance="-30") + assert "LOW POWER" in result + assert "-30" in result + + def test_placement_template_format(self): + """Placement templates should render with .format().""" + p = PromptsConfig() + result = p.placement_failed.format(building="powr", reason="no valid position") + assert "powr" in result + assert "no valid position" in result + + def test_planning_prompt_template(self): + """Planning prompt template should accept all expected variables.""" + p = PromptsConfig() + result = p.planning_prompt.format( + max_turns=10, map_name="test", map_width=64, map_height=64, + base_x=10, base_y=10, enemy_x=50, enemy_y=50, + faction="russia", side="Soviet", + opponent_summary="Easy AI", planning_nudge=p.planning_nudge, + ) + assert "10 turns" in result + assert "russia" in result + assert "end_planning_phase" in result + + def test_backward_compat_system_prompt_migration(self): + """agent.system_prompt should migrate to prompts.system_prompt.""" + config = OpenRARLConfig(agent=AgentConfig(system_prompt="My custom prompt")) + assert config.prompts.system_prompt == "My custom prompt" + + def test_prompts_system_prompt_takes_precedence(self): + """prompts.system_prompt should win over agent.system_prompt.""" + config = OpenRARLConfig( + agent=AgentConfig(system_prompt="agent version"), + prompts=PromptsConfig(system_prompt="prompts version"), + ) + assert config.prompts.system_prompt == "prompts version" + + def test_backward_compat_system_prompt_file_migration(self): + """agent.system_prompt_file should migrate to prompts.system_prompt_file.""" + config = OpenRARLConfig(agent=AgentConfig(system_prompt_file="/tmp/test.txt")) + assert config.prompts.system_prompt_file == "/tmp/test.txt" + + def test_env_var_prompts_file(self): + """PROMPTS_FILE env var should set prompts.prompts_file.""" + with patch.dict(os.environ, {"PROMPTS_FILE": "/tmp/prompts.yaml"}, clear=False): + config = load_config(config_path="/nonexistent/config.yaml") + assert config.prompts.prompts_file == "/tmp/prompts.yaml" + + def test_game_start_template(self): + """Game start template should render correctly.""" + p = PromptsConfig() + result = p.game_start.format( + strategy_section="\n\nRush strategy", + briefing="Map: test", + barracks_type="barr", + mcv_note=" Your MCV is unit 42.", + ) + assert "Game started!" in result + assert "Rush strategy" in result + assert "barr" in result + assert "unit 42" in result + + def test_insufficient_funds_template(self): + """Insufficient funds template should render correctly.""" + p = PromptsConfig() + result = p.insufficient_funds.format(available=500, item="3tnk", cost=950) + assert "500" in result + assert "3tnk" in result + assert "950" in result + + def test_build_queued_template(self): + """Build queued template should render correctly.""" + p = PromptsConfig() + result = p.build_queued.format(building="powr", cost=300, ticks=180, seconds=7.2) + assert "powr" in result + assert "300" in result + assert "180" in result + assert "auto-places" in result + + def test_build_unit_queued_template(self): + """Build unit queued template should render correctly.""" + p = PromptsConfig() + result = p.build_unit_queued.format( + count=3, unit="e1", cost=100, ticks_each=60, + ticks_total=180, seconds_total=7.2) + assert "3x" in result + assert "e1" in result + assert "60" in result + assert "180" in result + + def test_build_already_pending_template(self): + """Build already pending template should render correctly.""" + p = PromptsConfig() + result = p.build_already_pending.format(building="powr") + assert "powr" in result + assert "already queued" in result + + def test_max_alerts_default(self): + """AlertsConfig max_alerts should default to 0 (unlimited).""" + cfg = AlertsConfig() + assert cfg.max_alerts == 0 + + +# ── Compression Config ─────────────────────────────────────────────── + + +class TestCompressionConfig: + def test_defaults(self): + c = CompressionConfig() + assert c.include_strategy is True + assert c.include_military is True + assert c.include_production is True + + def test_disable_strategy(self): + c = CompressionConfig(include_strategy=False) + assert c.include_strategy is False + + def test_llm_compression_strategy_default(self): + llm = LLMConfig() + assert llm.compression_strategy == "sliding_window" + + def test_llm_compression_trigger_default(self): + llm = LLMConfig() + assert llm.compression_trigger == 0 + + def test_compression_strategy_none(self): + llm = LLMConfig(compression_strategy="none") + assert llm.compression_strategy == "none" + + def test_compression_trigger_custom(self): + llm = LLMConfig(compression_trigger=60) + assert llm.compression_trigger == 60 + + def test_prompts_compression_field(self): + p = PromptsConfig() + assert isinstance(p.compression, CompressionConfig) + assert p.compression.include_strategy is True + + def test_move_eta_template(self): + p = PromptsConfig() + result = p.move_eta.format(ticks=183, seconds=7.3) + assert "183" in result + assert "7.3" in result + + def test_full_config_compression_yaml(self): + """Compression fields round-trip through YAML config loading.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump({ + "llm": { + "compression_strategy": "none", + "compression_trigger": 60, + "keep_last_messages": 20, + }, + "prompts": { + "compression": { + "include_strategy": False, + "include_military": True, + "include_production": False, + } + } + }, f) + f.flush() + cfg = load_config(f.name) + os.unlink(f.name) + assert cfg.llm.compression_strategy == "none" + assert cfg.llm.compression_trigger == 60 + assert cfg.llm.keep_last_messages == 20 + assert cfg.prompts.compression.include_strategy is False + assert cfg.prompts.compression.include_production is False + assert cfg.prompts.compression.include_military is True + + +# ── Opponent Config ────────────────────────────────────────────────── + + +class TestOpponentConfig: + def test_default_spawns_enemy(self): + """Default opponent config spawns an enemy in Multi0.""" + cfg = OpponentConfig() + assert cfg.ai_slot == "Multi0" + assert cfg.bot_type == "easy" + + def test_disable_enemy_via_empty_slot(self): + cfg = OpponentConfig(ai_slot="") + assert cfg.ai_slot == "" + + def test_custom_bot_type(self): + cfg = OpponentConfig(bot_type="hard") + assert cfg.bot_type == "hard" + + +# ── Bot Type Mapping ───────────────────────────────────────────────── + + +class TestBotTypeMapping: + def test_beginner_maps_to_beginner(self): + from openra_env.server.openra_process import BOT_TYPE_MAP + assert BOT_TYPE_MAP["beginner"] == "beginner" + + def test_easy_maps_to_easy(self): + from openra_env.server.openra_process import BOT_TYPE_MAP + assert BOT_TYPE_MAP["easy"] == "easy" + + def test_medium_maps_to_medium(self): + from openra_env.server.openra_process import BOT_TYPE_MAP + assert BOT_TYPE_MAP["medium"] == "medium" + + def test_hard_maps_to_normal(self): + from openra_env.server.openra_process import BOT_TYPE_MAP + assert BOT_TYPE_MAP["hard"] == "normal" + + def test_brutal_maps_to_rush(self): + from openra_env.server.openra_process import BOT_TYPE_MAP + assert BOT_TYPE_MAP["brutal"] == "rush" + + def test_raw_names_pass_through(self): + from openra_env.server.openra_process import BOT_TYPE_MAP + for raw in ["rush", "normal", "turtle", "naval", "beginner", "easy", "medium"]: + assert BOT_TYPE_MAP.get(raw, raw) == raw + + def test_build_command_maps_hard(self): + from openra_env.server.openra_process import OpenRAConfig, OpenRAProcessManager + openra_path = str(Path(__file__).parent.parent / "OpenRA") + config = OpenRAConfig(openra_path=openra_path, bot_type="hard") + manager = OpenRAProcessManager(config) + cmd = manager._build_command() + bots_arg = [a for a in cmd if "Launch.Bots" in a][0] + assert "normal" in bots_arg + assert "hard" not in bots_arg + + def test_build_command_maps_brutal(self): + from openra_env.server.openra_process import OpenRAConfig, OpenRAProcessManager + openra_path = str(Path(__file__).parent.parent / "OpenRA") + config = OpenRAConfig(openra_path=openra_path, bot_type="brutal") + manager = OpenRAProcessManager(config) + cmd = manager._build_command() + bots_arg = [a for a in cmd if "Launch.Bots" in a][0] + assert "rush" in bots_arg + assert "brutal" not in bots_arg + + def test_build_command_no_enemy_with_empty_slot(self): + from openra_env.server.openra_process import OpenRAConfig, OpenRAProcessManager + openra_path = str(Path(__file__).parent.parent / "OpenRA") + config = OpenRAConfig(openra_path=openra_path, ai_slot="") + manager = OpenRAProcessManager(config) + cmd = manager._build_command() + bots_arg = [a for a in cmd if "Launch.Bots" in a][0] + assert bots_arg == "Launch.Bots=Multi1:rl-agent" + + def test_default_config_spawns_enemy(self): + from openra_env.server.openra_process import OpenRAConfig, OpenRAProcessManager + openra_path = str(Path(__file__).parent.parent / "OpenRA") + config = OpenRAConfig(openra_path=openra_path) + manager = OpenRAProcessManager(config) + cmd = manager._build_command() + bots_arg = [a for a in cmd if "Launch.Bots" in a][0] + # Default should include enemy (Multi0:normal) + assert "Multi0" in bots_arg + assert "normal" in bots_arg diff --git a/tests/test_environment.py b/tests/test_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8bc9786d60b6e7d6acf34ead0c562ecbd0873f --- /dev/null +++ b/tests/test_environment.py @@ -0,0 +1,286 @@ +"""Tests for OpenRAEnvironment using mocked bridge and process manager.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openra_env.generated import rl_bridge_pb2 +from openra_env.models import ActionType, CommandModel, OpenRAAction +from openra_env.server.openra_environment import OpenRAEnvironment + + +def _make_proto_observation(tick=0, cash=1000, done=False, result=""): + """Create a minimal protobuf GameObservation for testing.""" + obs = rl_bridge_pb2.GameObservation() + obs.tick = tick + obs.economy.cash = cash + obs.economy.ore = 0 + obs.economy.power_provided = 100 + obs.economy.power_drained = 50 + obs.economy.resource_capacity = 2000 + obs.economy.harvester_count = 1 + obs.military.units_killed = 0 + obs.military.units_lost = 0 + obs.military.buildings_killed = 0 + obs.military.buildings_lost = 0 + obs.military.army_value = 500 + obs.military.active_unit_count = 3 + obs.map_info.width = 64 + obs.map_info.height = 64 + obs.map_info.map_name = "Test Map" + obs.done = done + obs.result = result + obs.reward = 0.0 + + # Add a unit + unit = obs.units.add() + unit.actor_id = 1 + unit.type = "e1" + unit.pos_x = 100 + unit.pos_y = 200 + unit.cell_x = 4 + unit.cell_y = 8 + unit.hp_percent = 1.0 + unit.is_idle = True + unit.owner = "Player" + + # Add a building + bldg = obs.buildings.add() + bldg.actor_id = 10 + bldg.type = "powr" + bldg.pos_x = 50 + bldg.pos_y = 50 + bldg.hp_percent = 1.0 + bldg.owner = "Player" + bldg.is_powered = True + + return obs + + +class TestOpenRAEnvironmentReset: + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_reset_returns_observation(self, MockBridge, MockProcess): + mock_bridge = MockBridge.return_value + mock_bridge.close = AsyncMock() + mock_bridge.wait_for_ready = AsyncMock(return_value=True) + mock_bridge.start_session = AsyncMock(return_value=_make_proto_observation(tick=0)) + mock_bridge.session_started = False + # Mock get_state to return a GameState proto + mock_game_state = MagicMock() + mock_game_state.tick = 0 + mock_game_state.player_faction = "england" + mock_game_state.enemy_faction = "russia" + mock_bridge.get_state = AsyncMock(return_value=mock_game_state) + + mock_process = MockProcess.return_value + mock_process.kill = MagicMock() + mock_process.launch = MagicMock(return_value=12345) + + env = OpenRAEnvironment(openra_path="/fake/path") + env._bridge = mock_bridge + env._process = mock_process + + obs = env.reset() + + # reset() now returns a minimal observation (game is paused, + # session not yet started). Full obs available after session starts. + assert obs.tick == 0 + assert obs.economy.cash == 0 # Minimal obs — no economy data yet + mock_process.kill.assert_called_once() + mock_process.launch.assert_called_once() + # start_session should NOT be called during reset (deferred) + mock_bridge.start_session.assert_not_called() + + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_reset_with_seed(self, MockBridge, MockProcess): + mock_bridge = MockBridge.return_value + mock_bridge.close = AsyncMock() + mock_bridge.wait_for_ready = AsyncMock(return_value=True) + mock_bridge.start_session = AsyncMock(return_value=_make_proto_observation()) + mock_bridge.session_started = False + mock_bridge.get_state = AsyncMock(return_value=MagicMock(tick=0, player_faction="", enemy_faction="")) + + mock_process = MockProcess.return_value + mock_process.kill = MagicMock() + mock_process.launch = MagicMock(return_value=12345) + + env = OpenRAEnvironment(openra_path="/fake/path") + env._bridge = mock_bridge + env._process = mock_process + + env.reset(seed=42) + assert env._config.seed == 42 + + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_reset_with_episode_id(self, MockBridge, MockProcess): + mock_bridge = MockBridge.return_value + mock_bridge.close = AsyncMock() + mock_bridge.wait_for_ready = AsyncMock(return_value=True) + mock_bridge.start_session = AsyncMock(return_value=_make_proto_observation()) + mock_bridge.session_started = False + mock_bridge.get_state = AsyncMock(return_value=MagicMock(tick=0, player_faction="", enemy_faction="")) + + mock_process = MockProcess.return_value + mock_process.kill = MagicMock() + mock_process.launch = MagicMock(return_value=12345) + + env = OpenRAEnvironment(openra_path="/fake/path") + env._bridge = mock_bridge + env._process = mock_process + + env.reset(episode_id="custom-ep-001") + assert env.state.episode_id == "custom-ep-001" + + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_reset_raises_if_bridge_not_ready(self, MockBridge, MockProcess): + mock_bridge = MockBridge.return_value + mock_bridge.close = AsyncMock() + mock_bridge.wait_for_ready = AsyncMock(return_value=False) + + mock_process = MockProcess.return_value + mock_process.kill = MagicMock() + mock_process.launch = MagicMock() + + env = OpenRAEnvironment(openra_path="/fake/path") + env._bridge = mock_bridge + env._process = mock_process + + with pytest.raises(RuntimeError, match="gRPC bridge failed to start"): + env.reset() + + +class TestOpenRAEnvironmentStep: + def _setup_env(self, MockBridge, MockProcess): + mock_bridge = MockBridge.return_value + mock_bridge.close = AsyncMock() + mock_bridge.wait_for_ready = AsyncMock(return_value=True) + mock_bridge.start_session = AsyncMock(return_value=_make_proto_observation(tick=0)) + + mock_process = MockProcess.return_value + mock_process.kill = MagicMock() + mock_process.launch = MagicMock(return_value=12345) + + env = OpenRAEnvironment(openra_path="/fake/path") + env._bridge = mock_bridge + env._process = mock_process + return env, mock_bridge, mock_process + + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_step_returns_observation(self, MockBridge, MockProcess): + env, mock_bridge, _ = self._setup_env(MockBridge, MockProcess) + env.reset() + + mock_bridge.step = AsyncMock(return_value=_make_proto_observation(tick=10, cash=1500)) + + action = OpenRAAction(commands=[CommandModel(action=ActionType.NO_OP)]) + obs = env.step(action) + + assert obs.tick == 10 + assert obs.economy.cash == 1500 + assert env.state.step_count == 1 + assert env.state.game_tick == 10 + + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_step_increments_step_count(self, MockBridge, MockProcess): + env, mock_bridge, _ = self._setup_env(MockBridge, MockProcess) + env.reset() + + mock_bridge.step = AsyncMock(return_value=_make_proto_observation(tick=10)) + action = OpenRAAction(commands=[CommandModel(action=ActionType.NO_OP)]) + + env.step(action) + assert env.state.step_count == 1 + + mock_bridge.step = AsyncMock(return_value=_make_proto_observation(tick=20)) + env.step(action) + assert env.state.step_count == 2 + + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_step_with_multiple_commands(self, MockBridge, MockProcess): + env, mock_bridge, _ = self._setup_env(MockBridge, MockProcess) + env.reset() + + mock_bridge.step = AsyncMock(return_value=_make_proto_observation(tick=10)) + + action = OpenRAAction(commands=[ + CommandModel(action=ActionType.MOVE, actor_id=1, target_x=10, target_y=20), + CommandModel(action=ActionType.BUILD, item_type="powr"), + ]) + obs = env.step(action) + assert obs.tick == 10 + + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_step_terminal_observation(self, MockBridge, MockProcess): + env, mock_bridge, _ = self._setup_env(MockBridge, MockProcess) + env.reset() + + mock_bridge.step = AsyncMock( + return_value=_make_proto_observation(tick=1000, done=True, result="win") + ) + + action = OpenRAAction(commands=[CommandModel(action=ActionType.NO_OP)]) + obs = env.step(action) + + assert obs.done is True + assert obs.result == "win" + assert obs.reward > 0 # Should include victory reward + + +class TestOpenRAEnvironmentState: + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_initial_state(self, MockBridge, MockProcess): + env = OpenRAEnvironment(openra_path="/fake/path") + state = env.state + assert state.step_count == 0 + assert state.game_tick == 0 + + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_state_after_reset(self, MockBridge, MockProcess): + mock_bridge = MockBridge.return_value + mock_bridge.close = AsyncMock() + mock_bridge.wait_for_ready = AsyncMock(return_value=True) + mock_bridge.start_session = AsyncMock(return_value=_make_proto_observation()) + + mock_process = MockProcess.return_value + mock_process.kill = MagicMock() + mock_process.launch = MagicMock() + + env = OpenRAEnvironment(openra_path="/fake/path", map_name="test_map") + env._bridge = mock_bridge + env._process = mock_process + + env.reset(episode_id="ep-001") + + assert env.state.episode_id == "ep-001" + assert env.state.map_name == "test_map" + assert env.state.step_count == 0 + + +class TestOpenRAEnvironmentClose: + @patch("openra_env.server.openra_environment.OpenRAProcessManager") + @patch("openra_env.server.openra_environment.BridgeClient") + def test_close_cleans_up(self, MockBridge, MockProcess): + mock_bridge = MockBridge.return_value + mock_bridge.close = AsyncMock() + + mock_process = MockProcess.return_value + mock_process.kill = MagicMock() + + env = OpenRAEnvironment(openra_path="/fake/path") + env._bridge = mock_bridge + env._process = mock_process + + env.close() + + mock_bridge.close.assert_called_once() + mock_process.kill.assert_called_once() diff --git a/tests/test_llm_agent.py b/tests/test_llm_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..37aec1460b0465b38100e0301c0b7ee502997251 --- /dev/null +++ b/tests/test_llm_agent.py @@ -0,0 +1,248 @@ +"""Tests for llm_agent helper functions.""" + +import pytest + +from openra_env.agent import _bench_export_policy, _format_llm_api_error, _sanitize_messages +from openra_env.config import LLMConfig + + +class TestSanitizeMessages: + """Tests for _sanitize_messages — merges consecutive same-role messages.""" + + def test_empty(self): + assert _sanitize_messages([]) == [] + + def test_no_merge_needed(self): + msgs = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + result = _sanitize_messages(msgs) + assert len(result) == 3 + assert [m["role"] for m in result] == ["system", "user", "assistant"] + + def test_consecutive_user_merged(self): + msgs = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "first"}, + {"role": "user", "content": "second"}, + ] + result = _sanitize_messages(msgs) + assert len(result) == 2 + assert result[1]["role"] == "user" + assert "first" in result[1]["content"] + assert "second" in result[1]["content"] + + def test_three_consecutive_user_merged(self): + msgs = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "a"}, + {"role": "user", "content": "b"}, + {"role": "user", "content": "c"}, + ] + result = _sanitize_messages(msgs) + assert len(result) == 2 + assert result[1]["content"] == "a\n\nb\n\nc" + + def test_does_not_mutate_original(self): + msgs = [ + {"role": "user", "content": "first"}, + {"role": "user", "content": "second"}, + ] + _sanitize_messages(msgs) + # Original messages should be untouched + assert msgs[0]["content"] == "first" + assert msgs[1]["content"] == "second" + assert len(msgs) == 2 + + def test_mixed_roles_preserved(self): + msgs = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "u1"}, + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "u2"}, + {"role": "user", "content": "u3"}, + {"role": "assistant", "content": "a2"}, + ] + result = _sanitize_messages(msgs) + assert [m["role"] for m in result] == ["system", "user", "assistant", "user", "assistant"] + assert result[3]["content"] == "u2\n\nu3" + + def test_tool_then_user_gets_bridge_assistant(self): + """Mistral requires tool → assistant → user, not tool → user.""" + msgs = [ + {"role": "assistant", "content": "", "tool_calls": [{"id": "1"}]}, + {"role": "tool", "content": "result1", "tool_call_id": "1"}, + {"role": "user", "content": "briefing"}, + ] + result = _sanitize_messages(msgs) + assert len(result) == 4 + assert [m["role"] for m in result] == ["assistant", "tool", "assistant", "user"] + assert result[2]["content"] # bridge message is non-empty + + def test_tool_then_assistant_no_extra_bridge(self): + """When tool → assistant already exists, no bridge is inserted.""" + msgs = [ + {"role": "assistant", "content": "", "tool_calls": [{"id": "1"}]}, + {"role": "tool", "content": "result1", "tool_call_id": "1"}, + {"role": "assistant", "content": "Got the result."}, + ] + result = _sanitize_messages(msgs) + assert len(result) == 3 + assert [m["role"] for m in result] == ["assistant", "tool", "assistant"] + + def test_real_world_scenario(self): + """Simulates: nudge (user) → next turn briefing (user) → should merge.""" + msgs = [ + {"role": "system", "content": "You are playing Red Alert."}, + {"role": "user", "content": "STRATEGIC BRIEFING: ..."}, + {"role": "assistant", "content": "I will deploy the MCV."}, + {"role": "user", "content": "Continue playing. Use game tools."}, + {"role": "user", "content": "TURN BRIEFING: Funds 5000, ..."}, + ] + result = _sanitize_messages(msgs) + assert len(result) == 4 + roles = [m["role"] for m in result] + assert roles == ["system", "user", "assistant", "user"] + assert "Continue playing" in result[3]["content"] + assert "TURN BRIEFING" in result[3]["content"] + + def test_game_loop_tool_then_briefing(self): + """Real scenario: tool results from turn N, then briefing user msg for turn N+1.""" + msgs = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "initial briefing"}, + {"role": "assistant", "content": "", "tool_calls": [{"id": "c1"}]}, + {"role": "tool", "content": '{"ok": true}', "tool_call_id": "c1"}, + {"role": "user", "content": "TURN BRIEFING: tick 500"}, + ] + result = _sanitize_messages(msgs) + roles = [m["role"] for m in result] + assert roles == ["system", "user", "assistant", "tool", "assistant", "user"] + assert result[4]["role"] == "assistant" # bridge + assert result[5]["content"] == "TURN BRIEFING: tick 500" + + +class TestFormatLLMApiError: + """Tests for provider error mapping helper.""" + + def test_openrouter_tool_route_error_has_actionable_hint(self): + cfg = LLMConfig( + base_url="https://openrouter.ai/api/v1/chat/completions", + model="liquid/lfm-2.5-1.2b-thinking:free", + ) + msg = _format_llm_api_error( + 404, + ( + '{"error":{"message":"No endpoints found that support tool use.' + ' To learn more about provider routing","code":404}}' + ), + cfg, + ) + assert "supports tool calling" in msg + assert "OpenRA-RL requires tool-calling models" in msg + assert "not ':free'" in msg + + def test_auth_error_message_preserved(self): + cfg = LLMConfig(model="foo/bar") + msg = _format_llm_api_error(401, "unauthorized", cfg) + assert "Authentication failed (401)" in msg + + +class TestToolCallingPreflight: + """Tests for startup preflight capability checks.""" + + @pytest.mark.asyncio + async def test_openrouter_unsupported_tools_is_blocked(self, monkeypatch): + from openra_env import agent as agent_mod + + cfg = LLMConfig( + base_url="https://openrouter.ai/api/v1/chat/completions", + model="liquid/lfm-2.5-1.2b-thinking:free", + ) + + async def _fake_chat_completion(*args, **kwargs): + raise RuntimeError("No endpoints found that support tool use.") + + monkeypatch.setattr(agent_mod, "chat_completion", _fake_chat_completion) + ok, err = await agent_mod._preflight_tool_calling_support(cfg) + assert ok is False + assert "support tool use" in err.lower() + + @pytest.mark.asyncio + async def test_non_openrouter_skips_preflight_call(self, monkeypatch): + from openra_env import agent as agent_mod + + cfg = LLMConfig( + base_url="http://localhost:11434/v1/chat/completions", + model="qwen3:4b", + ) + called = False + + async def _fake_chat_completion(*args, **kwargs): + nonlocal called + called = True + return {} + + monkeypatch.setattr(agent_mod, "chat_completion", _fake_chat_completion) + ok, err = await agent_mod._preflight_tool_calling_support(cfg) + assert ok is True + assert err == "" + assert called is False + + +class TestBenchExportPolicy: + """Tests for when bench export/upload is allowed.""" + + def test_always_exports_locally_even_on_error(self): + should_export, should_upload, reason = _bench_export_policy(encountered_agent_error=True) + assert should_export is True + assert should_upload is False + assert "runtime [error]" in reason.lower() + + def test_allow_export_and_upload_when_no_runtime_error(self): + should_export, should_upload, reason = _bench_export_policy(encountered_agent_error=False) + assert should_export is True + assert should_upload is True + assert reason == "" + + +class TestRunAgentPreflightAbort: + """Regression tests for tool-capability preflight abort path.""" + + @pytest.mark.asyncio + async def test_openrouter_tool_capability_failure_aborts_before_reset(self, monkeypatch, capsys): + from types import SimpleNamespace + from openra_env import agent as agent_mod + + cfg = SimpleNamespace( + agent=SimpleNamespace(server_url="http://localhost:8000", max_turns=0, max_time_s=1800), + llm=LLMConfig( + base_url="https://openrouter.ai/api/v1/chat/completions", + model="liquid/lfm-2.5-1.2b-thinking:free", + request_timeout_s=120.0, + ), + ) + + client_constructed = False + + class _FailIfConstructedClient: + def __init__(self, *args, **kwargs): + nonlocal client_constructed + client_constructed = True + raise AssertionError("OpenRAMCPClient should not be constructed on preflight failure") + + async def _fake_preflight(_llm_config): + return False, "No endpoints found that support tool use." + + monkeypatch.setattr(agent_mod, "_preflight_tool_calling_support", _fake_preflight) + monkeypatch.setattr(agent_mod, "OpenRAMCPClient", _FailIfConstructedClient) + + await agent_mod.run_agent(cfg, verbose=False) + + out = capsys.readouterr().out + assert "Checking model route for tool-calling support..." in out + assert "Aborting before game launch (no match started)." in out + assert "Resetting environment (launching OpenRA)..." not in out + assert client_constructed is False diff --git a/tests/test_mcp_tools.py b/tests/test_mcp_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a90cd2c8a2efab88cd85436843e23490044a70 --- /dev/null +++ b/tests/test_mcp_tools.py @@ -0,0 +1,4918 @@ +"""Tests for MCP tool registration, game data module, and environment integration.""" + +import asyncio +from pathlib import Path + +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from openra_env.game_data import ( + RA_BUILDINGS, + RA_FACTIONS, + RA_TECH_TREE, + RA_UNITS, + get_all_building_types, + get_all_buildings_for_side, + get_all_unit_types, + get_all_units_for_side, + get_building_stats, + get_faction_info, + get_tech_tree, + get_unit_stats, +) +from openra_env.models import ActionType, CommandModel, OpenRAAction +from openra_env.server.openra_environment import OpenRAEnvironment + + +# ─── Game Data Tests ────────────────────────────────────────────────────────── + + +class TestUnitData: + def test_all_units_have_required_fields(self): + required = {"name", "category", "cost", "hp", "speed", "armor", "side", "prerequisites", "description"} + for unit_type, data in RA_UNITS.items(): + missing = required - set(data.keys()) + assert not missing, f"Unit '{unit_type}' missing fields: {missing}" + + def test_unit_costs_positive(self): + for unit_type, data in RA_UNITS.items(): + assert data["cost"] > 0, f"Unit '{unit_type}' has non-positive cost" + + def test_unit_hp_positive(self): + for unit_type, data in RA_UNITS.items(): + assert data["hp"] > 0, f"Unit '{unit_type}' has non-positive HP" + + def test_unit_sides_valid(self): + valid_sides = {"both", "allied", "soviet"} + for unit_type, data in RA_UNITS.items(): + assert data["side"] in valid_sides, f"Unit '{unit_type}' has invalid side: {data['side']}" + + def test_unit_categories_valid(self): + valid = {"infantry", "vehicle", "aircraft", "ship"} + for unit_type, data in RA_UNITS.items(): + assert data["category"] in valid, f"Unit '{unit_type}' has invalid category" + + def test_known_units_exist(self): + for key in ["e1", "e3", "1tnk", "3tnk", "harv", "mcv", "mig", "heli"]: + assert key in RA_UNITS, f"Expected unit '{key}' not found" + + def test_get_unit_stats_found(self): + result = get_unit_stats("e1") + assert result is not None + assert result["name"] == "Rifle Infantry" + assert result["cost"] == 100 + + def test_get_unit_stats_not_found(self): + assert get_unit_stats("nonexistent") is None + + def test_get_unit_stats_case_insensitive(self): + assert get_unit_stats("E1") is not None # Lowercased internally + assert get_unit_stats("e1") is not None + + +class TestBuildingData: + def test_all_buildings_have_required_fields(self): + required = {"name", "cost", "hp", "power", "side", "prerequisites", "produces", "description"} + for bldg_type, data in RA_BUILDINGS.items(): + missing = required - set(data.keys()) + assert not missing, f"Building '{bldg_type}' missing fields: {missing}" + + def test_building_costs_positive(self): + for bldg_type, data in RA_BUILDINGS.items(): + assert data["cost"] > 0, f"Building '{bldg_type}' has non-positive cost" + + def test_building_sides_valid(self): + valid_sides = {"both", "allied", "soviet"} + for bldg_type, data in RA_BUILDINGS.items(): + assert data["side"] in valid_sides, f"Building '{bldg_type}' has invalid side" + + def test_known_buildings_exist(self): + for key in ["fact", "powr", "barr", "tent", "proc", "weap", "dome"]: + assert key in RA_BUILDINGS, f"Expected building '{key}' not found" + + def test_power_plants_provide_power(self): + assert RA_BUILDINGS["powr"]["power"] > 0 + assert RA_BUILDINGS["apwr"]["power"] > 0 + + def test_production_buildings_consume_power(self): + for key in ["barr", "tent", "weap"]: + assert RA_BUILDINGS[key]["power"] < 0 + + def test_get_building_stats_found(self): + result = get_building_stats("powr") + assert result is not None + assert result["name"] == "Power Plant" + assert result["power"] == 100 + + def test_get_building_stats_not_found(self): + assert get_building_stats("nonexistent") is None + + +class TestTechTree: + def test_both_sides_present(self): + assert "soviet" in RA_TECH_TREE + assert "allied" in RA_TECH_TREE + + def test_soviet_starts_with_power(self): + assert RA_TECH_TREE["soviet"][0] == "powr" + + def test_allied_starts_with_power(self): + assert RA_TECH_TREE["allied"][0] == "powr" + + def test_all_tech_tree_entries_are_valid_buildings(self): + for side, entries in RA_TECH_TREE.items(): + for entry in entries: + assert entry in RA_BUILDINGS, f"Tech tree entry '{entry}' not in RA_BUILDINGS" + + def test_get_tech_tree_by_side(self): + result = get_tech_tree("soviet") + assert "soviet" in result + assert "allied" not in result + + def test_get_tech_tree_by_faction(self): + result = get_tech_tree("russia") + assert "soviet" in result + + def test_get_tech_tree_all(self): + result = get_tech_tree() + assert "soviet" in result + assert "allied" in result + + +class TestFactionData: + def test_all_factions_present(self): + for faction in ["england", "france", "germany", "russia", "ukraine"]: + assert faction in RA_FACTIONS + + def test_faction_sides_valid(self): + for faction, data in RA_FACTIONS.items(): + assert data["side"] in {"allied", "soviet"} + + def test_allied_factions(self): + for f in ["england", "france", "germany"]: + assert RA_FACTIONS[f]["side"] == "allied" + + def test_soviet_factions(self): + for f in ["russia", "ukraine"]: + assert RA_FACTIONS[f]["side"] == "soviet" + + def test_get_faction_info_returns_units_and_buildings(self): + result = get_faction_info("russia") + assert result is not None + assert "available_units" in result + assert "available_buildings" in result + assert len(result["available_units"]) > 5 + assert len(result["available_buildings"]) > 5 + + def test_get_faction_info_not_found(self): + assert get_faction_info("nonexistent") is None + + def test_faction_specific_units(self): + russia = get_faction_info("russia") + assert "ttnk" in russia["available_units"] + + germany = get_faction_info("germany") + assert "ctnk" in germany["available_units"] + + def test_get_all_unit_types(self): + types = get_all_unit_types() + assert len(types) > 10 + assert "e1" in types + assert types == sorted(types) # Should be sorted + + def test_get_all_building_types(self): + types = get_all_building_types() + assert len(types) > 10 + assert "powr" in types + assert types == sorted(types) + + +class TestBulkHelpers: + def test_get_all_units_for_soviet(self): + units = get_all_units_for_side("soviet") + assert len(units) > 10 + assert "e1" in units # both sides + assert "3tnk" in units # soviet only + assert "1tnk" not in units # allied only + for utype, data in units.items(): + assert "cost" in data + assert "hp" in data + + def test_get_all_units_for_allied(self): + units = get_all_units_for_side("allied") + assert len(units) > 10 + assert "e1" in units # both sides + assert "1tnk" in units # allied only + assert "3tnk" not in units # soviet only + + def test_get_all_buildings_for_soviet(self): + buildings = get_all_buildings_for_side("soviet") + assert len(buildings) > 10 + assert "powr" in buildings # both sides + assert "barr" in buildings # soviet only + assert "tent" not in buildings # allied only + + def test_get_all_buildings_for_allied(self): + buildings = get_all_buildings_for_side("allied") + assert len(buildings) > 10 + assert "powr" in buildings + assert "tent" in buildings + assert "barr" not in buildings + + +# ─── MCP Tool Registration Tests ───────────────────────────────────────────── + + +class TestMCPToolRegistration: + @pytest.fixture + def env(self): + """Create an OpenRAEnvironment instance (doesn't launch OpenRA).""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + # Manually initialize just the MCP parts + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + env._last_obs = None + env._register_tools(mcp) + return env, mcp + + def test_tools_registered(self, env): + _, mcp = env + from tests.conftest import get_tool_names + tool_names = get_tool_names(mcp) + + # Read tools + assert "get_game_state" in tool_names + assert "get_economy" in tool_names + assert "get_units" in tool_names + assert "get_buildings" in tool_names + assert "get_enemies" in tool_names + assert "get_production" in tool_names + assert "get_map_info" in tool_names + assert "get_exploration_status" in tool_names + + # Knowledge tools + assert "lookup_unit" in tool_names + assert "lookup_building" in tool_names + assert "lookup_tech_tree" in tool_names + assert "lookup_faction" in tool_names + + # Action tools + assert "advance" in tool_names + assert "move_units" in tool_names + assert "attack_move" in tool_names + assert "attack_target" in tool_names + assert "stop_units" in tool_names + assert "build_unit" in tool_names + assert "build_structure" in tool_names + assert "place_building" in tool_names + assert "deploy_unit" in tool_names + assert "sell_building" in tool_names + assert "repair_building" in tool_names + assert "set_rally_point" in tool_names + assert "guard_target" in tool_names + assert "set_stance" in tool_names + assert "harvest" in tool_names + assert "power_down" in tool_names + assert "set_primary" in tool_names + assert "cancel_production" in tool_names + assert "get_replay_path" in tool_names + + def test_tool_count(self, env): + _, mcp = env + from tests.conftest import get_tool_count + count = get_tool_count(mcp) + # 7 read + 1 exploration + 1 terrain + 4 knowledge + 3 bulk + 4 planning + 27 action + 1 replay = 48 + assert count == 48, f"Expected 48 tools, got {count}" + + +class TestMCPReadTools: + """Test read tools return cached observation data.""" + + @pytest.fixture + def env_with_obs(self): + """Create env with a cached observation.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + env._register_tools(mcp) + + # Planning phase attributes (required by get_game_state) + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = True + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + + env._last_obs = { + "tick": 100, + "done": False, + "result": "", + "economy": { + "cash": 5000, + "ore": 1000, + "power_provided": 200, + "power_drained": 80, + "resource_capacity": 5000, + "harvester_count": 2, + }, + "military": { + "units_killed": 3, + "units_lost": 1, + "buildings_killed": 0, + "buildings_lost": 0, + "army_value": 3500, + "active_unit_count": 5, + }, + "units": [ + { + "actor_id": 10, + "type": "1tnk", + "pos_x": 1000, + "pos_y": 2000, + "cell_x": 10, + "cell_y": 20, + "hp_percent": 0.8, + "is_idle": True, + "current_activity": "", + "owner": "Multi0", + "can_attack": True, + "facing": 0, + "experience_level": 0, + "stance": 3, + "speed": 113, + "attack_range": 5120, + "passenger_count": -1, + "is_building": False, + }, + ], + "buildings": [ + { + "actor_id": 1, + "type": "powr", + "pos_x": 500, + "pos_y": 500, + "hp_percent": 1.0, + "owner": "Multi0", + "is_producing": False, + "production_progress": 0.0, + "producing_item": "", + "is_powered": True, + "is_repairing": False, + "sell_value": 150, + "rally_x": -1, + "rally_y": -1, + "power_amount": 100, + "can_produce": [], + "cell_x": 5, + "cell_y": 5, + }, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["e1", "e3"], + } + return env, mcp + + def test_get_game_state_returns_summary(self, env_with_obs): + env, mcp = env_with_obs + # Get the tool function directly + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + assert result["tick"] == 100 + assert result["own_units"] == 1 + assert result["own_buildings"] == 1 + + def test_get_economy_returns_economy(self, env_with_obs): + env, mcp = env_with_obs + tool = mcp._tool_manager._tools["get_economy"] + result = tool.fn() + assert result["cash"] == 5000 + assert result["power_provided"] == 200 + + def test_get_units_returns_unit_list(self, env_with_obs): + env, mcp = env_with_obs + tool = mcp._tool_manager._tools["get_units"] + result = tool.fn() + assert len(result) == 1 + assert result[0]["type"] == "1tnk" + assert result[0]["actor_id"] == 10 + + def test_get_buildings_returns_building_list(self, env_with_obs): + env, mcp = env_with_obs + tool = mcp._tool_manager._tools["get_buildings"] + result = tool.fn() + assert len(result) == 1 + assert result[0]["type"] == "powr" + assert result[0]["power_amount"] == 100 + + def test_get_enemies_empty(self, env_with_obs): + env, mcp = env_with_obs + tool = mcp._tool_manager._tools["get_enemies"] + result = tool.fn() + assert result["units"] == [] + assert result["buildings"] == [] + + def test_get_production_empty(self, env_with_obs): + env, mcp = env_with_obs + tool = mcp._tool_manager._tools["get_production"] + result = tool.fn() + assert result["queue"] == [] + assert result["available"] == ["e1", "e3"] + + +class TestMCPKnowledgeTools: + """Test game knowledge tools return static data.""" + + @pytest.fixture + def mcp(self): + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + env._last_obs = None + env._register_tools(mcp) + return mcp + + def test_lookup_unit_found(self, mcp): + tool = mcp._tool_manager._tools["lookup_unit"] + result = tool.fn("3tnk") + assert result["name"] == "Heavy Tank" + assert result["cost"] == 1150 + + def test_lookup_unit_not_found(self, mcp): + tool = mcp._tool_manager._tools["lookup_unit"] + result = tool.fn("nonexistent") + assert "error" in result + assert "available_types" in result + + def test_lookup_building_found(self, mcp): + tool = mcp._tool_manager._tools["lookup_building"] + result = tool.fn("weap") + assert result["name"] == "War Factory" + + def test_lookup_tech_tree(self, mcp): + tool = mcp._tool_manager._tools["lookup_tech_tree"] + result = tool.fn("soviet") + assert "soviet" in result + + def test_lookup_faction(self, mcp): + tool = mcp._tool_manager._tools["lookup_faction"] + result = tool.fn("russia") + assert result["side"] == "soviet" + assert "available_units" in result + + +# ─── New Action Type Tests ──────────────────────────────────────────────────── + + +class TestNewActionTypes: + def test_power_down_action(self): + cmd = CommandModel(action=ActionType.POWER_DOWN, actor_id=42) + assert cmd.action == ActionType.POWER_DOWN + assert cmd.actor_id == 42 + + def test_set_primary_action(self): + cmd = CommandModel(action=ActionType.SET_PRIMARY, actor_id=99) + assert cmd.action == ActionType.SET_PRIMARY + + def test_action_in_openra_action(self): + action = OpenRAAction(commands=[ + CommandModel(action=ActionType.POWER_DOWN, actor_id=1), + CommandModel(action=ActionType.SET_PRIMARY, actor_id=2), + ]) + assert len(action.commands) == 2 + + +class TestBridgeActionMapping: + def test_new_action_types_in_bridge_map(self): + from openra_env.server.bridge_client import commands_to_proto + from openra_env.generated import rl_bridge_pb2 + + proto = commands_to_proto([ + {"action": "power_down", "actor_id": 10}, + {"action": "set_primary", "actor_id": 20}, + ]) + assert len(proto.commands) == 2 + assert proto.commands[0].action == rl_bridge_pb2.POWER_DOWN + assert proto.commands[0].actor_id == 10 + assert proto.commands[1].action == rl_bridge_pb2.SET_PRIMARY + assert proto.commands[1].actor_id == 20 + + +# ─── Process Manager Replay Config Test ────────────────────────────────────── + + +class TestReplayConfig: + def test_record_replays_default_false(self): + from openra_env.server.openra_process import OpenRAConfig + config = OpenRAConfig() + assert config.record_replays is False + + def test_record_replays_in_command(self): + from openra_env.server.openra_process import OpenRAConfig, OpenRAProcessManager + openra_path = str(Path(__file__).parent.parent / "OpenRA") + config = OpenRAConfig(openra_path=openra_path, record_replays=True) + manager = OpenRAProcessManager(config) + cmd = manager._build_command() + assert "Server.RecordReplays=True" in cmd + + def test_no_replay_arg_when_disabled(self): + from openra_env.server.openra_process import OpenRAConfig, OpenRAProcessManager + openra_path = str(Path(__file__).parent.parent / "OpenRA") + config = OpenRAConfig(openra_path=openra_path, record_replays=False) + manager = OpenRAProcessManager(config) + cmd = manager._build_command() + assert "Server.RecordReplays=True" not in cmd + + +# ─── MCP Bot Pattern Tests ────────────────────────────────────────────────── + + +class TestMCPBotPatterns: + """Test patterns used by the MCP bot and LLM agent.""" + + def test_tool_schema_to_openai_conversion(self): + """MCP tool schemas convert to valid OpenAI function calling format.""" + from examples.llm_agent import mcp_tools_to_openai + + # Simulate MCP Tool objects + class FakeTool: + def __init__(self, name, description, input_schema): + self.name = name + self.description = description + self.input_schema = input_schema + + tools = [ + FakeTool("get_game_state", "Get game state", {"type": "object", "properties": {}}), + FakeTool( + "move_units", + "Move units to position", + { + "type": "object", + "properties": { + "unit_ids": {"type": "array", "items": {"type": "integer"}}, + "target_x": {"type": "integer"}, + "target_y": {"type": "integer"}, + }, + "required": ["unit_ids", "target_x", "target_y"], + }, + ), + ] + + result = mcp_tools_to_openai(tools) + assert len(result) == 2 + assert result[0]["type"] == "function" + assert result[0]["function"]["name"] == "get_game_state" + assert result[1]["function"]["name"] == "move_units" + assert "properties" in result[1]["function"]["parameters"] + assert "unit_ids" in result[1]["function"]["parameters"]["properties"] + + def test_openai_schema_has_required_fields(self): + """Each converted tool has type, function.name, function.description, function.parameters.""" + from examples.llm_agent import mcp_tools_to_openai + + class FakeTool: + def __init__(self): + self.name = "test_tool" + self.description = "A test tool" + self.input_schema = {"type": "object", "properties": {"x": {"type": "integer"}}} + + result = mcp_tools_to_openai([FakeTool()]) + tool = result[0] + assert tool["type"] == "function" + assert "name" in tool["function"] + assert "description" in tool["function"] + assert "parameters" in tool["function"] + + def test_compress_history_keeps_system_prompt(self): + """History compression preserves the system prompt.""" + from examples.llm_agent import compress_history + + messages = [ + {"role": "system", "content": "You are a bot"}, + *[{"role": "user", "content": f"msg {i}"} for i in range(100)], + ] + + compressed = compress_history(messages, keep_last=10) + assert compressed[0]["role"] == "system" + assert compressed[0]["content"] == "You are a bot" + assert len(compressed) == 12 # system + summary + 10 recent + + def test_compress_history_noop_when_short(self): + """History compression is a no-op when messages are short.""" + from examples.llm_agent import compress_history + + messages = [ + {"role": "system", "content": "You are a bot"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + + compressed = compress_history(messages, keep_last=10) + assert len(compressed) == 3 # unchanged + + +class TestScriptedBotNewActions: + """Test that the scripted bot has the new Sprint 5 action handlers.""" + + def test_power_management_handler_exists(self): + from examples.scripted_bot import ScriptedBot + bot = ScriptedBot() + assert hasattr(bot, "_handle_power_management") + assert hasattr(bot, "_powered_down") + + def test_set_primary_handler_exists(self): + from examples.scripted_bot import ScriptedBot + bot = ScriptedBot() + assert hasattr(bot, "_handle_set_primary") + assert hasattr(bot, "_primary_set") + + def test_power_management_no_action_when_positive(self): + """No power down when power balance is positive.""" + from examples.scripted_bot import ScriptedBot + from openra_env.models import OpenRAObservation, EconomyInfo, BuildingInfoModel + + bot = ScriptedBot() + obs = OpenRAObservation( + economy=EconomyInfo(power_provided=200, power_drained=80), + buildings=[BuildingInfoModel(actor_id=1, type="dome", is_powered=True)], + ) + commands = bot._handle_power_management(obs) + assert len(commands) == 0 + + def test_power_management_powers_down_when_negative(self): + """Powers down non-essential building when power balance is negative.""" + from examples.scripted_bot import ScriptedBot + from openra_env.models import OpenRAObservation, EconomyInfo, BuildingInfoModel + + bot = ScriptedBot() + obs = OpenRAObservation( + economy=EconomyInfo(power_provided=50, power_drained=100), + buildings=[BuildingInfoModel(actor_id=1, type="dome", is_powered=True)], + ) + commands = bot._handle_power_management(obs) + assert len(commands) == 1 + assert commands[0].action == ActionType.POWER_DOWN + assert commands[0].actor_id == 1 + + def test_set_primary_with_multiple_barracks(self): + """Sets primary on newest barracks when 2+ exist.""" + from examples.scripted_bot import ScriptedBot + from openra_env.models import OpenRAObservation, BuildingInfoModel + + bot = ScriptedBot() + obs = OpenRAObservation( + buildings=[ + BuildingInfoModel(actor_id=10, type="tent"), + BuildingInfoModel(actor_id=20, type="tent"), + ], + ) + commands = bot._handle_set_primary(obs) + assert len(commands) == 1 + assert commands[0].action == ActionType.SET_PRIMARY + assert commands[0].actor_id == 20 # newest + + def test_set_primary_not_with_single_barracks(self): + """No set_primary when only one barracks exists.""" + from examples.scripted_bot import ScriptedBot + from openra_env.models import OpenRAObservation, BuildingInfoModel + + bot = ScriptedBot() + obs = OpenRAObservation( + buildings=[BuildingInfoModel(actor_id=10, type="tent")], + ) + commands = bot._handle_set_primary(obs) + assert len(commands) == 0 + + +class TestProductionValidation: + """Test that build_unit/build_structure/build_and_place validate available_production.""" + + @pytest.fixture + def env_with_allied_obs(self): + """Create env with Allied faction observation (has 1tnk, NOT 3tnk).""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + # Stub attributes needed by the tools + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "england" + + env._last_obs = { + "tick": 500, + "done": False, + "result": "", + "economy": { + "cash": 3000, + "ore": 500, + "power_provided": 200, + "power_drained": 80, + "resource_capacity": 4000, + "harvester_count": 2, + }, + "military": { + "units_killed": 0, + "units_lost": 0, + "buildings_killed": 0, + "buildings_lost": 0, + "army_value": 1000, + "active_unit_count": 3, + }, + "units": [], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + # Allied production: has 1tnk, e1, e3, powr, tent, proc — NO 3tnk + "available_production": [ + "e1", "e3", "e6", "spy", "medi", + "1tnk", "arty", "harv", "jeep", "truk", + "powr", "tent", "proc", "weap", "gun", "dome", + ], + } + + # Mock _refresh_obs to be a no-op (obs already set) + env._refresh_obs = lambda: None + + env._register_tools(mcp) + return env, mcp + + def test_build_unit_rejects_wrong_faction(self, env_with_allied_obs): + """build_unit('3tnk') should fail for Allied player with clear error.""" + env, mcp = env_with_allied_obs + tool = mcp._tool_manager._tools["build_unit"] + result = tool.fn(unit_type="3tnk") + assert "error" in result + assert "3tnk" in result["error"] + assert "available_units" in result + # Should list Allied units, not buildings + assert "1tnk" in result["available_units"] + assert "powr" not in result["available_units"] + + def test_build_unit_accepts_valid_faction_unit(self, env_with_allied_obs): + """build_unit('1tnk') should succeed for Allied player.""" + env, mcp = env_with_allied_obs + # Mock _execute_commands since we don't have a real bridge + env._execute_commands = lambda cmds: { + "tick": 501, "done": False, "result": "", + "economy": env._last_obs["economy"], + "own_units": 0, "own_buildings": 1, + "visible_enemies": 0, + "production": ["1tnk@0%"], + } + tool = mcp._tool_manager._tools["build_unit"] + result = tool.fn(unit_type="1tnk") + assert "error" not in result + assert result["tick"] == 501 + + def test_build_unit_accepts_e1_for_allied(self, env_with_allied_obs): + """build_unit('e1') should succeed for Allied player.""" + env, mcp = env_with_allied_obs + env._execute_commands = lambda cmds: { + "tick": 501, "done": False, "result": "", + "economy": env._last_obs["economy"], + "own_units": 0, "own_buildings": 1, + "visible_enemies": 0, + "production": ["e1@0%"], + } + tool = mcp._tool_manager._tools["build_unit"] + result = tool.fn(unit_type="e1") + assert "error" not in result + + def test_build_structure_rejects_unavailable(self, env_with_allied_obs): + """build_structure for unavailable building returns error.""" + env, mcp = env_with_allied_obs + tool = mcp._tool_manager._tools["build_structure"] + result = tool.fn(building_type="tsla") # Soviet Tesla Coil + assert "error" in result + assert "available_buildings" in result + assert "powr" in result["available_buildings"] + + def test_build_structure_accepts_valid(self, env_with_allied_obs): + """build_structure('powr') should succeed for Allied player.""" + env, mcp = env_with_allied_obs + env._execute_commands = lambda cmds: { + "tick": 501, "done": False, "result": "", + "economy": env._last_obs["economy"], + "own_units": 0, "own_buildings": 1, + "visible_enemies": 0, + "production": ["powr@0%"], + } + tool = mcp._tool_manager._tools["build_structure"] + result = tool.fn(building_type="powr") + assert "error" not in result + + def test_build_and_place_rejects_unavailable(self, env_with_allied_obs): + """build_and_place for unavailable building returns error.""" + env, mcp = env_with_allied_obs + tool = mcp._tool_manager._tools["build_and_place"] + result = tool.fn(building_type="tsla") + assert "error" in result + assert "available_buildings" in result + + def test_build_and_place_accepts_valid(self, env_with_allied_obs): + """build_and_place('proc') should succeed for Allied player.""" + env, mcp = env_with_allied_obs + env._execute_commands = lambda cmds: { + "tick": 501, "done": False, "result": "", + "economy": env._last_obs["economy"], + "own_units": 0, "own_buildings": 1, + "visible_enemies": 0, + "production": ["proc@0%"], + } + tool = mcp._tool_manager._tools["build_and_place"] + result = tool.fn(building_type="proc") + assert "error" not in result + assert "proc" in env._pending_placements + + def test_build_unit_error_lists_units_not_buildings(self, env_with_allied_obs): + """Error response should list only units, not buildings.""" + env, mcp = env_with_allied_obs + tool = mcp._tool_manager._tools["build_unit"] + result = tool.fn(unit_type="v2rl") # Soviet V2 Launcher + assert "error" in result + avail = result["available_units"] + # Should contain units + assert "e1" in avail + assert "1tnk" in avail + # Should NOT contain buildings + assert "powr" not in avail + assert "tent" not in avail + assert "proc" not in avail + + def test_build_structure_error_lists_buildings_not_units(self, env_with_allied_obs): + """Error response should list only buildings, not units.""" + env, mcp = env_with_allied_obs + tool = mcp._tool_manager._tools["build_structure"] + result = tool.fn(building_type="tsla") + assert "error" in result + avail = result["available_buildings"] + # Should contain buildings + assert "powr" in avail + assert "tent" in avail + # Should NOT contain units + assert "e1" not in avail + assert "1tnk" not in avail + + +class TestOreCapAlert: + """Test the ore storage capacity alert.""" + + @pytest.fixture + def env_with_full_ore(self): + """Create env with ore near capacity.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "" + env._last_production_progress = {} + env._prev_buildings = {} + env._prev_unit_ids = {} + env._enemy_ever_seen = False + + env._last_obs = { + "tick": 8000, + "done": False, + "result": "", + "economy": { + "cash": 1826, + "ore": 3800, # 95% of 4000 capacity + "power_provided": 300, + "power_drained": 190, + "resource_capacity": 4000, + "harvester_count": 2, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 500, "active_unit_count": 2, + }, + "units": [ + { + "actor_id": 10, "type": "e1", "pos_x": 1000, "pos_y": 2000, + "cell_x": 10, "cell_y": 20, "hp_percent": 1.0, + "is_idle": False, "current_activity": "", + "owner": "Multi0", "can_attack": True, "facing": 0, + "experience_level": 0, "stance": 3, "speed": 56, + "attack_range": 5120, "passenger_count": -1, "is_building": False, + }, + ], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + { + "actor_id": 2, "type": "proc", "pos_x": 600, "pos_y": 600, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 700, + "rally_x": -1, "rally_y": -1, "power_amount": -30, + "can_produce": [], "cell_x": 6, "cell_y": 6, + }, + { + "actor_id": 3, "type": "powr", "pos_x": 400, "pos_y": 400, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 150, + "rally_x": -1, "rally_y": -1, "power_amount": 100, + "can_produce": [], "cell_x": 4, "cell_y": 4, + }, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["e1", "powr", "proc"], + } + + env._register_tools(mcp) + return env, mcp + + def test_ore_cap_alert_fires(self, env_with_full_ore): + """Alert fires when ore >= 90% of capacity.""" + env, mcp = env_with_full_ore + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + alerts = result.get("alerts", []) + ore_alerts = [a for a in alerts if "ORE FULL" in a] + assert len(ore_alerts) == 1 + assert "income is being lost" in ore_alerts[0].lower() + + def test_ore_cap_alert_not_when_low(self, env_with_full_ore): + """Alert does NOT fire when ore is well below capacity.""" + env, mcp = env_with_full_ore + env._last_obs["economy"]["ore"] = 1000 # 25% of 4000 + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + alerts = result.get("alerts", []) + ore_alerts = [a for a in alerts if "ORE FULL" in a] + assert len(ore_alerts) == 0 + + +class TestWaterBuildingGuard: + """Test that water buildings skip auto-placement and warn.""" + + @pytest.fixture + def env_with_water_building(self): + """Create env with a completed spen in pending placements.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {"spen": {"cell_x": 0, "cell_y": 0}} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + + env._last_obs = { + "tick": 10000, + "done": False, + "result": "", + "economy": { + "cash": 2000, "ore": 1000, + "power_provided": 300, "power_drained": 200, + "resource_capacity": 4000, "harvester_count": 2, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 2000, "active_unit_count": 5, + }, + "units": [], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + ], + "production": [ + { + "queue_type": "Building", + "item": "spen", + "progress": 1.0, + "remaining_ticks": 0, + "remaining_cost": 0, + "paused": False, + }, + ], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["powr", "barr", "proc", "spen"], + } + + env._register_tools(mcp) + return env, mcp + + def test_water_building_skips_auto_placement(self, env_with_water_building): + """Water building (spen) should be removed from pending and warn.""" + env, mcp = env_with_water_building + assert "spen" in env._pending_placements + + # Trigger placement processing + env._process_pending_placements() + + # spen should be removed from pending placements + assert "spen" not in env._pending_placements + # Should have a warning in placement results + assert len(env._placement_results) == 1 + assert "WATER BUILDING" in env._placement_results[0] + assert "spen" in env._placement_results[0] + + def test_water_building_not_in_attempted(self, env_with_water_building): + """Water building should NOT enter the attempted tracking (no retries).""" + env, mcp = env_with_water_building + env._process_pending_placements() + assert "spen" not in env._attempted_placements + + +# ── Round 2 Tests ────────────────────────────────────────────────────────── + + +class TestExecuteCommandsTriggersPlacement: + """S1: _execute_commands() should trigger _process_pending_placements().""" + + def test_pending_placement_processed_via_execute_commands(self): + """When _execute_commands runs, pending placements should be processed.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._pending_placements = {"powr": {"cell_x": 5, "cell_y": 5}} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "england" + env._prev_buildings = {} + env._prev_unit_ids = {} + + obs_dict = { + "tick": 100, + "done": False, + "result": "", + "economy": { + "cash": 5000, "ore": 0, + "power_provided": 0, "power_drained": 0, + "resource_capacity": 4000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0, + }, + "units": [], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 5120, "pos_y": 5120, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + ], + "production": [ + { + "queue_type": "Building", + "item": "powr", + "progress": 1.0, + "remaining_ticks": 0, + "remaining_cost": 0, + "paused": False, + }, + ], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["powr", "proc"], + } + env._last_obs = obs_dict + + # Track whether _process_pending_placements was called + placement_called = [] + + def mock_process_pending(): + placement_called.append(True) + + env._process_pending_placements = mock_process_pending + + # Patch run_coroutine_threadsafe to return obs_dict directly + mock_future = MagicMock() + mock_future.result.return_value = obs_dict + + with patch("asyncio.run_coroutine_threadsafe", return_value=mock_future): + env._loop = MagicMock() + from openra_env.models import CommandModel, ActionType + result = env._execute_commands([CommandModel(action=ActionType.NO_OP)]) + + assert len(placement_called) == 1, "_process_pending_placements was not called by _execute_commands" + assert result["tick"] == 100 + + +class TestDeadUnitFiltering: + """S3: _resolve_unit_ids should filter dead unit IDs and warn.""" + + @pytest.fixture + def env_with_units(self): + """Create env with some living units.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._placement_results = [] + env._unit_groups = {"alpha": [10, 11, 99]} # 99 is dead + env._last_obs = { + "units": [ + {"actor_id": 10, "type": "e1", "can_attack": True, "is_idle": True}, + {"actor_id": 11, "type": "e1", "can_attack": True, "is_idle": False}, + {"actor_id": 12, "type": "e1", "can_attack": False, "is_idle": True}, + ], + } + return env + + def test_list_filters_dead_ids(self, env_with_units): + """List of int IDs filters out dead units.""" + env = env_with_units + result = env._resolve_unit_ids([10, 11, 50, 99], env._last_obs) + assert result == [10, 11] + # Should warn about dead units + dead_warnings = [r for r in env._placement_results if "DEAD UNITS" in r] + assert len(dead_warnings) == 1 + assert "50" in dead_warnings[0] + assert "99" in dead_warnings[0] + + def test_string_ids_filter_dead(self, env_with_units): + """Comma-separated string IDs filter dead units.""" + env = env_with_units + result = env._resolve_unit_ids("10,99,50", env._last_obs) + assert result == [10] + dead_warnings = [r for r in env._placement_results if "DEAD UNITS" in r] + assert len(dead_warnings) == 1 + + def test_bracketed_string_filters_dead(self, env_with_units): + """Bracketed string like '[10, 50]' filters dead units.""" + env = env_with_units + result = env._resolve_unit_ids("[10, 50]", env._last_obs) + assert result == [10] + + def test_group_filters_dead(self, env_with_units): + """Named group filters dead units from group members.""" + env = env_with_units + result = env._resolve_unit_ids("alpha", env._last_obs) + assert result == [10, 11] + dead_warnings = [r for r in env._placement_results if "DEAD UNITS" in r] + assert len(dead_warnings) == 1 + assert "99" in dead_warnings[0] + + def test_all_combat_returns_living(self, env_with_units): + """'all_combat' returns living units with can_attack, no dead warning.""" + env = env_with_units + result = env._resolve_unit_ids("all_combat", env._last_obs) + assert result == [10, 11] + assert len(env._placement_results) == 0 # no warnings + + def test_all_ids_dead(self, env_with_units): + """All requested IDs dead returns empty list with warning.""" + env = env_with_units + result = env._resolve_unit_ids([50, 99], env._last_obs) + assert result == [] + dead_warnings = [r for r in env._placement_results if "DEAD UNITS" in r] + assert len(dead_warnings) == 1 + + def test_no_dead_no_warning(self, env_with_units): + """All IDs valid produces no warning.""" + env = env_with_units + result = env._resolve_unit_ids([10, 11], env._last_obs) + assert result == [10, 11] + assert len(env._placement_results) == 0 + + +class TestBuildUnitFundsCheck: + """S4: build_unit should return error when insufficient funds.""" + + @pytest.fixture + def env_broke(self): + """Create env with $0 funds.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + env._last_production_progress = {} + + env._last_obs = { + "tick": 5000, + "done": False, + "result": "", + "economy": { + "cash": 0, "ore": 0, + "power_provided": 100, "power_drained": 50, + "resource_capacity": 4000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0, + }, + "units": [], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["e1", "e2", "powr", "proc", "barr"], + } + + # Mock _refresh_obs to be a no-op (obs already set) + env._refresh_obs = lambda: None + + env._register_tools(mcp) + return env, mcp + + def test_build_unit_rejects_no_funds(self, env_broke): + """build_unit returns error when funds are insufficient.""" + env, mcp = env_broke + tool = mcp._tool_manager._tools["build_unit"] + result = tool.fn(unit_type="e1", count=1) + assert "error" in result + assert "Insufficient funds" in result["error"] + assert "$0" in result["error"] + + def test_build_unit_allows_when_funded(self, env_broke): + """build_unit succeeds when funds are sufficient.""" + env, mcp = env_broke + env._last_obs["economy"]["cash"] = 500 + + # Mock _execute_commands since we don't have a real bridge + env._execute_commands = lambda cmds: { + "tick": 5001, "done": False, "result": "", + "economy": env._last_obs["economy"], + "own_units": 0, "own_buildings": 1, + "visible_enemies": 0, "production": [], + } + + tool = mcp._tool_manager._tools["build_unit"] + result = tool.fn(unit_type="e1", count=1) + assert "error" not in result + assert "tick" in result + + +class TestStalledProductionAlert: + """S2: get_game_state should alert when production stalled at $0.""" + + @pytest.fixture + def env_stalled(self): + """Create env with stalled production and $0 funds.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + # Pre-seed with same progress to simulate stall + env._last_production_progress = {"weap": 0.56} + + env._last_obs = { + "tick": 10000, + "done": False, + "result": "", + "economy": { + "cash": 0, "ore": 0, + "power_provided": 200, "power_drained": 100, + "resource_capacity": 4000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 500, "active_unit_count": 2, + }, + "units": [ + { + "actor_id": 10, "type": "e1", "pos_x": 1000, "pos_y": 2000, + "cell_x": 10, "cell_y": 20, "hp_percent": 1.0, + "is_idle": True, "current_activity": "", + "owner": "Multi0", "can_attack": True, "facing": 0, + "experience_level": 0, "stance": 3, "speed": 56, + "attack_range": 5120, "passenger_count": -1, "is_building": False, + }, + ], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + ], + "production": [ + { + "queue_type": "Building", + "item": "weap", + "progress": 0.56, + "remaining_ticks": 300, + "remaining_cost": 1000, + "paused": False, + }, + ], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["e1", "powr", "proc"], + } + + env._register_tools(mcp) + return env, mcp + + def test_stalled_alert_fires(self, env_stalled): + """Alert fires when production progress unchanged and $0 funds.""" + env, mcp = env_stalled + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + alerts = result.get("alerts", []) + stalled_alerts = [a for a in alerts if "STALLED" in a] + assert len(stalled_alerts) == 1 + assert "weap" in stalled_alerts[0] + assert "$0" in stalled_alerts[0] + + def test_stalled_alert_not_on_first_call(self, env_stalled): + """Alert does NOT fire on first call (no previous progress to compare).""" + env, mcp = env_stalled + # Clear the pre-seeded progress so it's like first call + env._last_production_progress = {} + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + alerts = result.get("alerts", []) + stalled_alerts = [a for a in alerts if "STALLED" in a] + assert len(stalled_alerts) == 0 + + def test_stalled_alert_not_when_funded(self, env_stalled): + """Alert does NOT fire when player has funds (even if progress same).""" + env, mcp = env_stalled + env._last_obs["economy"]["cash"] = 1000 + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + alerts = result.get("alerts", []) + stalled_alerts = [a for a in alerts if "STALLED" in a] + assert len(stalled_alerts) == 0 + + def test_stalled_alert_not_when_progressing(self, env_stalled): + """Alert does NOT fire when progress is advancing (even at $0).""" + env, mcp = env_stalled + # Previous was 0.50, current is 0.56 → progressing + env._last_production_progress = {"weap": 0.50} + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + alerts = result.get("alerts", []) + stalled_alerts = [a for a in alerts if "STALLED" in a] + assert len(stalled_alerts) == 0 + + def test_progress_snapshot_updated(self, env_stalled): + """_last_production_progress is updated after each call.""" + env, mcp = env_stalled + env._last_production_progress = {} + tool = mcp._tool_manager._tools["get_game_state"] + tool.fn() + assert "weap" in env._last_production_progress + assert abs(env._last_production_progress["weap"] - 0.56) < 0.01 + + +class TestBuildingStuckAlertText: + """S5: BUILDING STUCK alert should suggest get_valid_placements, not 'auto-cancel'.""" + + @pytest.fixture + def env_stuck_building(self): + """Create env with a stuck building in attempted placements.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {"powr": 5} # 5 failed attempts + env._placement_results = [] + env._player_faction = "russia" + env._last_production_progress = {} + + env._last_obs = { + "tick": 6000, + "done": False, + "result": "", + "economy": { + "cash": 2000, "ore": 500, + "power_provided": 100, "power_drained": 100, + "resource_capacity": 4000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0, + }, + "units": [], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + ], + "production": [ + { + "queue_type": "Building", + "item": "powr", + "progress": 1.0, + "remaining_ticks": 0, + "remaining_cost": 0, + "paused": False, + }, + ], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["powr", "proc"], + } + + env._register_tools(mcp) + return env, mcp + + def test_stuck_alert_suggests_valid_placements(self, env_stuck_building): + """BUILDING STUCK alert should be factual (no prescriptive tool suggestions).""" + env, mcp = env_stuck_building + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + alerts = result.get("alerts", []) + stuck_alerts = [a for a in alerts if "BUILDING STUCK" in a] + assert len(stuck_alerts) == 1 + assert "auto-placement failing" in stuck_alerts[0] + + +# ── Round 3 Tests ────────────────────────────────────────────────────────── + + +class TestUnderAttackAlertCap: + """S1: UNDER ATTACK alerts should be capped when >3 attackers.""" + + @pytest.fixture + def env_base(self): + """Create env with buildings and variable enemy counts.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + env._last_production_progress = {} + env._prev_buildings = {} + env._prev_unit_ids = {} + + env._last_obs = { + "tick": 8000, + "done": False, + "result": "", + "economy": { + "cash": 2000, "ore": 500, + "power_provided": 200, "power_drained": 100, + "resource_capacity": 4000, "harvester_count": 2, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 1000, "active_unit_count": 3, + }, + "units": [ + { + "actor_id": 10, "type": "e1", "pos_x": 5120, "pos_y": 5120, + "cell_x": 5, "cell_y": 5, "hp_percent": 1.0, + "is_idle": True, "current_activity": "", + "owner": "Multi0", "can_attack": True, "facing": 0, + "experience_level": 0, "stance": 3, "speed": 56, + "attack_range": 5120, "passenger_count": -1, "is_building": False, + }, + ], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 5120, "pos_y": 5120, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + { + "actor_id": 2, "type": "barr", "pos_x": 6144, "pos_y": 5120, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 6, "cell_y": 5, + }, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["e1", "powr", "proc"], + } + + env._register_tools(mcp) + return env, mcp + + def _make_enemy(self, actor_id, etype, cell_x, cell_y): + return { + "actor_id": actor_id, "type": etype, + "pos_x": cell_x * 1024, "pos_y": cell_y * 1024, + "cell_x": cell_x, "cell_y": cell_y, "hp_percent": 1.0, + "is_idle": False, "current_activity": "", "owner": "Multi1", + "can_attack": True, "facing": 0, "experience_level": 0, + "stance": 3, "speed": 56, "attack_range": 5120, + "passenger_count": -1, "is_building": False, + } + + def test_few_attackers_individual_alerts(self, env_base): + """≤3 attackers near base → individual alerts.""" + env, mcp = env_base + env._last_obs["visible_enemies"] = [ + self._make_enemy(100, "e1", 5, 6), + self._make_enemy(101, "e3", 6, 6), + ] + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + attack_alerts = [a for a in result["alerts"] if "UNDER ATTACK" in a] + assert len(attack_alerts) == 2 + assert any("e1" in a for a in attack_alerts) + assert any("e3" in a for a in attack_alerts) + + def test_many_attackers_summarized(self, env_base): + """>3 attackers near base → one summary alert with type breakdown.""" + env, mcp = env_base + env._last_obs["visible_enemies"] = [ + self._make_enemy(100, "e1", 5, 6), + self._make_enemy(101, "e1", 5, 7), + self._make_enemy(102, "e3", 6, 6), + self._make_enemy(103, "e3", 7, 5), + self._make_enemy(104, "e4", 6, 4), + ] + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + attack_alerts = [a for a in result["alerts"] if "UNDER ATTACK" in a] + assert len(attack_alerts) == 1 + assert "5 enemies" in attack_alerts[0] + assert "e1" in attack_alerts[0] + assert "e3" in attack_alerts[0] + + def test_far_enemies_no_alert(self, env_base): + """Enemies far from base → no UNDER ATTACK alert.""" + env, mcp = env_base + env._last_obs["visible_enemies"] = [ + self._make_enemy(100, "e1", 50, 50), # far away + ] + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + attack_alerts = [a for a in result["alerts"] if "UNDER ATTACK" in a] + assert len(attack_alerts) == 0 + + +class TestLossTracking: + """S2: Loss tracking should detect destroyed buildings and units.""" + + def test_building_destroyed_alert(self): + """DESTROYED alert fires when a building disappears between observations.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._placement_results = [] + env._prev_buildings = {1: "fact", 2: "weap", 3: "barr"} + env._prev_unit_ids = {} + env._last_obs = { + "buildings": [ + {"actor_id": 1, "type": "fact"}, + {"actor_id": 3, "type": "barr"}, + ], + "units": [], + } + env._update_loss_tracking() + destroyed = [r for r in env._placement_results if "DESTROYED" in r] + assert len(destroyed) == 1 + assert "weap" in destroyed[0] + + def test_units_lost_alert(self): + """UNITS LOST alert fires with type breakdown when units disappear.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._placement_results = [] + env._prev_buildings = {} + env._prev_unit_ids = {10: "e1", 11: "e1", 12: "e3", 13: "3tnk", 14: "e1"} + env._last_obs = { + "buildings": [], + "units": [ + {"actor_id": 10, "type": "e1"}, + {"actor_id": 14, "type": "e1"}, + ], + } + env._update_loss_tracking() + lost = [r for r in env._placement_results if "UNITS LOST" in r] + assert len(lost) == 1 + assert "3 destroyed" in lost[0] + assert "e1" in lost[0] + assert "e3" in lost[0] + assert "3tnk" in lost[0] + + def test_no_losses_no_alert(self): + """No losses → no alerts.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._placement_results = [] + env._prev_buildings = {1: "fact"} + env._prev_unit_ids = {10: "e1"} + env._last_obs = { + "buildings": [{"actor_id": 1, "type": "fact"}], + "units": [{"actor_id": 10, "type": "e1"}], + } + env._update_loss_tracking() + assert len(env._placement_results) == 0 + + def test_first_observation_no_alert(self): + """First observation (empty prev) → no alerts, just snapshot.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._placement_results = [] + env._prev_buildings = {} + env._prev_unit_ids = {} + env._last_obs = { + "buildings": [{"actor_id": 1, "type": "fact"}], + "units": [{"actor_id": 10, "type": "e1"}], + } + env._update_loss_tracking() + assert len(env._placement_results) == 0 + # Should have updated snapshots + assert env._prev_buildings == {1: "fact"} + assert env._prev_unit_ids == {10: "e1"} + + def test_multiple_buildings_destroyed(self): + """Multiple buildings destroyed → multiple DESTROYED alerts.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._placement_results = [] + env._prev_buildings = {1: "fact", 2: "weap", 3: "barr", 4: "kenn"} + env._prev_unit_ids = {} + env._last_obs = { + "buildings": [{"actor_id": 1, "type": "fact"}], + "units": [], + } + env._update_loss_tracking() + destroyed = [r for r in env._placement_results if "DESTROYED" in r] + assert len(destroyed) == 3 # weap, barr, kenn + + def test_snapshots_updated_after_tracking(self): + """_prev_buildings and _prev_unit_ids updated after tracking.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._placement_results = [] + env._prev_buildings = {1: "fact", 2: "weap"} + env._prev_unit_ids = {10: "e1"} + env._last_obs = { + "buildings": [{"actor_id": 1, "type": "fact"}], + "units": [{"actor_id": 10, "type": "e1"}, {"actor_id": 11, "type": "3tnk"}], + } + env._update_loss_tracking() + assert env._prev_buildings == {1: "fact"} + assert env._prev_unit_ids == {10: "e1", 11: "3tnk"} + + +class TestPrereqDiagnosis: + """S3: Production unavailable should diagnose missing prerequisites.""" + + @pytest.fixture + def env_no_kenn(self): + """Create env without kennel — dog should explain missing prereq.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + env._last_production_progress = {} + env._prev_buildings = {} + env._prev_unit_ids = {} + + env._last_obs = { + "tick": 5000, + "done": False, + "result": "", + "economy": { + "cash": 2000, "ore": 1000, + "power_provided": 200, "power_drained": 100, + "resource_capacity": 4000, "harvester_count": 2, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0, + }, + "units": [], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + { + "actor_id": 2, "type": "barr", "pos_x": 600, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 6, "cell_y": 5, + }, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["e1", "e2", "powr", "proc", "barr"], + } + + env._refresh_obs = lambda: None + env._register_tools(mcp) + return env, mcp + + def test_dog_missing_kennel(self, env_no_kenn): + """build_unit('dog') without kennel explains missing prerequisite.""" + env, mcp = env_no_kenn + tool = mcp._tool_manager._tools["build_unit"] + result = tool.fn(unit_type="dog", count=1) + assert "error" in result + assert "kenn" in result["error"] + assert "missing_prerequisites" in result + assert "kenn" in result["missing_prerequisites"] + + def test_3tnk_missing_fix(self, env_no_kenn): + """build_unit('3tnk') without fix explains missing prerequisites.""" + env, mcp = env_no_kenn + # Add weap but not fix + env._last_obs["buildings"].append({ + "actor_id": 3, "type": "weap", "pos_x": 700, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 7, "cell_y": 5, + }) + tool = mcp._tool_manager._tools["build_unit"] + result = tool.fn(unit_type="3tnk", count=1) + assert "error" in result + assert "fix" in result["error"] + assert "missing_prerequisites" in result + + def test_building_missing_prereq(self, env_no_kenn): + """build_structure for a building needing dome explains missing prereq.""" + env, mcp = env_no_kenn + tool = mcp._tool_manager._tools["build_structure"] + result = tool.fn(building_type="afld") + assert "error" in result + # afld requires dome and weap + assert "missing_prerequisites" in result + + def test_diagnose_unknown_type(self): + """_diagnose_unavailable for unknown type returns generic message.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._last_obs = {"buildings": []} + result = env._diagnose_unavailable("zzzz") + assert "not a known" in result["reason"] + + +# ── Round 4 Tests ────────────────────────────────────────────────────────── + + +class TestUnitFeedback: + """S1: move/attack_move/attack_target should return commanded_units feedback.""" + + @pytest.fixture + def env_with_units(self): + """Create env with units for move command testing.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + env._last_production_progress = {} + env._prev_buildings = {} + env._prev_unit_ids = {} + env._unit_groups = {} + + env._last_obs = { + "tick": 3000, + "done": False, + "result": "", + "economy": { + "cash": 2000, "ore": 500, + "power_provided": 200, "power_drained": 80, + "resource_capacity": 4000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0, + }, + "units": [ + { + "actor_id": 142, "type": "e1", "pos_x": 13000, "pos_y": 14000, + "hp_percent": 1.0, "owner": "Multi0", "is_idle": False, + "current_activity": "MoveTo", "can_attack": True, + "stance": 3, "cell_x": 13, "cell_y": 14, + "facing": 128, "experience_level": 0, "speed": 56, + "attack_range": 5120, "passenger_count": 0, "ammo": -1, + "is_building": False, + }, + { + "actor_id": 143, "type": "e1", "pos_x": 12000, "pos_y": 14000, + "hp_percent": 1.0, "owner": "Multi0", "is_idle": True, + "current_activity": "IdleDefault", "can_attack": True, + "stance": 3, "cell_x": 12, "cell_y": 14, + "facing": 256, "experience_level": 0, "speed": 56, + "attack_range": 5120, "passenger_count": 0, "ammo": -1, + "is_building": False, + }, + { + "actor_id": 154, "type": "dog", "pos_x": 50000, "pos_y": 30000, + "hp_percent": 1.0, "owner": "Multi0", "is_idle": False, + "current_activity": "AttackMoveActivity", "can_attack": True, + "stance": 3, "cell_x": 50, "cell_y": 30, + "facing": 64, "experience_level": 0, "speed": 99, + "attack_range": 1024, "passenger_count": 0, "ammo": -1, + "is_building": False, + }, + ], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 5000, "pos_y": 5000, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["e1", "e2", "dog", "powr", "proc"], + } + + env._refresh_obs = lambda: None + env._execute_commands = lambda cmds: { + "tick": 3050, "done": False, "result": "", + "economy": env._last_obs["economy"], + "own_units": 3, "own_buildings": 1, + "visible_enemies": 0, + "production": [], + } + env._register_tools(mcp) + return env, mcp + + def test_move_units_returns_feedback(self, env_with_units): + """move_units should include commanded_units with positions.""" + env, mcp = env_with_units + tool = mcp._tool_manager._tools["move_units"] + result = tool.fn(unit_ids="142,143", target_x=50, target_y=20) + assert "commanded_units" in result + assert len(result["commanded_units"]) == 2 + unit_142 = next(u for u in result["commanded_units"] if u["id"] == 142) + assert unit_142["type"] == "e1" + assert unit_142["cell_x"] == 13 + assert unit_142["cell_y"] == 14 + assert unit_142["activity"] == "MoveTo" + + def test_attack_move_returns_feedback(self, env_with_units): + """attack_move should include commanded_units with positions.""" + env, mcp = env_with_units + tool = mcp._tool_manager._tools["attack_move"] + result = tool.fn(unit_ids="154", target_x=90, target_y=40) + assert "commanded_units" in result + assert len(result["commanded_units"]) == 1 + assert result["commanded_units"][0]["type"] == "dog" + assert result["commanded_units"][0]["cell_x"] == 50 + assert result["commanded_units"][0]["cell_y"] == 30 + + def test_attack_move_all_combat(self, env_with_units): + """attack_move with all_combat includes all 3 combat units.""" + env, mcp = env_with_units + tool = mcp._tool_manager._tools["attack_move"] + result = tool.fn(unit_ids="all_combat", target_x=90, target_y=40) + assert "commanded_units" in result + assert len(result["commanded_units"]) == 3 + ids = {u["id"] for u in result["commanded_units"]} + assert ids == {142, 143, 154} + + def test_attack_target_returns_feedback(self, env_with_units): + """attack_target should include commanded_units.""" + env, mcp = env_with_units + tool = mcp._tool_manager._tools["attack_target"] + result = tool.fn(unit_ids="142,143", target_actor_id=999) + assert "commanded_units" in result + assert len(result["commanded_units"]) == 2 + + def test_stop_units_returns_feedback(self, env_with_units): + """stop_units should include commanded_units.""" + env, mcp = env_with_units + tool = mcp._tool_manager._tools["stop_units"] + result = tool.fn(unit_ids="154") + assert "commanded_units" in result + assert len(result["commanded_units"]) == 1 + assert result["commanded_units"][0]["id"] == 154 + + def test_command_group_returns_feedback(self, env_with_units): + """command_group should include commanded_units with positions.""" + env, mcp = env_with_units + # Set up a group + env._unit_groups["scouts"] = [142, 154] + tool = mcp._tool_manager._tools["command_group"] + result = tool.fn(group_name="scouts", command="attack_move", target_x=90, target_y=40) + assert "commanded_units" in result + assert len(result["commanded_units"]) == 2 + ids = {u["id"] for u in result["commanded_units"]} + assert ids == {142, 154} + + def test_feedback_includes_activity(self, env_with_units): + """commanded_units should include the current_activity field.""" + env, mcp = env_with_units + tool = mcp._tool_manager._tools["move_units"] + result = tool.fn(unit_ids="142", target_x=50, target_y=20) + assert result["commanded_units"][0]["activity"] == "MoveTo" + + def test_feedback_excludes_dead_units(self, env_with_units): + """If a commanded unit died during execution, it shouldn't appear in feedback.""" + env, mcp = env_with_units + # Unit 143 exists in obs but 999 doesn't — simulate commanding a valid + dead unit + # _resolve_unit_ids filters dead ones, so feedback should only have living + tool = mcp._tool_manager._tools["move_units"] + result = tool.fn(unit_ids="143", target_x=50, target_y=20) + assert len(result["commanded_units"]) == 1 + assert result["commanded_units"][0]["id"] == 143 + + +class TestFactDestroyedDiagnosis: + """S2: _diagnose_unavailable should detect missing Construction Yard.""" + + def test_powr_without_fact(self): + """build_and_place('powr') without fact says 'No Construction Yard'.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._last_obs = { + "buildings": [ + {"actor_id": 2, "type": "barr", "cell_x": 5, "cell_y": 5}, + ], + } + result = env._diagnose_unavailable("powr") + assert "No Construction Yard" in result["reason"] + assert "MCV" in result["reason"] + + def test_fact_present_uses_normal_diagnosis(self): + """With fact present, _diagnose_unavailable uses normal prereq check.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._last_obs = { + "buildings": [ + {"actor_id": 1, "type": "fact", "cell_x": 5, "cell_y": 5}, + ], + } + # powr has no explicit prereqs and fact is present → should NOT say "No Construction Yard" + result = env._diagnose_unavailable("powr") + assert "No Construction Yard" not in result["reason"] + + def test_afld_without_fact(self): + """Any building type without fact should say 'No Construction Yard'.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._last_obs = { + "buildings": [ + {"actor_id": 2, "type": "barr", "cell_x": 5, "cell_y": 5}, + {"actor_id": 3, "type": "weap", "cell_x": 6, "cell_y": 5}, + ], + } + result = env._diagnose_unavailable("afld") + assert "No Construction Yard" in result["reason"] + + def test_unit_without_fact_still_normal(self): + """Units (not buildings) should NOT get the fact check.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._last_obs = { + "buildings": [ + {"actor_id": 2, "type": "barr", "cell_x": 5, "cell_y": 5}, + ], + } + # dog is a unit, not a building — should use normal prereq diagnosis + result = env._diagnose_unavailable("dog") + assert "No Construction Yard" not in result["reason"] + assert "kenn" in result["reason"] + + +# ── Round 5 Tests ────────────────────────────────────────────────────── + +class TestBatchValidation: + """S1/S2: batch() should reject unsupported actions and validate build_unit.""" + + @pytest.fixture + def env_with_batch(self): + """Create env for batch testing.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + env._last_production_progress = {} + env._prev_buildings = {} + env._prev_unit_ids = {} + env._unit_groups = {} + env._enemy_ever_seen = False + + env._last_obs = { + "tick": 3000, "done": False, "result": "", + "economy": { + "cash": 500, "ore": 100, + "power_provided": 200, "power_drained": 80, + "resource_capacity": 4000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0, + }, + "units": [ + { + "actor_id": 150, "type": "e1", "pos_x": 10000, "pos_y": 10000, + "hp_percent": 1.0, "owner": "Multi0", "is_idle": True, + "current_activity": "", "can_attack": True, + "stance": 3, "cell_x": 10, "cell_y": 10, + "facing": 128, "experience_level": 0, "speed": 56, + "attack_range": 5120, "passenger_count": 0, "ammo": -1, + "is_building": False, + }, + ], + "buildings": [ + { + "actor_id": 1, "type": "fact", "pos_x": 5000, "pos_y": 5000, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["e1", "e2", "dog", "powr", "proc"], + } + + env._refresh_obs = lambda: None + env._execute_commands = lambda cmds: { + "tick": 3050, "done": False, "result": "", + "economy": env._last_obs["economy"], + "own_units": 1, "own_buildings": 1, + "visible_enemies": 0, + "production": [], + } + env._register_tools(mcp) + return env, mcp + + def test_batch_rejects_advance(self, env_with_batch): + """advance inside batch should be marked SKIPPED.""" + env, mcp = env_with_batch + tool = mcp._tool_manager._tools["batch"] + result = tool.fn(actions=[ + {"tool": "advance", "ticks": 100}, + {"tool": "attack_move", "unit_ids": "150", "target_x": 50, "target_y": 50}, + ]) + assert "advance:SKIPPED" in str(result.get("actions", [])) + assert "attack_move" in result.get("actions", []) + + def test_batch_build_unit_unavailable(self, env_with_batch): + """build_unit for unavailable unit should be marked FAILED.""" + env, mcp = env_with_batch + tool = mcp._tool_manager._tools["batch"] + result = tool.fn(actions=[ + {"tool": "build_unit", "unit_type": "mig", "count": 1}, + {"tool": "attack_move", "unit_ids": "150", "target_x": 50, "target_y": 50}, + ]) + assert "build_unit:FAILED" in result.get("actions", []) + assert "attack_move" in result.get("actions", []) + + def test_batch_all_unsupported_returns_error(self, env_with_batch): + """All unsupported actions should return error with SKIPPED list.""" + env, mcp = env_with_batch + tool = mcp._tool_manager._tools["batch"] + result = tool.fn(actions=[ + {"tool": "advance", "ticks": 100}, + {"tool": "get_game_state"}, + ]) + assert "error" in result + assert "advance:SKIPPED" in str(result.get("actions", [])) + + +class TestLossTrackingFixes: + """S3/S4: MCV deployment and husk decay should not be counted as losses.""" + + def test_mcv_deploy_not_loss(self): + """MCV disappearing + fact appearing should NOT trigger UNITS LOST.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._placement_results = [] + env._prev_buildings = {} + env._prev_unit_ids = {120: "mcv"} + env._last_obs = { + "buildings": [{"actor_id": 1, "type": "fact"}], + "units": [], + } + env._update_loss_tracking() + loss_alerts = [r for r in env._placement_results if "UNITS LOST" in r] + assert len(loss_alerts) == 0, f"MCV deployment should not be a loss: {loss_alerts}" + + def test_husk_decay_not_loss(self): + """Husk disappearing should NOT trigger UNITS LOST.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._placement_results = [] + env._prev_buildings = {1: "fact"} + env._prev_unit_ids = {200: "2tnk.husk"} + env._last_obs = { + "buildings": [{"actor_id": 1, "type": "fact"}], + "units": [], + } + env._update_loss_tracking() + loss_alerts = [r for r in env._placement_results if "UNITS LOST" in r] + assert len(loss_alerts) == 0, f"Husk decay should not be a loss: {loss_alerts}" + + def test_real_loss_still_tracked(self): + """Actual unit destruction should still be tracked.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._placement_results = [] + env._prev_buildings = {1: "fact"} + env._prev_unit_ids = {150: "e1", 151: "e1"} + env._last_obs = { + "buildings": [{"actor_id": 1, "type": "fact"}], + "units": [{"actor_id": 150, "type": "e1"}], + } + env._update_loss_tracking() + loss_alerts = [r for r in env._placement_results if "UNITS LOST" in r] + assert len(loss_alerts) == 1 + assert "1x e1" in loss_alerts[0] + + +class TestNoScoutingHistory: + """S6: NO SCOUTING alert should not fire after enemy has been seen.""" + + def test_no_scouting_fires_before_contact(self): + """Alert fires when enemies never seen.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._enemy_ever_seen = False + obs = { + "tick": 1000, + "visible_enemies": [], + "visible_enemy_buildings": [], + "units": [], "buildings": [], + "production": [], "economy": {"cash": 1000, "ore": 0}, + } + env._last_production_progress = {} + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + # Use the alert logic directly + alerts = [] + if obs.get("visible_enemies") or obs.get("visible_enemy_buildings"): + env._enemy_ever_seen = True + if obs["tick"] > 750 and not obs["visible_enemies"] and not obs.get("visible_enemy_buildings"): + if not env._enemy_ever_seen: + alerts.append("NO SCOUTING") + assert len(alerts) == 1 + + def test_no_scouting_suppressed_after_contact(self): + """Alert suppressed once enemy has been seen.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._enemy_ever_seen = True # enemy was seen before + alerts = [] + obs = {"tick": 5000, "visible_enemies": [], "visible_enemy_buildings": []} + if obs.get("visible_enemies") or obs.get("visible_enemy_buildings"): + env._enemy_ever_seen = True + if obs["tick"] > 750 and not obs["visible_enemies"] and not obs.get("visible_enemy_buildings"): + if not env._enemy_ever_seen: + alerts.append("NO SCOUTING") + assert len(alerts) == 0 + + +class TestTerrainNote: + """S7: get_terrain_at should return contextual note.""" + + def test_passable_terrain_note(self): + """Passable cell should say 'Passable terrain'.""" + import base64 + import struct + # Build minimal spatial map: 1 cell, 9 channels + channels = 9 + data = [0.0] * channels + data[0] = 2.0 # terrain_index = 2 (land) + data[3] = 1.0 # passable = 1.0 + raw = struct.pack(f"{channels}f", *data) + spatial = base64.b64encode(raw).decode() + + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._last_obs = { + "spatial_map": spatial, + "map_info": {"width": 1, "height": 1}, + "spatial_channels": channels, + } + env._refresh_obs = lambda: None + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + env._last_production_progress = {} + env._prev_buildings = {} + env._prev_unit_ids = {} + env._unit_groups = {} + env._enemy_ever_seen = False + env._register_tools(mcp) + + tool = mcp._tool_manager._tools["get_terrain_at"] + result = tool.fn(cell_x=0, cell_y=0) + assert result["passable"] is True + assert "Passable" in result["note"] + assert "Water" not in result["note"] + + def test_water_terrain_note(self): + """Impassable water cell should mention water.""" + import base64 + import struct + channels = 9 + data = [0.0] * channels + data[0] = 7.0 # terrain_index = 7 (water) + data[3] = 0.0 # passable = 0.0 + raw = struct.pack(f"{channels}f", *data) + spatial = base64.b64encode(raw).decode() + + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._last_obs = { + "spatial_map": spatial, + "map_info": {"width": 1, "height": 1}, + "spatial_channels": channels, + } + env._refresh_obs = lambda: None + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + env._last_production_progress = {} + env._prev_buildings = {} + env._prev_unit_ids = {} + env._unit_groups = {} + env._enemy_ever_seen = False + env._register_tools(mcp) + + tool = mcp._tool_manager._tools["get_terrain_at"] + result = tool.fn(cell_x=0, cell_y=0) + assert result["passable"] is False + assert "Water" in result["note"] + + +class TestAdvanceClamping: + """S8: advance() should report when ticks are clamped.""" + + @pytest.fixture + def env_with_advance(self): + """Create env for advance testing with mocked bridge.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + env._last_production_progress = {} + env._prev_buildings = {} + env._prev_unit_ids = {} + env._unit_groups = {} + env._enemy_ever_seen = False + env._state = MagicMock() + + obs_dict = { + "tick": 5000, "done": False, "result": "", + "economy": {"cash": 1000, "ore": 500, "power_provided": 200, + "power_drained": 80, "resource_capacity": 4000, + "harvester_count": 1}, + "units": [], "buildings": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "production": [], + "map_info": {"width": 128, "height": 128}, + } + env._last_obs = obs_dict + + # Mock the async bridge with a running loop in a background thread + loop = asyncio.new_event_loop() + import threading + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() + env._loop = loop + + mock_bridge = MagicMock() + async def mock_wait_ticks(t): + return MagicMock() + mock_bridge.wait_ticks = mock_wait_ticks + env._bridge = mock_bridge + + # Patch observation_to_dict + from openra_env.server import openra_environment + original_fn = openra_environment.observation_to_dict + openra_environment.observation_to_dict = lambda proto: obs_dict + + env._register_tools(mcp) + yield env, mcp + loop.call_soon_threadsafe(loop.stop) + thread.join(timeout=2) + loop.close() + openra_environment.observation_to_dict = original_fn + + def test_advance_clamp_note(self, env_with_advance): + """advance(1500) should include clamping note.""" + env, mcp = env_with_advance + tool = mcp._tool_manager._tools["advance"] + result = tool.fn(ticks=1500) + assert "note" in result + assert "1500" in result["note"] + assert "500" in result["note"] + + def test_advance_no_note_within_limit(self, env_with_advance): + """advance(100) should NOT include clamping note.""" + env, mcp = env_with_advance + tool = mcp._tool_manager._tools["advance"] + result = tool.fn(ticks=100) + assert "note" not in result + + +# ── Helpers for spatial data ──────────────────────────────────────────────── + + +def _make_spatial(width, height, channels=9, fog_values=None): + """Build a base64-encoded spatial map for testing. + + fog_values: dict mapping (x, y) -> fog float (default 0.0 = shroud). + Channel 3 = passability (1.0 for all), channel 4 = fog. + """ + import base64 + import struct + + data = [] + for y in range(height): + for x in range(width): + cell = [0.0] * channels + cell[3] = 1.0 # passable + if fog_values and (x, y) in fog_values: + cell[4] = fog_values[(x, y)] + data.extend(cell) + raw = struct.pack(f"{len(data)}f", *data) + return base64.b64encode(raw).decode() + + +def _make_env_with_tools(obs_dict): + """Create an env + mcp with tools registered and a given observation.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + env._last_obs = obs_dict + env._refresh_obs = lambda: None + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "russia" + env._last_production_progress = {} + env._prev_buildings = {} + env._prev_unit_ids = {} + env._unit_groups = {} + env._enemy_ever_seen = False + env._register_tools(mcp) + return env, mcp + + +# ── Type-based unit selector tests ────────────────────────────────────────── + + +class TestTypeBasedUnitSelectors: + """Test type: and category selectors in _resolve_unit_ids.""" + + def _make_obs(self): + return { + "units": [ + {"actor_id": 1, "type": "e1", "is_idle": True, "can_attack": True}, + {"actor_id": 2, "type": "e1", "is_idle": False, "can_attack": True}, + {"actor_id": 3, "type": "e3", "is_idle": True, "can_attack": True}, + {"actor_id": 4, "type": "1tnk", "is_idle": True, "can_attack": True}, + {"actor_id": 5, "type": "2tnk", "is_idle": False, "can_attack": True}, + {"actor_id": 6, "type": "harv", "is_idle": False, "can_attack": False}, + ], + } + + @pytest.fixture + def env(self): + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._unit_groups = {} + return env + + def test_type_selector_e1(self, env): + obs = self._make_obs() + result = env._resolve_unit_ids("type:e1", obs) + assert sorted(result) == [1, 2] + + def test_type_selector_1tnk(self, env): + obs = self._make_obs() + result = env._resolve_unit_ids("type:1tnk", obs) + assert result == [4] + + def test_type_selector_nonexistent(self, env): + obs = self._make_obs() + result = env._resolve_unit_ids("type:zzz", obs) + assert result == [] + + def test_type_selector_with_spaces(self, env): + obs = self._make_obs() + result = env._resolve_unit_ids("type: e1 ", obs) + assert sorted(result) == [1, 2] + + def test_all_infantry(self, env): + obs = self._make_obs() + result = env._resolve_unit_ids("all_infantry", obs) + # e1(1,2) and e3(3) are infantry + assert sorted(result) == [1, 2, 3] + + def test_all_vehicles(self, env): + obs = self._make_obs() + result = env._resolve_unit_ids("all_vehicles", obs) + # 1tnk(4), 2tnk(5), harv(6) are vehicles + assert sorted(result) == [4, 5, 6] + + def test_all_aircraft_empty(self, env): + obs = self._make_obs() + result = env._resolve_unit_ids("all_aircraft", obs) + assert result == [] + + def test_all_ships_empty(self, env): + obs = self._make_obs() + result = env._resolve_unit_ids("all_ships", obs) + assert result == [] + + +# ── Exploration stats in get_map_analysis tests ───────────────────────────── + + +class TestMapAnalysisExploration: + """Test that get_map_analysis includes exploration stats.""" + + def test_exploration_stats_present(self): + """get_map_analysis should include exploration section.""" + # 4x4 map, all fog = 1.0 (visible) + fog = {(x, y): 1.0 for x in range(4) for y in range(4)} + spatial = _make_spatial(4, 4, fog_values=fog) + obs = { + "units": [], + "buildings": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": spatial, + "spatial_channels": 9, + } + _, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_map_analysis"] + result = tool.fn() + assert "exploration" in result + assert result["exploration"]["explored_percent"] == 100.0 + assert result["exploration"]["unexplored_percent"] == 0.0 + assert result["exploration"]["visible_percent"] == 100.0 + + def test_exploration_partial(self): + """Half-explored map should show ~50%.""" + # 4x4 map, top half (y<2) visible, bottom half shroud + fog = {} + for y in range(4): + for x in range(4): + fog[(x, y)] = 1.0 if y < 2 else 0.0 + spatial = _make_spatial(4, 4, fog_values=fog) + obs = { + "units": [], + "buildings": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": spatial, + "spatial_channels": 9, + } + _, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_map_analysis"] + result = tool.fn() + assert result["exploration"]["explored_percent"] == 50.0 + + def test_quadrant_explored_percent(self): + """Quadrant summary should include explored_percent.""" + fog = {(x, y): 1.0 for x in range(4) for y in range(4)} + spatial = _make_spatial(4, 4, fog_values=fog) + obs = { + "units": [], + "buildings": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": spatial, + "spatial_channels": 9, + } + _, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_map_analysis"] + result = tool.fn() + for quad in ["NW", "NE", "SW", "SE"]: + assert "explored_percent" in result["quadrant_summary"][quad] + assert result["quadrant_summary"][quad]["explored_percent"] == 100.0 + + def test_no_spatial_data_no_exploration(self): + """Without spatial data, exploration section should not appear.""" + obs = { + "units": [], + "buildings": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": "", + "spatial_channels": 0, + } + _, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_map_analysis"] + result = tool.fn() + assert "exploration" not in result + + +# ── get_exploration_status tests ──────────────────────────────────────────── + + +class TestExplorationStatus: + """Test the get_exploration_status tool.""" + + def test_fully_explored(self): + """Fully visible map returns 100% explored.""" + fog = {(x, y): 1.0 for x in range(4) for y in range(4)} + spatial = _make_spatial(4, 4, fog_values=fog) + obs = { + "units": [ + {"actor_id": 1, "type": "e1", "cell_x": 1, "cell_y": 1, + "is_idle": True, "can_attack": True}, + ], + "buildings": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [{"actor_id": 99}], + "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": spatial, + "spatial_channels": 9, + } + env, mcp = _make_env_with_tools(obs) + env._enemy_ever_seen = True + tool = mcp._tool_manager._tools["get_exploration_status"] + result = tool.fn() + assert result["explored_percent"] == 100.0 + assert result["unexplored_percent"] == 0.0 + assert result["enemy_found"] is True + assert result["enemy_currently_visible"] == 1 + + def test_unexplored(self): + """All-shroud map returns 0% explored.""" + # All fog=0.0 (default) + spatial = _make_spatial(4, 4) + obs = { + "units": [], + "buildings": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": spatial, + "spatial_channels": 9, + } + _, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_exploration_status"] + result = tool.fn() + assert result["explored_percent"] == 0.0 + assert result["unexplored_percent"] == 100.0 + assert result["enemy_found"] is False + + def test_quadrant_exploration(self): + """Per-quadrant exploration should be reported.""" + # Only NW quadrant explored (x<2, y<2) + fog = {} + for y in range(4): + for x in range(4): + fog[(x, y)] = 1.0 if (x < 2 and y < 2) else 0.0 + spatial = _make_spatial(4, 4, fog_values=fog) + obs = { + "units": [], + "buildings": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": spatial, + "spatial_channels": 9, + } + _, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_exploration_status"] + result = tool.fn() + assert result["quadrant_exploration"]["NW"]["explored_percent"] == 100.0 + assert result["quadrant_exploration"]["NE"]["explored_percent"] == 0.0 + assert result["quadrant_exploration"]["SW"]["explored_percent"] == 0.0 + assert result["quadrant_exploration"]["SE"]["explored_percent"] == 0.0 + + def test_idle_counts(self): + """idle_combat_count and idle_infantry_count are correct.""" + spatial = _make_spatial(4, 4) + obs = { + "units": [ + {"actor_id": 1, "type": "e1", "cell_x": 1, "cell_y": 1, + "is_idle": True, "can_attack": True}, + {"actor_id": 2, "type": "e1", "cell_x": 2, "cell_y": 1, + "is_idle": True, "can_attack": True}, + {"actor_id": 3, "type": "1tnk", "cell_x": 3, "cell_y": 1, + "is_idle": True, "can_attack": True}, + {"actor_id": 4, "type": "harv", "cell_x": 1, "cell_y": 3, + "is_idle": False, "can_attack": False}, + ], + "buildings": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": spatial, + "spatial_channels": 9, + } + _, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_exploration_status"] + result = tool.fn() + assert result["idle_combat_count"] == 3 # e1, e1, 1tnk + assert result["idle_infantry_count"] == 2 # e1, e1 + + def test_base_position(self): + """Base position is computed from units+buildings.""" + spatial = _make_spatial(8, 8) + obs = { + "units": [ + {"actor_id": 1, "type": "e1", "cell_x": 2, "cell_y": 2, + "is_idle": True, "can_attack": True}, + ], + "buildings": [{"cell_x": 4, "cell_y": 4}], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 8, "height": 8, "map_name": "Test"}, + "spatial_map": spatial, + "spatial_channels": 9, + } + _, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_exploration_status"] + result = tool.fn() + assert result["base_position"] == {"x": 3, "y": 3} # avg of (2,2) and (4,4) + + def test_no_spatial_data(self): + """Without spatial data, returns 0% explored.""" + obs = { + "units": [], + "buildings": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": "", + "spatial_channels": 0, + } + _, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_exploration_status"] + result = tool.fn() + assert result["explored_percent"] == 0.0 + assert result["quadrant_exploration"] == {} + + +# ── Factual NO_SCOUTING alert tests ──────────────────────────────────────── + + +class TestFactualNoScoutingAlert: + """NO_SCOUTING alert should be fact-based, not prescriptive.""" + + def _make_obs_with_fog(self, tick=1000, fog_values=None, units=None): + spatial = _make_spatial(4, 4, fog_values=fog_values or {}) + return { + "tick": tick, + "done": False, + "result": "", + "economy": { + "cash": 5000, "ore": 1000, "power_provided": 200, + "power_drained": 80, "resource_capacity": 5000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, "buildings_killed": 0, + "buildings_lost": 0, "army_value": 0, "active_unit_count": 0, + }, + "units": units or [ + {"actor_id": 1, "type": "e1", "cell_x": 1, "cell_y": 1, + "pos_x": 1024, "pos_y": 1024, "hp_percent": 1.0, + "is_idle": True, "current_activity": "", "owner": "Multi1", + "can_attack": True, "facing": 0, "experience_level": 0, + "stance": 3, "speed": 56, "attack_range": 5120, + "passenger_count": -1, "is_building": False}, + ], + "buildings": [ + {"actor_id": 100, "type": "fact", "pos_x": 2048, "pos_y": 2048, + "hp_percent": 1.0, "owner": "Multi1", "is_producing": False, + "production_progress": 0.0, "producing_item": "", "is_powered": True, + "is_repairing": False, "sell_value": 500, "rally_x": -1, "rally_y": -1, + "power_amount": 0, "can_produce": ["powr"], "cell_x": 2, "cell_y": 2}, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": _make_spatial(4, 4, fog_values=fog_values or {}), + "spatial_channels": 9, + "available_production": ["e1"], + } + + def test_no_scouting_alert_is_factual(self): + """Alert should state facts: % explored and idle count.""" + obs = self._make_obs_with_fog(tick=1000) + env, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + scouting_alerts = [a for a in result["alerts"] if "NO SCOUTING" in a] + assert len(scouting_alerts) == 1 + alert = scouting_alerts[0] + # Should contain factual info + assert "enemy not found" in alert + assert "% of map explored" in alert + assert "idle combat units available" in alert + # Should NOT contain prescriptive language + assert "send a unit" not in alert + assert "explore the map" not in alert.replace("% of map explored", "") + + def test_no_scouting_alert_shows_exploration_percent(self): + """Alert should show actual exploration percentage.""" + # Half map explored + fog = {} + for y in range(4): + for x in range(4): + fog[(x, y)] = 1.0 if y < 2 else 0.0 + obs = self._make_obs_with_fog(tick=1000, fog_values=fog) + env, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + scouting_alerts = [a for a in result["alerts"] if "NO SCOUTING" in a] + assert len(scouting_alerts) == 1 + assert "50.0%" in scouting_alerts[0] + + def test_no_scouting_suppressed_after_enemy_found(self): + """Alert should not appear after enemy has been seen.""" + obs = self._make_obs_with_fog(tick=1000) + env, mcp = _make_env_with_tools(obs) + env._enemy_ever_seen = True + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + scouting_alerts = [a for a in result["alerts"] if "NO SCOUTING" in a] + assert len(scouting_alerts) == 0 + + def test_no_scouting_suppressed_early(self): + """Alert should not appear before tick 750.""" + obs = self._make_obs_with_fog(tick=500) + _, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + scouting_alerts = [a for a in result["alerts"] if "NO SCOUTING" in a] + assert len(scouting_alerts) == 0 + + +# ── Tool registration test for get_exploration_status ─────────────────────── + + +class TestExplorationStatusRegistration: + """get_exploration_status should be registered as a read tool.""" + + def test_registered(self): + obs = { + "units": [], "buildings": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_map": "", "spatial_channels": 0, + } + _, mcp = _make_env_with_tools(obs) + tool_names = set(mcp._tool_manager._tools.keys()) + assert "get_exploration_status" in tool_names + + def test_in_config_categories(self): + from openra_env.config import TOOL_CATEGORIES + assert "get_exploration_status" in TOOL_CATEGORIES + assert TOOL_CATEGORIES["get_exploration_status"] == "read" + + +# ── Build Confirmation & Guard Tests ────────────────────────────────────────── + + +class TestBuildConfirmationNotes: + """Build tools should return factual confirmation notes with tick estimates.""" + + @pytest.fixture + def env_build(self): + obs = { + "tick": 100, "done": False, "result": "", + "economy": { + "cash": 10000, "ore": 0, + "power_provided": 200, "power_drained": 50, + "resource_capacity": 5000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0, + }, + "units": [], + "buildings": [ + {"actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5}, + ], + "production": [], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test"}, + "available_production": ["e1", "e3", "powr", "proc", "barr", "tent"], + } + env, mcp = _make_env_with_tools(obs) + env._execute_commands = lambda cmds: { + "tick": 101, "done": False, "result": "", + "economy": obs["economy"], + "own_units": 0, "own_buildings": 1, + "visible_enemies": 0, "production": [], + } + return env, mcp + + def test_build_unit_returns_note(self, env_build): + env, mcp = env_build + tool = mcp._tool_manager._tools["build_unit"] + result = tool.fn(unit_type="e1", count=3) + assert "note" in result + assert "e1" in result["note"] + assert "3x" in result["note"] + # e1 costs $100 → 60 ticks each, 180 total + assert "60" in result["note"] + assert "180" in result["note"] + + def test_build_structure_returns_note(self, env_build): + env, mcp = env_build + tool = mcp._tool_manager._tools["build_structure"] + result = tool.fn(building_type="powr") + assert "note" in result + assert "powr" in result["note"] + # powr costs $300 → 180 ticks + assert "180" in result["note"] + + def test_build_and_place_returns_note(self, env_build): + env, mcp = env_build + tool = mcp._tool_manager._tools["build_and_place"] + result = tool.fn(building_type="powr") + assert "note" in result + assert "auto-places" in result["note"] + assert "180" in result["note"] + + +class TestPendingPlacementGuards: + """Prevent double-ordering or manual placement of auto-managed buildings.""" + + @pytest.fixture + def env_pending(self): + obs = { + "tick": 200, "done": False, "result": "", + "economy": { + "cash": 10000, "ore": 0, + "power_provided": 200, "power_drained": 50, + "resource_capacity": 5000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0, + }, + "units": [], + "buildings": [ + {"actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5}, + ], + "production": [ + {"queue_type": "Building", "item": "powr", "progress": 0.5, + "remaining_ticks": 90}, + ], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test"}, + "available_production": ["e1", "powr", "proc", "barr"], + } + env, mcp = _make_env_with_tools(obs) + env._pending_placements = {"powr": {"cell_x": 0, "cell_y": 0}} + return env, mcp + + def test_build_structure_rejects_pending(self, env_pending): + env, mcp = env_pending + tool = mcp._tool_manager._tools["build_structure"] + result = tool.fn(building_type="powr") + assert "note" in result + assert "already queued" in result["note"] + + def test_build_and_place_rejects_pending(self, env_pending): + env, mcp = env_pending + tool = mcp._tool_manager._tools["build_and_place"] + result = tool.fn(building_type="powr") + assert "note" in result + assert "already queued" in result["note"] + + def test_place_building_rejects_auto_managed(self, env_pending): + env, mcp = env_pending + tool = mcp._tool_manager._tools["place_building"] + result = tool.fn(building_type="powr") + assert "note" in result + assert "automatic" in result["note"] + + +class TestAlertPriorityAndCap: + """Alerts should be sorted by priority and capped by max_alerts.""" + + def test_alerts_sorted_by_priority(self): + """Higher priority alerts (lower number) come first.""" + obs = { + "tick": 1000, "done": False, "result": "", + "economy": { + "cash": 5000, "ore": 0, + "power_provided": 50, "power_drained": 100, + "resource_capacity": 5000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 500, "active_unit_count": 5, + }, + "units": [ + {"actor_id": i, "type": "e1", "pos_x": 100, "pos_y": 100, + "cell_x": 1, "cell_y": 1, "hp_percent": 1.0, "is_idle": True, + "current_activity": "", "owner": "Multi0", "can_attack": True, + "facing": 0, "experience_level": 0, "stance": 1, "speed": 71, + "attack_range": 5120, "passenger_count": -1, "is_building": False} + for i in range(10, 15) + ], + "buildings": [ + {"actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 0.3, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5}, + {"actor_id": 2, "type": "powr", "pos_x": 500, "pos_y": 600, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 150, + "rally_x": -1, "rally_y": -1, "power_amount": 100, + "can_produce": [], "cell_x": 5, "cell_y": 6}, + {"actor_id": 3, "type": "barr", "pos_x": 500, "pos_y": 700, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 7}, + {"actor_id": 4, "type": "proc", "pos_x": 500, "pos_y": 800, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 8}, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test"}, + "available_production": [], + } + env, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + alerts = result["alerts"] + # Should have: LOW POWER (priority 2), DAMAGED (priority 5), + # IDLE ARMY (priority 7), STANCE (priority 7) + assert any("LOW POWER" in a for a in alerts) + assert any("DAMAGED" in a for a in alerts) + # LOW POWER should come before DAMAGED + low_power_idx = next(i for i, a in enumerate(alerts) if "LOW POWER" in a) + damaged_idx = next(i for i, a in enumerate(alerts) if "DAMAGED" in a) + assert low_power_idx < damaged_idx + + def test_max_alerts_caps_output(self): + """max_alerts limits the number of alerts returned.""" + obs = { + "tick": 1000, "done": False, "result": "", + "economy": { + "cash": 5000, "ore": 0, + "power_provided": 50, "power_drained": 100, + "resource_capacity": 5000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 500, "active_unit_count": 5, + }, + "units": [ + {"actor_id": i, "type": "e1", "pos_x": 100, "pos_y": 100, + "cell_x": 1, "cell_y": 1, "hp_percent": 1.0, "is_idle": True, + "current_activity": "", "owner": "Multi0", "can_attack": True, + "facing": 0, "experience_level": 0, "stance": 1, "speed": 71, + "attack_range": 5120, "passenger_count": -1, "is_building": False} + for i in range(10, 15) + ], + "buildings": [ + {"actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 0.3, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5}, + {"actor_id": 2, "type": "powr", "pos_x": 500, "pos_y": 600, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 150, + "rally_x": -1, "rally_y": -1, "power_amount": 100, + "can_produce": [], "cell_x": 5, "cell_y": 6}, + {"actor_id": 3, "type": "barr", "pos_x": 500, "pos_y": 700, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 7}, + {"actor_id": 4, "type": "proc", "pos_x": 500, "pos_y": 800, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 8}, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test"}, + "available_production": [], + } + env, mcp = _make_env_with_tools(obs) + # Set max_alerts to 2 + env._app_config.alerts.max_alerts = 2 + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + assert len(result["alerts"]) <= 2 + + def test_max_alerts_zero_means_unlimited(self): + """max_alerts=0 means no cap (default).""" + obs = { + "tick": 1000, "done": False, "result": "", + "economy": { + "cash": 5000, "ore": 0, + "power_provided": 50, "power_drained": 100, + "resource_capacity": 5000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 500, "active_unit_count": 5, + }, + "units": [ + {"actor_id": i, "type": "e1", "pos_x": 100, "pos_y": 100, + "cell_x": 1, "cell_y": 1, "hp_percent": 1.0, "is_idle": True, + "current_activity": "", "owner": "Multi0", "can_attack": True, + "facing": 0, "experience_level": 0, "stance": 1, "speed": 71, + "attack_range": 5120, "passenger_count": -1, "is_building": False} + for i in range(10, 15) + ], + "buildings": [ + {"actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 0.3, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5}, + {"actor_id": 2, "type": "powr", "pos_x": 500, "pos_y": 600, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 150, + "rally_x": -1, "rally_y": -1, "power_amount": 100, + "can_produce": [], "cell_x": 5, "cell_y": 6}, + {"actor_id": 3, "type": "barr", "pos_x": 500, "pos_y": 700, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 7}, + {"actor_id": 4, "type": "proc", "pos_x": 500, "pos_y": 800, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 8}, + ], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test"}, + "available_production": [], + } + env, mcp = _make_env_with_tools(obs) + assert env._app_config.alerts.max_alerts == 0 + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + # Should have multiple alerts (LOW POWER, DAMAGED, IDLE ARMY, STANCE, etc.) + assert len(result["alerts"]) >= 3 + + +class TestProductionItemsTicks: + """Production items in get_game_state should include remaining ticks.""" + + def test_production_items_include_ticks(self): + obs = { + "tick": 500, "done": False, "result": "", + "economy": { + "cash": 5000, "ore": 0, + "power_provided": 200, "power_drained": 50, + "resource_capacity": 5000, "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0, + }, + "units": [], + "buildings": [ + {"actor_id": 1, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 0, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5}, + ], + "production": [ + {"queue_type": "Building", "item": "powr", "progress": 0.45, + "remaining_ticks": 99}, + {"queue_type": "Defense", "item": "e1", "progress": 0.8, + "remaining_ticks": 12}, + ], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test"}, + "available_production": [], + } + env, mcp = _make_env_with_tools(obs) + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + items = result["production_items"] + assert "powr@45%(~99 ticks)" in items[0] + assert "e1@80%(~12 ticks)" in items[1] + + +class TestEstimateBuildTicks: + """Test the _estimate_build_ticks helper.""" + + def test_powr_300_cost(self): + from openra_env.server.openra_environment import _estimate_build_ticks + assert _estimate_build_ticks(300) == 180 # 300 * 60 / 100 + + def test_e1_100_cost(self): + from openra_env.server.openra_environment import _estimate_build_ticks + assert _estimate_build_ticks(100) == 60 + + def test_proc_2000_cost(self): + from openra_env.server.openra_environment import _estimate_build_ticks + assert _estimate_build_ticks(2000) == 1200 + + def test_zero_cost(self): + from openra_env.server.openra_environment import _estimate_build_ticks + assert _estimate_build_ticks(0) == 0 + + +# ─── Movement ETA Tests ───────────────────────────────────────────────────── + + +class TestMovementETA: + """Test the _estimate_move_ticks helper and ETA in unit feedback.""" + + def test_estimate_basic(self): + from openra_env.server.openra_environment import _estimate_move_ticks + # e1 speed=56, moving 20 cells: 20*1024/56 = 365 + assert _estimate_move_ticks(56, 0, 0, 10, 10) == 20 * 1024 // 56 + + def test_estimate_zero_speed(self): + from openra_env.server.openra_environment import _estimate_move_ticks + assert _estimate_move_ticks(0, 0, 0, 10, 10) == 0 + + def test_estimate_same_position(self): + from openra_env.server.openra_environment import _estimate_move_ticks + assert _estimate_move_ticks(56, 5, 5, 5, 5) == 0 + + def test_estimate_fast_unit(self): + from openra_env.server.openra_environment import _estimate_move_ticks + # 1tnk speed=113, 10 cells + eta = _estimate_move_ticks(113, 0, 0, 5, 5) + assert eta == 10 * 1024 // 113 + + def test_unit_feedback_includes_eta(self): + """_add_unit_feedback adds eta_ticks when target is provided.""" + from openra_env.config import OpenRARLConfig + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + env._last_obs = { + "units": [ + {"actor_id": 1, "type": "e1", "cell_x": 10, "cell_y": 10, + "speed": 56, "current_activity": "Move"}, + ] + } + result = {} + env._add_unit_feedback(result, [1], target_x=20, target_y=10) + assert "commanded_units" in result + unit = result["commanded_units"][0] + assert "eta_ticks" in unit + assert "eta_seconds" in unit + assert unit["eta_ticks"] == 10 * 1024 // 56 + assert "note" in result + assert "ticks" in result["note"] + + def test_unit_feedback_no_eta_without_target(self): + """_add_unit_feedback omits eta when no target provided.""" + from openra_env.config import OpenRARLConfig + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + env._last_obs = { + "units": [ + {"actor_id": 1, "type": "e1", "cell_x": 10, "cell_y": 10, + "speed": 56, "current_activity": "Idle"}, + ] + } + result = {} + env._add_unit_feedback(result, [1]) + unit = result["commanded_units"][0] + assert "eta_ticks" not in unit + assert "note" not in result + + def test_unit_feedback_slowest_eta_in_note(self): + """ETA note uses the slowest unit's arrival time.""" + from openra_env.config import OpenRARLConfig + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + env._last_obs = { + "units": [ + {"actor_id": 1, "type": "e1", "cell_x": 0, "cell_y": 0, + "speed": 56, "current_activity": "Move"}, + {"actor_id": 2, "type": "1tnk", "cell_x": 0, "cell_y": 0, + "speed": 113, "current_activity": "Move"}, + ] + } + result = {} + env._add_unit_feedback(result, [1, 2], target_x=10, target_y=0) + # e1 is slower, so its ETA should be in the note + e1_eta = 10 * 1024 // 56 + assert str(e1_eta) in result["note"] + + +# ─── Enhanced Compression Tests ────────────────────────────────────────────── + + +class TestEnhancedCompression: + """Test the enhanced compress_history function.""" + + def test_trigger_threshold_default(self): + """Default trigger = keep_last * 2.""" + from openra_env.agent import compress_history + messages = [ + {"role": "system", "content": "sys"}, + *[{"role": "user", "content": f"m{i}"} for i in range(50)], + ] + # keep_last=40, trigger=0 → threshold=80. 51 < 80, no compression. + result = compress_history(messages, keep_last=40) + assert len(result) == 51 + + def test_trigger_threshold_custom(self): + """Custom trigger fires earlier.""" + from openra_env.agent import compress_history + messages = [ + {"role": "system", "content": "sys"}, + *[{"role": "user", "content": f"m{i}"} for i in range(50)], + ] + # trigger=30, 51 > 30, should compress to keep_last=10 + system + summary + result = compress_history(messages, keep_last=10, trigger=30) + assert len(result) == 12 # system + summary + 10 recent + + def test_strategy_extraction(self): + """Compression summary includes planning strategy.""" + from openra_env.agent import compress_history + messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "Game started!\nStrategy: Rush with tanks"}, + *[{"role": "user", "content": f"m{i}"} for i in range(60)], + ] + result = compress_history(messages, keep_last=10, trigger=20) + summary = result[1]["content"] + assert "Strategy: Rush with tanks" in summary + + def test_strategy_disabled(self): + """Compression skips strategy when include_strategy=False.""" + from openra_env.agent import compress_history + from openra_env.config import CompressionConfig + messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "Strategy: Rush with tanks"}, + *[{"role": "user", "content": f"m{i}"} for i in range(60)], + ] + comp = CompressionConfig(include_strategy=False) + result = compress_history(messages, keep_last=10, trigger=20, compression=comp) + summary = result[1]["content"] + assert "Strategy:" not in summary + + def test_military_stats_extraction(self): + """Compression summary includes military stats from state snapshots.""" + import json + from openra_env.agent import compress_history + state = { + "tick": 5000, "economy": {"cash": 1200}, + "own_units": 8, "own_buildings": 6, + "military": {"units_killed": 3, "units_lost": 1} + } + messages = [ + {"role": "system", "content": "sys"}, + {"role": "tool", "content": json.dumps(state)}, + *[{"role": "user", "content": f"m{i}"} for i in range(60)], + ] + result = compress_history(messages, keep_last=10, trigger=20) + summary = result[1]["content"] + assert "3 kills" in summary + assert "1 loss" in summary + + def test_military_disabled(self): + """Military stats skipped when include_military=False.""" + import json + from openra_env.agent import compress_history + from openra_env.config import CompressionConfig + state = { + "tick": 5000, "economy": {"cash": 1200}, + "own_units": 8, "own_buildings": 6, + "military": {"units_killed": 3, "units_lost": 1} + } + messages = [ + {"role": "system", "content": "sys"}, + {"role": "tool", "content": json.dumps(state)}, + *[{"role": "user", "content": f"m{i}"} for i in range(60)], + ] + comp = CompressionConfig(include_military=False) + result = compress_history(messages, keep_last=10, trigger=20, compression=comp) + summary = result[1]["content"] + assert "kills" not in summary + + def test_production_tracking(self): + """Compression summary tracks produced unit types.""" + import json + from openra_env.agent import compress_history + messages = [ + {"role": "system", "content": "sys"}, + {"role": "tool", "content": json.dumps({"note": "'e1' ($100 each) queued. ~60 ticks per unit"})}, + {"role": "tool", "content": json.dumps({"note": "'1tnk' ($800 each) queued. ~480 ticks per unit"})}, + *[{"role": "user", "content": f"m{i}"} for i in range(60)], + ] + result = compress_history(messages, keep_last=10, trigger=20) + summary = result[1]["content"] + assert "Units produced:" in summary + assert "e1" in summary + assert "1tnk" in summary + + def test_production_disabled(self): + """Production tracking skipped when include_production=False.""" + import json + from openra_env.agent import compress_history + from openra_env.config import CompressionConfig + messages = [ + {"role": "system", "content": "sys"}, + {"role": "tool", "content": json.dumps({"note": "'e1' ($100 each) queued. ~60 ticks per unit"})}, + *[{"role": "user", "content": f"m{i}"} for i in range(60)], + ] + comp = CompressionConfig(include_production=False) + result = compress_history(messages, keep_last=10, trigger=20, compression=comp) + summary = result[1]["content"] + assert "Units produced:" not in summary + + def test_error_tracking(self): + """Compression summary includes recent errors.""" + import json + from openra_env.agent import compress_history + messages = [ + {"role": "system", "content": "sys"}, + {"role": "tool", "content": json.dumps({"placement_failed": True})}, + *[{"role": "user", "content": f"m{i}"} for i in range(60)], + ] + result = compress_history(messages, keep_last=10, trigger=20) + summary = result[1]["content"] + assert "placement failed" in summary + + +# ─── State Briefing Format Tests ───────────────────────────────────────────── + + +class TestStateBriefingFormat: + """Test format_state_briefing shows unit activity and destination.""" + + def test_idle_unit_no_arrow(self): + from openra_env.agent import format_state_briefing + state = { + "tick": 100, "economy": {"cash": 500, "ore": 0, "harvester_count": 1}, + "power_balance": 10, "own_units": 1, "own_buildings": 0, + "units_summary": [ + {"id": 1, "type": "e1", "cell_x": 10, "cell_y": 10, + "idle": True, "can_attack": True, "stance": 0, "activity": "Idle"} + ], + "buildings_summary": [], "enemy_summary": [], "production_items": [], + "alerts": [], + } + text = format_state_briefing(state) + assert "1@(10,10)" in text + assert "→" not in text.split("Units:")[1].split("|")[0] # no arrow for idle + + def test_moving_unit_with_target(self): + from openra_env.agent import format_state_briefing + state = { + "tick": 200, "economy": {"cash": 500, "ore": 0, "harvester_count": 1}, + "power_balance": 10, "own_units": 1, "own_buildings": 0, + "units_summary": [ + {"id": 1, "type": "e1", "cell_x": 10, "cell_y": 10, + "idle": False, "can_attack": True, "stance": 0, "activity": "Move", + "target_x": 30, "target_y": 20} + ], + "buildings_summary": [], "enemy_summary": [], "production_items": [], + "alerts": [], + } + text = format_state_briefing(state) + assert "1@(10,10)→(30,20)" in text + + def test_moving_unit_without_target_shows_activity(self): + from openra_env.agent import format_state_briefing + state = { + "tick": 300, "economy": {"cash": 500, "ore": 0, "harvester_count": 1}, + "power_balance": 10, "own_units": 1, "own_buildings": 0, + "units_summary": [ + {"id": 1, "type": "e1", "cell_x": 10, "cell_y": 10, + "idle": False, "can_attack": True, "stance": 0, "activity": "AttackMove"} + ], + "buildings_summary": [], "enemy_summary": [], "production_items": [], + "alerts": [], + } + text = format_state_briefing(state) + # Should show short activity tag + assert "→att" in text + + def test_mixed_idle_and_moving(self): + from openra_env.agent import format_state_briefing + state = { + "tick": 400, "economy": {"cash": 500, "ore": 0, "harvester_count": 1}, + "power_balance": 10, "own_units": 2, "own_buildings": 0, + "units_summary": [ + {"id": 1, "type": "e1", "cell_x": 5, "cell_y": 5, + "idle": True, "can_attack": True, "stance": 0, "activity": "Idle"}, + {"id": 2, "type": "e1", "cell_x": 10, "cell_y": 10, + "idle": False, "can_attack": True, "stance": 0, "activity": "Move", + "target_x": 20, "target_y": 15}, + ], + "buildings_summary": [], "enemy_summary": [], "production_items": [], + "alerts": [], + } + text = format_state_briefing(state) + assert "1@(5,5)" in text # idle, no arrow + assert "2@(10,10)→(20,15)" in text # moving with target + + +# ─── Defense Placement Bias Tests ──────────────────────────────────────────── + + +class TestDefensePlacementBias: + """Defense buildings should be placed toward the enemy, not behind the base.""" + + def test_defense_placed_toward_enemy(self): + """A gun turret should be placed on the enemy side of the CY.""" + from openra_env.config import OpenRARLConfig + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + + # CY at (10, 10), enemy is to the right (high x) + obs = { + "buildings": [{"actor_id": 1, "type": "fact", "cell_x": 10, "cell_y": 10}], + "visible_enemies": [{"actor_id": 99, "type": "e1", "cell_x": 50, "cell_y": 10}], + "visible_enemy_buildings": [], + "map_info": {"width": 64, "height": 64}, + } + candidates = env._find_placement_candidates("gun", obs) + assert len(candidates) > 0 + # Top candidate should be to the RIGHT of CY (toward enemy at x=50) + best = candidates[0] + assert best["cell_x"] > 10, f"Defense placed at x={best['cell_x']}, expected > 10 (toward enemy)" + + def test_non_defense_closest_to_cy(self): + """A non-defense building (powr) should still sort by distance from CY.""" + from openra_env.config import OpenRARLConfig + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + + obs = { + "buildings": [{"actor_id": 1, "type": "fact", "cell_x": 10, "cell_y": 10}], + "visible_enemies": [{"actor_id": 99, "type": "e1", "cell_x": 50, "cell_y": 10}], + "visible_enemy_buildings": [], + "map_info": {"width": 64, "height": 64}, + } + candidates = env._find_placement_candidates("powr", obs) + assert len(candidates) > 0 + # powr is NOT defense — should be sorted by distance, not biased toward enemy + best = candidates[0] + assert best["distance"] <= 4, f"Non-defense building placed too far: dist={best['distance']}" + + def test_defense_uses_estimated_enemy_when_none_visible(self): + """Defense bias works even with no visible enemies (uses map opposite corner).""" + from openra_env.config import OpenRARLConfig + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + + # CY at (10, 10) on 64x64 map — enemy estimated at ~(54, 54) + obs = { + "buildings": [{"actor_id": 1, "type": "fact", "cell_x": 10, "cell_y": 10}], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 64, "height": 64}, + } + candidates = env._find_placement_candidates("pbox", obs) + assert len(candidates) > 0 + best = candidates[0] + # Should bias toward bottom-right (enemy direction) + assert best["cell_x"] >= 10 and best["cell_y"] >= 10 + + +# ─── Minimap Tests ────────────────────────────────────────────────────────── + + +class TestRenderMinimap: + """Test _render_minimap() ASCII minimap generation.""" + + def _make_spatial(self, width, height, channels=9, fill=None): + """Build a spatial_map (base64 encoded float32 tensor). + + fill: dict mapping (x, y, channel) -> float value. + All unset values default to 0.0. + """ + import base64 + import struct + + data = bytearray(width * height * channels * 4) + fill = fill or {} + for (x, y, ch), val in fill.items(): + idx = ((y * width + x) * channels + ch) * 4 + struct.pack_into("f", data, idx, val) + return base64.b64encode(bytes(data)).decode() + + def test_empty_obs_returns_empty(self): + from openra_env.server.openra_environment import _render_minimap + assert _render_minimap({}) == "" + assert _render_minimap({"map_info": {"width": 0, "height": 0}}) == "" + + def test_no_spatial_data_returns_empty(self): + from openra_env.server.openra_environment import _render_minimap + obs = { + "map_info": {"width": 10, "height": 10}, + "spatial_channels": 9, + "spatial_map": "", + } + assert _render_minimap(obs) == "" + + def test_all_unexplored(self): + from openra_env.server.openra_environment import _render_minimap + # 4x4 map, all fog=0 -> all '#' + obs = { + "map_info": {"width": 4, "height": 4}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(4, 4), + "buildings": [], "units": [], + "visible_enemies": [], "visible_enemy_buildings": [], + } + result = _render_minimap(obs, max_cols=4) + lines = result.strip().split("\n") + # header + 4 rows + legend = 6 lines + assert len(lines) == 6 + for row in lines[1:5]: + assert all(c == "#" for c in row), f"Expected all '#', got: {row}" + + def test_explored_shows_dot(self): + from openra_env.server.openra_environment import _render_minimap + # 4x4 map, all explored (fog ch4 > 0.25) + fill = {} + for y in range(4): + for x in range(4): + fill[(x, y, 4)] = 1.0 # fog = fully visible + fill[(x, y, 3)] = 1.0 # passable + obs = { + "map_info": {"width": 4, "height": 4}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(4, 4, fill=fill), + "buildings": [], "units": [], + "visible_enemies": [], "visible_enemy_buildings": [], + } + result = _render_minimap(obs, max_cols=4) + lines = result.strip().split("\n") + for row in lines[1:5]: + assert all(c == "." for c in row), f"Expected '.', got: {row}" + + def test_water_shows_tilde(self): + from openra_env.server.openra_environment import _render_minimap + fill = {} + for y in range(4): + for x in range(4): + fill[(x, y, 4)] = 1.0 # explored + fill[(x, y, 3)] = 0.0 # impassable (water) + obs = { + "map_info": {"width": 4, "height": 4}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(4, 4, fill=fill), + "buildings": [], "units": [], + "visible_enemies": [], "visible_enemy_buildings": [], + } + result = _render_minimap(obs, max_cols=4) + lines = result.strip().split("\n") + for row in lines[1:5]: + assert all(c == "~" for c in row), f"Expected '~', got: {row}" + + def test_resources_show_dollar(self): + from openra_env.server.openra_environment import _render_minimap + fill = { + (1, 1, 4): 1.0, # explored + (1, 1, 3): 1.0, # passable + (1, 1, 2): 0.5, # has resources + } + obs = { + "map_info": {"width": 4, "height": 4}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(4, 4, fill=fill), + "buildings": [], "units": [], + "visible_enemies": [], "visible_enemy_buildings": [], + } + result = _render_minimap(obs, max_cols=4) + lines = result.strip().split("\n") + assert lines[2][1] == "$" # row 1, col 1 + + def test_own_building_overlay(self): + from openra_env.server.openra_environment import _render_minimap + fill = {} + for y in range(4): + for x in range(4): + fill[(x, y, 4)] = 1.0 + fill[(x, y, 3)] = 1.0 + obs = { + "map_info": {"width": 4, "height": 4}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(4, 4, fill=fill), + "buildings": [{"cell_x": 2, "cell_y": 1}], + "units": [], + "visible_enemies": [], "visible_enemy_buildings": [], + } + result = _render_minimap(obs, max_cols=4) + lines = result.strip().split("\n") + assert lines[2][2] == "B" # row 1, col 2 + + def test_own_unit_overlay(self): + from openra_env.server.openra_environment import _render_minimap + fill = {} + for y in range(4): + for x in range(4): + fill[(x, y, 4)] = 1.0 + fill[(x, y, 3)] = 1.0 + obs = { + "map_info": {"width": 4, "height": 4}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(4, 4, fill=fill), + "buildings": [], + "units": [{"cell_x": 0, "cell_y": 0}], + "visible_enemies": [], "visible_enemy_buildings": [], + } + result = _render_minimap(obs, max_cols=4) + lines = result.strip().split("\n") + assert lines[1][0] == "@" # row 0, col 0 + + def test_enemy_building_overlay(self): + from openra_env.server.openra_environment import _render_minimap + fill = {} + for y in range(4): + for x in range(4): + fill[(x, y, 4)] = 1.0 + fill[(x, y, 3)] = 1.0 + obs = { + "map_info": {"width": 4, "height": 4}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(4, 4, fill=fill), + "buildings": [], + "units": [], + "visible_enemies": [], + "visible_enemy_buildings": [{"cell_x": 3, "cell_y": 3}], + } + result = _render_minimap(obs, max_cols=4) + lines = result.strip().split("\n") + assert lines[4][3] == "X" # row 3, col 3 + + def test_enemy_unit_overlay(self): + from openra_env.server.openra_environment import _render_minimap + fill = {} + for y in range(4): + for x in range(4): + fill[(x, y, 4)] = 1.0 + fill[(x, y, 3)] = 1.0 + obs = { + "map_info": {"width": 4, "height": 4}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(4, 4, fill=fill), + "buildings": [], + "units": [], + "visible_enemies": [{"cell_x": 1, "cell_y": 2}], + "visible_enemy_buildings": [], + } + result = _render_minimap(obs, max_cols=4) + lines = result.strip().split("\n") + assert lines[3][1] == "!" # row 2, col 1 + + def test_priority_enemy_over_own(self): + """Enemy unit should override own building at same cell.""" + from openra_env.server.openra_environment import _render_minimap + fill = {} + for y in range(4): + for x in range(4): + fill[(x, y, 4)] = 1.0 + fill[(x, y, 3)] = 1.0 + obs = { + "map_info": {"width": 4, "height": 4}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(4, 4, fill=fill), + "buildings": [{"cell_x": 1, "cell_y": 1}], + "units": [{"cell_x": 1, "cell_y": 1}], + "visible_enemies": [{"cell_x": 1, "cell_y": 1}], + "visible_enemy_buildings": [], + } + result = _render_minimap(obs, max_cols=4) + lines = result.strip().split("\n") + assert lines[2][1] == "!" # enemy unit wins + + def test_downsampling(self): + """Large map should downsample to ~max_cols width.""" + from openra_env.server.openra_environment import _render_minimap + fill = {} + for y in range(64): + for x in range(128): + fill[(x, y, 4)] = 1.0 + fill[(x, y, 3)] = 1.0 + obs = { + "map_info": {"width": 128, "height": 64}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(128, 64, fill=fill), + "buildings": [], "units": [], + "visible_enemies": [], "visible_enemy_buildings": [], + } + result = _render_minimap(obs, max_cols=28) + lines = result.strip().split("\n") + # First line is header + assert "Map (" in lines[0] + # Data rows should be ~28 chars wide (ceil(128/5)=26) + data_rows = lines[1:-1] # skip header and legend + for row in data_rows: + assert len(row) <= 28 + + def test_header_and_legend(self): + from openra_env.server.openra_environment import _render_minimap + obs = { + "map_info": {"width": 4, "height": 4}, + "spatial_channels": 9, + "spatial_map": self._make_spatial(4, 4), + "buildings": [], "units": [], + "visible_enemies": [], "visible_enemy_buildings": [], + } + result = _render_minimap(obs, max_cols=4) + assert result.startswith("Map (") + assert "YOUR:" in result + assert "ENEMY:" in result + + def test_get_game_state_includes_minimap(self): + """get_game_state result should have minimap and enemy_buildings_summary.""" + from openra_env.config import OpenRARLConfig + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + env._register_tools(mcp) + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = True + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + + env._last_obs = { + "tick": 100, "done": False, "result": "", + "economy": {"cash": 1000, "ore": 0, "power_provided": 100, + "power_drained": 50, "resource_capacity": 5000, + "harvester_count": 1}, + "military": {"units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0}, + "units": [], "buildings": [], "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [ + {"actor_id": 50, "type": "powr", "cell_x": 30, "cell_y": 30, + "hp_percent": 1.0, "owner": "Multi1"}, + ], + "map_info": {"width": 8, "height": 8, "map_name": "Test"}, + "spatial_channels": 9, + "spatial_map": "", + "available_production": [], + } + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + assert "minimap" in result + assert "enemy_buildings_summary" in result + assert len(result["enemy_buildings_summary"]) == 1 + assert result["enemy_buildings_summary"][0]["type"] == "powr" + + def test_minimap_disabled_by_config(self): + """When alerts.minimap=False, minimap should be empty.""" + from openra_env.config import OpenRARLConfig + cfg = OpenRARLConfig() + cfg.alerts.minimap = False + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = cfg + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + env._register_tools(mcp) + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = True + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + + import base64 + import struct + # 4x4 fully explored map + data = bytearray(4 * 4 * 9 * 4) + for y in range(4): + for x in range(4): + idx = ((y * 4 + x) * 9 + 4) * 4 + struct.pack_into("f", data, idx, 1.0) + idx = ((y * 4 + x) * 9 + 3) * 4 + struct.pack_into("f", data, idx, 1.0) + spatial = base64.b64encode(bytes(data)).decode() + + env._last_obs = { + "tick": 100, "done": False, "result": "", + "economy": {"cash": 1000, "ore": 0, "power_provided": 100, + "power_drained": 50, "resource_capacity": 5000, + "harvester_count": 1}, + "military": {"units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0}, + "units": [], "buildings": [], "production": [], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_channels": 9, + "spatial_map": spatial, + "available_production": [], + } + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + assert result["minimap"] == "" + + +class TestBriefingMinimap: + """Test format_state_briefing includes minimap and enemy buildings.""" + + def test_briefing_includes_minimap(self): + from openra_env.agent import format_state_briefing + state = { + "tick": 100, "economy": {"cash": 500, "ore": 0, "harvester_count": 1}, + "power_balance": 10, "own_units": 0, "own_buildings": 0, + "units_summary": [], "buildings_summary": [], + "enemy_summary": [], "enemy_buildings_summary": [], + "production_items": [], "alerts": [], + "minimap": "Map (4x4, 1cell=1x1):\n....\n....\n....\n....\nYOUR: B=building @=unit | ENEMY: X=building !=unit | terrain: .=land ~=water $=ore #=unexplored", + } + text = format_state_briefing(state) + assert "Map (4x4" in text + assert "YOUR:" in text + + def test_briefing_omits_empty_minimap(self): + from openra_env.agent import format_state_briefing + state = { + "tick": 100, "economy": {"cash": 500, "ore": 0, "harvester_count": 1}, + "power_balance": 10, "own_units": 0, "own_buildings": 0, + "units_summary": [], "buildings_summary": [], + "enemy_summary": [], "enemy_buildings_summary": [], + "production_items": [], "alerts": [], + "minimap": "", + } + text = format_state_briefing(state) + assert "Map (" not in text + + def test_briefing_includes_enemy_buildings(self): + from openra_env.agent import format_state_briefing + state = { + "tick": 100, "economy": {"cash": 500, "ore": 0, "harvester_count": 1}, + "power_balance": 10, "own_units": 0, "own_buildings": 0, + "units_summary": [], "buildings_summary": [], + "enemy_summary": [], + "enemy_buildings_summary": [ + {"id": 50, "type": "powr", "cell_x": 40, "cell_y": 40}, + {"id": 51, "type": "fact", "cell_x": 42, "cell_y": 40}, + ], + "production_items": [], "alerts": [], + "minimap": "", + } + text = format_state_briefing(state) + assert "powr" in text + assert "fact" in text + assert "center" in text # center position shown + + def test_briefing_enemy_units_and_buildings(self): + from openra_env.agent import format_state_briefing + state = { + "tick": 100, "economy": {"cash": 500, "ore": 0, "harvester_count": 1}, + "power_balance": 10, "own_units": 0, "own_buildings": 0, + "units_summary": [], "buildings_summary": [], + "enemy_summary": [ + {"id": 99, "type": "e1", "cell_x": 30, "cell_y": 30}, + ], + "enemy_buildings_summary": [ + {"id": 50, "type": "powr", "cell_x": 40, "cell_y": 40}, + ], + "production_items": [], "alerts": [], + "minimap": "", + } + text = format_state_briefing(state) + assert "1xe1" in text + assert "1xpowr" in text + + +# ─── Actor Existence & Production Queue Validation Tests ───────────────────── + + +class TestActorValidation: + """Test that actor-based tools validate actor existence before sending commands.""" + + @pytest.fixture + def env_with_actors(self): + """Create env with known units and buildings.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "england" + + env._last_obs = { + "tick": 500, + "done": False, + "result": "", + "economy": { + "cash": 3000, "ore": 500, "power_provided": 200, + "power_drained": 80, "resource_capacity": 4000, + "harvester_count": 1, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 1000, "active_unit_count": 2, + }, + "units": [ + { + "actor_id": 10, "type": "mcv", "pos_x": 1000, "pos_y": 2000, + "cell_x": 10, "cell_y": 20, "hp_percent": 1.0, + "is_idle": True, "current_activity": "", + "owner": "Multi0", "can_attack": False, "facing": 0, + "experience_level": 0, "stance": 3, "speed": 56, + "attack_range": 0, "passenger_count": -1, "is_building": False, + }, + { + "actor_id": 20, "type": "harv", "pos_x": 2000, "pos_y": 3000, + "cell_x": 20, "cell_y": 30, "hp_percent": 1.0, + "is_idle": False, "current_activity": "Harvest", + "owner": "Multi0", "can_attack": False, "facing": 0, + "experience_level": 0, "stance": 3, "speed": 40, + "attack_range": 0, "passenger_count": -1, "is_building": False, + }, + ], + "buildings": [ + { + "actor_id": 100, "type": "fact", "pos_x": 500, "pos_y": 500, + "hp_percent": 1.0, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 500, + "rally_x": -1, "rally_y": -1, "power_amount": 0, + "can_produce": [], "cell_x": 5, "cell_y": 5, + }, + { + "actor_id": 101, "type": "powr", "pos_x": 600, "pos_y": 600, + "hp_percent": 0.8, "owner": "Multi0", "is_producing": False, + "production_progress": 0.0, "producing_item": "", + "is_powered": True, "is_repairing": False, "sell_value": 150, + "rally_x": -1, "rally_y": -1, "power_amount": 100, + "can_produce": [], "cell_x": 6, "cell_y": 6, + }, + ], + "production": [ + {"type": "e1", "progress": 50, "paused": False}, + ], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": ["e1", "e3", "powr", "tent", "proc"], + } + + env._refresh_obs = lambda: None + + env._register_tools(mcp) + return env, mcp + + # ── deploy_unit ── + + def test_deploy_unit_rejects_missing_actor(self, env_with_actors): + env, mcp = env_with_actors + tool = mcp._tool_manager._tools["deploy_unit"] + result = tool.fn(unit_id=999) + assert "error" in result + assert "999" in result["error"] + assert "your_units" in result + + def test_deploy_unit_accepts_valid_actor(self, env_with_actors): + env, mcp = env_with_actors + env._execute_commands = lambda cmds: {"tick": 501, "done": False, "result": ""} + tool = mcp._tool_manager._tools["deploy_unit"] + result = tool.fn(unit_id=10) + assert "error" not in result + + # ── sell_building ── + + def test_sell_building_rejects_missing_actor(self, env_with_actors): + env, mcp = env_with_actors + tool = mcp._tool_manager._tools["sell_building"] + result = tool.fn(building_id=999) + assert "error" in result + assert "999" in result["error"] + assert "your_buildings" in result + + def test_sell_building_accepts_valid_actor(self, env_with_actors): + env, mcp = env_with_actors + env._execute_commands = lambda cmds: {"tick": 501, "done": False, "result": ""} + tool = mcp._tool_manager._tools["sell_building"] + result = tool.fn(building_id=100) + assert "error" not in result + + # ── repair_building ── + + def test_repair_building_rejects_missing_actor(self, env_with_actors): + env, mcp = env_with_actors + tool = mcp._tool_manager._tools["repair_building"] + result = tool.fn(building_id=999) + assert "error" in result + assert "your_buildings" in result + + def test_repair_building_accepts_valid_actor(self, env_with_actors): + env, mcp = env_with_actors + env._execute_commands = lambda cmds: {"tick": 501, "done": False, "result": ""} + tool = mcp._tool_manager._tools["repair_building"] + result = tool.fn(building_id=101) + assert "error" not in result + + # ── set_rally_point ── + + def test_set_rally_point_rejects_missing_actor(self, env_with_actors): + env, mcp = env_with_actors + tool = mcp._tool_manager._tools["set_rally_point"] + result = tool.fn(building_id=999, cell_x=10, cell_y=10) + assert "error" in result + assert "your_buildings" in result + + def test_set_rally_point_accepts_valid_actor(self, env_with_actors): + env, mcp = env_with_actors + env._execute_commands = lambda cmds: {"tick": 501, "done": False, "result": ""} + tool = mcp._tool_manager._tools["set_rally_point"] + result = tool.fn(building_id=100, cell_x=10, cell_y=10) + assert "error" not in result + + # ── harvest ── + + def test_harvest_rejects_missing_actor(self, env_with_actors): + env, mcp = env_with_actors + tool = mcp._tool_manager._tools["harvest"] + result = tool.fn(unit_id=999) + assert "error" in result + assert "your_units" in result + + def test_harvest_accepts_valid_actor(self, env_with_actors): + env, mcp = env_with_actors + env._execute_commands = lambda cmds: {"tick": 501, "done": False, "result": ""} + tool = mcp._tool_manager._tools["harvest"] + result = tool.fn(unit_id=20) + assert "error" not in result + + # ── power_down ── + + def test_power_down_rejects_missing_actor(self, env_with_actors): + env, mcp = env_with_actors + tool = mcp._tool_manager._tools["power_down"] + result = tool.fn(building_id=999) + assert "error" in result + assert "your_buildings" in result + + def test_power_down_accepts_valid_actor(self, env_with_actors): + env, mcp = env_with_actors + env._execute_commands = lambda cmds: {"tick": 501, "done": False, "result": ""} + tool = mcp._tool_manager._tools["power_down"] + result = tool.fn(building_id=101) + assert "error" not in result + + # ── set_primary ── + + def test_set_primary_rejects_missing_actor(self, env_with_actors): + env, mcp = env_with_actors + tool = mcp._tool_manager._tools["set_primary"] + result = tool.fn(building_id=999) + assert "error" in result + assert "your_buildings" in result + + def test_set_primary_accepts_valid_actor(self, env_with_actors): + env, mcp = env_with_actors + env._execute_commands = lambda cmds: {"tick": 501, "done": False, "result": ""} + tool = mcp._tool_manager._tools["set_primary"] + result = tool.fn(building_id=100) + assert "error" not in result + + # ── cancel_production ── + + def test_cancel_production_rejects_item_not_in_queue(self, env_with_actors): + env, mcp = env_with_actors + tool = mcp._tool_manager._tools["cancel_production"] + result = tool.fn(item_type="3tnk") + assert "error" in result + assert "3tnk" in result["error"] + assert "current_queue" in result + + def test_cancel_production_accepts_item_in_queue(self, env_with_actors): + env, mcp = env_with_actors + env._execute_commands = lambda cmds: {"tick": 501, "done": False, "result": ""} + tool = mcp._tool_manager._tools["cancel_production"] + result = tool.fn(item_type="e1") + assert "error" not in result + + def test_cancel_production_case_insensitive(self, env_with_actors): + env, mcp = env_with_actors + env._execute_commands = lambda cmds: {"tick": 501, "done": False, "result": ""} + tool = mcp._tool_manager._tools["cancel_production"] + result = tool.fn(item_type="E1") + assert "error" not in result + + +class TestEmptyProductionValidation: + """Test that production tools reject commands when no production buildings exist.""" + + @pytest.fixture + def env_no_production(self): + """Create env with empty available_production (all factories destroyed).""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + from fastmcp import FastMCP + mcp = FastMCP("openra-test") + + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = False + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._player_faction = "england" + + env._last_obs = { + "tick": 8000, + "done": False, + "result": "", + "economy": { + "cash": 500, "ore": 100, "power_provided": 0, + "power_drained": 0, "resource_capacity": 4000, + "harvester_count": 0, + }, + "military": { + "units_killed": 0, "units_lost": 3, + "buildings_killed": 0, "buildings_lost": 4, + "army_value": 0, "active_unit_count": 0, + }, + "units": [], + "buildings": [], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 128, "height": 128, "map_name": "Test Map"}, + "available_production": [], # Empty! All production buildings destroyed. + } + + env._refresh_obs = lambda: None + env._register_tools(mcp) + return env, mcp + + def test_build_unit_empty_production_returns_error(self, env_no_production): + env, mcp = env_no_production + tool = mcp._tool_manager._tools["build_unit"] + result = tool.fn(unit_type="e1") + assert "error" in result + assert "No production" in result["error"] + + def test_build_structure_empty_production_returns_error(self, env_no_production): + env, mcp = env_no_production + tool = mcp._tool_manager._tools["build_structure"] + result = tool.fn(building_type="powr") + assert "error" in result + assert "available_buildings" in result + assert result["available_buildings"] == [] + + def test_build_and_place_empty_production_returns_error(self, env_no_production): + env, mcp = env_no_production + tool = mcp._tool_manager._tools["build_and_place"] + result = tool.fn(building_type="powr") + assert "error" in result + assert "available_buildings" in result + assert result["available_buildings"] == [] + + +class TestActionToCommandsValidation: + """Test that _action_to_commands validates actors and production in batch/plan context.""" + + @pytest.fixture + def env_for_batch(self): + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + env._pending_placements = {} + return env + + def test_build_unit_empty_production_returns_empty(self, env_for_batch): + env = env_for_batch + obs = {"available_production": [], "units": [], "buildings": [], "production": []} + result = env._action_to_commands({"tool": "build_unit", "unit_type": "e1"}, obs) + assert result == [] + + def test_build_unit_unavailable_returns_empty(self, env_for_batch): + env = env_for_batch + obs = {"available_production": ["e1", "e3"], "units": [], "buildings": [], "production": []} + result = env._action_to_commands({"tool": "build_unit", "unit_type": "3tnk"}, obs) + assert result == [] + + def test_build_unit_available_returns_commands(self, env_for_batch): + env = env_for_batch + obs = {"available_production": ["e1", "e3"], "units": [], "buildings": [], "production": []} + result = env._action_to_commands({"tool": "build_unit", "unit_type": "e1"}, obs) + assert len(result) == 1 + assert result[0].action == ActionType.TRAIN + + def test_build_structure_empty_production_returns_empty(self, env_for_batch): + env = env_for_batch + obs = {"available_production": [], "units": [], "buildings": [], "production": []} + result = env._action_to_commands({"tool": "build_structure", "building_type": "powr"}, obs) + assert result == [] + + def test_build_and_place_empty_production_returns_empty(self, env_for_batch): + env = env_for_batch + obs = {"available_production": [], "units": [], "buildings": [], "production": []} + result = env._action_to_commands({"tool": "build_and_place", "building_type": "powr", "cell_x": 5, "cell_y": 5}, obs) + assert result == [] + + def test_deploy_unit_missing_actor_returns_empty(self, env_for_batch): + env = env_for_batch + obs = {"units": [{"actor_id": 10, "type": "mcv"}], "buildings": [], "production": []} + result = env._action_to_commands({"tool": "deploy_unit", "unit_id": 999}, obs) + assert result == [] + + def test_deploy_unit_valid_actor_returns_command(self, env_for_batch): + env = env_for_batch + obs = {"units": [{"actor_id": 10, "type": "mcv"}], "buildings": [], "production": []} + result = env._action_to_commands({"tool": "deploy_unit", "unit_id": 10}, obs) + assert len(result) == 1 + assert result[0].action == ActionType.DEPLOY + + def test_repair_building_missing_actor_returns_empty(self, env_for_batch): + env = env_for_batch + obs = {"units": [], "buildings": [{"actor_id": 100, "type": "fact"}], "production": []} + result = env._action_to_commands({"tool": "repair_building", "building_id": 999}, obs) + assert result == [] + + def test_set_rally_point_missing_actor_returns_empty(self, env_for_batch): + env = env_for_batch + obs = {"units": [], "buildings": [{"actor_id": 100, "type": "fact"}], "production": []} + result = env._action_to_commands({"tool": "set_rally_point", "building_id": 999, "cell_x": 5, "cell_y": 5}, obs) + assert result == [] + + def test_harvest_missing_actor_returns_empty(self, env_for_batch): + env = env_for_batch + obs = {"units": [{"actor_id": 20, "type": "harv"}], "buildings": [], "production": []} + result = env._action_to_commands({"tool": "harvest", "unit_id": 999}, obs) + assert result == [] + + def test_cancel_production_item_not_in_queue_returns_empty(self, env_for_batch): + env = env_for_batch + obs = {"units": [], "buildings": [], "production": [{"type": "e1", "progress": 50}]} + result = env._action_to_commands({"tool": "cancel_production", "item_type": "3tnk"}, obs) + assert result == [] + + def test_cancel_production_item_in_queue_returns_command(self, env_for_batch): + env = env_for_batch + obs = {"units": [], "buildings": [], "production": [{"type": "e1", "progress": 50}]} + result = env._action_to_commands({"tool": "cancel_production", "item_type": "e1"}, obs) + assert len(result) == 1 + assert result[0].action == ActionType.CANCEL_PRODUCTION + + +# ─── Exploration & Reward Vector Tests ────────────────────────────────────── + + +class TestExplorationPercent: + """Test that get_game_state includes explored_percent from spatial tensor.""" + + def _make_env(self): + from openra_env.config import OpenRARLConfig + from fastmcp import FastMCP + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + env._accumulated_reward_vector = {} + mcp = FastMCP("openra-test") + env._register_tools(mcp) + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = True + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._placement_results = [] + env._prev_buildings = {} + env._prev_unit_ids = {} + return env, mcp + + def test_explored_percent_no_spatial_data(self): + """Without spatial data, explored_percent should be 0.""" + env, mcp = self._make_env() + env._last_obs = { + "tick": 50, "done": False, "result": "", + "economy": {"cash": 1000, "ore": 0, "power_provided": 100, + "power_drained": 50, "resource_capacity": 5000, + "harvester_count": 1}, + "military": {"units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0}, + "units": [], "buildings": [], "production": [], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_channels": 0, "spatial_map": "", + "available_production": [], + } + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + assert "explored_percent" in result + assert result["explored_percent"] == 0.0 + + def test_explored_percent_with_spatial_data(self): + """With spatial data, explored_percent should reflect fog channel.""" + import base64 + import struct + + env, mcp = self._make_env() + w, h, ch = 4, 4, 9 + # Build spatial tensor: 4x4 map, 9 channels + # Channel 4 is fog: >0.25 = explored + data = bytearray(w * h * ch * 4) + for i in range(w * h): + offset = (i * ch + 4) * 4 + # Half explored (first 8 cells), half shroud (last 8) + val = 1.0 if i < 8 else 0.0 + struct.pack_into("f", data, offset, val) + + env._last_obs = { + "tick": 50, "done": False, "result": "", + "economy": {"cash": 1000, "ore": 0, "power_provided": 100, + "power_drained": 50, "resource_capacity": 5000, + "harvester_count": 1}, + "military": {"units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0}, + "units": [], "buildings": [], "production": [], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": w, "height": h, "map_name": "Test"}, + "spatial_channels": ch, + "spatial_map": base64.b64encode(bytes(data)).decode(), + "available_production": [], + } + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + assert result["explored_percent"] == 50.0 + + def test_explored_percent_fully_explored(self): + """All cells explored → 100%.""" + import base64 + import struct + + env, mcp = self._make_env() + w, h, ch = 2, 2, 9 + data = bytearray(w * h * ch * 4) + for i in range(w * h): + struct.pack_into("f", data, (i * ch + 4) * 4, 1.0) + + env._last_obs = { + "tick": 10, "done": False, "result": "", + "economy": {"cash": 0, "ore": 0, "power_provided": 0, + "power_drained": 0, "resource_capacity": 0, + "harvester_count": 0}, + "military": {"units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0}, + "units": [], "buildings": [], "production": [], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": w, "height": h, "map_name": "Test"}, + "spatial_channels": ch, + "spatial_map": base64.b64encode(bytes(data)).decode(), + "available_production": [], + } + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + assert result["explored_percent"] == 100.0 + + +class TestRewardVectorAccumulation: + """Test that reward vector accumulates across steps and appears in get_game_state.""" + + def test_reward_vector_in_game_state(self): + """get_game_state should include reward_vector field.""" + from openra_env.config import OpenRARLConfig + from fastmcp import FastMCP + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + env._accumulated_reward_vector = {"combat": 0.5, "economy": 0.3} + mcp = FastMCP("openra-test") + env._register_tools(mcp) + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = True + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._placement_results = [] + env._prev_buildings = {} + env._prev_unit_ids = {} + env._last_obs = { + "tick": 50, "done": False, "result": "", + "economy": {"cash": 1000, "ore": 0, "power_provided": 100, + "power_drained": 50, "resource_capacity": 5000, + "harvester_count": 1}, + "military": {"units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0}, + "units": [], "buildings": [], "production": [], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_channels": 0, "spatial_map": "", + "available_production": [], + } + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + assert "reward_vector" in result + assert result["reward_vector"]["combat"] == 0.5 + assert result["reward_vector"]["economy"] == 0.3 + + def test_reward_vector_empty_when_no_steps(self): + """Before any steps, reward_vector should be empty dict.""" + from openra_env.config import OpenRARLConfig + from fastmcp import FastMCP + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + env._accumulated_reward_vector = {} + mcp = FastMCP("openra-test") + env._register_tools(mcp) + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = True + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._placement_results = [] + env._prev_buildings = {} + env._prev_unit_ids = {} + env._last_obs = { + "tick": 0, "done": False, "result": "", + "economy": {"cash": 0, "ore": 0, "power_provided": 0, + "power_drained": 0, "resource_capacity": 0, + "harvester_count": 0}, + "military": {"units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0}, + "units": [], "buildings": [], "production": [], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": 4, "height": 4, "map_name": "Test"}, + "spatial_channels": 0, "spatial_map": "", + "available_production": [], + } + tool = mcp._tool_manager._tools["get_game_state"] + result = tool.fn() + assert result["reward_vector"] == {} + + def test_accumulated_vector_sums_correctly(self): + """Simulating multiple accumulation steps.""" + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._accumulated_reward_vector = {} + + # Step 1 + vec1 = {"combat": 0.1, "economy": 0.2, "intelligence": 0.05} + for k, v in vec1.items(): + env._accumulated_reward_vector[k] = env._accumulated_reward_vector.get(k, 0.0) + v + + # Step 2 + vec2 = {"combat": 0.3, "economy": -0.1, "tempo": 0.5} + for k, v in vec2.items(): + env._accumulated_reward_vector[k] = env._accumulated_reward_vector.get(k, 0.0) + v + + assert abs(env._accumulated_reward_vector["combat"] - 0.4) < 1e-9 + assert abs(env._accumulated_reward_vector["economy"] - 0.1) < 1e-9 + assert abs(env._accumulated_reward_vector["intelligence"] - 0.05) < 1e-9 + assert abs(env._accumulated_reward_vector["tempo"] - 0.5) < 1e-9 + + +class TestStartPlanningRewardDimensions: + """Test that start_planning_phase includes reward_dimensions.""" + + def test_reward_dimensions_in_planning(self): + from openra_env.config import OpenRARLConfig + from openra_env.models import OpenRAState + from fastmcp import FastMCP + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + env._app_config = OpenRARLConfig() + env._accumulated_reward_vector = {} + mcp = FastMCP("openra-test") + env._planning_active = False + env._planning_strategy = "" + env._planning_enabled = True + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._player_faction = "russia" + env._enemy_faction = "england" + env._unit_groups = {} + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._PLACEABLE_QUEUE_TYPES = {"Building", "Defense"} + env._state = OpenRAState() + + class FakeConfig: + bot_type = "normal" + env._config = FakeConfig() + + env._register_tools(mcp) + env._last_obs = { + "tick": 0, "done": False, "result": "", + "economy": {"cash": 5000, "ore": 0, "power_provided": 100, + "power_drained": 0, "resource_capacity": 5000, + "harvester_count": 1}, + "military": {"units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0}, + "units": [], + "buildings": [ + {"actor_id": 1, "type": "fact", "pos_x": 5120, "pos_y": 5120, + "hp_percent": 1.0, "owner": "Multi0", "can_produce": ["powr"], + "cell_x": 5, "cell_y": 5}, + ], + "production": [], + "visible_enemies": [], "visible_enemy_buildings": [], + "map_info": {"width": 64, "height": 64, "map_name": "Test"}, + "spatial_channels": 0, "spatial_map": "", + "available_production": ["powr"], + } + tool = mcp._tool_manager._tools["start_planning_phase"] + result = tool.fn() + assert "reward_dimensions" in result + rd = result["reward_dimensions"] + assert "combat" in rd + assert "economy" in rd + assert "intelligence" in rd + assert "outcome" in rd + assert len(rd) == 8 + + +# ── Bench Export Tests ────────────────────────────────────────────────────── + + +class TestBenchExportJson: + """Tests for the bench export JSON built in agent.py scorecard.""" + + def _build_submission(self, mil=None, final=None, replay=None, model="test/model"): + """Build a bench submission dict the same way agent.py does.""" + from datetime import datetime, timezone + + mil = mil or {"kills_cost": 1000, "deaths_cost": 500, "assets_value": 8000} + final = final or {"result": "loss", "tick": 5000, "explored_percent": 45.0, "reward_vector": {"combat": 0.5}} + replay = replay or {"path": "/tmp/test.orarep"} + + return { + "agent_name": model, + "agent_type": "LLM", + "opponent": "Beginner", + "games": 1, + "result": final.get("result", ""), + "win": final.get("result") == "win", + "ticks": final.get("tick", 0), + "kills_cost": mil.get("kills_cost", 0), + "deaths_cost": mil.get("deaths_cost", 0), + "kd_ratio": round(mil.get("kills_cost", 0) / max(mil.get("deaths_cost", 1), 1), 2), + "assets_value": mil.get("assets_value", 0), + "explored_percent": final.get("explored_percent", 0), + "reward_vector": final.get("reward_vector", {}), + "replay_path": replay.get("path", ""), + "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + } + + def test_required_fields_present(self): + """Bench submission JSON must have all fields the API expects.""" + sub = self._build_submission() + required = {"agent_name", "agent_type", "opponent", "result", "ticks", + "kills_cost", "deaths_cost", "assets_value"} + missing = required - set(sub.keys()) + assert not missing, f"Missing required fields: {missing}" + + def test_kd_ratio_handles_zero_deaths(self): + """K/D ratio should not crash when deaths_cost is 0.""" + sub = self._build_submission(mil={"kills_cost": 500, "deaths_cost": 0, "assets_value": 3000}) + assert sub["kd_ratio"] == 500.0 + + def test_win_flag_matches_result(self): + """win boolean should be True only when result is 'win'.""" + loss = self._build_submission(final={"result": "loss", "tick": 100, "explored_percent": 0, "reward_vector": {}}) + assert loss["win"] is False + + win = self._build_submission(final={"result": "win", "tick": 100, "explored_percent": 0, "reward_vector": {}}) + assert win["win"] is True + + def test_json_serializable(self): + """Submission dict must be fully JSON-serializable.""" + import json + sub = self._build_submission() + serialized = json.dumps(sub) + roundtripped = json.loads(serialized) + assert roundtripped["agent_name"] == "test/model" + assert roundtripped["kills_cost"] == 1000 + + def test_model_slug_in_filename(self): + """Export filename should contain a sanitized model slug.""" + from datetime import datetime, timezone + model = "qwen/qwen3-coder-next" + slug = model.replace("/", "_")[:40] + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + filename = f"bench-{slug}-{ts}.json" + assert "qwen_qwen3-coder-next" in filename + assert "/" not in filename diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..41cb487216ee4edd73cf24730b760abdf3e149b9 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,312 @@ +"""Tests for OpenRA-RL Pydantic models.""" + +import pytest + +from openra_env.models import ( + ActionType, + BuildingInfoModel, + CommandModel, + EconomyInfo, + MapInfoModel, + MilitaryInfo, + OpenRAAction, + OpenRAObservation, + OpenRAState, + ProductionInfoModel, + UnitInfoModel, +) + + +class TestActionType: + def test_enum_values(self): + assert ActionType.NO_OP == "no_op" + assert ActionType.MOVE == "move" + assert ActionType.ATTACK == "attack" + assert ActionType.BUILD == "build" + assert ActionType.TRAIN == "train" + + def test_enum_from_string(self): + assert ActionType("move") == ActionType.MOVE + assert ActionType("no_op") == ActionType.NO_OP + + def test_all_action_types_exist(self): + expected = { + "no_op", "move", "attack_move", "attack", "stop", + "harvest", "build", "train", "deploy", "sell", + "repair", "place_building", "cancel_production", "set_rally_point", + "guard", "set_stance", "enter_transport", "unload", + "power_down", "set_primary", "surrender", + } + actual = {a.value for a in ActionType} + assert actual == expected + + +class TestCommandModel: + def test_minimal_command(self): + cmd = CommandModel(action=ActionType.NO_OP) + assert cmd.action == ActionType.NO_OP + assert cmd.actor_id == 0 + assert cmd.target_x == 0 + assert cmd.queued is False + + def test_move_command(self): + cmd = CommandModel( + action=ActionType.MOVE, + actor_id=42, + target_x=100, + target_y=200, + ) + assert cmd.action == ActionType.MOVE + assert cmd.actor_id == 42 + assert cmd.target_x == 100 + assert cmd.target_y == 200 + + def test_attack_command(self): + cmd = CommandModel( + action=ActionType.ATTACK, + actor_id=10, + target_actor_id=99, + ) + assert cmd.target_actor_id == 99 + + def test_build_command(self): + cmd = CommandModel( + action=ActionType.BUILD, + item_type="powr", + ) + assert cmd.item_type == "powr" + + def test_serialization_roundtrip(self): + cmd = CommandModel( + action=ActionType.MOVE, + actor_id=5, + target_x=10, + target_y=20, + queued=True, + ) + data = cmd.model_dump() + restored = CommandModel(**data) + assert restored == cmd + + +class TestOpenRAAction: + def test_empty_action(self): + action = OpenRAAction() + assert action.commands == [] + + def test_single_command(self): + action = OpenRAAction( + commands=[CommandModel(action=ActionType.NO_OP)] + ) + assert len(action.commands) == 1 + + def test_multiple_commands(self): + action = OpenRAAction( + commands=[ + CommandModel(action=ActionType.MOVE, actor_id=1, target_x=10, target_y=20), + CommandModel(action=ActionType.ATTACK, actor_id=2, target_actor_id=99), + CommandModel(action=ActionType.BUILD, item_type="powr"), + ] + ) + assert len(action.commands) == 3 + assert action.commands[0].action == ActionType.MOVE + assert action.commands[1].action == ActionType.ATTACK + assert action.commands[2].action == ActionType.BUILD + + def test_serialization_roundtrip(self): + action = OpenRAAction( + commands=[ + CommandModel(action=ActionType.MOVE, actor_id=1, target_x=10, target_y=20), + ] + ) + data = action.model_dump() + restored = OpenRAAction(**data) + assert len(restored.commands) == 1 + assert restored.commands[0].actor_id == 1 + + +class TestEconomyInfo: + def test_defaults(self): + eco = EconomyInfo() + assert eco.cash == 0 + assert eco.ore == 0 + assert eco.power_provided == 0 + assert eco.power_drained == 0 + assert eco.resource_capacity == 0 + assert eco.harvester_count == 0 + + def test_with_values(self): + eco = EconomyInfo(cash=5000, power_provided=100, power_drained=60, harvester_count=2) + assert eco.cash == 5000 + assert eco.power_provided == 100 + assert eco.power_drained == 60 + assert eco.harvester_count == 2 + + +class TestMilitaryInfo: + def test_defaults(self): + mil = MilitaryInfo() + assert mil.units_killed == 0 + assert mil.units_lost == 0 + assert mil.army_value == 0 + + def test_with_values(self): + mil = MilitaryInfo(units_killed=5, units_lost=2, army_value=3000) + assert mil.units_killed == 5 + assert mil.units_lost == 2 + assert mil.army_value == 3000 + + +class TestUnitInfoModel: + def test_required_fields(self): + unit = UnitInfoModel(actor_id=1, type="e1") + assert unit.actor_id == 1 + assert unit.type == "e1" + assert unit.hp_percent == 1.0 + assert unit.is_idle is True + + def test_full_unit(self): + unit = UnitInfoModel( + actor_id=42, + type="1tnk", + pos_x=1024, + pos_y=2048, + cell_x=4, + cell_y=8, + hp_percent=0.75, + is_idle=False, + current_activity="Attack", + owner="Nod", + can_attack=True, + ) + assert unit.hp_percent == 0.75 + assert unit.is_idle is False + assert unit.can_attack is True + + +class TestBuildingInfoModel: + def test_required_fields(self): + bldg = BuildingInfoModel(actor_id=10, type="powr") + assert bldg.actor_id == 10 + assert bldg.type == "powr" + assert bldg.is_powered is True + + def test_producing_building(self): + bldg = BuildingInfoModel( + actor_id=20, + type="barr", + is_producing=True, + production_progress=0.5, + producing_item="e1", + ) + assert bldg.is_producing is True + assert bldg.producing_item == "e1" + + +class TestProductionInfoModel: + def test_required_fields(self): + prod = ProductionInfoModel(queue_type="Infantry", item="e1") + assert prod.queue_type == "Infantry" + assert prod.item == "e1" + assert prod.progress == 0.0 + assert prod.paused is False + + +class TestMapInfoModel: + def test_defaults(self): + m = MapInfoModel() + assert m.width == 0 + assert m.height == 0 + assert m.map_name == "" + + def test_with_values(self): + m = MapInfoModel(width=128, height=128, map_name="Allied vs Soviet") + assert m.width == 128 + assert m.map_name == "Allied vs Soviet" + + +class TestOpenRAObservation: + def test_default_observation(self): + obs = OpenRAObservation() + assert obs.tick == 0 + assert obs.units == [] + assert obs.buildings == [] + assert obs.done is False + assert obs.result == "" + + def test_full_observation(self): + obs = OpenRAObservation( + tick=150, + economy=EconomyInfo(cash=5000, power_provided=100), + military=MilitaryInfo(units_killed=3), + units=[ + UnitInfoModel(actor_id=1, type="e1"), + UnitInfoModel(actor_id=2, type="1tnk"), + ], + buildings=[ + BuildingInfoModel(actor_id=10, type="powr"), + ], + production=[ + ProductionInfoModel(queue_type="Infantry", item="e1", progress=0.5), + ], + visible_enemies=[ + UnitInfoModel(actor_id=99, type="e1", owner="Enemy"), + ], + map_info=MapInfoModel(width=128, height=128), + available_production=["e1", "e3", "1tnk"], + done=False, + reward=0.5, + result="", + ) + assert obs.tick == 150 + assert len(obs.units) == 2 + assert len(obs.buildings) == 1 + assert len(obs.production) == 1 + assert len(obs.visible_enemies) == 1 + assert obs.economy.cash == 5000 + assert obs.available_production == ["e1", "e3", "1tnk"] + + def test_terminal_observation(self): + obs = OpenRAObservation(done=True, result="win", reward=1.0) + assert obs.done is True + assert obs.result == "win" + + def test_serialization_roundtrip(self): + obs = OpenRAObservation( + tick=100, + economy=EconomyInfo(cash=3000), + units=[UnitInfoModel(actor_id=1, type="e1")], + ) + data = obs.model_dump() + restored = OpenRAObservation(**data) + assert restored.tick == 100 + assert restored.economy.cash == 3000 + assert len(restored.units) == 1 + + +class TestOpenRAState: + def test_defaults(self): + state = OpenRAState() + assert state.game_tick == 0 + assert state.map_name == "" + assert state.opponent_type == "bot_normal" + assert state.step_count == 0 + + def test_with_values(self): + state = OpenRAState( + episode_id="abc123", + step_count=50, + game_tick=500, + map_name="Test Map", + opponent_type="bot_hard", + ) + assert state.episode_id == "abc123" + assert state.step_count == 50 + assert state.game_tick == 500 + + def test_serialization_roundtrip(self): + state = OpenRAState(episode_id="test", game_tick=100) + data = state.model_dump() + restored = OpenRAState(**data) + assert restored.episode_id == "test" + assert restored.game_tick == 100 diff --git a/tests/test_planning.py b/tests/test_planning.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3081a2e68ccdaa9604b9f2f83a6801fcd4b948 --- /dev/null +++ b/tests/test_planning.py @@ -0,0 +1,722 @@ +"""Tests for pre-game planning phase: opponent intel module and planning MCP tools.""" + +import time + +import pytest + +from openra_env.opponent_intel import ( + AI_PROFILES, + get_opponent_profile, + get_opponent_summary, +) + + +# ─── Opponent Intel Module Tests ───────────────────────────────────────────── + + +class TestAIProfiles: + def test_all_difficulties_present(self): + assert "beginner" in AI_PROFILES + assert "easy" in AI_PROFILES + assert "medium" in AI_PROFILES + assert "normal" in AI_PROFILES + assert "hard" in AI_PROFILES + + def test_profiles_have_required_fields(self): + required = { + "difficulty", + "display_name", + "aggressiveness", + "expansion_tendency", + "unit_diversity", + "build_order_quality", + "estimated_win_rate_vs_new_player", + "typical_first_attack_tick", + "behavioral_traits", + "recommended_counters", + "typical_army_composition", + "recent_match_history", + } + for difficulty, profile in AI_PROFILES.items(): + missing = required - set(profile.keys()) + assert not missing, f"Profile '{difficulty}' missing fields: {missing}" + + def test_win_rates_are_valid(self): + for difficulty, profile in AI_PROFILES.items(): + rate = profile["estimated_win_rate_vs_new_player"] + assert 0.0 <= rate <= 1.0, f"Profile '{difficulty}' has invalid win rate: {rate}" + + def test_attack_ticks_are_positive(self): + for difficulty, profile in AI_PROFILES.items(): + assert profile["typical_first_attack_tick"] > 0 + + def test_army_composition_sums_to_one(self): + for difficulty, profile in AI_PROFILES.items(): + total = sum(profile["typical_army_composition"].values()) + assert abs(total - 1.0) < 0.01, f"Profile '{difficulty}' army composition sums to {total}" + + def test_match_history_has_results(self): + for difficulty, profile in AI_PROFILES.items(): + history = profile["recent_match_history"] + assert len(history) >= 3, f"Profile '{difficulty}' has too few matches" + for match in history: + assert match["result"] in ("win", "loss") + assert match["duration_ticks"] > 0 + assert match["score"] > 0 + + def test_normal_ai_is_aggressive(self): + """Normal AI should be aggressive per user requirements.""" + profile = AI_PROFILES["normal"] + assert profile["aggressiveness"] == "high" + assert profile["expansion_tendency"] == "high" + + def test_difficulty_ordering(self): + """Harder difficulties should have higher win rates and earlier attacks.""" + beginner = AI_PROFILES["beginner"] + easy = AI_PROFILES["easy"] + medium = AI_PROFILES["medium"] + normal = AI_PROFILES["normal"] + hard = AI_PROFILES["hard"] + assert beginner["estimated_win_rate_vs_new_player"] < easy["estimated_win_rate_vs_new_player"] + assert easy["estimated_win_rate_vs_new_player"] < medium["estimated_win_rate_vs_new_player"] + assert medium["estimated_win_rate_vs_new_player"] < normal["estimated_win_rate_vs_new_player"] + assert normal["estimated_win_rate_vs_new_player"] < hard["estimated_win_rate_vs_new_player"] + assert beginner["typical_first_attack_tick"] > easy["typical_first_attack_tick"] + assert easy["typical_first_attack_tick"] > medium["typical_first_attack_tick"] + assert medium["typical_first_attack_tick"] > normal["typical_first_attack_tick"] + assert normal["typical_first_attack_tick"] > hard["typical_first_attack_tick"] + + +class TestGetOpponentProfile: + def test_get_by_difficulty(self): + for key in ("beginner", "easy", "medium", "normal", "hard"): + profile = get_opponent_profile(key) + assert profile is not None + assert profile["difficulty"].lower() == key + + def test_strips_bot_prefix(self): + profile = get_opponent_profile("bot_normal") + assert profile is not None + assert profile["difficulty"] == "Normal" + + def test_case_insensitive(self): + assert get_opponent_profile("NORMAL") is not None + assert get_opponent_profile("Bot_Hard") is not None + + def test_unknown_returns_none(self): + assert get_opponent_profile("impossible") is None + assert get_opponent_profile("") is None + + +class TestGetOpponentSummary: + def test_summary_contains_key_sections(self): + summary = get_opponent_summary("normal") + assert "Opponent Scouting Report" in summary + assert "Aggressiveness" in summary + assert "Behavioral traits" in summary + assert "Recommended counters" in summary + assert "Win rate" in summary + + def test_summary_for_each_difficulty(self): + for key in ("easy", "normal", "hard"): + summary = get_opponent_summary(key) + assert len(summary) > 100, f"Summary for '{key}' seems too short" + + def test_unknown_returns_error_message(self): + summary = get_opponent_summary("nonexistent") + assert "Unknown" in summary + + def test_normal_summary_mentions_aggression(self): + summary = get_opponent_summary("normal") + assert "aggressive" in summary.lower() + + def test_normal_summary_mentions_expansion(self): + summary = get_opponent_summary("normal") + assert "second base" in summary.lower() or "expand" in summary.lower() + + +# ─── Planning MCP Tool Tests ──────────────────────────────────────────────── + + +class TestPlanningTools: + """Test planning phase MCP tools on a bare OpenRAEnvironment.""" + + @pytest.fixture + def env_with_obs(self): + """Create env with planning support and a cached observation.""" + from fastmcp import FastMCP + from openra_env.server.openra_environment import OpenRAEnvironment + + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + mcp = FastMCP("openra-test") + + # Set up planning state + env._planning_enabled = True + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_active = False + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._planning_strategy = "" + + # Set up required env state + env._player_faction = "russia" + env._enemy_faction = "england" + env._unit_groups = {} + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._PLACEABLE_QUEUE_TYPES = {"Building", "Defense"} + + class FakeConfig: + bot_type = "normal" + + env._config = FakeConfig() + + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + + from openra_env.models import OpenRAState + env._state = OpenRAState() + + # Cached observation + env._last_obs = { + "tick": 0, + "done": False, + "result": "", + "economy": { + "cash": 10000, + "ore": 0, + "power_provided": 0, + "power_drained": 0, + "resource_capacity": 5000, + "harvester_count": 0, + }, + "military": { + "units_killed": 0, + "units_lost": 0, + "buildings_killed": 0, + "buildings_lost": 0, + "army_value": 0, + "active_unit_count": 1, + }, + "units": [ + { + "actor_id": 100, + "type": "mcv", + "pos_x": 32768, + "pos_y": 32768, + "cell_x": 32, + "cell_y": 32, + "hp_percent": 1.0, + "is_idle": True, + "current_activity": "", + "owner": "Multi0", + "can_attack": False, + "facing": 0, + "experience_level": 0, + "stance": 1, + "speed": 56, + "attack_range": 0, + "passenger_count": 0, + "ammo": -1, + "is_building": False, + }, + ], + "buildings": [], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 64, "height": 64, "map_name": "singles"}, + "available_production": [], + "spatial_map": "", + "spatial_channels": 0, + } + + # _refresh_obs is a no-op for testing (observation already cached) + env._refresh_obs = lambda: None + + env._register_tools(mcp) + return env, mcp + + def _get_tool(self, mcp, name): + """Get a tool function from the MCP tool manager.""" + from tests.conftest import get_tool_fn + return get_tool_fn(mcp, name) + + def test_planning_tools_registered(self, env_with_obs): + _, mcp = env_with_obs + from tests.conftest import get_tool_names + tool_names = get_tool_names(mcp) + + assert "get_opponent_intel" in tool_names + assert "start_planning_phase" in tool_names + assert "end_planning_phase" in tool_names + assert "get_planning_status" in tool_names + + def test_tool_count_increased(self, env_with_obs): + _, mcp = env_with_obs + from tests.conftest import get_tool_count + count = get_tool_count(mcp) + # 7 read + 1 exploration + 1 terrain + 4 knowledge + 3 bulk + 4 planning + 27 action + 1 replay = 48 + assert count == 48, f"Expected 48 tools, got {count}" + + def test_get_opponent_intel(self, env_with_obs): + env, mcp = env_with_obs + fn = self._get_tool(mcp, "get_opponent_intel") + assert fn is not None + result = fn() + assert result["difficulty"] == "Normal" + assert result["aggressiveness"] == "high" + assert result["your_faction"] == "russia" + assert result["enemy_faction"] == "england" + + def test_start_planning_phase(self, env_with_obs): + env, mcp = env_with_obs + fn = self._get_tool(mcp, "start_planning_phase") + result = fn() + assert result["planning_active"] is True + assert result["max_turns"] == 10 + assert "map" in result + assert "base_position" in result + assert "enemy_estimated_position" in result + assert result["your_faction"] == "russia" + assert result["your_side"] == "soviet" + assert result["enemy_faction"] == "england" + assert "tech_tree" in result + assert "opponent_intel" in result + assert "opponent_summary" in result + assert "instructions" in result + assert len(result["starting_units"]) == 1 + assert result["starting_units"][0]["type"] == "mcv" + assert env._planning_active is True + + def test_start_planning_when_disabled(self, env_with_obs): + env, mcp = env_with_obs + env._planning_enabled = False + fn = self._get_tool(mcp, "start_planning_phase") + result = fn() + assert result["planning_enabled"] is False + assert "message" in result + assert env._planning_active is False + + def test_double_start_planning(self, env_with_obs): + env, mcp = env_with_obs + fn = self._get_tool(mcp, "start_planning_phase") + fn() # First start + result = fn() # Second start + assert "error" in result + assert "already active" in result["error"].lower() + + def test_end_planning_phase(self, env_with_obs): + env, mcp = env_with_obs + start_fn = self._get_tool(mcp, "start_planning_phase") + end_fn = self._get_tool(mcp, "end_planning_phase") + + start_fn() + result = end_fn(strategy="Rush with tanks, build 2 refineries early") + + assert result["planning_complete"] is True + assert result["strategy_recorded"] is True + assert "tanks" in result["strategy"] + assert result["planning_duration_seconds"] >= 0 + assert env._planning_active is False + assert env._planning_strategy == "Rush with tanks, build 2 refineries early" + assert env._state.planning_strategy == "Rush with tanks, build 2 refineries early" + + def test_end_planning_without_start(self, env_with_obs): + env, mcp = env_with_obs + fn = self._get_tool(mcp, "end_planning_phase") + result = fn(strategy="some strategy") + assert "error" in result + + def test_end_planning_empty_strategy(self, env_with_obs): + env, mcp = env_with_obs + start_fn = self._get_tool(mcp, "start_planning_phase") + end_fn = self._get_tool(mcp, "end_planning_phase") + + start_fn() + result = end_fn() + + assert result["planning_complete"] is True + assert result["strategy_recorded"] is False + assert result["strategy"] == "" + + def test_get_planning_status_before_start(self, env_with_obs): + env, mcp = env_with_obs + fn = self._get_tool(mcp, "get_planning_status") + result = fn() + assert result["planning_active"] is False + assert "strategy" in result + + def test_get_planning_status_during_planning(self, env_with_obs): + env, mcp = env_with_obs + start_fn = self._get_tool(mcp, "start_planning_phase") + status_fn = self._get_tool(mcp, "get_planning_status") + + start_fn() + result = status_fn() + + assert result["planning_active"] is True + assert result["turns_used"] == 0 + assert result["turns_remaining"] == 10 + assert result["time_elapsed_seconds"] >= 0 + assert result["time_remaining_seconds"] > 0 + + def test_get_planning_status_when_disabled(self, env_with_obs): + env, mcp = env_with_obs + env._planning_enabled = False + fn = self._get_tool(mcp, "get_planning_status") + result = fn() + assert result["planning_enabled"] is False + + def test_game_state_includes_planning_indicator(self, env_with_obs): + env, mcp = env_with_obs + start_fn = self._get_tool(mcp, "start_planning_phase") + get_state_fn = self._get_tool(mcp, "get_game_state") + + start_fn() + result = get_state_fn() + + assert result.get("planning_active") is True + assert result.get("planning_turns_remaining") == 10 + + def test_game_state_includes_strategy_after_planning(self, env_with_obs): + env, mcp = env_with_obs + start_fn = self._get_tool(mcp, "start_planning_phase") + end_fn = self._get_tool(mcp, "end_planning_phase") + get_state_fn = self._get_tool(mcp, "get_game_state") + + start_fn() + end_fn(strategy="Build tanks and attack early") + result = get_state_fn() + + assert result.get("planning_active") is None or result.get("planning_active") is False + assert result.get("planning_strategy") == "Build tanks and attack early" + + def test_planning_base_position_from_mcv(self, env_with_obs): + env, mcp = env_with_obs + fn = self._get_tool(mcp, "start_planning_phase") + result = fn() + # MCV is at cell (32, 32) + assert result["base_position"]["x"] == 32 + assert result["base_position"]["y"] == 32 + + def test_planning_enemy_position_estimate(self, env_with_obs): + env, mcp = env_with_obs + fn = self._get_tool(mcp, "start_planning_phase") + result = fn() + # Map is 64x64, base at (32, 32), enemy should be opposite + assert result["enemy_estimated_position"]["x"] == 32 # 64 - 32 + assert result["enemy_estimated_position"]["y"] == 32 # 64 - 32 + + def test_start_planning_includes_key_units(self, env_with_obs): + env, mcp = env_with_obs + fn = self._get_tool(mcp, "start_planning_phase") + result = fn() + assert "key_units" in result + assert len(result["key_units"]) > 0 + # Key units should have full stats (cost, hp, etc.) + for utype, udata in result["key_units"].items(): + assert "cost" in udata + assert "hp" in udata + assert "name" in udata + + def test_start_planning_includes_key_buildings(self, env_with_obs): + env, mcp = env_with_obs + fn = self._get_tool(mcp, "start_planning_phase") + result = fn() + assert "key_buildings" in result + assert len(result["key_buildings"]) > 0 + for btype, bdata in result["key_buildings"].items(): + assert "cost" in bdata + assert "name" in bdata + + def test_start_planning_instructions_mention_new_tools(self, env_with_obs): + env, mcp = env_with_obs + fn = self._get_tool(mcp, "start_planning_phase") + result = fn() + instructions = result["instructions"] + assert "get_faction_briefing" in instructions + assert "get_map_analysis" in instructions + assert "batch_lookup" in instructions + + +# ─── Bulk Knowledge Tool Tests ───────────────────────────────────────────── + + +class TestBulkKnowledgeTools: + """Test bulk knowledge tools: get_faction_briefing, get_map_analysis, batch_lookup.""" + + @pytest.fixture + def env_with_obs(self): + """Create env with planning support and a cached observation.""" + from fastmcp import FastMCP + from openra_env.server.openra_environment import OpenRAEnvironment + + env = OpenRAEnvironment.__new__(OpenRAEnvironment) + mcp = FastMCP("openra-test") + + env._planning_enabled = True + env._planning_max_turns = 10 + env._planning_max_time_s = 60.0 + env._planning_active = False + env._planning_start_time = 0.0 + env._planning_turns_used = 0 + env._planning_strategy = "" + + env._player_faction = "russia" + env._enemy_faction = "england" + env._unit_groups = {} + env._pending_placements = {} + env._attempted_placements = {} + env._placement_results = [] + env._PLACEABLE_QUEUE_TYPES = {"Building", "Defense"} + + class FakeConfig: + bot_type = "normal" + + env._config = FakeConfig() + + from openra_env.config import OpenRARLConfig + env._app_config = OpenRARLConfig() + + from openra_env.models import OpenRAState + env._state = OpenRAState() + + env._last_obs = { + "tick": 0, + "done": False, + "result": "", + "economy": { + "cash": 10000, "ore": 0, "power_provided": 0, + "power_drained": 0, "resource_capacity": 5000, "harvester_count": 0, + }, + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 1, + }, + "units": [ + { + "actor_id": 100, "type": "mcv", + "pos_x": 32768, "pos_y": 32768, + "cell_x": 32, "cell_y": 32, + "hp_percent": 1.0, "is_idle": True, + "current_activity": "", "owner": "Multi0", + "can_attack": False, "facing": 0, + "experience_level": 0, "stance": 1, + "speed": 56, "attack_range": 0, + "passenger_count": 0, "ammo": -1, "is_building": False, + }, + ], + "buildings": [], + "production": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "map_info": {"width": 64, "height": 64, "map_name": "singles"}, + "available_production": [], + "spatial_map": "", + "spatial_channels": 0, + } + + env._refresh_obs = lambda: None + env._register_tools(mcp) + return env, mcp + + def _get_tool(self, mcp, name): + from tests.conftest import get_tool_fn + return get_tool_fn(mcp, name) + + def test_bulk_tools_registered(self, env_with_obs): + _, mcp = env_with_obs + from tests.conftest import get_tool_names + tool_names = get_tool_names(mcp) + assert "get_faction_briefing" in tool_names + assert "get_map_analysis" in tool_names + assert "batch_lookup" in tool_names + + def test_tool_count_with_bulk_tools(self, env_with_obs): + _, mcp = env_with_obs + from tests.conftest import get_tool_count + count = get_tool_count(mcp) + # 7 read + 1 exploration + 1 terrain + 4 knowledge + 3 bulk + 4 planning + 27 action + 1 replay = 48 + assert count == 48, f"Expected 48 tools, got {count}" + + # ── get_faction_briefing ── + + def test_faction_briefing_returns_units(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "get_faction_briefing") + result = fn() + assert result["faction"] == "russia" + assert result["side"] == "soviet" + assert "units" in result + assert len(result["units"]) > 10 # Soviet has ~20 units + assert "e1" in result["units"] # Rifle Infantry is available to both sides + assert "3tnk" in result["units"] # Heavy Tank is soviet + + def test_faction_briefing_returns_buildings(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "get_faction_briefing") + result = fn() + assert "buildings" in result + assert len(result["buildings"]) > 10 + assert "powr" in result["buildings"] + assert "barr" in result["buildings"] + assert "weap" in result["buildings"] + + def test_faction_briefing_returns_tech_tree(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "get_faction_briefing") + result = fn() + assert "tech_tree" in result + assert len(result["tech_tree"]) > 5 + assert result["tech_tree"][0] == "powr" + + def test_faction_briefing_units_have_full_stats(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "get_faction_briefing") + result = fn() + e1 = result["units"]["e1"] + assert e1["name"] == "Rifle Infantry" + assert e1["cost"] == 100 + assert e1["hp"] == 5000 + assert "speed" in e1 + assert "description" in e1 + + def test_faction_briefing_excludes_wrong_side(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "get_faction_briefing") + result = fn() + # Allied-only units should NOT be in soviet briefing + assert "1tnk" not in result["units"] # Light Tank is allied-only + assert "tent" not in result["buildings"] # Allied barracks + + def test_faction_briefing_allied(self, env_with_obs): + env, mcp = env_with_obs + env._player_faction = "england" + fn = self._get_tool(mcp, "get_faction_briefing") + result = fn() + assert result["side"] == "allied" + assert "1tnk" in result["units"] + assert "tent" in result["buildings"] + assert "3tnk" not in result["units"] + + # ── get_map_analysis ── + + def test_map_analysis_no_spatial(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "get_map_analysis") + result = fn() + assert result["map_name"] == "singles" + assert result["width"] == 64 + assert result["height"] == 64 + assert "base_position" in result + assert "enemy_estimated_position" in result + assert "note" in result # No spatial data available + + def test_map_analysis_with_spatial(self, env_with_obs): + """Test map analysis with a small synthetic spatial map.""" + import base64 + import struct + + env, mcp = env_with_obs + w, h, channels = 4, 4, 9 # Small test map + data = [0.0] * (w * h * channels) + + # Set up terrain: all passable, some resources + for y in range(h): + for x in range(w): + base_idx = (y * w + x) * channels + data[base_idx + 0] = 1.0 # terrain + data[base_idx + 3] = 1.0 # passable + + # Add resources at (1,1) and (2,2) + data[(1 * w + 1) * channels + 2] = 5.0 + data[(2 * w + 2) * channels + 2] = 3.0 + + # Make (3,3) water (impassable) + data[(3 * w + 3) * channels + 3] = 0.0 + + raw_bytes = struct.pack(f"{len(data)}f", *data) + env._last_obs["spatial_map"] = base64.b64encode(raw_bytes).decode("ascii") + env._last_obs["spatial_channels"] = channels + env._last_obs["map_info"] = {"width": w, "height": h, "map_name": "test_map"} + + fn = self._get_tool(mcp, "get_map_analysis") + result = fn() + + assert result["map_name"] == "test_map" + assert result["width"] == w + assert result["height"] == h + assert "passable_ratio" in result + assert "has_water" in result + assert "map_type" in result + assert "resource_patches" in result + assert "quadrant_summary" in result + assert "strategic_notes" in result + assert result["passable_ratio"] > 0.5 + + # ── batch_lookup ── + + def test_batch_lookup_units(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "batch_lookup") + result = fn(queries=[ + {"type": "unit", "name": "e1"}, + {"type": "unit", "name": "3tnk"}, + ]) + assert result["count"] == 2 + assert result["results"][0]["name"] == "Rifle Infantry" + assert result["results"][1]["name"] == "Heavy Tank" + + def test_batch_lookup_buildings(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "batch_lookup") + result = fn(queries=[ + {"type": "building", "name": "powr"}, + {"type": "building", "name": "weap"}, + ]) + assert result["count"] == 2 + assert result["results"][0]["name"] == "Power Plant" + assert result["results"][1]["name"] == "War Factory" + + def test_batch_lookup_mixed(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "batch_lookup") + result = fn(queries=[ + {"type": "unit", "name": "e1"}, + {"type": "building", "name": "powr"}, + {"type": "faction", "name": "russia"}, + {"type": "tech_tree", "name": "soviet"}, + ]) + assert result["count"] == 4 + assert result["results"][0]["name"] == "Rifle Infantry" + assert result["results"][1]["name"] == "Power Plant" + assert result["results"][2]["display_name"] == "Russia" + assert "soviet" in result["results"][3] + + def test_batch_lookup_unknown_item(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "batch_lookup") + result = fn(queries=[ + {"type": "unit", "name": "nonexistent"}, + {"type": "unit", "name": "e1"}, + ]) + assert result["count"] == 2 + assert "error" in result["results"][0] + assert result["results"][1]["name"] == "Rifle Infantry" + + def test_batch_lookup_unknown_type(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "batch_lookup") + result = fn(queries=[{"type": "invalid", "name": "x"}]) + assert "error" in result["results"][0] + + def test_batch_lookup_empty(self, env_with_obs): + _, mcp = env_with_obs + fn = self._get_tool(mcp, "batch_lookup") + result = fn(queries=[]) + assert result["count"] == 0 + assert result["results"] == [] diff --git a/tests/test_reward.py b/tests/test_reward.py new file mode 100644 index 0000000000000000000000000000000000000000..eb70aa892bc54c1f017785c5dd63dda4c242bf9f --- /dev/null +++ b/tests/test_reward.py @@ -0,0 +1,259 @@ +"""Tests for the OpenRA-RL reward function.""" + +import pytest + +from openra_env.reward import OpenRARewardFunction, RewardState, RewardWeights +from openra_rl_util.reward_vector import RewardVector + + +def make_obs( + cash=0, + units_killed=0, + units_lost=0, + buildings_killed=0, + buildings_lost=0, + army_value=0, + done=False, + result="", +): + return { + "economy": {"cash": cash}, + "military": { + "units_killed": units_killed, + "units_lost": units_lost, + "buildings_killed": buildings_killed, + "buildings_lost": buildings_lost, + "army_value": army_value, + }, + "done": done, + "result": result, + } + + +class TestRewardWeights: + def test_defaults(self): + w = RewardWeights() + assert w.survival == 0.001 + assert w.economic_efficiency == 0.01 + assert w.aggression == 0.1 + assert w.defense == 0.05 + assert w.victory == 1.0 + assert w.defeat == -1.0 + + def test_custom_weights(self): + w = RewardWeights(survival=0.0, victory=10.0) + assert w.survival == 0.0 + assert w.victory == 10.0 + + +class TestRewardState: + def test_defaults(self): + s = RewardState() + assert s.prev_cash == 0 + assert s.prev_units_killed == 0 + assert s.prev_units_lost == 0 + + +class TestOpenRARewardFunction: + def test_survival_reward(self): + rf = OpenRARewardFunction() + obs = make_obs() + reward = rf.compute(obs) + assert reward == pytest.approx(0.001) + + def test_cash_increase_reward(self): + rf = OpenRARewardFunction() + # First tick establishes baseline + rf.compute(make_obs(cash=1000)) + # Cash increased by 2000 + reward = rf.compute(make_obs(cash=3000)) + # survival (0.001) + economic_efficiency * (2000/1000) = 0.001 + 0.01 * 2.0 = 0.021 + assert reward == pytest.approx(0.021) + + def test_cash_decrease_no_reward(self): + rf = OpenRARewardFunction() + rf.compute(make_obs(cash=5000)) + # Cash decreased — no economic reward (only survival) + reward = rf.compute(make_obs(cash=3000)) + assert reward == pytest.approx(0.001) + + def test_kill_reward(self): + rf = OpenRARewardFunction() + rf.compute(make_obs()) + # Killed 2 enemy units + reward = rf.compute(make_obs(units_killed=2)) + # survival + aggression * 2 = 0.001 + 0.1 * 2 = 0.201 + assert reward == pytest.approx(0.201) + + def test_building_kill_reward(self): + rf = OpenRARewardFunction() + rf.compute(make_obs()) + # Killed 1 enemy building + reward = rf.compute(make_obs(buildings_killed=1)) + # survival + aggression * 1 = 0.001 + 0.1 = 0.101 + assert reward == pytest.approx(0.101) + + def test_loss_penalty(self): + rf = OpenRARewardFunction() + rf.compute(make_obs()) + # Lost 3 units + reward = rf.compute(make_obs(units_lost=3)) + # survival - defense * 3 = 0.001 - 0.05 * 3 = -0.149 + assert reward == pytest.approx(-0.149) + + def test_building_loss_penalty(self): + rf = OpenRARewardFunction() + rf.compute(make_obs()) + reward = rf.compute(make_obs(buildings_lost=2)) + # survival - defense * 2 = 0.001 - 0.05 * 2 = -0.099 + assert reward == pytest.approx(-0.099) + + def test_victory_reward(self): + rf = OpenRARewardFunction() + reward = rf.compute(make_obs(done=True, result="win")) + # survival + victory = 0.001 + 1.0 = 1.001 + assert reward == pytest.approx(1.001) + + def test_defeat_penalty(self): + rf = OpenRARewardFunction() + reward = rf.compute(make_obs(done=True, result="lose")) + # survival + defeat = 0.001 + (-1.0) = -0.999 + assert reward == pytest.approx(-0.999) + + def test_draw_no_terminal_reward(self): + rf = OpenRARewardFunction() + reward = rf.compute(make_obs(done=True, result="draw")) + # Only survival, no terminal reward for draw + assert reward == pytest.approx(0.001) + + def test_reset_clears_state(self): + rf = OpenRARewardFunction() + rf.compute(make_obs(cash=5000, units_killed=10)) + rf.reset() + # After reset, deltas computed from zero baseline + reward = rf.compute(make_obs(cash=1000, units_killed=1)) + # survival + econ*(1000/1000) + aggression*1 = 0.001 + 0.01 + 0.1 = 0.111 + assert reward == pytest.approx(0.111) + + def test_custom_weights(self): + weights = RewardWeights(survival=0.0, aggression=1.0, defense=0.0, victory=100.0) + rf = OpenRARewardFunction(weights=weights) + rf.compute(make_obs()) + reward = rf.compute(make_obs(units_killed=5)) + # aggression * 5 = 5.0 (no survival, no defense) + assert reward == pytest.approx(5.0) + + def test_combined_scenario(self): + rf = OpenRARewardFunction() + rf.compute(make_obs(cash=1000)) + # Next tick: gained cash, killed 1 enemy, lost 1 unit + reward = rf.compute(make_obs(cash=2000, units_killed=1, units_lost=1)) + # survival + econ*(1000/1000) + aggression*1 - defense*1 + # = 0.001 + 0.01 + 0.1 - 0.05 = 0.061 + assert reward == pytest.approx(0.061) + + def test_delta_tracking_across_steps(self): + rf = OpenRARewardFunction() + rf.compute(make_obs(units_killed=0)) + rf.compute(make_obs(units_killed=2)) # killed 2 + # Now units_killed is still 2, so delta is 0 + reward = rf.compute(make_obs(units_killed=2)) + assert reward == pytest.approx(0.001) # only survival + + def test_empty_observation(self): + rf = OpenRARewardFunction() + reward = rf.compute({}) + # Should handle missing keys gracefully, just survival + assert reward == pytest.approx(0.001) + + +class TestRewardVectorIntegration: + """Test reward vector mode integration.""" + + def test_vector_disabled_by_default(self): + rf = OpenRARewardFunction() + assert rf.vector_enabled is False + assert rf.compute_vector(make_obs()) is None + + def test_vector_enabled(self): + rf = OpenRARewardFunction(vector_enabled=True) + assert rf.vector_enabled is True + v = rf.compute_vector(make_full_obs()) + assert isinstance(v, RewardVector) + + def test_compute_all_without_vector(self): + rf = OpenRARewardFunction() + scalar, vec_dict = rf.compute_all(make_obs()) + assert isinstance(scalar, float) + assert vec_dict is None + + def test_compute_all_with_vector(self): + rf = OpenRARewardFunction(vector_enabled=True) + scalar, vec_dict = rf.compute_all(make_full_obs()) + assert isinstance(scalar, float) + assert isinstance(vec_dict, dict) + assert "combat" in vec_dict + assert "economy" in vec_dict + assert "outcome" in vec_dict + assert len(vec_dict) == 8 + + def test_vector_reset(self): + rf = OpenRARewardFunction(vector_enabled=True) + rf.compute_all(make_full_obs()) + rf.reset() + # After reset, should compute cleanly + _, vec_dict = rf.compute_all(make_full_obs()) + assert vec_dict is not None + + def test_vector_win_outcome(self): + rf = OpenRARewardFunction(vector_enabled=True) + rf.compute_vector(make_full_obs()) # baseline + v = rf.compute_vector(make_full_obs(done=True, result="win")) + assert v.outcome == 1.0 + + def test_vector_lose_outcome(self): + rf = OpenRARewardFunction(vector_enabled=True) + rf.compute_vector(make_full_obs()) # baseline + v = rf.compute_vector(make_full_obs(done=True, result="lose")) + assert v.outcome == -1.0 + + def test_vector_combat_from_kills(self): + rf = OpenRARewardFunction(vector_enabled=True) + rf.compute_vector(make_full_obs()) # baseline + v = rf.compute_vector(make_full_obs(kills_cost=5000)) + assert v.combat > 0 + + def test_vector_dimensions_bounded(self): + rf = OpenRARewardFunction(vector_enabled=True) + v = rf.compute_vector(make_full_obs()) + for name, val in v.as_dict().items(): + assert -1.0 <= val <= 1.0, f"{name} = {val} out of bounds" + + +def make_full_obs( + cash=5000, kills_cost=0, deaths_cost=0, assets_value=0, + done=False, result="", harvester_count=1, +): + """Create a full observation dict suitable for RewardVectorComputer.""" + return { + "military": { + "units_killed": 0, "units_lost": 0, + "buildings_killed": 0, "buildings_lost": 0, + "army_value": 0, "active_unit_count": 0, + "kills_cost": kills_cost, "deaths_cost": deaths_cost, + "assets_value": assets_value, "experience": 0, "order_count": 0, + }, + "economy": { + "cash": cash, "ore": 0, + "power_provided": 100, "power_drained": 0, + "resource_capacity": 2000, "harvester_count": harvester_count, + }, + "units": [], + "buildings": [], + "visible_enemies": [], + "visible_enemy_buildings": [], + "production_queues": [], + "done": done, + "result": result, + "tick": 0, + }