import pathlib import gradio as gr import open_clip import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model, _, transform = open_clip.create_model_and_transforms( "coca_ViT-L-14", pretrained="mscoco_finetuned_laion2B-s13B-b90k" ) model.to(device) def output_generate(image): im = transform(image).unsqueeze(0).to(device) with torch.no_grad(), torch.cuda.amp.autocast(): generated = model.generate(im, seq_len=20) return open_clip.decode(generated[0].detach()).split("")[0].replace("", "") def inference_caption(image, decoding_method="Beam search", rep_penalty=1.2, top_p=0.5, min_seq_len=5, seq_len=20): im = transform(image).unsqueeze(0).to(device) generation_type = "beam_search" if decoding_method == "Beam search" else "top_p" with torch.no_grad(), torch.cuda.amp.autocast(): generated = model.generate( im, generation_type=generation_type, top_p=top_p, min_seq_len=min_seq_len, seq_len=seq_len, repetition_penalty=rep_penalty ) return open_clip.decode(generated[0].detach()).split("")[0].replace("", "") paths = sorted(pathlib.Path("images").glob("*.jpg")) with gr.Blocks( css=""" .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} #component-21 > div.wrap.svelte-w6rprc {height: 600px;} """ ) as iface: state = gr.State([]) # gr.Markdown(title) # gr.Markdown(description) # gr.Markdown(article) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil") # with gr.Row(): sampling = gr.Radio( choices=["Beam search", "Nucleus sampling"], value="Beam search", label="Text Decoding Method", interactive=True, ) rep_penalty = gr.Slider( minimum=1.0, maximum=5.0, value=1.5, step=0.5, interactive=True, label="Repeat Penalty (larger value prevents repetition)", ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=1.0, step=0.1, interactive=True, label="Top p (used with nucleus sampling)", ) min_seq_len = gr.Number( value=5, label="Minimum Sequence Length", precision=0, interactive=True ) seq_len = gr.Number( value=20, label="Maximum Sequence Length", precision=0, interactive=True ) with gr.Column(scale=1.8): with gr.Column(): caption_output = gr.Textbox(lines=1, label="Caption Output") caption_button = gr.Button( value="Caption it!", interactive=True, variant="primary" ) caption_button.click( inference_caption, [ image_input, sampling, rep_penalty, top_p, min_seq_len, seq_len ], [caption_output], ) # iface = gr.Interface( # fn=output_generate, # inputs=gr.Image(label="Input image", type="pil"), # outputs=gr.Text(label="Caption output"), # title="CoCa: Contrastive Captioners", # description=( # """
An open source implementation of CoCa: Contrastive Captioners are Image-Text Foundation Models https://arxiv.org/abs/2205.01917. #
Built using open_clip with an effort from LAION. #
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. Duplicate Space""" # ), # article="""""", # examples=[path.as_posix() for path in paths], # ) iface.launch()