koclip / utils.py
jaketae's picture
fix: close read file
1991cb1
raw history blame
No virus
1.34 kB
import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
from koclip import FlaxHybridCLIP
@st.cache(allow_output_mutation=True)
def load_index(img_file):
filenames, embeddings = [], []
with open(img_file, "r") as f:
for line in f:
cols = line.strip().split("\t")
filename = cols[0]
embedding = [float(x) for x in cols[1].split(",")]
filenames.append(filename)
embeddings.append(embedding)
embeddings = np.array(embeddings)
index = nmslib.init(method="hnsw", space="cosinesimil")
index.addDataPointBatch(embeddings)
index.createIndex({"post": 2}, print_progress=True)
return filenames, index
@st.cache(allow_output_mutation=True)
def load_model(model_name="koclip/koclip-base"):
assert model_name in {"koclip/koclip-base", "koclip/koclip-large"}
model = FlaxHybridCLIP.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
if model_name == "koclip/koclip-large":
processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
"google/vit-large-patch16-224"
)
return model, processor