lit-demo / app.py
andsteing's picture
Minimal version with lit-tuning-demo data.
ab79e7e
raw
history blame
3.02 kB
import contextlib
import functools
import json
import logging
import os
import time
import urllib.request
import gradio as gr
import open_clip # works on open-clip-torch>=2.23.0, timm>=0.9.8
import PIL.Image
import torch
import torch.nn.functional as F
INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json'
IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg'
@contextlib.contextmanager
def timed(name):
t0 = time.monotonic()
try:
yield
finally:
logging.info('Timed %s: %.1f secs', name, time.monotonic() - t0)
@functools.cache
def load_model(name='hf-hub:timm/ViT-SO400M-14-SigLIP-384'):
with timed('loading model, preprocess, tokenizer'):
t0 = time.time()
model, preprocess = open_clip.create_model_from_pretrained(name)
tokenizer = open_clip.get_tokenizer(name)
logging.info('loaded in %.1fs', time.time() - t0)
return model, preprocess, tokenizer
def generate_answers(image_path, prompts):
model, preprocess, tokenizer = load_model()
with torch.no_grad(), torch.cuda.amp.autocast():
logging.info('Opening image "%s"', image_path)
with timed(f'opening image "{image_path}"'):
image = PIL.Image.open(image_path)
with timed('image features'):
image = preprocess(image).unsqueeze(0)
image_features = model.encode_image(image)
with timed('text features'):
prompts = prompts.split(', ')
text = tokenizer(prompts, context_length=model.context_length)
text_features = model.encode_text(text)
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
exp, bias = model.logit_scale.exp(), model.logit_bias
text_probs = torch.sigmoid(image_features @ text_features.T * exp + bias)
return list(zip(prompts, [round(p.item(), 3) for p in text_probs[0]]))
def create_app():
info = json.load(urllib.request.urlopen(INFO_URL))
with gr.Blocks() as demo:
gr.Markdown('Minimal gradio clone of [lit-tuning-demo](https://google-research.github.io/vision_transformer/lit/)')
gr.Markdown('Using `open_clip` implementation of SigLIP model `timm/ViT-SO400M-14-SigLIP-384`')
with gr.Row():
image = gr.Image(label='input_image', type='filepath')
with gr.Column():
prompts = gr.Textbox(label='prompts')
answer = gr.Textbox(label='answer')
run = gr.Button('Run')
gr.Examples(
examples=[
[IMG_URL_FMT.format(ex['id']), ex['prompts']]
for ex in info
],
inputs=[image, prompts],
outputs=[answer],
)
run.click(fn=generate_answers, inputs=[image, prompts], outputs=[answer])
return demo
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
for k, v in os.environ.items():
logging.info('environ["%s"] = %r', k, v)
_ = load_model()
create_app().queue().launch()