Supra-A2A-Nano-Exp

Status: experimental / educational prototype. Not a polished product.

Supra-A2A-Nano-Exp is a small proof-of-concept any-to-any model from SupraLabs: a single autoregressive GPT that reads and writes both text and images by treating them as one unified sequence of discrete tokens. Text uses standard BPE tokens; images are discretized into a small set of learned "visual words" by a convolutional VQ-VAE and appended to the same vocabulary. The whole multimodal stream is modeled by one Transformer with one set of weights, no separate vision encoder or diffusion head bolted on.

This release exists mainly to demonstrate the idea of unified tokenization across modalities on consumer-scale hardware. At ~30M total parameters and a 384-token context, do not expect coherent long-form text or photorealistic images. Treat it as a transparent, hackable example architecture rather than a capable generator.

How it works

Every input, text or image, is serialized into one token stream wrapped in control tags:

<TEXT>some text here</TEXT>
<IMAGE><FRAME>[64 visual tokens]</IMAGE>
<VIDEO><FRAME>[visual tokens]<FRAME>[visual tokens]...</VIDEO>
  • Text tokens come from a GPT-2-style BPE tokenizer (50,257 tokens) plus 7 control tokens (<TEXT>, </TEXT>, <IMAGE>, </IMAGE>, <VIDEO>, </VIDEO>, <FRAME>), for 50,264 text-side ids total.
  • Visual tokens come from a VQ-VAE codebook of 256 entries. An image is encoded by 3 strided convolutions (/8 downsampling total), and each resulting spatial cell is snapped to its nearest codebook vector. A 64x64 image becomes an 8x8 grid, i.e. 64 visual tokens.
  • These 256 visual codes are appended right after the text vocabulary, so the GPT's combined vocabulary is exactly 50,264 + 256 = 50,520 tokens. One embedding table, one output head, one model for both modalities.

Because images are projected into the same id space as text, the GPT can in principle attend across modalities: condition image generation on a text prompt, or condition text generation on image content, using the exact same attention mechanism it uses for next-token text prediction.

Architecture

Component Detail
GPT backbone 4 Transformer blocks, pre-norm, fused QKV attention, causal
Embedding dim 256
Context length 384 tokens
Attention heads 4 (assumed — see note below)
MLP expansion 4x (256 -> 1024 -> 256), GELU
Combined vocabulary 50,520 (50,264 text + 256 visual)
GPT parameters ~29.7M
VQ-VAE 3-layer conv encoder / decoder, /8 downsampling, 256x64 codebook
VQ-VAE parameters ~0.22M
Total parameters ~29.9M
Precision fp32

Note on attention heads: the checkpoint stores QKV as a single fused Linear layer, so the head count isn't recoverable from the weights alone. run_supra_a2a.py defaults to 4 heads (64-dim per head, the GPT-2 convention). If you trained this checkpoint yourself and used a different head count, change N_HEAD at the top of the script — loading will succeed either way (shapes match), but the wrong value will silently produce incorrect attention.

Note on pixel normalization: similarly, the VQ-VAE decoder's final activation has no learnable parameters, so it isn't visible in the checkpoint either. The script defaults to sigmoid (assumes training in [0, 1]). If reconstructions look off, try VQVAE_OUTPUT_ACTIVATION = "tanh".

Files in this repository

File Description
model.safetensors GPT backbone weights (state_dict, fp32)
vqvae.safetensors VQ-VAE encoder/decoder/codebook weights (state_dict, fp32)
tokenizer.json Fast BPE tokenizer (GPT-2 base vocab + 7 control tokens)
tokenizer_config.json Tokenizer metadata (special tokens, GPT2Tokenizer class)
run_supra_a2a.py Plug-and-play inference script (see below)

Samples

Prompt Response
Text2Text: Once upon a time Once upon a time, there was a little girl named Lily. She loved to visit her grandma's house. Her grandma was very bossy and always told her what to do. One day, Lily went to visit her grandma and something strange happened. The sky turned dark and it started to rain. Lily was scared, but her grandma told her not to worry. They sat inside and played games until the rain stopped.
Text2Image: A dog running on snow image
Text2Video: A snow scene https://cdn-uploads.huggingface.co/production/uploads/68df176c403a7bf9e8ae85a8/Lv1UqRG1m8KkiT46X8RNC.mp4

Usage

