City
Initial version
218e10f
raw history blame
No virus
3.78 kB
import os
import torch
import gradio as gr
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from model import AestheticPredictorModel
HFREPO = "City96/CityAesthetics"
MODELS = [
"CityAesthetics-Anime-v1.8",
]
class CityAestheticsPipeline:
"""
Demo pipeline for [image=>score] prediction
Accepts a list of model paths on initialization.
Resulting object can be called directly with a PIL image as the input.
Returns a dict with the model name as key and the score [0.0;1.0] as a value.
"""
def __init__(self, model_paths):
self.models = {}
for path in model_paths:
name = os.path.splitext(os.path.basename(path))[0]
self.models[name] = self.load_model(path)
clip_ver = "openai/clip-vit-large-patch14"
self.proc = CLIPImageProcessor.from_pretrained(clip_ver)
self.clip = CLIPVisionModelWithProjection.from_pretrained(clip_ver)
print("CityAesthetics: Pipeline init ok") # debug
def load_model(self, path):
sd = load_file(path)
assert tuple(sd["up.0.weight"].shape) == (1024, 768) # only allow CLIP ver
model = AestheticPredictorModel()
model.load_state_dict(sd)
model.eval()
return model
def __call__(self, raw):
img = self.proc(images=raw, return_tensors="pt")
with torch.no_grad():
emb = self.clip(pixel_values=img["pixel_values"])
emb = emb["image_embeds"].detach().cpu()
out = {}
for name, model in self.models.items():
pred = model(emb)
out[name] = float(pred.squeeze(0))
return out
def get_model_path(name):
fname = f"{name}.safetensors"
# local path: [models/AesPred-Anime-v1.8.safetensors]
path = os.path.join(os.path.dirname(os.path.realpath(__file__)),"models")
if os.path.isfile(os.path.join(path, fname)):
print("CityAesthetics: Using local model")
return os.path.join(path, fname)
# huggingface hub fallback
print("CityAesthetics: Using HF Hub model")
return str(hf_hub_download(
token = os.environ.get("HFS_TOKEN") or True,
repo_id = HFREPO,
filename = fname,
# subfolder = fname.split('-')[1],
))
article = """\
# About
This is the live demo for the CityAesthetics class of predictors.
For more information, you can check out the [Huggingface Hub](https://huggingface.co/city96/CityAesthetics) or [GitHub page](https://github.com/city96/CityAesthetics).
## CityAesthetics-Anime
This flavor is optimized for scoring anime images with at least one subject present.
### Intentional biases:
- Completely negative towards real life photos (ideal score of 0%)
- Strongly Negative towards text (subtitles, memes, etc) and manga panels
- Fairly negative towards 3D and to some extent 2.5D images
- Negative towards western cartoons and stylized images (chibi, parody)
### Expected output scores:
- Non-anime images should always score below 20%
- Sketches/rough lineart/oekaki get around 20-40%
- Flat shading/TV anime gets around 40-50%
- Above 50% is mostly scored based on my personal style preferences
### Issues:
- Tends to filter male characters.
- Requires at least 1 subject, won't work for scenery/landscapes.
- Noticeable positive bias towards anime characters with animal ears.
- Hit-or-miss with AI generated images due to style/quality not being correlated.
"""
pipeline = CityAestheticsPipeline([get_model_path(x) for x in MODELS])
gr.Interface(
fn = pipeline,
title = "CityAesthetics demo",
article = article,
inputs = gr.Image(label="Input image", type="pil"),
outputs = gr.Label(label="Model prediction", show_label=False),
examples = "./examples",
allow_flagging = "never",
analytics_enabled = False,
).launch()