File size: 3,783 Bytes
218e10f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
import gradio as gr
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

from model import AestheticPredictorModel

HFREPO = "City96/CityAesthetics"
MODELS = [
	"CityAesthetics-Anime-v1.8",
]

class CityAestheticsPipeline:
	"""

	Demo pipeline for [image=>score] prediction

		Accepts a list of model paths on initialization.

		Resulting object can be called directly with a PIL image as the input.

		Returns a dict with the model name as key and the score [0.0;1.0] as a value.

	"""
	def __init__(self, model_paths):
		self.models = {}
		for path in model_paths:
			name = os.path.splitext(os.path.basename(path))[0]
			self.models[name] = self.load_model(path)

		clip_ver = "openai/clip-vit-large-patch14"
		self.proc = CLIPImageProcessor.from_pretrained(clip_ver)
		self.clip = CLIPVisionModelWithProjection.from_pretrained(clip_ver)
		print("CityAesthetics: Pipeline init ok") # debug

	def load_model(self, path):
		sd = load_file(path)
		assert tuple(sd["up.0.weight"].shape) == (1024, 768) # only allow CLIP ver
		model = AestheticPredictorModel()
		model.load_state_dict(sd)
		model.eval()
		return model

	def __call__(self, raw):
		img = self.proc(images=raw, return_tensors="pt")
		with torch.no_grad():
			emb = self.clip(pixel_values=img["pixel_values"])
		emb = emb["image_embeds"].detach().cpu()
		out = {}
		for name, model in self.models.items():
			pred = model(emb)
			out[name] = float(pred.squeeze(0))
		return out

def get_model_path(name):
	fname = f"{name}.safetensors"

	# local path: [models/AesPred-Anime-v1.8.safetensors]
	path  = os.path.join(os.path.dirname(os.path.realpath(__file__)),"models")
	if os.path.isfile(os.path.join(path, fname)):
		print("CityAesthetics: Using local model")
		return os.path.join(path, fname)

	# huggingface hub fallback
	print("CityAesthetics: Using HF Hub model")
	return str(hf_hub_download(
		token    = os.environ.get("HFS_TOKEN") or True,
		repo_id  = HFREPO,
		filename = fname,
		# subfolder = fname.split('-')[1],
	))

article = """\

# About



This is the live demo for the CityAesthetics class of predictors.



For more information, you can check out the [Huggingface Hub](https://huggingface.co/city96/CityAesthetics) or [GitHub page](https://github.com/city96/CityAesthetics).



## CityAesthetics-Anime



This flavor is optimized for scoring anime images with at least one subject present.



### Intentional biases:



- Completely negative towards real life photos (ideal score of 0%)

- Strongly Negative towards text (subtitles, memes, etc) and manga panels

- Fairly negative towards 3D and to some extent 2.5D images

- Negative towards western cartoons and stylized images (chibi, parody)



### Expected output scores:



- Non-anime images should always score below 20%

- Sketches/rough lineart/oekaki get around 20-40%

- Flat shading/TV anime gets around 40-50%

- Above 50% is mostly scored based on my personal style preferences



### Issues:



- Tends to filter male characters.

- Requires at least 1 subject, won't work for scenery/landscapes.

- Noticeable positive bias towards anime characters with animal ears.

- Hit-or-miss with AI generated images due to style/quality not being correlated.

"""

pipeline = CityAestheticsPipeline([get_model_path(x) for x in MODELS])
gr.Interface(
	fn      = pipeline,
	title   = "CityAesthetics demo",
	article = article,
	inputs  = gr.Image(label="Input image", type="pil"),
	outputs = gr.Label(label="Model prediction", show_label=False),
	examples = "./examples",
	allow_flagging = "never",
	analytics_enabled = False,
).launch()