File size: 2,240 Bytes
cf349fd
503acf7
2cf3514
 
f1d50b1
587ab22
f1d50b1
98e7562
 
2cf3514
a1fc7fb
cf349fd
a1fc7fb
 
 
 
 
 
 
 
7dd4f77
a1fc7fb
 
 
 
 
cf349fd
1991cb1
 
 
 
 
 
 
cf349fd
2cf3514
cf349fd
2cf3514
cf349fd
f1d50b1
2cf3514
0e0bacc
98e7562
 
 
 
 
 
 
 
 
 
 
a1fc7fb
98e7562
 
587ab22
f1d50b1
 
 
 
2cf3514
 
 
503acf7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor

from config import MODEL_LIST
from koclip import FlaxHybridCLIP
from global_session import GlobalState
from threading import Lock


def load_index(img_file):
    state = GlobalState(img_file)
    if not hasattr(state, '_lock'):
        state._lock = Lock()
    print(f"Locking loading of features : {img_file} to avoid concurrent caching.")

    with state._lock:
        cached_index = load_index_cached(img_file)

    print(f"Unlocking loading of features : {img_file} to avoid concurrent caching.")
    return cached_index


@st.cache(allow_output_mutation=True)
def load_index_cached(img_file):
    filenames, embeddings = [], []
    with open(img_file, "r") as f:
        for line in f:
            cols = line.strip().split("\t")
            filename = cols[0]
            embedding = [float(x) for x in cols[1].split(",")]
            filenames.append(filename)
            embeddings.append(embedding)
    embeddings = np.array(embeddings)
    index = nmslib.init(method="hnsw", space="cosinesimil")
    index.addDataPointBatch(embeddings)
    index.createIndex({"post": 2}, print_progress=True)
    return filenames, index


def load_model(model_name="koclip/koclip-base"):
    state = GlobalState(model_name)
    if not hasattr(state, '_lock'):
        state._lock = Lock()
    print(f"Locking loading of model : {model_name} to avoid concurrent caching.")

    with state._lock:
        cached_model = load_model_cached(model_name)

    print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.")
    return cached_model


@st.cache(allow_output_mutation=True)
def load_model_cached(model_name):
    assert model_name in {f"koclip/{model}" for model in MODEL_LIST}
    model = FlaxHybridCLIP.from_pretrained(model_name)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
    if model_name == "koclip/koclip-large":
        processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
            "google/vit-large-patch16-224"
        )
    return model, processor