import streamlit as st from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor from koclip import FlaxHybridCLIP @st.cache(allow_output_mutation=True) def load_model(model_name="koclip/koclip"): assert model_name in {"koclip/koclip", "koclip/koclip-large"} model = FlaxHybridCLIP.from_pretrained(model_name) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large") if model_name == "koclip/koclip-large": processor.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-large-patch16-224") return model, processor @st.cache(allow_output_mutation=True) def load_model_v2(model_name="koclip/koclip"): model = FlaxHybridCLIP.from_pretrained(model_name) processor = CLIPProcessor.from_pretrained(model_name) return model, processor