aq-predictor / app.py
hwajjala's picture
Add aq model
05b1066
raw
history blame
No virus
1.72 kB
import os
import clip
import torch
import logging
import json
import pandas as pd
from PIL import Image
import gradio as gr
from autogluon.tabular import TabularPredictor
predictor = TabularPredictor.load("AutogluonModels/ag-20240615_190835")
# set logging level
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("AQ")
CLIP_MODEL_NAME = "ViT-B/32"
clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")
def predict_fn(input_img):
input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
image = preprocess(input_img).unsqueeze(0)
with torch.no_grad():
image_features = clip_model.encode_image(image).numpy()
input_df = pd.DataFrame(image_features[0].reshape(1, -1))
quality_score = predictor.predict(input_df).iloc[0]
logger.info(f"decision: {quality_score}")
decision_json = json.dumps({"quality_score": quality_score}).encode("utf-8")
logger.info(f"decision_json: {decision_json}")
return decision_json
iface = gr.Interface(
fn=predict_fn,
inputs="image",
outputs="text",
description="""
The model returns the probability of the image being a base body. If
probability > 0.9, the image can be automatically tagged as a base body. If
probability < 0.2, the image can be automatically REJECTED as NOT as base
body. All other cases will be submitted for moderation.
Please flag if you think the decision is wrong.
""",
allow_flagging="manual",
flagging_options=[
": decision should be accept",
": decision should be reject",
": decision should be moderation",
],
)
iface.launch()