Samet Yilmaz
Add pillow
6cae924
raw
history blame
2.86 kB
from vllm import LLM, SamplingParams
import gradio as gr
from PIL import Image
from io import BytesIO
import base64
import requests
repo_id = "mistral-community/pixtral-12b-240910" #Replace to the model you would like to use
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
max_tokens_per_img = 4096
max_img_per_msg = 5
def encode_image(image: Image.Image, image_format="PNG") -> str:
im_file = BytesIO()
image.save(im_file, format=image_format)
im_bytes = im_file.getvalue()
im_64 = base64.b64encode(im_bytes).decode("utf-8")
return im_64
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(image_url, prompt, progress=gr.Progress(track_tqdm=True)):
# tokenize image urls and text
llm = LLM(model="mistralai/Pixtral-12B-2409",
tokenizer_mode="mistral",
max_model_len=65536,
max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
limit_mm_per_prompt={"image": max_img_per_msg}) # Name or path of your model
image = Image.open(BytesIO(requests.get(image_url).content))
image = image.resize((3844, 2408))
new_image_url = f"data:image/png;base64,{encode_image(image, image_format='PNG')}"
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": new_image_url}}]
},
]
outputs = llm.chat(messages, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
return outputs
example_images = ["https://picsum.photos/id/237/200/300"]
example_prompts = ["What do you see in this image?"]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Text-to-Image Gradio Template
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your prompt",
container=False,
)
image_url = gr.Text(
label="Image URL",
show_label=False,
max_lines=1,
placeholder="Enter your image URL",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Textbox(
show_label=False
)
gr.Examples(
examples=example_images,
inputs=[image_url]
)
gr.Examples(
examples=example_prompts,
inputs=[prompt]
)
gr.on(
triggers=[run_button.click, image_url.submit, prompt.submit],
fn=infer,
inputs=[image_url, prompt],
outputs=[result]
)
demo.queue().launch()