pip install torch transformers huggingface_hub safetensors pillow numpy
#!/usr/bin/env python3
"""
Supra-A2A-Nano-Exp - inference runner
======================================

An experimental any-to-any model from SupraLabs: a single autoregressive GPT
operates over one unified vocabulary that mixes text (BPE, GPT-2 style) and
discrete visual codes from a convolutional VQ-VAE. Text and images are
serialized into the same token stream, delimited by control tokens
(<TEXT>, <IMAGE>, <VIDEO>, <FRAME>, and their closing tags).

This script reconstructs the exact architecture from the raw state_dicts
(no config.json ships with the weights) and exposes a SupraA2A class with
high-level methods for text completion, image reconstruction (a VQ-VAE
sanity check), and text-conditioned image generation.

Quick start:
    python run_supra_a2a.py --mode text --prompt "Once upon a time"
    python run_supra_a2a.py --mode chat
    python run_supra_a2a.py --mode text2image --prompt "<TEXT>a red square</TEXT><IMAGE>" --out gen.png
    python run_supra_a2a.py --mode text2video --prompt "<TEXT>A snow scene</TEXT><VIDEO>" --frames 4 --out video.gif

    python run_supra_a2a.py --mode reconstruct --image photo.png --out recon.png

    python run_supra_a2a.py --mode image2text --image gen.png
    python run_supra_a2a.py --mode image2image --image photo.png --prompt "<TEXT>make it red</TEXT><IMAGE>" --out edit.png
    python run_supra_a2a.py --mode image2video --image photo.png --prompt "<TEXT>make it move</TEXT><VIDEO>" --frames 4 --out animated.gif

    python run_supra_a2a.py --mode video2text --image video.gif
    python run_supra_a2a.py --mode video2image --image video.gif --prompt "<TEXT>extract style</TEXT><IMAGE>" --out frame.png
    python run_supra_a2a.py --mode video2video --image video.gif --prompt "<TEXT>change style</TEXT><VIDEO>" --frames 4 --out transformed.gif
"""

from __future__ import annotations

import argparse
import math
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from safetensors.torch import load_file as load_safetensors

# --------------------------------------------------------------------------- #
# Config
# --------------------------------------------------------------------------- #

MODEL_WEIGHT_STEMS = ["model", "vqvae"]  # each may exist as .safetensors or .pt
TOKENIZER_FILES = ["tokenizer.json", "tokenizer_config.json"]
# Patterns used for the Hub fallback download; covers both weight formats so
# this works whether the repo ships .safetensors (preferred), .pt, or both.
HUB_ALLOW_PATTERNS = [f"{s}.safetensors" for s in MODEL_WEIGHT_STEMS] + [
    f"{s}.pt" for s in MODEL_WEIGHT_STEMS
] + TOKENIZER_FILES
DEFAULT_REPO_ID = "SupraLabs/Supra-A2A-Nano-Exp"

# NOTE: the GPT checkpoint does not store n_head explicitly (qkv is a single
# fused Linear layer). 256 / 4 = 64 head_dim (the GPT-2 convention) is the
# most likely value and is the default here, but it cannot be verified from
# the weights alone. A wrong value will NOT break loading (shapes still
# match) but will silently produce incorrect attention. If generations look
# off, try N_HEAD = 8 (head_dim 32) instead.
N_HEAD = 4

# Pixel normalization used by the VQ-VAE decoder's final layer. There is no
# parametrized activation after the last ConvTranspose2d in the checkpoint,
# so this is also an assumption rather than a certainty. Default: sigmoid
# (assumes images were trained in [0, 1]). If reconstructions look washed
# out / inverted, try "tanh" (assumes [-1, 1] training).
VQVAE_OUTPUT_ACTIVATION = "sigmoid"  # "sigmoid" | "tanh" | "none"


# --------------------------------------------------------------------------- #
# Weight resolution (local dir -> Hugging Face Hub fallback)
# --------------------------------------------------------------------------- #

def _has_all_weights(d: Path) -> bool:
    weights_ok = all(
        (d / f"{stem}.safetensors").exists() or (d / f"{stem}.pt").exists()
        for stem in MODEL_WEIGHT_STEMS
    )
    tok_ok = all((d / f).exists() for f in TOKENIZER_FILES)
    return weights_ok and tok_ok


def resolve_weights_dir(weights_dir: Path, repo_id: str) -> Path:
    """Return a directory containing the tokenizer files plus, for each model
    in MODEL_WEIGHT_STEMS, either a .safetensors or a .pt file.

    Looks locally first; if anything is missing, tries to download from the
    Hub via huggingface_hub.snapshot_download (works once the
    SupraLabs/Supra-A2A-Nano-Exp repo is public).
    """
    if _has_all_weights(weights_dir):
        return weights_dir

    print(f"[info] weight files incomplete in {weights_dir}")
    print(f"[info] trying to download from huggingface.co/{repo_id} ...")
    try:
        from huggingface_hub import snapshot_download
    except ImportError as e:
        raise RuntimeError(
            "huggingface_hub is not installed and local weights are incomplete. "
            "Run: pip install huggingface_hub"
        ) from e

    try:
        downloaded = snapshot_download(repo_id=repo_id, allow_patterns=HUB_ALLOW_PATTERNS)
    except Exception as e:
        raise RuntimeError(
            f"Could not find the weights locally or on the Hub ({repo_id}). "
            f"Place tokenizer files plus {MODEL_WEIGHT_STEMS} (.safetensors or .pt) "
            f"manually in {weights_dir}.\nOriginal error: {e}"
        ) from e
    downloaded = Path(downloaded)
    if not _has_all_weights(downloaded):
        raise RuntimeError(f"Download from {repo_id} completed but required files are still missing.")
    return downloaded


