Molmo-4bit / app.py
zamal's picture
Update app.py
4f1e215 verified
raw
history blame
2.63 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import requests
from io import BytesIO
import spaces # Import spaces for ZeroGPU support
# Load the model and processor
repo_name = "cyan2k/molmo-7B-O-bnb-4bit"
arguments = {
"device_map": "auto", # Device will be set automatically
"torch_dtype": "auto", # Use appropriate precision
"trust_remote_code": True # Allow loading remote code
}
# Load the processor (this part doesn't need GPU yet)
processor = AutoProcessor.from_pretrained(repo_name, **arguments)
# Define the function for image description
@spaces.GPU # This ensures the function gets GPU access when needed
def describe_image(image, question):
# Load the model inside the function and move it to GPU
model = AutoModelForCausalLM.from_pretrained(repo_name, **arguments).to('cuda')
# Process the uploaded image along with the user's question
inputs = processor.process(
images=[image],
text=question if question else "Describe this image in great detail without missing any piece of information"
)
# Move inputs to model device (GPU)
inputs = {k: v.to('cuda').unsqueeze(0) for k, v in inputs.items()}
# Generate output using the model on GPU
output = model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=1024, stop_strings="<|endoftext|>"),
tokenizer=processor.tokenizer,
)
# Decode the generated tokens
generated_tokens = output[0, inputs["input_ids"].size(1):]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
# Gradio interface
def gradio_app():
with gr.Blocks() as demo:
gr.Markdown("# Image Long Description with Molmo-7B 4 bit quantized\n### Upload an image and ask a question about it!")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload an Image")
question_input = gr.Textbox(placeholder="Ask a question about the image (e.g., 'What is happening in this image?')", label="Question (Optional)")
output_text = gr.Textbox(label="Image Description", interactive=False)
# Submit button to generate the description
submit_btn = gr.Button("Generate Description")
# Callback to run when submit button is clicked
submit_btn.click(
fn=describe_image,
inputs=[image_input, question_input],
outputs=output_text
)
# Launch the Gradio interface
demo.launch()
# Launch the Gradio app
gradio_app()