artwork-scorer / app.py
Muinez's picture
Upload 2 files
31ac1b2
raw
history blame
No virus
1.12 kB
import gradio as gr
import torch
from transformers import AutoImageProcessor, ConvNextV2ForImageClassification
from transformers import AutoModelForImageClassification
from torch import nn
import dbimutils as utils
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
image_processor = AutoImageProcessor.from_pretrained("Muinez/artwork-scorer")
model = AutoModelForImageClassification.from_pretrained("Muinez/artwork-scorer", problem_type="multi_label_classification").to(DEVICE)
def predict(img):
file = utils.preprocess_image(img)
encoded = image_processor(file, return_tensors="pt").to(DEVICE)
with torch.no_grad():
logits = model(**encoded).logits.cpu()
outputs = nn.functional.sigmoid(logits)
return outputs[0][0], outputs[0][1]
gr.Interface(
title="Artwork scorer",
description="Predicts score (0-1) for artwork.\nCould be wrong!!!\nDoes not work very well with nsfw i.e. it was not trained on it",
fn=predict,
allow_flagging="never",
inputs=gr.Image(type="pil"),
outputs=[gr.Number(label="Score"), gr.Number(label="View count ratio (probably useless)")]
).launch()