File size: 2,158 Bytes
f4b1311
194b093
 
f4b1311
 
 
2a62a79
91d8e2e
2a62a79
91d8e2e
f4b1311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267519a
 
 
 
 
 
 
 
 
 
 
4e55465
267519a
f4b1311
77c92b5
2a62a79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4b1311
91d8e2e
f4b1311
267519a
f4b1311
 
 
 
267519a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import clip
import torch
import logging
import json
import pickle
from PIL import Image
import gradio as gr
from scipy.special import softmax


logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"

TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"


clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")

with open(
    os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
    text_prompts = json.load(f)
with open(
    os.path.join(
        os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
    ),
    "rb",
) as f:
    lr_model = pickle.load(f)

logger.info("Logistic regression model loaded, coefficients: ")


all_text_features = []
with torch.no_grad():
    for k, prompts in text_prompts.items():
        assert len(prompts) == 2
        inputs = clip.tokenize(prompts)
        outputs = clip_model.encode_text(inputs)
        all_text_features.append(outputs)
    all_text_features = torch.cat(all_text_features, dim=0)
    all_text_features = all_text_features.cpu()


def predict_fn(input_img):
    input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
    image = preprocess(
        input_img
    ).unsqueeze(0)
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
        cosine_simlarities = softmax(
            (all_text_features @ image_features.cpu().T)
            .squeeze()
            .reshape(len(text_prompts), 2, -1),
            axis=1,
        )[:, 0, :]
        # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
        logger.info(f"cosine_simlarities: {cosine_simlarities}")
    probabilities = lr_model.predict_proba(
        cosine_simlarities.reshape(1, -1)
    )
    logger.info(f"probabilities: {probabilities}")
    decision_json = json.dumps(
        {"is_base_body": float(probabilities[0][1])}
    ).encode("utf-8")
    logger.info(f"decision_json: {decision_json}")
    return decision_json


iface = gr.Interface(
    fn=predict_fn,
    inputs="image",
    outputs="text",
    allow_flagging="manual"
)
iface.launch()