cheapsake / app.py
airzy1's picture
Update app.py
739eb54 verified
import os
import json
import re
from typing import Any, Dict, Tuple
import torch
import gradio as gr
import spaces
from PIL import Image, ImageOps
# Qwen3-VL requires the latest Transformers from source.
# In your Space requirements, use:
# pip install git+https://github.com/huggingface/transformers
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
# ---------------------------
# Environment / cache setup
# ---------------------------
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:64"
# Writable cache for Spaces
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["HF_HUB_CACHE"] = "/tmp/hf/hub"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
os.makedirs("/tmp/hf/hub", exist_ok=True)
os.makedirs("/tmp/hf/transformers", exist_ok=True)
torch.set_float32_matmul_precision("high")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# Qwen3-VL upgrade path
MODEL_ID = "Qwen/Qwen3-VL-8B-Instruct"
processor = None
model = None
def load_model() -> None:
global processor, model
if model is not None and processor is not None:
return
print("Loading processor...")
processor = AutoProcessor.from_pretrained(
MODEL_ID,
token=HF_TOKEN if HF_TOKEN else None,
)
print("Loading model...")
model = Qwen3VLForConditionalGeneration.from_pretrained(
MODEL_ID,
token=HF_TOKEN if HF_TOKEN else None,
device_map="auto",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
print("Setting eval mode...")
model.eval()
print("Model ready")
def normalize_image(image: Image.Image) -> Image.Image:
return ImageOps.exif_transpose(image).convert("RGB")
def extract_json(text: str) -> Dict[str, Any]:
text = (text or "").strip()
# Strip common markdown fences.
text = re.sub(r"^\s*```(?:json)?\s*", "", text, flags=re.I)
text = re.sub(r"\s*```\s*$", "", text, flags=re.I)
try:
return json.loads(text)
except Exception:
pass
# Try to find the first JSON object in the text.
match = re.search(r"\{.*\}", text, flags=re.S)
if match:
try:
return json.loads(match.group(0))
except Exception:
pass
return {"raw_output": text}
PROMPT = """
Return only valid JSON.
List each pantry items once.
Use this format:
{["item1", "item2"]}
"""
@spaces.GPU(size="large", duration=60)
def analyze_pantry(image: Image.Image) -> Tuple[Image.Image, Dict[str, Any]]:
if image is None:
return None, {"error": "Upload an image first."}
load_model()
prepared = normalize_image(image)
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "Analyze this pantry image in detail, list all items"}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": prepared},
{"type": "text", "text": PROMPT},
],
},
]
# Qwen3-VL official Transformers usage.
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
print("inputs:", inputs)
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=False,
repetition_penalty=1.1,
no_repeat_ngram_size=3
)
prompt_len = inputs["input_ids"].shape[-1]
generated_text = processor.batch_decode(
[output_ids[0][prompt_len:]],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0].strip()
print("generated_text:", generated_text)
parsed = extract_json(generated_text)
if isinstance(parsed, dict) and "raw_output" not in parsed:
parsed["_raw_output"] = generated_text
return prepared, parsed
with gr.Blocks() as demo:
gr.Markdown("# Pantry Scanner")
with gr.Row():
image_input = gr.Image(type="pil", label="Pantry image")
with gr.Row():
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Row():
prepared_output = gr.Image(type="pil", label="Feeding image")
output_json = gr.JSON(label="Detected items")
analyze_btn.click(
analyze_pantry,
inputs=image_input,
outputs=[prepared_output, output_json],
)
demo.queue(max_size=8)
demo.launch()