def load_state_dict_any(weights_dir: Path, stem: str) -> dict:
    """Load a checkpoint by stem name, preferring .safetensors over legacy .pt."""
    st_path = weights_dir / f"{stem}.safetensors"
    pt_path = weights_dir / f"{stem}.pt"
    if st_path.exists():
        return load_safetensors(str(st_path))
    if pt_path.exists():
        return torch.load(pt_path, map_location="cpu", weights_only=False)
    raise FileNotFoundError(f"Neither {st_path.name} nor {pt_path.name} found in {weights_dir}")


# --------------------------------------------------------------------------- #
# VQ-VAE (conv encoder / codebook / transposed-conv decoder)
# --------------------------------------------------------------------------- #

class VectorQuantizer(nn.Module):
    """Discrete codebook with nearest-neighbor lookup (inference only, no EMA)."""

    def __init__(self, num_codes: int, dim: int):
        super().__init__()
        self.num_codes = num_codes
        self.dim = dim
        self.embedding = nn.Embedding(num_codes, dim)

    def encode(self, z: torch.Tensor) -> torch.Tensor:
        """z: (B, C, H, W) -> discrete indices (B, H, W)."""
        b, c, h, w = z.shape
        z_flat = z.permute(0, 2, 3, 1).reshape(-1, c)
        codebook = self.embedding.weight  # (num_codes, dim)
        dist = (
            z_flat.pow(2).sum(1, keepdim=True)
            - 2 * z_flat @ codebook.t()
            + codebook.pow(2).sum(1)
        )
        idx = dist.argmin(dim=1)
        return idx.view(b, h, w)

    def decode(self, idx: torch.Tensor) -> torch.Tensor:
        """idx: (B, H, W) -> z_q (B, C, H, W)."""
        z_q = self.embedding(idx)  # (B, H, W, C)
        return z_q.permute(0, 3, 1, 2).contiguous()


class VQVAE(nn.Module):
    """Conv autoencoder with /8 downsampling, codebook size/dim read from the checkpoint."""

    def __init__(self, codebook_size: int = 256, code_dim: int = 64):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, code_dim, kernel_size=4, stride=2, padding=1),
        )
        self.vq = VectorQuantizer(codebook_size, code_dim)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(code_dim, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
        )

    def _final_activation(self, x: torch.Tensor) -> torch.Tensor:
        if VQVAE_OUTPUT_ACTIVATION == "sigmoid":
            return torch.sigmoid(x)
        if VQVAE_OUTPUT_ACTIVATION == "tanh":
            return torch.tanh(x)
        return x

    @torch.no_grad()
    def encode_to_indices(self, img: torch.Tensor) -> torch.Tensor:
        """img: (B, 3, H, W) normalized to [0, 1], H and W multiples of 8."""
        z = self.enc(img)
        return self.vq.encode(z)

    @torch.no_grad()
    def decode_from_indices(self, idx: torch.Tensor) -> torch.Tensor:
        """idx: (B, H, W) -> image (B, 3, H*8, W*8) in [0, 1]."""
        z_q = self.vq.decode(idx)
        x = self.dec(z_q)
        return self._final_activation(x).clamp(0, 1)


# --------------------------------------------------------------------------- #
# GPT (nanoGPT-style: pre-norm, fused qkv, 4x MLP)
# --------------------------------------------------------------------------- #

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd: int, n_head: int, block_size: int):
        super().__init__()
        assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
        self.n_head = n_head
        self.qkv = nn.Linear(n_embd, 3 * n_embd)
        self.proj = nn.Linear(n_embd, n_embd)
        mask = torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
        self.register_buffer("mask", mask)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, t, c = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.split(c, dim=2)
        hd = c // self.n_head
        q = q.view(b, t, self.n_head, hd).transpose(1, 2)
        k = k.view(b, t, self.n_head, hd).transpose(1, 2)
        v = v.view(b, t, self.n_head, hd).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(hd)
        att = att.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(b, t, c)
        return self.proj(y)


