File size: 1,255 Bytes
f4b1311
194b093
 
f4b1311
 
 
91d8e2e
 
f4b1311
 
 
 
 
 
 
 
267519a
f4b1311
 
 
 
 
 
 
 
 
 
 
 
 
 
267519a
 
 
 
 
 
 
 
 
 
 
 
 
f4b1311
267519a
f4b1311
 
91d8e2e
f4b1311
267519a
f4b1311
 
 
 
267519a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import clip
import torch
import logging
import json
import pickle
import gradio as gr


logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"

TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"



clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")

with open(
    os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
    text_prompts = json.load(f)
with open(
    os.path.join(
        os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
    ),
    "rb",
) as f:
    lr_model = pickle.load(f)

logger.info("Logistic regression model loaded, coefficients: ")


all_text_features = []
with torch.no_grad():
    for k, prompts in text_prompts.items():
        assert len(prompts) == 2
        inputs = clip.tokenize(prompts)
        outputs = clip_model.encode_text(inputs)
        all_text_features.append(outputs)
    all_text_features = torch.cat(all_text_features, dim=0)
    all_text_features = all_text_features.cpu().numpy()


def predict_fn(name):
    return "Hello " + name + "!"


iface = gr.Interface(
    fn=predict_fn,
    inputs="image",
    outputs="text",
    allow_flagging="manual"
)
iface.launch()