fastVLM / app.py
james-ham's picture
create app.py
08aebdd verified
# FastVLM Screenshot Explainer (CPU-only, no uploads)
# Space idea: curated gallery β†’ caption / extract numbers / VQA
# Model: apple/FastVLM-0.5B (Research-only license)
# ─────────────────────────────────────────────────────────────────────────────
import time
import io
import requests
from PIL import Image
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "apple/FastVLM-0.5B"
IMAGE_TOKEN_INDEX = -200 # per model card
DEVICE = "cpu"
# A tiny curated gallery (HF/COCO-hosted images)
SAMPLES = {
# general photo (COCO)
"Dog-in-street (COCO)": "http://images.cocodataset.org/val2017/000000039769.jpg",
# charts (ChartMuseum dataset)
"Chart β€” Blind wine tasting": "https://huggingface.co/datasets/lytang/ChartMuseum/resolve/main/images/wine_blind_taste.png",
"Chart β€” Life expectancy (Africa vs Asia)": "https://huggingface.co/datasets/lytang/ChartMuseum/resolve/main/images/life-expectancy-africa-vs-asia.png",
# document-like page (HF internal testing)
"Document page β€” example": "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/1.jpg",
}
TASK_PROMPTS = {
"Explain": "Describe this image in detail.",
"Extract numbers": (
"Extract every number you can see with its label/context. "
"Return a concise YAML list with fields: value, what_it_refers_to."
),
"Write alt-text": (
"Write high-quality alt-text (<=200 chars) that would help a blind user understand "
"the key content and purpose of this image."
),
"Ask a question…": None, # free-form
}
# ── Model load (CPU) ─────────────────────────────────────────────────────────
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32, # CPU
device_map={"": DEVICE},
trust_remote_code=True,
)
# ── Helpers ──────────────────────────────────────────────────────────────────
def _fetch_image(url: str) -> Image.Image:
r = requests.get(url, timeout=20)
r.raise_for_status()
return Image.open(io.BytesIO(r.content)).convert("RGB")
def _build_inputs(prompt: str):
# Build chat with <image> placeholder exactly once (per model card)
messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
pre, post = rendered.split("<image>", 1)
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device)
return input_ids, attention_mask
def _prepare_pixels(pil_image: Image.Image):
# Use the model's own processor from the vision tower
px = model.get_vision_tower().image_processor(images=pil_image, return_tensors="pt")["pixel_values"]
return px.to(model.device, dtype=model.dtype)
@torch.inference_mode()
def run_inference(choice: str, task: str, user_q: str, max_new_tokens: int, temperature: float):
try:
img = _fetch_image(SAMPLES[choice])
except Exception as e:
return None, f"Could not load image: {e}", ""
# Decide prompt
if task == "Ask a question…":
prompt = user_q.strip() or "Answer questions about this image."
else:
prompt = TASK_PROMPTS[task]
# Build model inputs
input_ids, attention_mask = _build_inputs(prompt)
px = _prepare_pixels(img)
# Generate
t0 = time.perf_counter()
out = model.generate(
inputs=input_ids,
attention_mask=attention_mask,
images=px,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
)
t1 = time.perf_counter()
text = tok.decode(out[0], skip_special_tokens=True)
# Rough throughput metric
gen_len = (out.shape[-1] - input_ids.shape[-1])
elapsed = t1 - t0
meta = f"⏱️ {elapsed:.2f}s β€’ new tokens: {gen_len} β€’ ~{(gen_len/elapsed if elapsed>0 else 0):.1f} tok/s β€’ device: {DEVICE.upper()}"
return img, text.strip(), meta
# ── Gradio UI ────────────────────────────────────────────────────────────────
with gr.Blocks(title="FastVLM Screenshot Explainer (CPU)") as demo:
gr.Markdown(
"""
# ⚑ FastVLM Screenshot Explainer β€” CPU-only (no uploads)
Click an example image, pick a task, and go.
Model: **apple/FastVLM-0.5B** (research license).
"""
)
with gr.Row():
choice = gr.Dropdown(
label="Choose example image",
choices=list(SAMPLES.keys()),
value=list(SAMPLES.keys())[0],
)
task = gr.Radio(
label="Task",
choices=list(TASK_PROMPTS.keys()),
value="Explain",
info="β€˜Ask a question…’ enables free-form VQA.",
)
user_q = gr.Textbox(label="If asking a question, type it here", placeholder="e.g., What is the trend from 1950 to 2000?")
with gr.Accordion("Generation settings", open=False):
max_new = gr.Slider(32, 256, 128, step=8, label="max_new_tokens")
temp = gr.Slider(0.0, 1.0, 0.2, step=0.05, label="temperature")
go = gr.Button("Explain / Answer", variant="primary")
with gr.Row():
img_out = gr.Image(label="Image", interactive=False)
txt_out = gr.Textbox(label="Model output", lines=14)
meta = gr.Markdown()
go.click(run_inference, [choice, task, user_q, max_new, temp], [img_out, txt_out, meta])
gr.Markdown(
"""
**Notes**
- Runs on CPU by default (float32). For GPUs, restart Space with CUDA and it will auto-use float16.
- Model + usage based on the official model card’s `trust_remote_code` API and <image> token handling.
- **License:** Apple AML Research License β€” *research & non-commercial use only*.
"""
)
if __name__ == "__main__":
demo.launch()