Spaces:
Runtime error
Runtime error
import os | |
from urllib.parse import urlparse | |
from PIL import Image | |
import requests | |
import torch | |
from timm.models.hub import download_cached_file | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
import gradio as gr | |
from mm_commerce import BLIP_Decoder | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def is_url(url_or_filename): | |
parsed = urlparse(url_or_filename) | |
return parsed.scheme in ("http", "https") | |
def load_checkpoint(url_or_filename): | |
if is_url(url_or_filename): | |
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) | |
checkpoint = torch.load(cached_file, map_location='cpu') | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location='cpu') | |
else: | |
raise RuntimeError('checkpoint url or path is invalid') | |
return checkpoint | |
image_size = 224 | |
transform = transforms.Compose([ | |
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
]) | |
model = BLIP_Decoder(med_config='configs/med_large_config.json', vit='large_v2', prompt='[DEC]') | |
ckpt = 'https://huggingface.co/zhezh/mm_commerce_zhcn/resolve/main/model.pth' | |
sd = load_checkpoint(ckpt) | |
model.load_state_dict(sd, strict=True) | |
model.eval() | |
model = model.to(device) | |
def inference(raw_image, strategy): | |
image = transform(raw_image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
if strategy == "Beam search": | |
caption = model.generate(image, sample=False, num_beams=10, max_length=100, min_length=10) | |
else: | |
caption = model.generate(image, sample=True, top_p=0.9, max_length=100, min_length=10) | |
return 'εεζθΏ°: ' + '"' + ''.join(caption[0][6:-5].split()) + '"' | |
inputs = [ | |
gr.inputs.Image(type='pil'), | |
gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Beam search", label="ζζ¬ηζηη₯") | |
] | |
outputs = gr.outputs.Textbox(label="ηζηζ ι’(Output)") | |
title = "MM Commerce ZhCN (δΈζεεζθΏ°ηζ)" | |
description = "δΈζεεζθΏ°ηζ -- By Zhe Zhang" | |
demo = gr.Interface( | |
inference, inputs, outputs, title=title, description=description, | |
# article=article, | |
examples=[ | |
['starrynight.jpeg', "Nucleus sampling"], | |
['resources/examples/zhuobu.jpg', "Beam search"], | |
['resources/examples/jiandao.jpg', "Beam search"], | |
['resources/examples/lego-yellow.jpg', "Beam search"], | |
['resources/examples/charger.jpg', "Beam search"], | |
['resources/examples/charger-ugreen.jpg', "Beam search"], | |
['resources/examples/charger-hw.jpg', "Beam search"], | |
], | |
) | |
# demo.launch(enable_queue=True, share=True, server_name='0.0.0.0', server_port=8080,) | |
demo.launch(enable_queue=False) | |