mm-commerce / app.py
zhangzhe45
add
50d1ff1
raw
history blame
No virus
2.96 kB
import os
from urllib.parse import urlparse
from PIL import Image
import requests
import torch
from timm.models.hub import download_cached_file
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import gradio as gr
from mm_commerce import BLIP_Decoder
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def load_checkpoint(url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')
return checkpoint
image_size = 224
transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
model = BLIP_Decoder(med_config='configs/med_large_config.json', vit='large_v2', prompt='[DEC]')
ckpt = 'https://huggingface.co/zhezh/mm_commerce_zhcn/resolve/main/model.pth'
sd = load_checkpoint(ckpt)
model.load_state_dict(sd, strict=True)
model.eval()
model = model.to('cuda')
def inference(raw_image, strategy):
image = transform(raw_image).unsqueeze(0).to(device)
with torch.no_grad():
if strategy == "Beam search":
caption = model.generate(image, sample=False, num_beams=10, max_length=100, min_length=10)
else:
caption = model.generate(image, sample=True, top_p=0.9, max_length=100, min_length=10)
return '商品描述: ' + '"' + ''.join(caption[0][6:-5].split()) + '"'
inputs = [
gr.inputs.Image(type='pil'),
gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Beam search", label="ζ–‡ζœ¬η”Ÿζˆη­–η•₯")
]
outputs = gr.outputs.Textbox(label="η”Ÿζˆηš„ζ ‡ι’˜(Output)")
title = "MM Commerce ZhCN (δΈ­ζ–‡ε•†ε“ζθΏ°η”Ÿζˆ)"
description = "δΈ­ζ–‡ε•†ε“ζθΏ°η”Ÿζˆ -- By Zhe Zhang"
demo = gr.Interface(
inference, inputs, outputs, title=title, description=description,
# article=article,
examples=[
['starrynight.jpeg', "Nucleus sampling"],
['resources/examples/zhuobu.jpg', "Beam search"],
['resources/examples/jiandao.jpg', "Beam search"],
['resources/examples/lego-yellow.jpg', "Beam search"],
['resources/examples/charger.jpg', "Beam search"],
['resources/examples/charger-ugreen.jpg', "Beam search"],
['resources/examples/charger-hw.jpg', "Beam search"],
],
)
# demo.launch(enable_queue=True, share=True, server_name='0.0.0.0', server_port=8080,)
demo.launch(enable_queue=True)