class Block(nn.Module):
    def __init__(self, n_embd: int, n_head: int, block_size: int):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class GPT(nn.Module):
    def __init__(self, vocab_size: int, n_embd: int, block_size: int, n_layer: int, n_head: int):
        super().__init__()
        self.block_size = block_size
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)
        self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        b, t = idx.shape
        assert t <= self.block_size, f"sequence length ({t}) exceeds model context ({self.block_size})"
        pos = torch.arange(t, device=idx.device)
        x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        return self.head(x)


# --------------------------------------------------------------------------- #
# Unified pipeline
# --------------------------------------------------------------------------- #

@dataclass
class VocabLayout:
    text_vocab_size: int     # BPE + control tokens (<TEXT>, <IMAGE>, ...)
    visual_offset: int       # = text_vocab_size: where visual codes start
    visual_vocab_size: int   # VQ-VAE codebook size
    total_vocab_size: int    # text_vocab_size + visual_vocab_size


class SupraA2A:
    """High-level wrapper: tokenizer + VQ-VAE + GPT sharing one unified vocabulary."""

    def __init__(self, weights_dir: Path, device: Optional[str] = None, n_head: int = N_HEAD):
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

        from transformers import PreTrainedTokenizerFast

        self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(weights_dir / "tokenizer.json"))
        self.tokenizer.pad_token = "<|endoftext|>"
        self.tokenizer.bos_token = "<|endoftext|>"
        self.tokenizer.eos_token = "<|endoftext|>"
        self.tokenizer.unk_token = "<|endoftext|>"

        vq_state = load_state_dict_any(weights_dir, "vqvae")
        gpt_state = load_state_dict_any(weights_dir, "model")

        codebook_size, code_dim = vq_state["vq.embedding.weight"].shape
        self.vqvae = VQVAE(codebook_size=codebook_size, code_dim=code_dim)
        self.vqvae.load_state_dict(vq_state, strict=True)
        self.vqvae.eval().to(self.device)

        total_vocab, n_embd = gpt_state["tok_emb.weight"].shape
        block_size = gpt_state["pos_emb.weight"].shape[0]
        n_layer = len({k.split(".")[1] for k in gpt_state if k.startswith("blocks.")})

        self.gpt = GPT(total_vocab, n_embd, block_size, n_layer, n_head)
        self.gpt.load_state_dict(gpt_state, strict=True)
        self.gpt.eval().to(self.device)

        text_vocab_size = len(self.tokenizer)
        self.vocab = VocabLayout(
            text_vocab_size=text_vocab_size,
            visual_offset=text_vocab_size,
            visual_vocab_size=codebook_size,
            total_vocab_size=total_vocab,
        )
        expected_total = self.vocab.text_vocab_size + self.vocab.visual_vocab_size
        if expected_total != total_vocab:
            raise ValueError(
                f"Vocabulary mismatch: tokenizer + codebook = {expected_total}, "
                f"but the GPT's tok_emb expects {total_vocab}. "
                "Make sure you're using the tokenizer.json that matches these weights."
            )

        print(
            f"[ok] SupraA2A loaded on {self.device} | "
            f"GPT: {n_layer}L/{n_embd}d/ctx{block_size} | "
            f"vocab text={text_vocab_size} + visual={codebook_size} = {total_vocab} | "
            f"VQ-VAE: {codebook_size}x{code_dim} codes (/8 downsample)"
        )

    # ------------------------------------------------------------------ #
    # Text
    # ------------------------------------------------------------------ #

    def encode_text(self, text: str) -> List[int]:
        return self.tokenizer.encode(text)

    def decode_text(self, ids: List[int]) -> str:
        # visual token ids (outside the tokenizer's range) are filtered out
        text_ids = [i for i in ids if i < self.vocab.text_vocab_size]
        return self.tokenizer.decode(text_ids)

    # ------------------------------------------------------------------ #
    # Image <-> visual tokens
    # ------------------------------------------------------------------ #

    def image_to_tokens(self, image: Image.Image) -> Tuple[List[int], Tuple[int, int]]:
        """Convert a PIL image into visual token ids (already offset into the unified vocab).

        Height/width must be multiples of 8 (three /2 downsampling stages).
        """
        w, h = image.size
        if w % 8 or h % 8:
            raise ValueError(f"Image {w}x{h} must have dimensions that are multiples of 8.")
        img = image.convert("RGB")
        arr = np.array(img, dtype=np.uint8)  # copy (writable), avoids a torch warning
        tensor = torch.from_numpy(arr).float().permute(2, 0, 1) / 255.0
        tensor = tensor.unsqueeze(0).to(self.device)

        with torch.no_grad():
            idx = self.vqvae.encode_to_indices(tensor)  # (1, h/8, w/8)
        grid_h, grid_w = idx.shape[1], idx.shape[2]
        flat = (idx.view(-1) + self.vocab.visual_offset).tolist()
        return flat, (grid_h, grid_w)

    def tokens_to_image(self, token_ids: List[int], grid: Tuple[int, int]) -> Image.Image:
        """Convert visual token ids (offset) back into a PIL image."""
        grid_h, grid_w = grid
        expected = grid_h * grid_w
        if len(token_ids) != expected:
            raise ValueError(f"Expected {expected} visual tokens for grid {grid}, got {len(token_ids)}.")
        raw = [i - self.vocab.visual_offset for i in token_ids]
        if any(i < 0 or i >= self.vocab.visual_vocab_size for i in raw):
            raise ValueError("Token id(s) outside the visual codebook range.")
        idx = torch.tensor(raw, device=self.device).view(1, grid_h, grid_w)
        with torch.no_grad():
            img_t = self.vqvae.decode_from_indices(idx)[0].cpu()
        arr = (img_t.permute(1, 2, 0).numpy() * 255).round().astype("uint8")
        return Image.fromarray(arr)

    def reconstruct_image(self, image: Image.Image) -> Image.Image:
        """Encode -> decode round-trip through the VQ-VAE only (no GPT). Useful sanity check."""
        tokens, grid = self.image_to_tokens(image)
        return self.tokens_to_image(tokens, grid)

    # ------------------------------------------------------------------ #
    # Autoregressive sampling
    # ------------------------------------------------------------------ #

    @torch.no_grad()
    def _sample_step(
        self,
        ids: List[int],
        temperature: float,
        top_k: Optional[int],
        allowed_range: Optional[Tuple[int, int]] = None,
    ) -> int:
        ctx = ids[-self.gpt.block_size :]
        x = torch.tensor([ctx], device=self.device)
        logits = self.gpt(x)[0, -1] / max(temperature, 1e-5)

        if allowed_range is not None:
            lo, hi = allowed_range
            mask = torch.full_like(logits, float("-inf"))
            mask[lo:hi] = 0
            logits = logits + mask

        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[-1]] = float("-inf")

        probs = F.softmax(logits, dim=-1)
        return int(torch.multinomial(probs, num_samples=1).item())

    def generate(
        self,
        prompt_ids: List[int],
        max_new_tokens: int = 64,
        temperature: float = 0.8,
        top_k: Optional[int] = 40,
        stop_token_id: Optional[int] = None,
    ) -> List[int]:
        ids = list(prompt_ids)
        for _ in range(max_new_tokens):
            next_id = self._sample_step(ids, temperature, top_k)
            ids.append(next_id)
            if stop_token_id is not None and next_id == stop_token_id:
                break
        return ids

    def complete_text(self, prompt: str, max_new_tokens: int = 64, temperature: float = 0.8, top_k: int = 40) -> str:
        eot_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
        prompt_ids = self.encode_text(prompt)
        out_ids = self.generate(prompt_ids, max_new_tokens, temperature, top_k, stop_token_id=eot_id)
        return self.decode_text(out_ids)

    def generate_image(
        self,
        prompt: str,
        grid: Tuple[int, int] = (8, 8),
        temperature: float = 1.0,
        top_k: Optional[int] = None,
    ) -> Image.Image:
        """Generate grid_h*grid_w visual tokens conditioned on `prompt` and decode to an image.

        While sampling image tokens, the sampler is restricted to the visual
        codebook's id range, so it can never "leak" a text token into the image.
        """
        ids = self.encode_text(prompt)
        lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
        n_tokens = grid[0] * grid[1]
        for _ in range(n_tokens):
            next_id = self._sample_step(ids, temperature, top_k, allowed_range=(lo, hi))
            ids.append(next_id)
        visual_ids = ids[-n_tokens:]
        return self.tokens_to_image(visual_ids, grid)

    def generate_video(
        self,
        prompt: str,
        num_frames: int = 4,
        grid: Tuple[int, int] = (8, 8),
        temperature: float = 1.0,
        top_k: Optional[int] = None,
    ) -> List[Image.Image]:
        frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
        
        ids = self.encode_text(prompt)
        
        lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
        n_tokens_per_frame = grid[0] * grid[1]
        
        frames = []
        
        for f in range(num_frames):
            print(f"[info] Generating frame {f+1}/{num_frames}...")
            ids.append(frame_token_id)
            
            for _ in range(n_tokens_per_frame):
                next_id = self._sample_step(ids, temperature, top_k, allowed_range=(lo, hi))
                ids.append(next_id)
                
            frame_tokens = ids[-n_tokens_per_frame:]
            img = self.tokens_to_image(frame_tokens, grid)
            frames.append(img)
            
        return frames
    
    def image_to_text(self, image: Image.Image, max_new_tokens: int = 64, temperature: float = 0.8, top_k: int = 40) -> str:
        visual_ids, _ = self.image_to_tokens(image)
        
        image_start_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
        frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
        image_end_id = self.tokenizer.convert_tokens_to_ids("</IMAGE>")
        text_start_id = self.tokenizer.convert_tokens_to_ids("<TEXT>")
        eot_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
        
        prompt_ids = [image_start_id, frame_token_id] + visual_ids + [image_end_id, text_start_id]
        
        out_ids = self.generate(prompt_ids, max_new_tokens, temperature, top_k, stop_token_id=eot_id)
        
        gen_ids = out_ids[len(prompt_ids):]
        return self.decode_text(gen_ids)

    def video_to_text(self, frames: List[Image.Image], max_new_tokens: int = 64, temperature: float = 0.8, top_k: int = 40) -> str:
        video_start_id = self.tokenizer.convert_tokens_to_ids("<VIDEO>")
        frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
        video_end_id = self.tokenizer.convert_tokens_to_ids("</VIDEO>")
        text_start_id = self.tokenizer.convert_tokens_to_ids("<TEXT>")
        eot_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
        
        prompt_ids = [video_start_id]
        for frame in frames:
            visual_ids, _ = self.image_to_tokens(frame)
            prompt_ids += [frame_token_id] + visual_ids
        prompt_ids += [video_end_id, text_start_id]
        
        out_ids = self.generate(prompt_ids, max_new_tokens, temperature, top_k, stop_token_id=eot_id)
        gen_ids = out_ids[len(prompt_ids):]
        return self.decode_text(gen_ids)
    
    def image_to_image(self, image: Image.Image, prompt: str, grid: Tuple[int, int] = (8, 8), temperature: float = 1.0, top_k: Optional[int] = None) -> Image.Image:
        visual_ids, _ = self.image_to_tokens(image)
        image_start_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
        frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
        image_end_id = self.tokenizer.convert_tokens_to_ids("</IMAGE>")
        
        prompt_ids = [image_start_id, frame_token_id] + visual_ids + [image_end_id] + self.encode_text(prompt) + [image_start_id]
        
        lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
        n_tokens = grid[0] * grid[1]
        for _ in range(n_tokens):
            next_id = self._sample_step(prompt_ids, temperature, top_k, allowed_range=(lo, hi))
            prompt_ids.append(next_id)
        return self.tokens_to_image(prompt_ids[-n_tokens:], grid)

    def image_to_video(self, image: Image.Image, prompt: str, num_frames: int = 4, grid: Tuple[int, int] = (8, 8), temperature: float = 1.0, top_k: Optional[int] = None) -> List[Image.Image]:
        visual_ids, _ = self.image_to_tokens(image)
        image_start_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
        frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
        image_end_id = self.tokenizer.convert_tokens_to_ids("</IMAGE>")
        video_start_id = self.tokenizer.convert_tokens_to_ids("<VIDEO>")
        
        prompt_ids = [image_start_id, frame_token_id] + visual_ids + [image_end_id] + self.encode_text(prompt) + [video_start_id]
        
        lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
        n_tokens_per_frame = grid[0] * grid[1]
        frames = []
        for f in range(num_frames):
            print(f"[info] Generating frame {f+1}/{num_frames}...")
            prompt_ids.append(frame_token_id)
            for _ in range(n_tokens_per_frame):
                next_id = self._sample_step(prompt_ids, temperature, top_k, allowed_range=(lo, hi))
                prompt_ids.append(next_id)
            frames.append(self.tokens_to_image(prompt_ids[-n_tokens_per_frame:], grid))
        return frames

    def video_to_image(self, frames: List[Image.Image], prompt: str, grid: Tuple[int, int] = (8, 8), temperature: float = 1.0, top_k: Optional[int] = None) -> Image.Image:
        video_start_id = self.tokenizer.convert_tokens_to_ids("<VIDEO>")
        frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
        video_end_id = self.tokenizer.convert_tokens_to_ids("</VIDEO>")
        image_start_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
        
        prompt_ids = [video_start_id]
        for frame in frames:
            visual_ids, _ = self.image_to_tokens(frame)
            prompt_ids += [frame_token_id] + visual_ids
        prompt_ids += [video_end_id] + self.encode_text(prompt) + [image_start_id]
        
        lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
        n_tokens = grid[0] * grid[1]
        for _ in range(n_tokens):
            next_id = self._sample_step(prompt_ids, temperature, top_k, allowed_range=(lo, hi))
            prompt_ids.append(next_id)
        return self.tokens_to_image(prompt_ids[-n_tokens:], grid)

    def video_to_video(self, frames: List[Image.Image], prompt: str, num_frames: int = 4, grid: Tuple[int, int] = (8, 8), temperature: float = 1.0, top_k: Optional[int] = None) -> List[Image.Image]:
        video_start_id = self.tokenizer.convert_tokens_to_ids("<VIDEO>")
        frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
        video_end_id = self.tokenizer.convert_tokens_to_ids("</VIDEO>")
        
        prompt_ids = [video_start_id]
        for frame in frames:
            visual_ids, _ = self.image_to_tokens(frame)
            prompt_ids += [frame_token_id] + visual_ids
        prompt_ids += [video_end_id] + self.encode_text(prompt) + [video_start_id]
        
        lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
        n_tokens_per_frame = grid[0] * grid[1]
        out_frames = []
        for f in range(num_frames):
            print(f"[info] Generating frame {f+1}/{num_frames}...")
            prompt_ids.append(frame_token_id)
            for _ in range(n_tokens_per_frame):
                next_id = self._sample_step(prompt_ids, temperature, top_k, allowed_range=(lo, hi))
                prompt_ids.append(next_id)
            out_frames.append(self.tokens_to_image(prompt_ids[-n_tokens_per_frame:], grid))
        return out_frames

