|
import os |
|
import clip |
|
import torch |
|
import logging |
|
import json |
|
import pickle |
|
import gradio as gr |
|
|
|
|
|
logger = logging.getLogger("basebody") |
|
CLIP_MODEL_NAME = "ViT-B/16" |
|
|
|
TEXT_PROMPTS_FILE_NAME = "text_prompts.json" |
|
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl" |
|
|
|
|
|
|
|
clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu") |
|
|
|
with open( |
|
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r" |
|
) as f: |
|
text_prompts = json.load(f) |
|
with open( |
|
os.path.join( |
|
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME |
|
), |
|
"rb", |
|
) as f: |
|
lr_model = pickle.load(f) |
|
|
|
logger.info("Logistic regression model loaded, coefficients: ") |
|
|
|
|
|
all_text_features = [] |
|
with torch.no_grad(): |
|
for k, prompts in text_prompts.items(): |
|
assert len(prompts) == 2 |
|
inputs = clip.tokenize(prompts) |
|
outputs = clip_model.encode_text(inputs) |
|
all_text_features.append(outputs) |
|
all_text_features = torch.cat(all_text_features, dim=0) |
|
all_text_features = all_text_features.cpu().numpy() |
|
|
|
|
|
def predict_fn(name): |
|
return "Hello " + name + "!" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_fn, |
|
inputs="image", |
|
outputs="text", |
|
allow_flagging="manual" |
|
) |
|
iface.launch() |
|
|