koclip / utils.py
devtrent's picture
Improve logging
7dd4f77
raw history blame
No virus
2.24 kB
import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
from config import MODEL_LIST
from koclip import FlaxHybridCLIP
from global_session import GlobalState
from threading import Lock
def load_index(img_file):
state = GlobalState(img_file)
if not hasattr(state, '_lock'):
state._lock = Lock()
print(f"Locking loading of features : {img_file} to avoid concurrent caching.")
with state._lock:
cached_index = load_index_cached(img_file)
print(f"Unlocking loading of features : {img_file} to avoid concurrent caching.")
return cached_index
@st.cache(allow_output_mutation=True)
def load_index_cached(img_file):
filenames, embeddings = [], []
with open(img_file, "r") as f:
for line in f:
cols = line.strip().split("\t")
filename = cols[0]
embedding = [float(x) for x in cols[1].split(",")]
filenames.append(filename)
embeddings.append(embedding)
embeddings = np.array(embeddings)
index = nmslib.init(method="hnsw", space="cosinesimil")
index.addDataPointBatch(embeddings)
index.createIndex({"post": 2}, print_progress=True)
return filenames, index
def load_model(model_name="koclip/koclip-base"):
state = GlobalState(model_name)
if not hasattr(state, '_lock'):
state._lock = Lock()
print(f"Locking loading of model : {model_name} to avoid concurrent caching.")
with state._lock:
cached_model = load_model_cached(model_name)
print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.")
return cached_model
@st.cache(allow_output_mutation=True)
def load_model_cached(model_name):
assert model_name in {f"koclip/{model}" for model in MODEL_LIST}
model = FlaxHybridCLIP.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
if model_name == "koclip/koclip-large":
processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
"google/vit-large-patch16-224"
)
return model, processor