basebody / app.py
hwajjala's picture
Add additional check
cb8ddf6
raw
history blame
5.51 kB
import os
import clip
import torch
import logging
import json
import pickle
from PIL import Image
import gradio as gr
from scipy.special import softmax
logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"
TEXT_PROMPTS_OLD_FILE_NAME = "text_prompts.json"
TEXT_PROMPTS_FILE_NAME = "text_prompts2.json"
HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME = "logistic_regression_l1.pkl"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_26.pkl"
HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"
HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "Roblox/basebody_feedback"
)
clip_model, preprocess = clip.load(
CLIP_MODEL_NAME, device="cpu"
)
with open(
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
text_prompts = json.load(f)
with open(
os.path.join(os.path.dirname(__file__), HAIR_TEXT_PROMPTS_FILE_NAME), "r"
) as f:
hair_text_prompts = json.load(f)
with open(
os.path.join(
os.path.dirname(__file__), TEXT_PROMPTS_OLD_FILE_NAME
),
"r",
) as f:
text_prompts_old = json.load(f)
with open(
os.path.join(
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
),
"rb",
) as f:
lr_model = pickle.load(f)
with open(
os.path.join(
os.path.dirname(__file__), LOGISTIC_REGRESSION_OLD_MODEL_FILE_NAME
),
"rb",
) as f:
lr_old_model = pickle.load(f)
with open(
os.path.join(
os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
),
"rb",
) as f:
hair_rf_model = pickle.load(f)
logger.info("Logistic regression model loaded, coefficients: ")
def get_text_features(text_prompts):
all_text_features = []
with torch.no_grad():
for k, prompts in text_prompts.items():
assert len(prompts) == 2
inputs = clip.tokenize(prompts)
outputs = clip_model.encode_text(inputs)
all_text_features.append(outputs)
all_text_features = torch.cat(all_text_features, dim=0)
all_text_features = all_text_features.cpu()
return all_text_features
all_text_features = get_text_features(text_prompts)
hair_text_features = get_text_features(hair_text_prompts)
old_text_features = get_text_features(text_prompts_old)
def get_cosine_similarities(image_features, text_features, text_prompts):
cosine_simlarities = softmax(
(text_features @ image_features.cpu().T)
.squeeze()
.reshape(len(text_prompts), 2, -1),
axis=1,
)[:, 0, :]
return cosine_simlarities
def predict_fn(input_img):
input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
image = preprocess(
input_img
).unsqueeze(0)
with torch.no_grad():
image_features = clip_model.encode_image(image)
base_body_cosine_simlarities = get_cosine_similarities(
image_features, all_text_features, text_prompts
)
hair_cosine_simlarities = get_cosine_similarities(
image_features, hair_text_features, hair_text_prompts
)
old_cosine_simlarities = get_cosine_similarities(
image_features, old_text_features, text_prompts_old
)
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
probabilities = lr_model.predict_proba(
base_body_cosine_simlarities.reshape(1, -1)
)
hair_probabilities = hair_rf_model.predict_proba(
hair_cosine_simlarities.reshape(1, -1)
)
old_lr_probabilities = lr_old_model.predict_proba(
old_cosine_simlarities.reshape(1, -1)
)
logger.info(f"probabilities: {probabilities}")
result_probabilty = float(probabilities[0][1].round(3))
hair_result_probabilty = float(hair_probabilities[0][1].round(3))
old_result_probabilty = float(old_lr_probabilities[0][1].round(3))
# get decision string
if result_probabilty > 0.77:
if hair_result_probabilty < 0.5:
logger.info("hair_result_probabilty < 0.5")
result_probabilty = hair_result_probabilty
decision = "AUTO REJECT"
else:
decision = "AUTO ACCEPT"
elif result_probabilty < 0.2:
decision = "AUTO REJECT"
elif old_result_probabilty < 0.06:
logger.info("old_result_probabilty < 0.06")
result_probabilty = old_result_probabilty
decision = "AUTO REJECT"
else:
decision = "MODERATION"
logger.info(f"decision: {decision}")
decision_json = json.dumps(
{"is_base_body": result_probabilty, "decision": decision}
).encode("utf-8")
logger.info(f"decision_json: {decision_json}")
return decision_json
iface = gr.Interface(
fn=predict_fn,
inputs="image",
outputs="text",
description="""
The model returns the probability of the image being a base body. If
probability > 0.77, the image can be automatically tagged as a base body. If
probability < 0.4, the image can be automatically REJECTED as NOT as base
body. All other cases will be submitted for moderation.
Please flag if you think the decision is wrong.
""",
allow_flagging="manual",
flagging_options=[
": decision should be accept",
": decision should be reject",
": decision should be moderation"
],
flagging_callback=hf_writer
)
iface.launch()