An ImageGPT-like model trained from LFM-2.5-350M-Base on 250k captions and images with a max resolution of 64x64 from laion/conceptual-captions-12m-webdataset with 64 color palettes extracted via sklearn's MiniBatchKMeans

Usage:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import numpy as np
from PIL import Image
from torchvision import transforms

dtype = torch.bfloat16
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_id = "crumb/imagegpt-lfm2.5-350m"
model = AutoModelForCausalLM.from_pretrained(
    model_id, dtype=dtype, device_map={"":device},attn_implementation="flash_attention_2") #<- fa2 if using compatible gpu!
tokenizer = AutoTokenizer.from_pretrained(model_id)

def decode_to_image(generated_text: str) -> Image.Image:
    """
    Parses the generated text string back into a PIL Image.
    Handles incomplete generations (draws whatever full rows exist).
    """
    try:
        res_section = generated_text.split("<|res|>")[1].split("<|pal|>")[0].strip()
        w, h = map(int, res_section.split("x"))
    except (IndexError, ValueError):
        raise ValueError("Could not parse resolution. Model output malformed.")

    palette = {}
    try:
        pal_section = generated_text.split("<|pal|>")[1].split("<|img|>")[0].strip()
        pal_entries = pal_section.split("<|")

        for entry in pal_entries:
            if not entry.strip():
                continue
            idx_str, color_str = entry.split("|>")
            idx = int(idx_str)
            r, g, b = map(int, color_str.split("_"))
            palette[idx] = [r, g, b]

    except (IndexError, ValueError):
        raise ValueError("Could not parse palette. Model output malformed.")

    try:
        img_section = generated_text.split("<|img|>")[1].strip()
        token_matches = re.findall(r"<\|(\d+)\|>", img_section)
        indices = [int(m) for m in token_matches]

    except (IndexError, ValueError):
        raise ValueError("Could not parse image tokens. Model output malformed.")

    pixels = []
    for idx in indices:
        pixels.append(palette.get(idx, [0, 0, 0]))

    pixels_np = np.array(pixels, dtype=np.uint8)
    num_complete_rows = len(pixels_np) // w

    if num_complete_rows == 0:
        return Image.new('RGB', (1, 1), (0, 0, 0))

    pixels_np = pixels_np[:num_complete_rows * w]
    img_array = pixels_np.reshape(num_complete_rows, w, 3)
    img = Image.fromarray(img_array, 'RGB')
    return img

with torch.no_grad():
    # to generate conditioned on a prompt and resolution:
    # prompt = f"<|cap|>\nA cyberpunk city at night\n<|res|>\n64x64\n<|pal|>\n"
    # keep in mind this was trained with a constant longest side length of 64px
    prompt = f"<|cap|>"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    output_ids = model.generate(
        input_ids,
        max_new_tokens=8000,
        temperature=1.0,
        top_k=0,
        do_sample=True,
        # penalty_alpha=0.3 # <- sometimes a penalty alpha between 0.2 and 0.4 helps avoid degenerate cases
    )
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
    img = decode_to_image(generated_text)

# now img is a PIL Image, you can do
img.save("output.png")

# or if in a notebook, you can just run
img
# to view it immediately
Downloads last month
4
Safetensors
Model size
0.4B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support