import gzip import json import numpy as np import streamlit as st import torch import tqdm from sentence_transformers import SentenceTransformer @st.cache(allow_output_mutation=True) def load_model(model_name, model_dict): assert model_name in model_dict.keys() # Lazy downloading model_ids = model_dict[model_name] if type(model_ids) == str: output = SentenceTransformer(model_ids) elif hasattr(model_ids, '__iter__'): output = [SentenceTransformer(name) for name in model_ids] return output @st.cache(allow_output_mutation=True) def load_embeddings(): # embedding pre-generated corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-distilbert-emb.csv', max_rows=10000)) return corpus_emb.float() @st.cache(allow_output_mutation=True) def filter_questions(tag, max_questions=10000): posts = [] max_posts = 6e6 with gzip.open("./data/stackoverflow-titles.jsonl.gz", "rt") as fIn: for line in tqdm.auto.tqdm(fIn, total=max_posts, desc="Load data"): posts.append(json.loads(line)) if len(posts) >= max_posts: break filtered_posts = [] for post in posts: if tag in post["tags"]: filtered_posts.append(post) if len(filtered_posts) >= max_questions: break return filtered_posts