Spaces:
Build error
Build error
| import nmslib | |
| import numpy as np | |
| import streamlit as st | |
| from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor | |
| from config import MODEL_LIST | |
| from koclip import FlaxHybridCLIP | |
| from global_session import GlobalState | |
| from threading import Lock | |
| def load_index(img_file): | |
| state = GlobalState(img_file) | |
| if not hasattr(state, '_lock'): | |
| state._lock = Lock() | |
| print(f"Locking loading of features : {img_file} to avoid concurrent caching.") | |
| with state._lock: | |
| cached_index = load_index_cached(img_file) | |
| print(f"Unlocking loading of features : {img_file} to avoid concurrent caching.") | |
| return cached_index | |
| def load_index_cached(img_file): | |
| filenames, embeddings = [], [] | |
| with open(img_file, "r") as f: | |
| for line in f: | |
| cols = line.strip().split("\t") | |
| filename = cols[0] | |
| embedding = [float(x) for x in cols[1].split(",")] | |
| filenames.append(filename) | |
| embeddings.append(embedding) | |
| embeddings = np.array(embeddings) | |
| index = nmslib.init(method="hnsw", space="cosinesimil") | |
| index.addDataPointBatch(embeddings) | |
| index.createIndex({"post": 2}, print_progress=True) | |
| return filenames, index | |
| def load_model(model_name="koclip/koclip-base"): | |
| state = GlobalState(model_name) | |
| if not hasattr(state, '_lock'): | |
| state._lock = Lock() | |
| print(f"Locking loading of model : {model_name} to avoid concurrent caching.") | |
| with state._lock: | |
| cached_model = load_model_cached(model_name) | |
| print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.") | |
| return cached_model | |
| def load_model_cached(model_name): | |
| assert model_name in {f"koclip/{model}" for model in MODEL_LIST} | |
| model = FlaxHybridCLIP.from_pretrained(model_name) | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large") | |
| if model_name == "koclip/koclip-large": | |
| processor.feature_extractor = ViTFeatureExtractor.from_pretrained( | |
| "google/vit-large-patch16-224" | |
| ) | |
| return model, processor | |