# --------------------------------------------------------------------------- #
# CLI
# --------------------------------------------------------------------------- #

def main() -> None:
    parser = argparse.ArgumentParser(description="Supra-A2A-Nano-Exp - inference runner")
    parser.add_argument("--mode", choices=[
        "text", "chat", "reconstruct", 
        "text2image", "text2video", 
        "image2text", "video2text",
        "image2image", "image2video", "video2image", "video2video"
    ], default="text")
    parser.add_argument("--weights_dir", type=Path, default=Path(__file__).resolve().parent)
    parser.add_argument("--repo_id", default=DEFAULT_REPO_ID)
    parser.add_argument("--prompt", default="<TEXT>")
    parser.add_argument("--image", type=Path, help="input image (reconstruct mode)")
    parser.add_argument("--out", type=Path, default=Path("output.png"))
    parser.add_argument("--grid", type=int, nargs=2, default=[8, 8], metavar=("H", "W"))
    parser.add_argument("--frames", type=int, default=4, help="Number of generated video frames")
    parser.add_argument("--max_new_tokens", type=int, default=64)
    parser.add_argument("--temperature", type=float, default=0.8)
    parser.add_argument("--top_k", type=int, default=40)
    parser.add_argument("--device", default=None)
    parser.add_argument("--seed", type=int, default=None)
    args = parser.parse_args()

    if args.seed is not None:
        torch.manual_seed(args.seed)

    weights_dir = resolve_weights_dir(args.weights_dir, args.repo_id)
    model = SupraA2A(weights_dir, device=args.device)

    if args.mode == "text":
        print(model.complete_text(args.prompt, args.max_new_tokens, args.temperature, args.top_k))

    elif args.mode == "chat":
        print("Chat mode - Ctrl+C to exit.")
        while True:
            try:
                prompt = input("\n> ")
            except (KeyboardInterrupt, EOFError):
                break
            print(model.complete_text(prompt, args.max_new_tokens, args.temperature, args.top_k))

    elif args.mode == "reconstruct":
        if not args.image:
            sys.exit("--image is required in reconstruct mode")
        img = Image.open(args.image)
        recon = model.reconstruct_image(img)
        recon.save(args.out)
        print(f"[ok] reconstruction saved to {args.out}")

    elif args.mode == "text2image":
        img = model.generate_image(
            args.prompt, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k
        )
        img.save(args.out)
        print(f"[ok] generated image saved to {args.out}")

    elif args.mode == "text2video":
        frames = model.generate_video(
            args.prompt, num_frames=args.frames, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k
        )
        
        out_gif = args.out.with_suffix(".gif")
        frames[0].save(
            out_gif,
            save_all=True,
            append_images=frames[1:],
            duration=200,
            loop=0
        )
        print(f"[ok] Video saved as animated GIF in: {out_gif}")

    elif args.mode == "image2text":
        if not args.image:
            sys.exit("--image is needed for image2text")
        img = Image.open(args.image)
        description = model.image_to_text(img, args.max_new_tokens, args.temperature, args.top_k)
        print(f"\n[Caption fpr {args.image.name}]:\n{description}")

    elif args.mode == "video2text":
        if not args.image:
            sys.exit("--image is needed for video2text (pass e.g. the animated .gif)")
        
        gif = Image.open(args.image)
        frames = []
        try:
            while True:
                frames.append(gif.convert("RGB"))
                gif.seek(gif.tell() + 1)
        except EOFError:
            pass
            
        print(f"[info] Processing {len(frames)} frames from video/GIF...")
        description = model.video_to_text(frames, args.max_new_tokens, args.temperature, args.top_k)
        print(f"\n[Caption for video {args.image.name}]:\n{description}")

    elif args.mode == "image2image":
        if not args.image:
            sys.exit("--image is needed for image2image")
        img = Image.open(args.image)
        out_img = model.image_to_image(img, args.prompt, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k)
        out_img.save(args.out)
        print(f"[ok] Image-to-Image result saved as: {args.out}")

    elif args.mode == "image2video":
        if not args.image:
            sys.exit("--image is needed for image2video")
        img = Image.open(args.image)
        frames = model.image_to_video(img, args.prompt, num_frames=args.frames, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k)
        out_gif = args.out.with_suffix(".gif")
        frames[0].save(out_gif, save_all=True, append_images=frames[1:], duration=200, loop=0)
        print(f"[ok] Image-to-Video result saved as: {out_gif}")

    elif args.mode == "video2image":
        if not args.image:
            sys.exit("--image (animated .gif) is needed for video2image")
        gif = Image.open(args.image)
        frames = []
        try:
            while True:
                frames.append(gif.convert("RGB")); gif.seek(gif.tell() + 1)
        except EOFError: pass
        
        frames = frames[:3] 
        print(f"[info] Processing {len(frames)} Input-Frames...")
        out_img = model.video_to_image(frames, args.prompt, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k)
        out_img.save(args.out)
        print(f"[ok] Video-to-Image result saved as: {args.out}")

    elif args.mode == "video2video":
        if not args.image:
            sys.exit("--image (animated .gif) is needed for video2video")
        gif = Image.open(args.image)
        frames = []
        try:
            while True:
                frames.append(gif.convert("RGB")); gif.seek(gif.tell() + 1)
        except EOFError: pass
        
        frames = frames[:3]
        print(f"[info] Processing {len(frames)} Input-Frames...")
        out_frames = model.video_to_video(frames, args.prompt, num_frames=args.frames, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k)
        out_gif = args.out.with_suffix(".gif")
        out_frames[0].save(out_gif, save_all=True, append_images=out_frames[1:], duration=200, loop=0)
        print(f"[ok] Video-to-Video result saved as: {out_gif}")

