koclip / utils.py
jaketae's picture
feature: add streamlit backbone
f1d50b1
raw history blame
No virus
899 Bytes
import streamlit as st
from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor
from koclip import FlaxHybridCLIP
@st.cache(allow_output_mutation=True)
def load_model(model_name="koclip/koclip"):
assert model_name in {"koclip/koclip", "koclip/koclip-large"}
model = FlaxHybridCLIP.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
if model_name == "koclip/koclip-large":
processor.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-large-patch16-224")
return model, processor
@st.cache(allow_output_mutation=True)
def load_model_v2(model_name="koclip/koclip"):
model = FlaxHybridCLIP.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
return model, processor