File size: 899 Bytes
f1d50b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import streamlit as st
from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor

from koclip import FlaxHybridCLIP


@st.cache(allow_output_mutation=True)
def load_model(model_name="koclip/koclip"):
    assert model_name in {"koclip/koclip", "koclip/koclip-large"}
    model = FlaxHybridCLIP.from_pretrained(model_name)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
    if model_name == "koclip/koclip-large":
        processor.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-large-patch16-224")
    return model, processor

@st.cache(allow_output_mutation=True)
def load_model_v2(model_name="koclip/koclip"):
    model = FlaxHybridCLIP.from_pretrained(model_name)
    processor = CLIPProcessor.from_pretrained(model_name)
    return model, processor