File size: 2,462 Bytes
f4b1311
194b093
 
f4b1311
 
 
2a62a79
91d8e2e
2a62a79
91d8e2e
f4b1311
 
 
 
 
 
 
e1f60ba
 
 
 
f4b1311
e1f60ba
 
 
f4b1311
 
 
 
 
 
 
 
 
 
 
 
 
267519a
 
 
 
 
 
 
 
 
 
 
4e55465
267519a
f4b1311
77c92b5
2a62a79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a2bd30
2a62a79
 
 
f4b1311
91d8e2e
f4b1311
267519a
f4b1311
 
e1f60ba
7a2bd30
 
 
 
 
e1f60ba
f4b1311
267519a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import clip
import torch
import logging
import json
import pickle
from PIL import Image
import gradio as gr
from scipy.special import softmax


logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"

TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"

HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(
    HF_TOKEN, "Roblox/basebody_feedback"
)

clip_model, preprocess = clip.load(
    CLIP_MODEL_NAME, device="cpu"
)

with open(
    os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
    text_prompts = json.load(f)
with open(
    os.path.join(
        os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
    ),
    "rb",
) as f:
    lr_model = pickle.load(f)

logger.info("Logistic regression model loaded, coefficients: ")


all_text_features = []
with torch.no_grad():
    for k, prompts in text_prompts.items():
        assert len(prompts) == 2
        inputs = clip.tokenize(prompts)
        outputs = clip_model.encode_text(inputs)
        all_text_features.append(outputs)
    all_text_features = torch.cat(all_text_features, dim=0)
    all_text_features = all_text_features.cpu()


def predict_fn(input_img):
    input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
    image = preprocess(
        input_img
    ).unsqueeze(0)
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
        cosine_simlarities = softmax(
            (all_text_features @ image_features.cpu().T)
            .squeeze()
            .reshape(len(text_prompts), 2, -1),
            axis=1,
        )[:, 0, :]
        # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
        logger.info(f"cosine_simlarities: {cosine_simlarities}")
    probabilities = lr_model.predict_proba(
        cosine_simlarities.reshape(1, -1)
    )
    logger.info(f"probabilities: {probabilities}")
    decision_json = json.dumps(
        {"is_base_body": float(probabilities[0][1].round(3))}
    ).encode("utf-8")
    logger.info(f"decision_json: {decision_json}")
    return decision_json


iface = gr.Interface(
    fn=predict_fn,
    inputs="image",
    outputs="text",
    allow_flagging="manual",
    flagging_options=[
        "probability error > 0.5",
        "0.2 < probability error < 0.5",
        "probability error < 0.2"
    ],
    flagging_callback=hf_writer
)
iface.launch()