Spaces:
Build error
Build error
File size: 2,240 Bytes
cf349fd 503acf7 2cf3514 f1d50b1 587ab22 f1d50b1 98e7562 2cf3514 a1fc7fb cf349fd a1fc7fb 7dd4f77 a1fc7fb cf349fd 1991cb1 cf349fd 2cf3514 cf349fd 2cf3514 cf349fd f1d50b1 2cf3514 0e0bacc 98e7562 a1fc7fb 98e7562 587ab22 f1d50b1 2cf3514 503acf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
from config import MODEL_LIST
from koclip import FlaxHybridCLIP
from global_session import GlobalState
from threading import Lock
def load_index(img_file):
state = GlobalState(img_file)
if not hasattr(state, '_lock'):
state._lock = Lock()
print(f"Locking loading of features : {img_file} to avoid concurrent caching.")
with state._lock:
cached_index = load_index_cached(img_file)
print(f"Unlocking loading of features : {img_file} to avoid concurrent caching.")
return cached_index
@st.cache(allow_output_mutation=True)
def load_index_cached(img_file):
filenames, embeddings = [], []
with open(img_file, "r") as f:
for line in f:
cols = line.strip().split("\t")
filename = cols[0]
embedding = [float(x) for x in cols[1].split(",")]
filenames.append(filename)
embeddings.append(embedding)
embeddings = np.array(embeddings)
index = nmslib.init(method="hnsw", space="cosinesimil")
index.addDataPointBatch(embeddings)
index.createIndex({"post": 2}, print_progress=True)
return filenames, index
def load_model(model_name="koclip/koclip-base"):
state = GlobalState(model_name)
if not hasattr(state, '_lock'):
state._lock = Lock()
print(f"Locking loading of model : {model_name} to avoid concurrent caching.")
with state._lock:
cached_model = load_model_cached(model_name)
print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.")
return cached_model
@st.cache(allow_output_mutation=True)
def load_model_cached(model_name):
assert model_name in {f"koclip/{model}" for model in MODEL_LIST}
model = FlaxHybridCLIP.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
if model_name == "koclip/koclip-large":
processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
"google/vit-large-patch16-224"
)
return model, processor
|