Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| # Use CPU as requested | |
| device = "cpu" | |
| def load_vlm(model_name): | |
| """Helper to load model and processor.""" | |
| try: | |
| print(f"Loading {model_name}...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| f'microsoft/{model_name}', | |
| trust_remote_code=True | |
| ).to(device).eval() | |
| processor = AutoProcessor.from_pretrained( | |
| f'microsoft/{model_name}', | |
| trust_remote_code=True | |
| ) | |
| return model, processor | |
| except Exception as e: | |
| print(f"Error loading {model_name}: {e}") | |
| return None, None | |
| # Load both models | |
| model_base, proc_base = load_vlm('Florence-2-base') | |
| model_large, proc_large = load_vlm('Florence-2-large') | |
| def describe_image(uploaded_image, model_choice): | |
| if uploaded_image is None: | |
| return "Please upload an image." | |
| # Select model based on UI choice | |
| if model_choice == "Florence-2-base": | |
| model, processor = model_base, proc_base | |
| else: | |
| model, processor = model_large, proc_large | |
| if model is None: | |
| return f"{model_choice} failed to load." | |
| if not isinstance(uploaded_image, Image.Image): | |
| uploaded_image = Image.fromarray(uploaded_image) | |
| # Core generation logic | |
| inputs = processor(text="<MORE_DETAILED_CAPTION>", images=uploaded_image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| num_beams=3, | |
| do_sample=False, | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| result = processor.post_process_generation( | |
| generated_text, | |
| task="<MORE_DETAILED_CAPTION>", | |
| image_size=(uploaded_image.width, uploaded_image.height) | |
| ) | |
| return result["<MORE_DETAILED_CAPTION>"] | |
| # Simplified Gradio Layout | |
| css = ".submit-btn { background-color: #4682B4 !important; color: white !important; }" | |
| with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: | |
| gr.Markdown("# **Florence-2 Models Image Captions**") | |
| gr.Markdown("> Select the model to use. **Base** is faster; **Large** is more accurate.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Upload Image", type="pil") | |
| model_choice = gr.Radio( | |
| choices=["Florence-2-base", "Florence-2-large"], | |
| label="Model Choice", | |
| value="Florence-2-base" | |
| ) | |
| generate_btn = gr.Button("Generate Caption", elem_classes="submit-btn") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Generated Caption", lines=6, interactive=True) | |
| generate_btn.click( | |
| fn=describe_image, | |
| inputs=[image_input, model_choice], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |