Supra-A2A-Nano-Exp
Status: experimental / educational prototype. Not a polished product.
Supra-A2A-Nano-Exp is a small proof-of-concept any-to-any model from SupraLabs: a single autoregressive GPT that reads and writes both text and images by treating them as one unified sequence of discrete tokens. Text uses standard BPE tokens; images are discretized into a small set of learned "visual words" by a convolutional VQ-VAE and appended to the same vocabulary. The whole multimodal stream is modeled by one Transformer with one set of weights, no separate vision encoder or diffusion head bolted on.
This release exists mainly to demonstrate the idea of unified tokenization across modalities on consumer-scale hardware. At ~30M total parameters and a 384-token context, do not expect coherent long-form text or photorealistic images. Treat it as a transparent, hackable example architecture rather than a capable generator.
How it works
Every input, text or image, is serialized into one token stream wrapped in control tags:
<TEXT>some text here</TEXT>
<IMAGE><FRAME>[64 visual tokens]</IMAGE>
<VIDEO><FRAME>[visual tokens]<FRAME>[visual tokens]...</VIDEO>
- Text tokens come from a GPT-2-style BPE tokenizer (50,257 tokens) plus
7 control tokens (
<TEXT>,</TEXT>,<IMAGE>,</IMAGE>,<VIDEO>,</VIDEO>,<FRAME>), for 50,264 text-side ids total. - Visual tokens come from a VQ-VAE codebook of 256 entries. An image is encoded by 3 strided convolutions (/8 downsampling total), and each resulting spatial cell is snapped to its nearest codebook vector. A 64x64 image becomes an 8x8 grid, i.e. 64 visual tokens.
- These 256 visual codes are appended right after the text vocabulary, so
the GPT's combined vocabulary is exactly
50,264 + 256 = 50,520tokens. One embedding table, one output head, one model for both modalities.
Because images are projected into the same id space as text, the GPT can in principle attend across modalities: condition image generation on a text prompt, or condition text generation on image content, using the exact same attention mechanism it uses for next-token text prediction.
Architecture
| Component | Detail |
|---|---|
| GPT backbone | 4 Transformer blocks, pre-norm, fused QKV attention, causal |
| Embedding dim | 256 |
| Context length | 384 tokens |
| Attention heads | 4 (assumed — see note below) |
| MLP expansion | 4x (256 -> 1024 -> 256), GELU |
| Combined vocabulary | 50,520 (50,264 text + 256 visual) |
| GPT parameters | ~29.7M |
| VQ-VAE | 3-layer conv encoder / decoder, /8 downsampling, 256x64 codebook |
| VQ-VAE parameters | ~0.22M |
| Total parameters | ~29.9M |
| Precision | fp32 |
Note on attention heads: the checkpoint stores QKV as a single fused
Linear layer, so the head count isn't recoverable from the weights alone.
run_supra_a2a.py defaults to 4 heads (64-dim per head, the GPT-2
convention). If you trained this checkpoint yourself and used a different
head count, change N_HEAD at the top of the script — loading will succeed
either way (shapes match), but the wrong value will silently produce
incorrect attention.
Note on pixel normalization: similarly, the VQ-VAE decoder's final
activation has no learnable parameters, so it isn't visible in the
checkpoint either. The script defaults to sigmoid (assumes training in
[0, 1]). If reconstructions look off, try VQVAE_OUTPUT_ACTIVATION = "tanh".
Files in this repository
| File | Description |
|---|---|
model.safetensors |
GPT backbone weights (state_dict, fp32) |
vqvae.safetensors |
VQ-VAE encoder/decoder/codebook weights (state_dict, fp32) |
tokenizer.json |
Fast BPE tokenizer (GPT-2 base vocab + 7 control tokens) |
tokenizer_config.json |
Tokenizer metadata (special tokens, GPT2Tokenizer class) |
run_supra_a2a.py |
Plug-and-play inference script (see below) |
Samples
| Prompt | Response |
|---|---|
| Text2Text: Once upon a time | Once upon a time, there was a little girl named Lily. She loved to visit her grandma's house. Her grandma was very bossy and always told her what to do. One day, Lily went to visit her grandma and something strange happened. The sky turned dark and it started to rain. Lily was scared, but her grandma told her not to worry. They sat inside and played games until the rain stopped. |
| Text2Image: A dog running on snow | ![]() |
| Text2Video: A snow scene | https://cdn-uploads.huggingface.co/production/uploads/68df176c403a7bf9e8ae85a8/Lv1UqRG1m8KkiT46X8RNC.mp4 |
Usage
pip install torch transformers huggingface_hub safetensors pillow numpy
#!/usr/bin/env python3
"""
Supra-A2A-Nano-Exp - inference runner
======================================
An experimental any-to-any model from SupraLabs: a single autoregressive GPT
operates over one unified vocabulary that mixes text (BPE, GPT-2 style) and
discrete visual codes from a convolutional VQ-VAE. Text and images are
serialized into the same token stream, delimited by control tokens
(<TEXT>, <IMAGE>, <VIDEO>, <FRAME>, and their closing tags).
This script reconstructs the exact architecture from the raw state_dicts
(no config.json ships with the weights) and exposes a SupraA2A class with
high-level methods for text completion, image reconstruction (a VQ-VAE
sanity check), and text-conditioned image generation.
Quick start:
python run_supra_a2a.py --mode text --prompt "Once upon a time"
python run_supra_a2a.py --mode chat
python run_supra_a2a.py --mode text2image --prompt "<TEXT>a red square</TEXT><IMAGE>" --out gen.png
python run_supra_a2a.py --mode text2video --prompt "<TEXT>A snow scene</TEXT><VIDEO>" --frames 4 --out video.gif
python run_supra_a2a.py --mode reconstruct --image photo.png --out recon.png
python run_supra_a2a.py --mode image2text --image gen.png
python run_supra_a2a.py --mode image2image --image photo.png --prompt "<TEXT>make it red</TEXT><IMAGE>" --out edit.png
python run_supra_a2a.py --mode image2video --image photo.png --prompt "<TEXT>make it move</TEXT><VIDEO>" --frames 4 --out animated.gif
python run_supra_a2a.py --mode video2text --image video.gif
python run_supra_a2a.py --mode video2image --image video.gif --prompt "<TEXT>extract style</TEXT><IMAGE>" --out frame.png
python run_supra_a2a.py --mode video2video --image video.gif --prompt "<TEXT>change style</TEXT><VIDEO>" --frames 4 --out transformed.gif
"""
from __future__ import annotations
import argparse
import math
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from safetensors.torch import load_file as load_safetensors
# --------------------------------------------------------------------------- #
# Config
# --------------------------------------------------------------------------- #
MODEL_WEIGHT_STEMS = ["model", "vqvae"] # each may exist as .safetensors or .pt
TOKENIZER_FILES = ["tokenizer.json", "tokenizer_config.json"]
# Patterns used for the Hub fallback download; covers both weight formats so
# this works whether the repo ships .safetensors (preferred), .pt, or both.
HUB_ALLOW_PATTERNS = [f"{s}.safetensors" for s in MODEL_WEIGHT_STEMS] + [
f"{s}.pt" for s in MODEL_WEIGHT_STEMS
] + TOKENIZER_FILES
DEFAULT_REPO_ID = "SupraLabs/Supra-A2A-Nano-Exp"
# NOTE: the GPT checkpoint does not store n_head explicitly (qkv is a single
# fused Linear layer). 256 / 4 = 64 head_dim (the GPT-2 convention) is the
# most likely value and is the default here, but it cannot be verified from
# the weights alone. A wrong value will NOT break loading (shapes still
# match) but will silently produce incorrect attention. If generations look
# off, try N_HEAD = 8 (head_dim 32) instead.
N_HEAD = 4
# Pixel normalization used by the VQ-VAE decoder's final layer. There is no
# parametrized activation after the last ConvTranspose2d in the checkpoint,
# so this is also an assumption rather than a certainty. Default: sigmoid
# (assumes images were trained in [0, 1]). If reconstructions look washed
# out / inverted, try "tanh" (assumes [-1, 1] training).
VQVAE_OUTPUT_ACTIVATION = "sigmoid" # "sigmoid" | "tanh" | "none"
# --------------------------------------------------------------------------- #
# Weight resolution (local dir -> Hugging Face Hub fallback)
# --------------------------------------------------------------------------- #
def _has_all_weights(d: Path) -> bool:
weights_ok = all(
(d / f"{stem}.safetensors").exists() or (d / f"{stem}.pt").exists()
for stem in MODEL_WEIGHT_STEMS
)
tok_ok = all((d / f).exists() for f in TOKENIZER_FILES)
return weights_ok and tok_ok
def resolve_weights_dir(weights_dir: Path, repo_id: str) -> Path:
"""Return a directory containing the tokenizer files plus, for each model
in MODEL_WEIGHT_STEMS, either a .safetensors or a .pt file.
Looks locally first; if anything is missing, tries to download from the
Hub via huggingface_hub.snapshot_download (works once the
SupraLabs/Supra-A2A-Nano-Exp repo is public).
"""
if _has_all_weights(weights_dir):
return weights_dir
print(f"[info] weight files incomplete in {weights_dir}")
print(f"[info] trying to download from huggingface.co/{repo_id} ...")
try:
from huggingface_hub import snapshot_download
except ImportError as e:
raise RuntimeError(
"huggingface_hub is not installed and local weights are incomplete. "
"Run: pip install huggingface_hub"
) from e
try:
downloaded = snapshot_download(repo_id=repo_id, allow_patterns=HUB_ALLOW_PATTERNS)
except Exception as e:
raise RuntimeError(
f"Could not find the weights locally or on the Hub ({repo_id}). "
f"Place tokenizer files plus {MODEL_WEIGHT_STEMS} (.safetensors or .pt) "
f"manually in {weights_dir}.\nOriginal error: {e}"
) from e
downloaded = Path(downloaded)
if not _has_all_weights(downloaded):
raise RuntimeError(f"Download from {repo_id} completed but required files are still missing.")
return downloaded
def load_state_dict_any(weights_dir: Path, stem: str) -> dict:
"""Load a checkpoint by stem name, preferring .safetensors over legacy .pt."""
st_path = weights_dir / f"{stem}.safetensors"
pt_path = weights_dir / f"{stem}.pt"
if st_path.exists():
return load_safetensors(str(st_path))
if pt_path.exists():
return torch.load(pt_path, map_location="cpu", weights_only=False)
raise FileNotFoundError(f"Neither {st_path.name} nor {pt_path.name} found in {weights_dir}")
# --------------------------------------------------------------------------- #
# VQ-VAE (conv encoder / codebook / transposed-conv decoder)
# --------------------------------------------------------------------------- #
class VectorQuantizer(nn.Module):
"""Discrete codebook with nearest-neighbor lookup (inference only, no EMA)."""
def __init__(self, num_codes: int, dim: int):
super().__init__()
self.num_codes = num_codes
self.dim = dim
self.embedding = nn.Embedding(num_codes, dim)
def encode(self, z: torch.Tensor) -> torch.Tensor:
"""z: (B, C, H, W) -> discrete indices (B, H, W)."""
b, c, h, w = z.shape
z_flat = z.permute(0, 2, 3, 1).reshape(-1, c)
codebook = self.embedding.weight # (num_codes, dim)
dist = (
z_flat.pow(2).sum(1, keepdim=True)
- 2 * z_flat @ codebook.t()
+ codebook.pow(2).sum(1)
)
idx = dist.argmin(dim=1)
return idx.view(b, h, w)
def decode(self, idx: torch.Tensor) -> torch.Tensor:
"""idx: (B, H, W) -> z_q (B, C, H, W)."""
z_q = self.embedding(idx) # (B, H, W, C)
return z_q.permute(0, 3, 1, 2).contiguous()
class VQVAE(nn.Module):
"""Conv autoencoder with /8 downsampling, codebook size/dim read from the checkpoint."""
def __init__(self, codebook_size: int = 256, code_dim: int = 64):
super().__init__()
self.enc = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, code_dim, kernel_size=4, stride=2, padding=1),
)
self.vq = VectorQuantizer(codebook_size, code_dim)
self.dec = nn.Sequential(
nn.ConvTranspose2d(code_dim, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
)
def _final_activation(self, x: torch.Tensor) -> torch.Tensor:
if VQVAE_OUTPUT_ACTIVATION == "sigmoid":
return torch.sigmoid(x)
if VQVAE_OUTPUT_ACTIVATION == "tanh":
return torch.tanh(x)
return x
@torch.no_grad()
def encode_to_indices(self, img: torch.Tensor) -> torch.Tensor:
"""img: (B, 3, H, W) normalized to [0, 1], H and W multiples of 8."""
z = self.enc(img)
return self.vq.encode(z)
@torch.no_grad()
def decode_from_indices(self, idx: torch.Tensor) -> torch.Tensor:
"""idx: (B, H, W) -> image (B, 3, H*8, W*8) in [0, 1]."""
z_q = self.vq.decode(idx)
x = self.dec(z_q)
return self._final_activation(x).clamp(0, 1)
# --------------------------------------------------------------------------- #
# GPT (nanoGPT-style: pre-norm, fused qkv, 4x MLP)
# --------------------------------------------------------------------------- #
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd: int, n_head: int, block_size: int):
super().__init__()
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
self.n_head = n_head
self.qkv = nn.Linear(n_embd, 3 * n_embd)
self.proj = nn.Linear(n_embd, n_embd)
mask = torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
self.register_buffer("mask", mask)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, t, c = x.shape
qkv = self.qkv(x)
q, k, v = qkv.split(c, dim=2)
hd = c // self.n_head
q = q.view(b, t, self.n_head, hd).transpose(1, 2)
k = k.view(b, t, self.n_head, hd).transpose(1, 2)
v = v.view(b, t, self.n_head, hd).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(hd)
att = att.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1, 2).contiguous().view(b, t, c)
return self.proj(y)
class Block(nn.Module):
def __init__(self, n_embd: int, n_head: int, block_size: int):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head, block_size)
self.ln2 = nn.LayerNorm(n_embd)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, vocab_size: int, n_embd: int, block_size: int, n_layer: int, n_head: int):
super().__init__()
self.block_size = block_size
self.tok_emb = nn.Embedding(vocab_size, n_embd)
self.pos_emb = nn.Embedding(block_size, n_embd)
self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, idx: torch.Tensor) -> torch.Tensor:
b, t = idx.shape
assert t <= self.block_size, f"sequence length ({t}) exceeds model context ({self.block_size})"
pos = torch.arange(t, device=idx.device)
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
for blk in self.blocks:
x = blk(x)
x = self.ln_f(x)
return self.head(x)
# --------------------------------------------------------------------------- #
# Unified pipeline
# --------------------------------------------------------------------------- #
@dataclass
class VocabLayout:
text_vocab_size: int # BPE + control tokens (<TEXT>, <IMAGE>, ...)
visual_offset: int # = text_vocab_size: where visual codes start
visual_vocab_size: int # VQ-VAE codebook size
total_vocab_size: int # text_vocab_size + visual_vocab_size
class SupraA2A:
"""High-level wrapper: tokenizer + VQ-VAE + GPT sharing one unified vocabulary."""
def __init__(self, weights_dir: Path, device: Optional[str] = None, n_head: int = N_HEAD):
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
from transformers import PreTrainedTokenizerFast
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(weights_dir / "tokenizer.json"))
self.tokenizer.pad_token = "<|endoftext|>"
self.tokenizer.bos_token = "<|endoftext|>"
self.tokenizer.eos_token = "<|endoftext|>"
self.tokenizer.unk_token = "<|endoftext|>"
vq_state = load_state_dict_any(weights_dir, "vqvae")
gpt_state = load_state_dict_any(weights_dir, "model")
codebook_size, code_dim = vq_state["vq.embedding.weight"].shape
self.vqvae = VQVAE(codebook_size=codebook_size, code_dim=code_dim)
self.vqvae.load_state_dict(vq_state, strict=True)
self.vqvae.eval().to(self.device)
total_vocab, n_embd = gpt_state["tok_emb.weight"].shape
block_size = gpt_state["pos_emb.weight"].shape[0]
n_layer = len({k.split(".")[1] for k in gpt_state if k.startswith("blocks.")})
self.gpt = GPT(total_vocab, n_embd, block_size, n_layer, n_head)
self.gpt.load_state_dict(gpt_state, strict=True)
self.gpt.eval().to(self.device)
text_vocab_size = len(self.tokenizer)
self.vocab = VocabLayout(
text_vocab_size=text_vocab_size,
visual_offset=text_vocab_size,
visual_vocab_size=codebook_size,
total_vocab_size=total_vocab,
)
expected_total = self.vocab.text_vocab_size + self.vocab.visual_vocab_size
if expected_total != total_vocab:
raise ValueError(
f"Vocabulary mismatch: tokenizer + codebook = {expected_total}, "
f"but the GPT's tok_emb expects {total_vocab}. "
"Make sure you're using the tokenizer.json that matches these weights."
)
print(
f"[ok] SupraA2A loaded on {self.device} | "
f"GPT: {n_layer}L/{n_embd}d/ctx{block_size} | "
f"vocab text={text_vocab_size} + visual={codebook_size} = {total_vocab} | "
f"VQ-VAE: {codebook_size}x{code_dim} codes (/8 downsample)"
)
# ------------------------------------------------------------------ #
# Text
# ------------------------------------------------------------------ #
def encode_text(self, text: str) -> List[int]:
return self.tokenizer.encode(text)
def decode_text(self, ids: List[int]) -> str:
# visual token ids (outside the tokenizer's range) are filtered out
text_ids = [i for i in ids if i < self.vocab.text_vocab_size]
return self.tokenizer.decode(text_ids)
# ------------------------------------------------------------------ #
# Image <-> visual tokens
# ------------------------------------------------------------------ #
def image_to_tokens(self, image: Image.Image) -> Tuple[List[int], Tuple[int, int]]:
"""Convert a PIL image into visual token ids (already offset into the unified vocab).
Height/width must be multiples of 8 (three /2 downsampling stages).
"""
w, h = image.size
if w % 8 or h % 8:
raise ValueError(f"Image {w}x{h} must have dimensions that are multiples of 8.")
img = image.convert("RGB")
arr = np.array(img, dtype=np.uint8) # copy (writable), avoids a torch warning
tensor = torch.from_numpy(arr).float().permute(2, 0, 1) / 255.0
tensor = tensor.unsqueeze(0).to(self.device)
with torch.no_grad():
idx = self.vqvae.encode_to_indices(tensor) # (1, h/8, w/8)
grid_h, grid_w = idx.shape[1], idx.shape[2]
flat = (idx.view(-1) + self.vocab.visual_offset).tolist()
return flat, (grid_h, grid_w)
def tokens_to_image(self, token_ids: List[int], grid: Tuple[int, int]) -> Image.Image:
"""Convert visual token ids (offset) back into a PIL image."""
grid_h, grid_w = grid
expected = grid_h * grid_w
if len(token_ids) != expected:
raise ValueError(f"Expected {expected} visual tokens for grid {grid}, got {len(token_ids)}.")
raw = [i - self.vocab.visual_offset for i in token_ids]
if any(i < 0 or i >= self.vocab.visual_vocab_size for i in raw):
raise ValueError("Token id(s) outside the visual codebook range.")
idx = torch.tensor(raw, device=self.device).view(1, grid_h, grid_w)
with torch.no_grad():
img_t = self.vqvae.decode_from_indices(idx)[0].cpu()
arr = (img_t.permute(1, 2, 0).numpy() * 255).round().astype("uint8")
return Image.fromarray(arr)
def reconstruct_image(self, image: Image.Image) -> Image.Image:
"""Encode -> decode round-trip through the VQ-VAE only (no GPT). Useful sanity check."""
tokens, grid = self.image_to_tokens(image)
return self.tokens_to_image(tokens, grid)
# ------------------------------------------------------------------ #
# Autoregressive sampling
# ------------------------------------------------------------------ #
@torch.no_grad()
def _sample_step(
self,
ids: List[int],
temperature: float,
top_k: Optional[int],
allowed_range: Optional[Tuple[int, int]] = None,
) -> int:
ctx = ids[-self.gpt.block_size :]
x = torch.tensor([ctx], device=self.device)
logits = self.gpt(x)[0, -1] / max(temperature, 1e-5)
if allowed_range is not None:
lo, hi = allowed_range
mask = torch.full_like(logits, float("-inf"))
mask[lo:hi] = 0
logits = logits + mask
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[-1]] = float("-inf")
probs = F.softmax(logits, dim=-1)
return int(torch.multinomial(probs, num_samples=1).item())
def generate(
self,
prompt_ids: List[int],
max_new_tokens: int = 64,
temperature: float = 0.8,
top_k: Optional[int] = 40,
stop_token_id: Optional[int] = None,
) -> List[int]:
ids = list(prompt_ids)
for _ in range(max_new_tokens):
next_id = self._sample_step(ids, temperature, top_k)
ids.append(next_id)
if stop_token_id is not None and next_id == stop_token_id:
break
return ids
def complete_text(self, prompt: str, max_new_tokens: int = 64, temperature: float = 0.8, top_k: int = 40) -> str:
eot_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
prompt_ids = self.encode_text(prompt)
out_ids = self.generate(prompt_ids, max_new_tokens, temperature, top_k, stop_token_id=eot_id)
return self.decode_text(out_ids)
def generate_image(
self,
prompt: str,
grid: Tuple[int, int] = (8, 8),
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> Image.Image:
"""Generate grid_h*grid_w visual tokens conditioned on `prompt` and decode to an image.
While sampling image tokens, the sampler is restricted to the visual
codebook's id range, so it can never "leak" a text token into the image.
"""
ids = self.encode_text(prompt)
lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
n_tokens = grid[0] * grid[1]
for _ in range(n_tokens):
next_id = self._sample_step(ids, temperature, top_k, allowed_range=(lo, hi))
ids.append(next_id)
visual_ids = ids[-n_tokens:]
return self.tokens_to_image(visual_ids, grid)
def generate_video(
self,
prompt: str,
num_frames: int = 4,
grid: Tuple[int, int] = (8, 8),
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> List[Image.Image]:
frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
ids = self.encode_text(prompt)
lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
n_tokens_per_frame = grid[0] * grid[1]
frames = []
for f in range(num_frames):
print(f"[info] Generating frame {f+1}/{num_frames}...")
ids.append(frame_token_id)
for _ in range(n_tokens_per_frame):
next_id = self._sample_step(ids, temperature, top_k, allowed_range=(lo, hi))
ids.append(next_id)
frame_tokens = ids[-n_tokens_per_frame:]
img = self.tokens_to_image(frame_tokens, grid)
frames.append(img)
return frames
def image_to_text(self, image: Image.Image, max_new_tokens: int = 64, temperature: float = 0.8, top_k: int = 40) -> str:
visual_ids, _ = self.image_to_tokens(image)
image_start_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
image_end_id = self.tokenizer.convert_tokens_to_ids("</IMAGE>")
text_start_id = self.tokenizer.convert_tokens_to_ids("<TEXT>")
eot_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
prompt_ids = [image_start_id, frame_token_id] + visual_ids + [image_end_id, text_start_id]
out_ids = self.generate(prompt_ids, max_new_tokens, temperature, top_k, stop_token_id=eot_id)
gen_ids = out_ids[len(prompt_ids):]
return self.decode_text(gen_ids)
def video_to_text(self, frames: List[Image.Image], max_new_tokens: int = 64, temperature: float = 0.8, top_k: int = 40) -> str:
video_start_id = self.tokenizer.convert_tokens_to_ids("<VIDEO>")
frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
video_end_id = self.tokenizer.convert_tokens_to_ids("</VIDEO>")
text_start_id = self.tokenizer.convert_tokens_to_ids("<TEXT>")
eot_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
prompt_ids = [video_start_id]
for frame in frames:
visual_ids, _ = self.image_to_tokens(frame)
prompt_ids += [frame_token_id] + visual_ids
prompt_ids += [video_end_id, text_start_id]
out_ids = self.generate(prompt_ids, max_new_tokens, temperature, top_k, stop_token_id=eot_id)
gen_ids = out_ids[len(prompt_ids):]
return self.decode_text(gen_ids)
def image_to_image(self, image: Image.Image, prompt: str, grid: Tuple[int, int] = (8, 8), temperature: float = 1.0, top_k: Optional[int] = None) -> Image.Image:
visual_ids, _ = self.image_to_tokens(image)
image_start_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
image_end_id = self.tokenizer.convert_tokens_to_ids("</IMAGE>")
prompt_ids = [image_start_id, frame_token_id] + visual_ids + [image_end_id] + self.encode_text(prompt) + [image_start_id]
lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
n_tokens = grid[0] * grid[1]
for _ in range(n_tokens):
next_id = self._sample_step(prompt_ids, temperature, top_k, allowed_range=(lo, hi))
prompt_ids.append(next_id)
return self.tokens_to_image(prompt_ids[-n_tokens:], grid)
def image_to_video(self, image: Image.Image, prompt: str, num_frames: int = 4, grid: Tuple[int, int] = (8, 8), temperature: float = 1.0, top_k: Optional[int] = None) -> List[Image.Image]:
visual_ids, _ = self.image_to_tokens(image)
image_start_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
image_end_id = self.tokenizer.convert_tokens_to_ids("</IMAGE>")
video_start_id = self.tokenizer.convert_tokens_to_ids("<VIDEO>")
prompt_ids = [image_start_id, frame_token_id] + visual_ids + [image_end_id] + self.encode_text(prompt) + [video_start_id]
lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
n_tokens_per_frame = grid[0] * grid[1]
frames = []
for f in range(num_frames):
print(f"[info] Generating frame {f+1}/{num_frames}...")
prompt_ids.append(frame_token_id)
for _ in range(n_tokens_per_frame):
next_id = self._sample_step(prompt_ids, temperature, top_k, allowed_range=(lo, hi))
prompt_ids.append(next_id)
frames.append(self.tokens_to_image(prompt_ids[-n_tokens_per_frame:], grid))
return frames
def video_to_image(self, frames: List[Image.Image], prompt: str, grid: Tuple[int, int] = (8, 8), temperature: float = 1.0, top_k: Optional[int] = None) -> Image.Image:
video_start_id = self.tokenizer.convert_tokens_to_ids("<VIDEO>")
frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
video_end_id = self.tokenizer.convert_tokens_to_ids("</VIDEO>")
image_start_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
prompt_ids = [video_start_id]
for frame in frames:
visual_ids, _ = self.image_to_tokens(frame)
prompt_ids += [frame_token_id] + visual_ids
prompt_ids += [video_end_id] + self.encode_text(prompt) + [image_start_id]
lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
n_tokens = grid[0] * grid[1]
for _ in range(n_tokens):
next_id = self._sample_step(prompt_ids, temperature, top_k, allowed_range=(lo, hi))
prompt_ids.append(next_id)
return self.tokens_to_image(prompt_ids[-n_tokens:], grid)
def video_to_video(self, frames: List[Image.Image], prompt: str, num_frames: int = 4, grid: Tuple[int, int] = (8, 8), temperature: float = 1.0, top_k: Optional[int] = None) -> List[Image.Image]:
video_start_id = self.tokenizer.convert_tokens_to_ids("<VIDEO>")
frame_token_id = self.tokenizer.convert_tokens_to_ids("<FRAME>")
video_end_id = self.tokenizer.convert_tokens_to_ids("</VIDEO>")
prompt_ids = [video_start_id]
for frame in frames:
visual_ids, _ = self.image_to_tokens(frame)
prompt_ids += [frame_token_id] + visual_ids
prompt_ids += [video_end_id] + self.encode_text(prompt) + [video_start_id]
lo, hi = self.vocab.visual_offset, self.vocab.visual_offset + self.vocab.visual_vocab_size
n_tokens_per_frame = grid[0] * grid[1]
out_frames = []
for f in range(num_frames):
print(f"[info] Generating frame {f+1}/{num_frames}...")
prompt_ids.append(frame_token_id)
for _ in range(n_tokens_per_frame):
next_id = self._sample_step(prompt_ids, temperature, top_k, allowed_range=(lo, hi))
prompt_ids.append(next_id)
out_frames.append(self.tokens_to_image(prompt_ids[-n_tokens_per_frame:], grid))
return out_frames
# --------------------------------------------------------------------------- #
# CLI
# --------------------------------------------------------------------------- #
def main() -> None:
parser = argparse.ArgumentParser(description="Supra-A2A-Nano-Exp - inference runner")
parser.add_argument("--mode", choices=[
"text", "chat", "reconstruct",
"text2image", "text2video",
"image2text", "video2text",
"image2image", "image2video", "video2image", "video2video"
], default="text")
parser.add_argument("--weights_dir", type=Path, default=Path(__file__).resolve().parent)
parser.add_argument("--repo_id", default=DEFAULT_REPO_ID)
parser.add_argument("--prompt", default="<TEXT>")
parser.add_argument("--image", type=Path, help="input image (reconstruct mode)")
parser.add_argument("--out", type=Path, default=Path("output.png"))
parser.add_argument("--grid", type=int, nargs=2, default=[8, 8], metavar=("H", "W"))
parser.add_argument("--frames", type=int, default=4, help="Number of generated video frames")
parser.add_argument("--max_new_tokens", type=int, default=64)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_k", type=int, default=40)
parser.add_argument("--device", default=None)
parser.add_argument("--seed", type=int, default=None)
args = parser.parse_args()
if args.seed is not None:
torch.manual_seed(args.seed)
weights_dir = resolve_weights_dir(args.weights_dir, args.repo_id)
model = SupraA2A(weights_dir, device=args.device)
if args.mode == "text":
print(model.complete_text(args.prompt, args.max_new_tokens, args.temperature, args.top_k))
elif args.mode == "chat":
print("Chat mode - Ctrl+C to exit.")
while True:
try:
prompt = input("\n> ")
except (KeyboardInterrupt, EOFError):
break
print(model.complete_text(prompt, args.max_new_tokens, args.temperature, args.top_k))
elif args.mode == "reconstruct":
if not args.image:
sys.exit("--image is required in reconstruct mode")
img = Image.open(args.image)
recon = model.reconstruct_image(img)
recon.save(args.out)
print(f"[ok] reconstruction saved to {args.out}")
elif args.mode == "text2image":
img = model.generate_image(
args.prompt, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k
)
img.save(args.out)
print(f"[ok] generated image saved to {args.out}")
elif args.mode == "text2video":
frames = model.generate_video(
args.prompt, num_frames=args.frames, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k
)
out_gif = args.out.with_suffix(".gif")
frames[0].save(
out_gif,
save_all=True,
append_images=frames[1:],
duration=200,
loop=0
)
print(f"[ok] Video saved as animated GIF in: {out_gif}")
elif args.mode == "image2text":
if not args.image:
sys.exit("--image is needed for image2text")
img = Image.open(args.image)
description = model.image_to_text(img, args.max_new_tokens, args.temperature, args.top_k)
print(f"\n[Caption fpr {args.image.name}]:\n{description}")
elif args.mode == "video2text":
if not args.image:
sys.exit("--image is needed for video2text (pass e.g. the animated .gif)")
gif = Image.open(args.image)
frames = []
try:
while True:
frames.append(gif.convert("RGB"))
gif.seek(gif.tell() + 1)
except EOFError:
pass
print(f"[info] Processing {len(frames)} frames from video/GIF...")
description = model.video_to_text(frames, args.max_new_tokens, args.temperature, args.top_k)
print(f"\n[Caption for video {args.image.name}]:\n{description}")
elif args.mode == "image2image":
if not args.image:
sys.exit("--image is needed for image2image")
img = Image.open(args.image)
out_img = model.image_to_image(img, args.prompt, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k)
out_img.save(args.out)
print(f"[ok] Image-to-Image result saved as: {args.out}")
elif args.mode == "image2video":
if not args.image:
sys.exit("--image is needed for image2video")
img = Image.open(args.image)
frames = model.image_to_video(img, args.prompt, num_frames=args.frames, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k)
out_gif = args.out.with_suffix(".gif")
frames[0].save(out_gif, save_all=True, append_images=frames[1:], duration=200, loop=0)
print(f"[ok] Image-to-Video result saved as: {out_gif}")
elif args.mode == "video2image":
if not args.image:
sys.exit("--image (animated .gif) is needed for video2image")
gif = Image.open(args.image)
frames = []
try:
while True:
frames.append(gif.convert("RGB")); gif.seek(gif.tell() + 1)
except EOFError: pass
frames = frames[:3]
print(f"[info] Processing {len(frames)} Input-Frames...")
out_img = model.video_to_image(frames, args.prompt, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k)
out_img.save(args.out)
print(f"[ok] Video-to-Image result saved as: {args.out}")
elif args.mode == "video2video":
if not args.image:
sys.exit("--image (animated .gif) is needed for video2video")
gif = Image.open(args.image)
frames = []
try:
while True:
frames.append(gif.convert("RGB")); gif.seek(gif.tell() + 1)
except EOFError: pass
frames = frames[:3]
print(f"[info] Processing {len(frames)} Input-Frames...")
out_frames = model.video_to_video(frames, args.prompt, num_frames=args.frames, grid=tuple(args.grid), temperature=args.temperature, top_k=args.top_k)
out_gif = args.out.with_suffix(".gif")
out_frames[0].save(out_gif, save_all=True, append_images=out_frames[1:], duration=200, loop=0)
print(f"[ok] Video-to-Video result saved as: {out_gif}")
if __name__ == "__main__":
main()
Actions
Modalities:
Text --> Image
Text --> Video
Text --> Text
Image --> Image
Image --> Video
Image --> Text
Video --> Image
Video --> Video
Video --> Text
# Text
python run_supra_a2a.py --mode text --prompt "Once upon a time"
# Chat
python run_supra_a2a.py --mode chat
# Text-To-Image
python run_supra_a2a.py --mode text2image --prompt "<TEXT>a red square</TEXT><IMAGE>" --out gen.png
# Text-To-Video
python run_supra_a2a.py --mode text2video --prompt "<TEXT>A snow scene</TEXT><VIDEO>" --frames 4 --out video.gif
# Reconstruct (test VQ-VAE for sanity checj)
python run_supra_a2a.py --mode reconstruct --image gen.png --out gen_recon.png
# Image-To-Text
python run_supra_a2a.py --mode image2text --image gen.png
# Image-To-Image
python run_supra_a2a.py --mode image2image --image gen.png --prompt "<TEXT>make it red</TEXT><IMAGE>" --out edit.png
# Image-To-Video
python run_supra_a2a.py --mode image2video --image gen.png --prompt "<TEXT>make it move</TEXT><VIDEO>" --frames 4 --out animated.gif
# Video-To-Text
python run_supra_a2a.py --mode video2text --image video.gif
# Video-To-Image
python run_supra_a2a.py --mode video2image --image video.gif --prompt "<TEXT>extract style</TEXT><IMAGE>" --out frame.png
# Video-To-Video
python run_supra_a2a.py --mode video2video --image video.gif --prompt "<TEXT>change style</TEXT><VIDEO>" --frames 4 --out transformed.gif
run_supra_a2a.py is self-contained: it rebuilds the GPT and VQ-VAE modules
directly from the raw state_dicts, validates that the tokenizer vocabulary
and the visual codebook line up with the GPT's embedding table, and falls
back to downloading missing files from this Hub repo automatically. No other
custom code or config file is required.
Limitations
- Trained at nano scale (~30M params, 384-token context) for proof-of-concept purposes, not benchmark performance. Text generations will be incoherent past a sentence or two, and generated images will be low-resolution and abstract rather than photorealistic.
- Image side only handles square dimensions that are multiples of 8 (the VQ-VAE's downsampling factor).
- No instruction-tuning or RLHF; this is a base, unaligned model trained purely on the next-token objective across the unified token stream.
About SupraLabs
SupraLabs is a small open-source AI lab building tiny language and multimodal models from scratch on consumer hardware, released openly so others can study, fine-tune, and build on them. Supra-A2A-Nano-Exp is part of that broader effort alongside the Supra language model family and the SupraVID text-to-video project.
