import os import torch from transformers import AutoProcessor, PaliGemmaForConditionalGeneration from PIL import Image import gradio as gr # ----------------------------------------------------------------------------- # Load HF token from environment # ----------------------------------------------------------------------------- HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") if not HF_TOKEN: raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set") # ----------------------------------------------------------------------------- # 1) GPU inference function # ----------------------------------------------------------------------------- def run_inference_on_gpu( model_id: str, image: Image.Image, prompt: str = "caption", max_new_tokens: int = 100 ) -> str: # ensure CUDA is available assert torch.cuda.is_available(), "CUDA not available—check your PyTorch installation!" device = torch.device("cuda") dtype = torch.float16 # load tokenizer + model onto GPU with explicit token processor = AutoProcessor.from_pretrained(model_id, use_auth_token=HF_TOKEN) model = PaliGemmaForConditionalGeneration.from_pretrained( model_id, torch_dtype=dtype, device_map=None, use_auth_token=HF_TOKEN ).to(device).eval() # build multimodal prompt image_tokens = "" multimodal_prompt = f"{image_tokens} {prompt}" # prepare inputs inputs = processor( text=multimodal_prompt, images=[image], padding="longest", return_tensors="pt", do_convert_rgb=True, ) inputs = {k: v.to(device) for k, v in inputs.items()} # generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, num_beams=3, do_sample=False, ) # decode return processor.decode(outputs[0].cpu(), skip_special_tokens=True) # ----------------------------------------------------------------------------- # 2) Gradio UI # ----------------------------------------------------------------------------- MODEL_ID = "mychen76/paligemma-3b-mix-448-med_30k-ct-brain" def caption_fn(image, prompt, max_tokens): """ Gradio callback: takes a PIL image, a text prompt, and max tokens → returns the generated caption. """ return run_inference_on_gpu( model_id=MODEL_ID, image=image, prompt=prompt, max_new_tokens=max_tokens, ) demo = gr.Interface( fn=caption_fn, inputs=[ gr.Image(type="pil", label="Upload CT Scan"), gr.Textbox( value="What do you see in this CT scan?", label="Prompt" ), gr.Slider( minimum=10, maximum=300, step=10, value=100, label="Max New Tokens" ), ], outputs=gr.Textbox(label="Model Caption"), title="PaliGemma CT-Scan Captioning", description=( "Upload a brain CT scan (or any image), write a short prompt, " "and let the PaliGemma model describe what it sees." ), allow_flagging="never", ) if __name__ == "__main__": demo.launch(share=False, show_api=False)