Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from transformers import AutoModelForCausalLM,AutoProcessor | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
processor = AutoProcessor.from_pretrained("microsoft/git-base") | |
model = AutoModelForCausalLM.from_pretrained("sam749/sd-portrait-caption").to(device) | |
def generate_captions(images:[Image],max_length=200): | |
# prepare image for the model | |
inputs = processor(images=images, return_tensors="pt").to(device) | |
pixel_values = inputs.pixel_values | |
generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length) | |
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
return generated_caption | |
def generate_caption(image,max_length=200): | |
return generate_captions(image,max_length)[0] | |
inputs = [ | |
gr.Image(sources=["upload", "clipboard"], | |
height=400, | |
type="pil" | |
), | |
gr.Slider(minimum=10, | |
maximum=400, | |
value=200, | |
label='max length', | |
step=8, | |
) | |
] | |
outputs = [ | |
gr.Text(label="Generated Caption"), | |
] | |
demo = gr.Interface( | |
fn=generate_caption, | |
inputs=inputs, | |
outputs=outputs, | |
title="Stable Diffusion Portrait Captioner", | |
theme="gradio/monochrome", | |
api_name="caption", | |
submit_btn=gr.Button("caption it", variant="primary"), | |
allow_flagging="never", | |
) | |
demo.queue( | |
max_size=10, | |
) | |
demo.launch() | |