if __name__ == "__main__":
    main()

Actions

Modalities:

Text --> Image
Text --> Video
Text --> Text

Image --> Image
Image --> Video
Image --> Text

Video --> Image
Video --> Video
Video --> Text

# Text
python run_supra_a2a.py --mode text --prompt "Once upon a time"

# Chat
python run_supra_a2a.py --mode chat

# Text-To-Image
python run_supra_a2a.py --mode text2image --prompt "<TEXT>a red square</TEXT><IMAGE>" --out gen.png

# Text-To-Video
python run_supra_a2a.py --mode text2video --prompt "<TEXT>A snow scene</TEXT><VIDEO>" --frames 4 --out video.gif

# Reconstruct (test VQ-VAE for sanity checj)
python run_supra_a2a.py --mode reconstruct --image gen.png --out gen_recon.png

# Image-To-Text
python run_supra_a2a.py --mode image2text --image gen.png

# Image-To-Image
python run_supra_a2a.py --mode image2image --image gen.png --prompt "<TEXT>make it red</TEXT><IMAGE>" --out edit.png

# Image-To-Video
python run_supra_a2a.py --mode image2video --image gen.png --prompt "<TEXT>make it move</TEXT><VIDEO>" --frames 4 --out animated.gif

# Video-To-Text
python run_supra_a2a.py --mode video2text --image video.gif

