koclip / utils.py
devtrent's picture
Text to image Search Engine demo
cf349fd
raw history blame
No virus
1.52 kB
import nmslib
import streamlit as st
from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor
import numpy as np
from koclip import FlaxHybridCLIP
@st.cache(allow_output_mutation=True)
def load_index(img_file):
filenames, embeddings = [], []
lines = open(img_file, "r")
for line in lines:
cols = line.strip().split('\t')
filename = cols[0]
embedding = np.array([float(x) for x in cols[1].split(',')])
filenames.append(filename)
embeddings.append(embedding)
embeddings = np.array(embeddings)
index = nmslib.init(method='hnsw', space='cosinesimil')
index.addDataPointBatch(embeddings)
index.createIndex({'post': 2}, print_progress=True)
return filenames, index
@st.cache(allow_output_mutation=True)
def load_model(model_name="koclip/koclip"):
assert model_name in {"koclip/koclip", "koclip/koclip-large"}
model = FlaxHybridCLIP.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
if model_name == "koclip/koclip-large":
processor.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-large-patch16-224")
return model, processor
@st.cache(allow_output_mutation=True)
def load_model_v2(model_name="koclip/koclip"):
model = FlaxHybridCLIP.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
return model, processor