fffffgggg54's picture
Update app.py
5d752fd verified
# after SmilingWolf/wd-tagger
import gradio as gr
import torch
import torch.nn as nn
import timm
import timm.layers.ml_decoder
from transformers import AutoModel, AutoTokenizer
import torchvision
from torchvision import transforms
import PIL
from PIL import Image
import requests
from io import BytesIO
import json
import pickle
headers = {
"User-Agent": "Gradio 0-shot classification demo",
}
TITLE = "Danbooru 0-shot classifiction demo"
DESCRIPTION = """
Demo for 0-shot classification on Danbooru images.
Davit-tiny backbone, ML-Decoder classification head, Alibaba-NLP/gte-large-en-v1.5 text embedding model.
Training set includes IDs with <= 5,400,000 and last 3 digits in range [0, 899], inclusive.
Get image by uploading or fetching by post ID.
Get tag description by input box or fetching by tag name.
"""
def scrape_img(postID):
postURL = "https://danbooru.donmai.us/posts/" + str(postID) + ".json"
postData = json.loads(requests.get(postURL, headers=headers).content)
imageURL = postData['file_url']
print("Getting image from " + imageURL)
response = requests.get(imageURL, headers=headers)
image = Image.open(BytesIO(response.content))
image.load()
return image
def scrape_wiki(tagName):
wikiHistoryURL = f"https://danbooru.donmai.us/wiki_page_versions.json?search[title]={tagName}"
wikiHistory = json.loads(requests.get(wikiHistoryURL, headers=headers).content)
wikiBody = (": " + wikiHistory[0]['body'] if len(wikiHistory) > 0 else "")
return tagName + wikiBody
class Predictor:
def __init__(self):
self.img_size = (288, 288)
self.cls_model = None
self.tokenizer = None
self.text_emb_model = None
self.class_embed = None
self.tag_names = None
self.load_model()
def load_model(self):
with open('tags1588.pkl', 'rb') as f:
classes = pickle.load(f)
tagNames = classes[0].to_list()
self.tag_names = tagNames
pretrained_weights = torch.load('model.pth', map_location=torch.device('cpu'))
self.class_embed = pretrained_weights['0.head.head.class_embed.weight']
cls_model = timm.create_model('davit_tiny', num_classes=len(classes))
cls_model = timm.layers.ml_decoder.add_ml_decoder_head(
cls_model,
num_groups=len(classes),
class_embed=self.class_embed,
class_embed_merge='',
shared_fc=True)
cls_model = nn.Sequential(cls_model)
cls_model.load_state_dict(pretrained_weights, strict=True)
cls_model = cls_model.eval()
self.cls_model = cls_model
model_path = 'Alibaba-NLP/gte-large-en-v1.5'
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.text_emb_model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
self.text_emb_model = self.text_emb_model.eval()
def embed_text(self, input_strings):
with torch.no_grad():
# Tokenize the input texts
embeddingList = []
for text in input_strings:
batch_dict = self.tokenizer(text, padding=True, truncation=False, return_tensors='pt')
outputs = self.text_emb_model(**batch_dict.to(self.text_emb_model.device))
embeddings = outputs.last_hidden_state[:, 0]
embeddingList.append(embeddings.cpu())
embeddings = torch.cat(embeddingList)
return embeddings
def prepare_image(self, image):
image.load() # check if file valid
image = image.convert("RGBA")
color = (255,255,255)
background = Image.new('RGB', image.size, color)
background.paste(image, mask=image.split()[3])
image = background
image = transforms.Resize(self.img_size, interpolation = torchvision.transforms.InterpolationMode.BICUBIC)(image)
image = transforms.ToTensor()(image)
return image
def predict(
self,
image,
query,
tag_names,
):
image = self.prepare_image(image)
image_features = self.cls_model[0].forward_features(image.unsqueeze(0))
outputs = self.cls_model[0].head(image_features, q = query).sigmoid().float()
general_tag_list = list(zip(tag_names, outputs[0].tolist()))
general_tag_list.sort(key=lambda y: y[1], reverse=True)
general_tag_preds_dict = {}
for tag, prob in general_tag_list[:50]:
general_tag_preds_dict[tag] = prob
return general_tag_preds_dict
def predict_seen_tags(
self,
image,
):
return self.predict(image, self.class_embed, self.tag_names)
def predict_new_tag(
self,
image,
description,
):
return self.predict(image, self.embed_text([description]), ["embedding"])["embedding"]
def main():
predictor = Predictor()
with gr.Blocks(title=TITLE) as demo:
with gr.Column():
gr.Markdown(
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
)
gr.Markdown(value=DESCRIPTION)
with gr.Row():
with gr.Column(variant="panel"):
image = gr.Image(type="pil", image_mode="RGBA", label="Input", height=600)
with gr.Row():
post_id = gr.Textbox(label="Post ID")
with gr.Column():
clear = gr.ClearButton(
value="Clear image",
components=[
image,
],
variant="secondary",
size="lg",
)
get_post = gr.Button(value="Get Post", variant="primary", size="lg")
with gr.Row():
submit = gr.Button(value="Predict known tags", variant="primary", size="lg")
with gr.Column(variant="panel"):
tag_description = gr.Textbox(label="Tag description")
with gr.Row():
tag_name = gr.Textbox(label="Tag Name")
description_prediction = gr.Textbox(label="Probability")
with gr.Row():
clear_tag_data = gr.ClearButton(value="Clear tag", variant="secondary", size="lg")
get_tag_description = gr.Button(value="Get tag description", variant="primary", size="lg")
predict_on_description = gr.Button(value="Predict described tag")
general_bars = gr.Label(label="Known tags")
clear.add(
[
general_bars,
description_prediction,
post_id,
]
)
clear_tag_data.add(
[
tag_description,
tag_name,
description_prediction,
]
)
examples = gr.Examples(
[
[
"8801249",
"short_over_long_sleeves"
],
],
inputs=[
post_id,
tag_name,
],
run_on_click=False,
cache_examples=False,
)
submit.click(
predictor.predict_seen_tags,
inputs=[
image,
],
outputs=[general_bars],
)
predict_on_description.click(
predictor.predict_new_tag,
inputs=[image, tag_description],
outputs=[description_prediction]
)
get_post.click(
scrape_img,
inputs=[post_id],
outputs=[image]
)
get_tag_description.click(
scrape_wiki,
inputs=[tag_name],
outputs=[tag_description]
)
demo.queue(max_size=10)
demo.launch()
if __name__ == "__main__":
main()