# Video-To-Image
python run_supra_a2a.py --mode video2image --image video.gif --prompt "<TEXT>extract style</TEXT><IMAGE>" --out frame.png

# Video-To-Video
python run_supra_a2a.py --mode video2video --image video.gif --prompt "<TEXT>change style</TEXT><VIDEO>" --frames 4 --out transformed.gif

run_supra_a2a.py is self-contained: it rebuilds the GPT and VQ-VAE modules directly from the raw state_dicts, validates that the tokenizer vocabulary and the visual codebook line up with the GPT's embedding table, and falls back to downloading missing files from this Hub repo automatically. No other custom code or config file is required.

Limitations

  • Trained at nano scale (~30M params, 384-token context) for proof-of-concept purposes, not benchmark performance. Text generations will be incoherent past a sentence or two, and generated images will be low-resolution and abstract rather than photorealistic.
  • Image side only handles square dimensions that are multiples of 8 (the VQ-VAE's downsampling factor).
  • No instruction-tuning or RLHF; this is a base, unaligned model trained purely on the next-token objective across the unified token stream.

About SupraLabs

SupraLabs is a small open-source AI lab building tiny language and multimodal models from scratch on consumer hardware, released openly so others can study, fine-tune, and build on them. Supra-A2A-Nano-Exp is part of that broader effort alongside the Supra language model family and the SupraVID text-to-video project.

Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
29.7M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support