basebody / app.py
hwajjala's picture
Add execution logic
2a62a79
raw
history blame
2.17 kB
import os
import clip
import torch
import logging
import json
import pickle
from PIL import Image
import gradio as gr
from scipy.special import softmax
logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"
TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")
with open(
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
text_prompts = json.load(f)
with open(
os.path.join(
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
),
"rb",
) as f:
lr_model = pickle.load(f)
logger.info("Logistic regression model loaded, coefficients: ")
all_text_features = []
with torch.no_grad():
for k, prompts in text_prompts.items():
assert len(prompts) == 2
inputs = clip.tokenize(prompts)
outputs = clip_model.encode_text(inputs)
all_text_features.append(outputs)
all_text_features = torch.cat(all_text_features, dim=0)
all_text_features = all_text_features.cpu().numpy()
def predict_fn(input_img):
input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
image = preprocess(
input_img
).unsqueeze(0)
with torch.no_grad():
image_features = clip_model.encode_image(image)
cosine_simlarities = softmax(
(all_text_features @ image_features.cpu().T)
.squeeze()
.reshape(len(text_prompts), 2, -1),
axis=1,
)[:, 0, :]
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
logger.info(f"cosine_simlarities: {cosine_simlarities}")
probabilities = lr_model.predict_proba(
cosine_simlarities.reshape(1, -1)
)
logger.info(f"probabilities: {probabilities}")
decision_json = json.dumps(
{"is_base_body": float(probabilities[0][1])}
).encode("utf-8")
logger.info(f"decision_json: {decision_json}")
return decision_json
iface = gr.Interface(
fn=predict_fn,
inputs="image",
outputs="text",
allow_flagging="manual"
)
iface.launch()