import os | |
import clip | |
import torch | |
import logging | |
import json | |
import pickle | |
import gradio as gr | |
logger = logging.getLogger("basebody") | |
CLIP_MODEL_NAME = "ViT-B/16" | |
TEXT_PROMPTS_FILE_NAME = "text_prompts.json" | |
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl" | |
clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu") | |
with open( | |
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r" | |
) as f: | |
text_prompts = json.load(f) | |
with open( | |
os.path.join( | |
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME | |
), | |
"rb", | |
) as f: | |
lr_model = pickle.load(f) | |
def greet(name): | |
return "Hello " + name + "!" | |
iface = gr.Interface( | |
fn=greet, | |
inputs="image", | |
outputs="text", | |
allow_flagging="manual" | |
) | |
iface.launch() |