sam749's picture
Upload folder using huggingface_hub
3442a32 verified
raw
history blame
1.47 kB
import os
import torch
import gradio as gr
from PIL import Image
from transformers import AutoModelForCausalLM,AutoProcessor
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = AutoProcessor.from_pretrained("microsoft/git-base")
model = AutoModelForCausalLM.from_pretrained("sam749/sd-portrait-caption").to(device)
def generate_captions(images:[Image],max_length=200):
# prepare image for the model
inputs = processor(images=images, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values
generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_caption
def generate_caption(image,max_length=200):
return generate_captions(image,max_length)[0]
inputs = [
gr.Image(sources=["upload", "clipboard"],
height=400,
type="pil"
),
gr.Slider(minimum=10,
maximum=400,
value=200,
label='max length',
step=8,
)
]
outputs = [
gr.Text(label="Generated Caption"),
]
demo = gr.Interface(
fn=generate_caption,
inputs=inputs,
outputs=outputs,
title="Stable Diffusion Portrait Captioner",
theme="gradio/monochrome",
api_name="caption",
submit_btn=gr.Button("caption it", variant="primary"),
allow_flagging="never",
)
demo.queue(
max_size=10,
)
demo.launch()