VisionLLM / app.py
DivyanshHF's picture
Update app.py
f3b369a verified
# app.py
import os
import sys
import types
import importlib.machinery
from PIL import Image
import gradio as gr
# ===============================
# Helper to create package-like dummy modules
# ===============================
def _mk_pkg(name: str):
m = types.ModuleType(name)
spec = importlib.machinery.ModuleSpec(name, loader=None, is_package=True)
spec.submodule_search_locations = []
m.__spec__ = spec
m.__path__ = []
return m
# ===============================
# Disable GPU-only/optional paths
# ===============================
os.environ.setdefault("FLASH_ATTENTION", "0")
os.environ.setdefault("XFORMERS_DISABLED", "1")
os.environ.setdefault("ACCELERATE_USE_DEVICE_MAP", "0")
os.environ.setdefault("DISABLE_TRITON", "1") # avoid triton kernels
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # uncomment to force CPU
# ===============================
# flash_attn stubs (package + submodules)
# ===============================
flash_attn_pkg = _mk_pkg("flash_attn")
flash_attn_interface = types.ModuleType("flash_attn.flash_attn_interface")
flash_attn_interface.__spec__ = importlib.machinery.ModuleSpec(
"flash_attn.flash_attn_interface", loader=None
)
flash_attn_bert_padding = types.ModuleType("flash_attn.bert_padding")
flash_attn_bert_padding.__spec__ = importlib.machinery.ModuleSpec(
"flash_attn.bert_padding", loader=None
)
def _dummy_func(*args, **kwargs):
raise RuntimeError("flash_attn is not available in this environment.")
flash_attn_interface.flash_attn_unpadded_qkvpacked_func = _dummy_func
flash_attn_interface.flash_attn_varlen_qkvpacked_func = _dummy_func
flash_attn_bert_padding.pad_input = _dummy_func
flash_attn_bert_padding.unpad_input = _dummy_func
sys.modules["flash_attn"] = flash_attn_pkg
sys.modules["flash_attn.flash_attn_interface"] = flash_attn_interface
sys.modules["flash_attn.bert_padding"] = flash_attn_bert_padding
# ===============================
# ps3 stub (optional vision tower)
# ===============================
ps3_pkg = _mk_pkg("ps3")
class _PS3Config: pass
class _PS3VisionConfig: pass
class _PS3ImageProcessor: pass
class _PS3VisionModel: pass
ps3_pkg.PS3Config = _PS3Config
ps3_pkg.PS3VisionConfig = _PS3VisionConfig
ps3_pkg.PS3ImageProcessor = _PS3ImageProcessor
ps3_pkg.PS3VisionModel = _PS3VisionModel
sys.modules["ps3"] = ps3_pkg
# ===============================
# Quantization stubs to avoid Triton/Torch custom kernels
# VILA sometimes imports:
# - from .FloatPointQuantizeTriton import *
# - from FloatPointQuantizeTriton import *
# - from FloatPointQuantizeTorch import *
# Provide both names (absolute and package-qualified) with no-op funcs.
# ===============================
def _mk_fpq_module(mod_name: str):
mod = types.ModuleType(mod_name)
# Provide the APIs qfunction expects
def _id(x, *a, **k): return x
mod.block_cut = _id
mod.block_quant = _id
mod.block_reshape = _id
return mod
# Absolute names
sys.modules["FloatPointQuantizeTorch"] = _mk_fpq_module("FloatPointQuantizeTorch")
sys.modules["FloatPointQuantizeTriton"] = _mk_fpq_module("FloatPointQuantizeTriton")
# Package-qualified under llava.model
sys.modules["llava.model.FloatPointQuantizeTorch"] = sys.modules["FloatPointQuantizeTorch"]
sys.modules["llava.model.FloatPointQuantizeTriton"] = sys.modules["FloatPointQuantizeTriton"]
# ===============================
# Load VILA
# ===============================
from llava.model.builder import load_pretrained_model
from llava.constants import DEFAULT_IMAGE_TOKEN
MODEL_PATH = "Efficient-Large-Model/VILA1.5-3b"
try:
tokenizer, model, image_processor, context_len = load_pretrained_model(
MODEL_PATH, model_name="", model_base=None
)
except Exception as e:
ERR = f"Failed to load model '{MODEL_PATH}': {e}"
def _boot_error_ui():
with gr.Blocks(title="VILA 1.5 3B – Error") as demo:
gr.Markdown("### ❌ Model failed to load")
gr.Markdown(ERR)
demo.launch()
_boot_error_ui()
raise
# Fallback chat template if missing
if getattr(tokenizer, "chat_template", None) is None:
tokenizer.chat_template = (
"{% for message in messages %}{{ message['role'] | upper }}: "
"{{ message['content'] }}\n{% endfor %}ASSISTANT:"
)
# ===============================
# Inference
# ===============================
from PIL import Image as _PILImage
def vila_infer(image, prompt):
if image is None:
return "Please upload an image."
if not prompt or not str(prompt).strip():
prompt = "Please describe the image."
pil = _PILImage.fromarray(image).convert("RGB")
try:
out = model.generate_content(
prompt=[{
"from": "human",
"value": [
{"type": "image", "value": pil},
{"type": "text", "value": prompt}
]
}],
generation_config=None # default decoding
)
return str(out).strip()
except Exception as e:
return f"❌ Inference error: {e}"
# ===============================
# UI
# ===============================
with gr.Blocks(title="VILA 1.5 3B (HF Space)") as demo:
gr.Markdown("## πŸ–ΌοΈ VILA-1.5-3B β€” Image Description Demo")
gr.Markdown("Upload an image and press **Run**.")
with gr.Row():
img = gr.Image(type="numpy", label="Image", height=320)
prompt = gr.Textbox(label="Prompt", value="Please describe the image", lines=2)
run_btn = gr.Button("Run")
out = gr.Textbox(label="Output", lines=10)
run_btn.click(vila_infer, [img, prompt], out)
demo.launch()