File size: 2,631 Bytes
7ac4196
fbbadab
7ac4196
093f79d
 
e80f948
ab9eef9
fbbadab
 
ab9eef9
e80f948
 
 
ab9eef9
fbbadab
e80f948
fbbadab
 
e80f948
 
4f1e215
e80f948
 
 
4f1e215
fbbadab
 
4f1e215
fbbadab
ab9eef9
e80f948
 
ab9eef9
e80f948
fbbadab
 
ab9eef9
fbbadab
 
ab9eef9
fbbadab
 
 
ab9eef9
00b84bb
 
e80f948
fbbadab
4f1e215
 
fbbadab
4f1e215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab9eef9
4f1e215
 
fbbadab
 
e80f948
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import requests
from io import BytesIO
import spaces  # Import spaces for ZeroGPU support

# Load the model and processor
repo_name = "cyan2k/molmo-7B-O-bnb-4bit"
arguments = {
    "device_map": "auto",   # Device will be set automatically
    "torch_dtype": "auto",  # Use appropriate precision
    "trust_remote_code": True  # Allow loading remote code
}

# Load the processor (this part doesn't need GPU yet)
processor = AutoProcessor.from_pretrained(repo_name, **arguments)

# Define the function for image description
@spaces.GPU  # This ensures the function gets GPU access when needed
def describe_image(image, question):
    # Load the model inside the function and move it to GPU
    model = AutoModelForCausalLM.from_pretrained(repo_name, **arguments).to('cuda')
    
    # Process the uploaded image along with the user's question
    inputs = processor.process(
        images=[image],
        text=question if question else "Describe this image in great detail without missing any piece of information"
    )

    # Move inputs to model device (GPU)
    inputs = {k: v.to('cuda').unsqueeze(0) for k, v in inputs.items()}

    # Generate output using the model on GPU
    output = model.generate_from_batch(
        inputs,
        GenerationConfig(max_new_tokens=1024, stop_strings="<|endoftext|>"),
        tokenizer=processor.tokenizer,
    )

    # Decode the generated tokens
    generated_tokens = output[0, inputs["input_ids"].size(1):]
    generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

    return generated_text

# Gradio interface
def gradio_app():
    with gr.Blocks() as demo:
        gr.Markdown("# Image Long Description with Molmo-7B 4 bit quantized\n### Upload an image and ask a question about it!")

        with gr.Row():
            image_input = gr.Image(type="pil", label="Upload an Image")
            question_input = gr.Textbox(placeholder="Ask a question about the image (e.g., 'What is happening in this image?')", label="Question (Optional)")

        output_text = gr.Textbox(label="Image Description", interactive=False)

        # Submit button to generate the description
        submit_btn = gr.Button("Generate Description")

        # Callback to run when submit button is clicked
        submit_btn.click(
            fn=describe_image,
            inputs=[image_input, question_input],
            outputs=output_text
        )

    # Launch the Gradio interface
    demo.launch()

# Launch the Gradio app
gradio_app()