| | import gradio as gr |
| | import spaces |
| | import time |
| | import os |
| | from PIL import Image, ImageOps, ImageDraw |
| | import numpy as np |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import torch |
| |
|
| | DEFAULT_CANVAS = 64 |
| | DEFAULT_BRUSH = 2 |
| |
|
| | def make_blank_canvas(w: int, h: int) -> Image.Image: |
| | |
| | return Image.new("L", (w, h), 0) |
| |
|
| | def pil_to_rowstring(img: Image.Image) -> str: |
| | arr = np.array(img.convert("L"), dtype=np.uint8) |
| | lines = [",".join(map(str, row.tolist())) + ";" for row in arr] |
| | return "\n".join(lines) |
| |
|
| | def pil_to_binstring(img: Image.Image, thresh: int = 128) -> str: |
| | arr = np.array(img.convert("L"), dtype=np.uint8) |
| | mask = (arr >= int(thresh)).astype(np.uint8) |
| | lines = [",".join(map(str, row.tolist())) + ";" for row in mask] |
| | return "\n".join(lines) |
| |
|
| | |
| | _LLM_CACHE = {} |
| |
|
| | def load_llm(model_id: str): |
| | |
| | from huggingface_hub import login |
| | token = os.environ.get("HF_TOKEN") |
| | if token: |
| | login(token=token) |
| | |
| | if model_id in _LLM_CACHE: |
| | return _LLM_CACHE[model_id] |
| | |
| | |
| | dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| | |
| | |
| | tok = AutoTokenizer.from_pretrained(model_id) |
| | if tok.pad_token is None: |
| | tok.pad_token = tok.eos_token |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | mdl = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | torch_dtype=dtype, |
| | device_map="auto" if torch.cuda.is_available() else None, |
| | trust_remote_code=True |
| | ) |
| | |
| | if not torch.cuda.is_available(): |
| | mdl = mdl.to(device) |
| | |
| | _LLM_CACHE[model_id] = (tok, mdl) |
| | return tok, mdl |
| |
|
| | @spaces.GPU |
| | def run_llm(prompt: str, max_new_tokens: int = 64, temperature: float = 0.0, model_id: str = "meta-llama/Llama-3.2-1B") -> str: |
| | try: |
| | tok, mdl = load_llm(model_id) |
| | |
| | |
| | inputs = tok(prompt, return_tensors="pt") |
| | inputs = {k: v.to(next(mdl.parameters()).device) for k, v in inputs.items()} |
| | |
| | |
| | with torch.inference_mode(): |
| | outputs = mdl.generate( |
| | inputs["input_ids"], |
| | max_new_tokens=int(max_new_tokens), |
| | do_sample=(temperature > 0), |
| | temperature=temperature if temperature > 0 else None, |
| | top_p=None, |
| | pad_token_id=tok.eos_token_id, |
| | eos_token_id=tok.eos_token_id, |
| | use_cache=True, |
| | ) |
| | |
| | |
| | new_tokens = outputs[0][inputs["input_ids"].shape[1]:] |
| | text = tok.decode(new_tokens, skip_special_tokens=True) |
| | return text.strip() |
| | |
| | except Exception as e: |
| | return f"[LLM error: {e}]" |
| |
|
| | def csv_single_line(csv_multiline: str) -> str: |
| | |
| | return (csv_multiline or "").replace("\n", "") |
| |
|
| | def parse_csv_image(s: str, width: int): |
| | |
| | try: |
| | rows = [r for r in s.strip().split(";") if r != ""] |
| | parsed_rows = [] |
| | for r in rows: |
| | nums = [] |
| | for tok in r.split(","): |
| | tok = ''.join(ch for ch in tok if ch.isdigit()) |
| | if tok == "": |
| | continue |
| | v = max(0, min(255, int(tok))) |
| | nums.append(v) |
| | if nums: |
| | |
| | if len(nums) < width: |
| | nums = nums + [0] * (width - len(nums)) |
| | else: |
| | nums = nums[:width] |
| | parsed_rows.append(nums) |
| | if not parsed_rows: |
| | return None |
| | arr = np.array(parsed_rows, dtype=np.uint8) |
| | return Image.fromarray(arr, mode="L") |
| | except Exception: |
| | return None |
| |
|
| | def apply_settings(canvas_px): |
| | w = int(canvas_px) |
| | h = int(canvas_px) |
| | |
| | return gr.ImageEditor( |
| | canvas_size=(w, h), |
| | value=make_blank_canvas(w, h), |
| | image_mode="RGBA", |
| | brush=gr.Brush( |
| | default_size=DEFAULT_BRUSH, |
| | colors=["black", "#404040", "#808080", "#C0C0C0", "white"], |
| | default_color="white", |
| | color_mode="fixed", |
| | ), |
| | eraser=gr.Eraser(default_size=1), |
| | transforms=("crop", "resize"), |
| | height=500, |
| | ) |
| |
|
| | |
| | def process_upload(im, canvas_px, scale, invert, binarize, bin_thresh): |
| | if not im or im.get("background") is None: |
| | return None, None, None |
| | bg = im["background"] |
| | img = Image.fromarray(bg) |
| | |
| | img = img.convert("L") |
| | |
| | w, h = img.size |
| | target_w = int(canvas_px) if canvas_px is not None else w |
| | if target_w <= 0: |
| | target_w = w |
| | target_h = max(1, round(h * target_w / max(1, w))) |
| | resized = img.resize((target_w, target_h), Image.LANCZOS) |
| |
|
| | |
| | canvas_gray = Image.new("L", (target_w, target_w), 0) |
| | canvas_gray.paste(resized, (0, 0)) |
| |
|
| | |
| | editor_value = canvas_gray |
| |
|
| | |
| | |
| | |
| | base_for_text = canvas_gray |
| | if invert: |
| | base_for_text = ImageOps.invert(base_for_text) |
| | if bool(binarize): |
| | text = pil_to_binstring(base_for_text, bin_thresh) |
| | else: |
| | text = pil_to_rowstring(base_for_text) |
| |
|
| | s = max(1, int(scale) if scale is not None else 8) |
| | preview = base_for_text.resize((base_for_text.width * s, base_for_text.height * s), Image.NEAREST) |
| | return editor_value, preview, text |
| |
|
| | def make_preview(im, scale, invert, binarize, bin_thresh): |
| | if im is None or im.get("composite") is None: |
| | return None, "" |
| | arr = im["composite"] |
| | base = Image.fromarray(arr).convert("L") |
| | |
| | base_for_text = ImageOps.invert(base) if invert else base |
| | if bool(binarize): |
| | text = pil_to_binstring(base_for_text, bin_thresh) |
| | else: |
| | text = pil_to_rowstring(base_for_text) |
| |
|
| | |
| | s = max(1, int(scale) if scale is not None else 8) |
| | preview = base_for_text.resize((base_for_text.width * s, base_for_text.height * s), Image.NEAREST) |
| | return preview, text |
| |
|
| | def extrapolate_with_llm(csv_text, canvas_px, out_rows, model_id): |
| | one_line = csv_single_line(csv_text) |
| | |
| | input_rows_count = len([r for r in (one_line or "").split(";") if r.strip()]) |
| | try: |
| | width = int(canvas_px) |
| | except Exception: |
| | width = DEFAULT_CANVAS |
| | max_tokens = int(out_rows) * width * 2 |
| | prompt = one_line |
| | |
| | gen = run_llm(prompt, int(max_tokens), model_id=model_id) |
| | |
| | if gen.startswith("[LLM error:"): |
| | return gen, None |
| |
|
| | |
| | combined = (one_line or "") + (gen or "") |
| | rows = [r for r in combined.split(";") if r.strip()] |
| |
|
| | parsed = [] |
| | max_w = 0 |
| | for r in rows: |
| | vals = [] |
| | for tok in r.split(","): |
| | tok = tok.strip() |
| | if not tok: |
| | continue |
| | try: |
| | v = int(float(tok)) |
| | except Exception: |
| | continue |
| | |
| | if v < 0: v = 0 |
| | if v > 255: v = 255 |
| | vals.append(v) |
| | if vals: |
| | parsed.append(vals) |
| | if len(vals) > max_w: |
| | max_w = len(vals) |
| |
|
| | if not parsed: |
| | return gen, None |
| |
|
| | |
| | arr_rows = [] |
| | for vals in parsed: |
| | if len(vals) < max_w: |
| | vals = vals + [0] * (max_w - len(vals)) |
| | else: |
| | vals = vals[:max_w] |
| | arr_rows.append(vals) |
| |
|
| | arr = np.array(arr_rows, dtype=np.uint8) |
| | |
| | if set(np.unique(arr).tolist()).issubset({0, 1}): |
| | arr = arr * 255 |
| | img = Image.fromarray(arr, mode="L") |
| |
|
| | |
| | target_w = 512 |
| | orig_w, orig_h = img.size |
| | target_h = max(1, round(orig_h * target_w / max(1, orig_w))) |
| | img = img.resize((target_w, target_h), Image.NEAREST) |
| |
|
| | |
| | |
| | if input_rows_count > 0 and orig_h > 0: |
| | y = round(input_rows_count * target_h / orig_h) |
| | y = max(0, min(target_h - 1, y)) |
| | img_rgb = img.convert("RGB") |
| | draw = ImageDraw.Draw(img_rgb) |
| | draw.line([(0, y), (img_rgb.width - 1, y)], fill=(255, 0, 0), width=1) |
| | img = img_rgb |
| |
|
| | display_text = (gen or "").replace(";", ";\n") |
| | return display_text, img |
| |
|
| | |
| | theme = gr.Theme.from_hub('gstaff/xkcd') |
| | theme.set(block_background_fill="#7ffacd8e") |
| |
|
| | with gr.Blocks(theme=theme, title="Image Extrapolation with LLMs") as demo: |
| | gr.Markdown("### Extrapolate images with LLMs") |
| | gr.Markdown("Draw or upload an image, and let an LLM continue the pattern!") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1, min_width=220): |
| | canvas_px = gr.Slider(32, 128, value=DEFAULT_CANVAS, step=1, label="Canvas size (px)") |
| | preview_scale = gr.Slider(1, 16, value=8, step=1, label="Preview scale (×)") |
| | invert_preview = gr.Checkbox(value=False, label="Invert preview") |
| |
|
| | with gr.Accordion("Binarize", open=False): |
| | binarize_csv = gr.Checkbox(value=False, label="Turn 0-255 into 0/1") |
| | bin_thresh = gr.Slider(0, 255, value=128, step=1, label="Threshold") |
| |
|
| | out_rows_default_value = 3 |
| | out_rows = gr.Slider(1, 16, value=out_rows_default_value, step=1, label="Number of output rows") |
| | llm_choice = gr.Dropdown( |
| | label="LLM model", |
| | choices=[ |
| | "meta-llama/Llama-3.2-1B", |
| | "meta-llama/Llama-3.2-3B", |
| | "meta-llama/Llama-3.1-8B", |
| | "HuggingFaceTB/SmolLM2-1.7B", |
| | "HuggingFaceTB/SmolLM3-3B", |
| | "openai/gpt-oss-20b", |
| | "openai/gpt-oss-120b", |
| | ], |
| | value="meta-llama/Llama-3.2-1B", |
| | ) |
| | out_tokens_info = gr.Markdown(f"**Output tokens:** {DEFAULT_CANVAS * out_rows_default_value * 2}") |
| |
|
| | with gr.Column(scale=4): |
| | im = gr.ImageEditor( |
| | type="numpy", |
| | canvas_size=(DEFAULT_CANVAS, DEFAULT_CANVAS), |
| | image_mode="RGBA", |
| | brush=gr.Brush( |
| | default_size=DEFAULT_BRUSH, |
| | colors=["black", "#404040", "#808080", "#C0C0C0", "white"], |
| | default_color="black", |
| | color_mode="fixed", |
| | ), |
| | eraser=gr.Eraser(default_size=1), |
| | transforms=("crop", "resize"), |
| | height=500, |
| | ) |
| | im_preview = gr.Image(height=512, label="Preview (scaled)") |
| | |
| | preview_text = gr.Code( |
| | label="Preview as CSV (rows end with ';')", |
| | lines=12, |
| | interactive=False, |
| | max_lines=12 |
| | ) |
| | |
| | |
| | def update_button_label(model_id): |
| | return f"Extrapolate with LLM ({model_id.split('/')[-1]})" |
| |
|
| | extrap_btn = gr.Button( |
| | value="Extrapolate with LLM (Llama-3.2-1B)", |
| | variant="primary" |
| | ) |
| | |
| | llm_text = gr.Code( |
| | label="LLM output (single-line CSV)", |
| | lines=6, |
| | interactive=False, |
| | ) |
| | llm_image = gr.Image(label="LLM parsed image", height=512) |
| |
|
| | |
| | canvas_px.change(apply_settings, inputs=[canvas_px], outputs=im) |
| | canvas_px.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text]) |
| | |
| | im.upload(process_upload, inputs=[im, canvas_px, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im, im_preview, preview_text]) |
| | im.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text], show_progress="hidden") |
| | preview_scale.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text]) |
| | invert_preview.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text]) |
| | binarize_csv.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text]) |
| | bin_thresh.change(make_preview, inputs=[im, preview_scale, invert_preview, binarize_csv, bin_thresh], outputs=[im_preview, preview_text]) |
| | |
| | extrap_btn.click(extrapolate_with_llm, inputs=[preview_text, canvas_px, out_rows, llm_choice], outputs=[llm_text, llm_image]) |
| |
|
| | |
| | llm_choice.change(update_button_label, inputs=[llm_choice], outputs=[extrap_btn]) |
| |
|
| | def update_tokens(out_rows, canvas_px): |
| | try: |
| | width = int(canvas_px) |
| | except Exception: |
| | width = DEFAULT_CANVAS |
| | tokens = int(out_rows) * width * 2 |
| | return f"**Output tokens:** {tokens}" |
| |
|
| | out_rows.change(update_tokens, inputs=[out_rows, canvas_px], outputs=out_tokens_info) |
| | canvas_px.change(update_tokens, inputs=[out_rows, canvas_px], outputs=out_tokens_info) |
| |
|
| | demo.load(update_tokens, inputs=[out_rows, canvas_px], outputs=out_tokens_info) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |