Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		yuhangzang
		
	commited on
		
		
					Commit 
							
							·
						
						67b36a4
	
1
								Parent(s):
							
							12e3e78
								
update
Browse files- .gitattributes +4 -0
- README.md +23 -0
- app.py +216 -0
- examples/example_0.png +3 -0
- requirements.txt +6 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            *.png filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            *.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            *.jpeg filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            *.webp filter=lfs diff=lfs merge=lfs -text
         | 
    	
        README.md
    CHANGED
    
    | @@ -12,3 +12,26 @@ short_description: ' A unified framework for reasoning and reward modeling' | |
| 12 | 
             
            ---
         | 
| 13 |  | 
| 14 | 
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 12 | 
             
            ---
         | 
| 13 |  | 
| 14 | 
             
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            ## 使用说明(ZeroGPU)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            - Space 类型选择 `Gradio`,硬件选择 `ZeroGPU`(需要 PRO 或企业组织)。
         | 
| 19 | 
            +
            - 本仓库包含一个最小可用的 Spark-VL 演示:上传图片 + 输入文本,返回模型生成结果。
         | 
| 20 | 
            +
            - 关键代码在 `app.py`:
         | 
| 21 | 
            +
              - 使用 `spaces.GPU` 装饰推理函数,调用时申请 GPU,用完后释放。
         | 
| 22 | 
            +
              - 首次调用按需加载 `internlm/Spark-VL-7B`,优先尝试 `flash_attention_2`,失败则回退到 `eager`。
         | 
| 23 | 
            +
              - 推理结束把模型移回 CPU,快速释放 ZeroGPU 显存。
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            ### 本地/Space 运行
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            1) 推送到 Hugging Face Space 后,在 Space 设置中选择硬件 `ZeroGPU`。
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            2) 运行入口:`app.py`,界面包含:图片、提示词、采样参数(max_new_tokens/temperature/top_p/top_k)。
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            3) 可选环境变量:
         | 
| 32 | 
            +
               - `SPARK_MODEL_ID`:默认 `internlm/Spark-VL-7B`。
         | 
| 33 | 
            +
               - `ATTN_IMPL`:默认 `flash_attention_2`,可改为 `eager`。
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            ### 依赖
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            见 `requirements.txt`(Gradio 5.x,Transformers 4.45+,qwen-vl-utils 等)。ZeroGPU 的基础镜像已包含合适的 PyTorch 版本。
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,216 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            from typing import List
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import spaces
         | 
| 7 | 
            +
            import gradio as gr
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            from PIL import Image
         | 
| 10 | 
            +
            from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            MODEL_ID = os.environ.get("SPARK_MODEL_ID", "internlm/Spark-VL-7B")
         | 
| 13 | 
            +
            DTYPE = torch.bfloat16
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            _model = None
         | 
| 16 | 
            +
            _processor = None
         | 
| 17 | 
            +
            _attn_impl = None
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def _load_model_and_processor():
         | 
| 21 | 
            +
                global _model, _processor, _attn_impl
         | 
| 22 | 
            +
                if _model is not None and _processor is not None:
         | 
| 23 | 
            +
                    return _model, _processor
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # Prefer flash-attn if available, otherwise fall back to eager.
         | 
| 26 | 
            +
                attn_impl = os.environ.get("ATTN_IMPL", "flash_attention_2")
         | 
| 27 | 
            +
                try:
         | 
| 28 | 
            +
                    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
         | 
| 29 | 
            +
                        MODEL_ID,
         | 
| 30 | 
            +
                        torch_dtype=DTYPE,
         | 
| 31 | 
            +
                        attn_implementation=attn_impl,
         | 
| 32 | 
            +
                        device_map="auto",
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
                    _attn_impl = attn_impl
         | 
| 35 | 
            +
                except Exception:
         | 
| 36 | 
            +
                    # Fallback for environments without flash-attn
         | 
| 37 | 
            +
                    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
         | 
| 38 | 
            +
                        MODEL_ID,
         | 
| 39 | 
            +
                        torch_dtype=DTYPE,
         | 
| 40 | 
            +
                        attn_implementation="eager",
         | 
| 41 | 
            +
                        device_map="auto",
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
                    _attn_impl = "eager"
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                processor = AutoProcessor.from_pretrained(MODEL_ID)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                _model = model
         | 
| 48 | 
            +
                _processor = processor
         | 
| 49 | 
            +
                return _model, _processor
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def _prepare_inputs(image, prompt):
         | 
| 53 | 
            +
                messages = [
         | 
| 54 | 
            +
                    {
         | 
| 55 | 
            +
                        "role": "user",
         | 
| 56 | 
            +
                        "content": [
         | 
| 57 | 
            +
                            {"type": "image", "image": image},
         | 
| 58 | 
            +
                            {"type": "text", "text": prompt},
         | 
| 59 | 
            +
                        ],
         | 
| 60 | 
            +
                    }
         | 
| 61 | 
            +
                ]
         | 
| 62 | 
            +
                chat_text = _processor.apply_chat_template(
         | 
| 63 | 
            +
                    messages, tokenize=False, add_generation_prompt=True
         | 
| 64 | 
            +
                )
         | 
| 65 | 
            +
                inputs = _processor(
         | 
| 66 | 
            +
                    text=[chat_text],
         | 
| 67 | 
            +
                    # Pass the single image directly; template contains <image> placeholder
         | 
| 68 | 
            +
                    images=[image] if image is not None else None,
         | 
| 69 | 
            +
                    return_tensors="pt",
         | 
| 70 | 
            +
                )
         | 
| 71 | 
            +
                return inputs
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def _decode(generated_ids, input_ids):
         | 
| 75 | 
            +
                # Trim the prompt part before decoding
         | 
| 76 | 
            +
                trimmed = generated_ids[:, input_ids.shape[1] :]
         | 
| 77 | 
            +
                out = _processor.batch_decode(
         | 
| 78 | 
            +
                    trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
         | 
| 79 | 
            +
                )
         | 
| 80 | 
            +
                return out[0] if out else ""
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            @spaces.GPU(duration=120)
         | 
| 84 | 
            +
            def generate(image, prompt, max_new_tokens, temperature, top_p, top_k):
         | 
| 85 | 
            +
                if image is None:
         | 
| 86 | 
            +
                    return "Please upload an image."
         | 
| 87 | 
            +
                prompt = (prompt or "").strip()
         | 
| 88 | 
            +
                if not prompt:
         | 
| 89 | 
            +
                    return "Please enter a prompt."
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                start = time.time()
         | 
| 92 | 
            +
                model, _ = _load_model_and_processor()
         | 
| 93 | 
            +
                try:
         | 
| 94 | 
            +
                    # Ensure model resides on GPU during the call
         | 
| 95 | 
            +
                    p = next(model.parameters())
         | 
| 96 | 
            +
                    if p.device.type != "cuda":
         | 
| 97 | 
            +
                        model.to("cuda")
         | 
| 98 | 
            +
                except StopIteration:
         | 
| 99 | 
            +
                    pass
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                try:
         | 
| 102 | 
            +
                    inputs = _prepare_inputs(image, prompt)
         | 
| 103 | 
            +
                    dev = next(model.parameters()).device
         | 
| 104 | 
            +
                    inputs = {k: v.to(dev) if hasattr(v, "to") else v for k, v in inputs.items()}
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    gen_kwargs = {
         | 
| 107 | 
            +
                        "max_new_tokens": int(max_new_tokens),
         | 
| 108 | 
            +
                        "do_sample": True,
         | 
| 109 | 
            +
                        "temperature": float(temperature),
         | 
| 110 | 
            +
                        "top_p": float(top_p),
         | 
| 111 | 
            +
                        "top_k": int(top_k),
         | 
| 112 | 
            +
                        "use_cache": True,
         | 
| 113 | 
            +
                    }
         | 
| 114 | 
            +
                    with torch.inference_mode():
         | 
| 115 | 
            +
                        out_ids = model.generate(**inputs, **gen_kwargs)
         | 
| 116 | 
            +
                    text = _decode(out_ids, inputs["input_ids"])
         | 
| 117 | 
            +
                    took = time.time() - start
         | 
| 118 | 
            +
                    return f"{text}\n\n[attn={_attn_impl}, time={took:.1f}s]"
         | 
| 119 | 
            +
                except Exception as e:
         | 
| 120 | 
            +
                    return f"Inference failed: {type(e).__name__}: {e}"
         | 
| 121 | 
            +
                finally:
         | 
| 122 | 
            +
                    # Release GPU quickly on ZeroGPU by moving weights off CUDA.
         | 
| 123 | 
            +
                    try:
         | 
| 124 | 
            +
                        if hasattr(model, "to"):
         | 
| 125 | 
            +
                            model.to("cpu")
         | 
| 126 | 
            +
                        torch.cuda.empty_cache()
         | 
| 127 | 
            +
                    except Exception:
         | 
| 128 | 
            +
                        pass
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            def build_ui():
         | 
| 132 | 
            +
                with gr.Blocks() as demo:
         | 
| 133 | 
            +
                    gr.Markdown("# Spark-VL ZeroGPU Demo\nUpload an image or choose from the example gallery, then enter a prompt.")
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # Build an image gallery from ./examples
         | 
| 136 | 
            +
                    def _gather_examples() -> List[str]:
         | 
| 137 | 
            +
                        exts = ("*.jpg", "*.jpeg", "*.png", "*.webp")
         | 
| 138 | 
            +
                        imgs: List[str] = []
         | 
| 139 | 
            +
                        for ptn in exts:
         | 
| 140 | 
            +
                            imgs.extend(sorted(glob.glob(os.path.join("examples", ptn))))
         | 
| 141 | 
            +
                        # Deduplicate while keeping order
         | 
| 142 | 
            +
                        seen = set()
         | 
| 143 | 
            +
                        uniq = []
         | 
| 144 | 
            +
                        for p in imgs:
         | 
| 145 | 
            +
                            if p not in seen:
         | 
| 146 | 
            +
                                uniq.append(p)
         | 
| 147 | 
            +
                                seen.add(p)
         | 
| 148 | 
            +
                        return uniq
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    example_images = _gather_examples()
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    default_candidates = [
         | 
| 153 | 
            +
                        os.path.join("examples", "example_0.png"),
         | 
| 154 | 
            +
                    ]
         | 
| 155 | 
            +
                    default_image_path = next((p for p in default_candidates if os.path.exists(p)), None)
         | 
| 156 | 
            +
                    default_image = Image.open(default_image_path) if default_image_path else None
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    with gr.Row():
         | 
| 159 | 
            +
                        with gr.Column(scale=1):
         | 
| 160 | 
            +
                            image = gr.Image(type="pil", label="Image", value=default_image)
         | 
| 161 | 
            +
                            gallery = gr.Gallery(
         | 
| 162 | 
            +
                                value=example_images,
         | 
| 163 | 
            +
                                label="Example Gallery",
         | 
| 164 | 
            +
                                show_label=True,
         | 
| 165 | 
            +
                                columns=4,
         | 
| 166 | 
            +
                                height=240,
         | 
| 167 | 
            +
                                allow_preview=True,
         | 
| 168 | 
            +
                            )
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                            # When a thumbnail is clicked, load it into the image input
         | 
| 171 | 
            +
                            def _on_gallery_select(evt):
         | 
| 172 | 
            +
                                try:
         | 
| 173 | 
            +
                                    idx = int(evt.index)
         | 
| 174 | 
            +
                                except Exception:
         | 
| 175 | 
            +
                                    return None
         | 
| 176 | 
            +
                                if idx is None or idx < 0 or idx >= len(example_images):
         | 
| 177 | 
            +
                                    return None
         | 
| 178 | 
            +
                                # Return PIL image so upstream expects a PIL image
         | 
| 179 | 
            +
                                try:
         | 
| 180 | 
            +
                                    return Image.open(example_images[idx])
         | 
| 181 | 
            +
                                except Exception:
         | 
| 182 | 
            +
                                    return example_images[idx]
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                            gallery.select(fn=_on_gallery_select, inputs=None, outputs=image)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                        with gr.Column(scale=1):
         | 
| 187 | 
            +
                            prompt = gr.Textbox(
         | 
| 188 | 
            +
                                label="Prompt",
         | 
| 189 | 
            +
                                value=(
         | 
| 190 | 
            +
                                    "As seen in the diagram, three darts are thrown at nine fixed balloons. "
         | 
| 191 | 
            +
                                    "If a balloon is hit it will burst and the dart continues in the same direction "
         | 
| 192 | 
            +
                                    "it had beforehand. How many balloons will not be hit by a dart?"
         | 
| 193 | 
            +
                                ),
         | 
| 194 | 
            +
                                lines=4,
         | 
| 195 | 
            +
                            )
         | 
| 196 | 
            +
                            max_new_tokens = gr.Slider(16, 512, value=128, step=8, label="max_new_tokens")
         | 
| 197 | 
            +
                            temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
         | 
| 198 | 
            +
                            top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="top_p")
         | 
| 199 | 
            +
                            top_k = gr.Slider(1, 200, value=50, step=1, label="top_k")
         | 
| 200 | 
            +
                            run = gr.Button("Generate")
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    output = gr.Textbox(label="Model Output", lines=8)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    run.click(
         | 
| 205 | 
            +
                        fn=generate,
         | 
| 206 | 
            +
                        inputs=[image, prompt, max_new_tokens, temperature, top_p, top_k],
         | 
| 207 | 
            +
                        outputs=output,
         | 
| 208 | 
            +
                        show_progress=True,
         | 
| 209 | 
            +
                    )
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    demo.queue(concurrency_count=1, max_size=10).launch()
         | 
| 212 | 
            +
                return demo
         | 
| 213 | 
            +
             | 
| 214 | 
            +
             | 
| 215 | 
            +
            if __name__ == "__main__":
         | 
| 216 | 
            +
                build_ui()
         | 
    	
        examples/example_0.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            transformers>=4.45.0
         | 
| 2 | 
            +
            accelerate>=0.33.0
         | 
| 3 | 
            +
            qwen-vl-utils>=0.0.8
         | 
| 4 | 
            +
            gradio>=5.49.1
         | 
| 5 | 
            +
            spaces>=0.24.0
         | 
| 6 | 
            +
            pillow
         | 
