basebody / app.py
hwajjala's picture
Add model files and parse
f4b1311
raw
history blame
No virus
807 Bytes
import os
import clip
import torch
import logging
import json
import pickle
import gradio as gr
logger = logging.getLogger("basebody")
CLIP_MODEL_NAME = "ViT-B/16"
TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")
with open(
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
) as f:
text_prompts = json.load(f)
with open(
os.path.join(
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
),
"rb",
) as f:
lr_model = pickle.load(f)
def greet(name):
return "Hello " + name + "!"
iface = gr.Interface(
fn=greet,
inputs="image",
outputs="text",
allow_flagging="manual"
)
iface.launch()