|
import gzip |
|
import json |
|
import numpy as np |
|
|
|
import streamlit as st |
|
import torch |
|
import tqdm |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(model_name, model_dict): |
|
assert model_name in model_dict.keys() |
|
|
|
model_ids = model_dict[model_name] |
|
if type(model_ids) == str: |
|
output = SentenceTransformer(model_ids) |
|
elif hasattr(model_ids, '__iter__'): |
|
output = [SentenceTransformer(name) for name in model_ids] |
|
|
|
return output |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_embeddings(): |
|
|
|
corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-distilbert-emb.csv', max_rows=10000)) |
|
return corpus_emb.float() |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def filter_questions(tag, max_questions=10000): |
|
posts = [] |
|
max_posts = 6e6 |
|
with gzip.open("./data/stackoverflow-titles.jsonl.gz", "rt") as fIn: |
|
for line in tqdm.auto.tqdm(fIn, total=max_posts, desc="Load data"): |
|
posts.append(json.loads(line)) |
|
|
|
if len(posts) >= max_posts: |
|
break |
|
|
|
filtered_posts = [] |
|
for post in posts: |
|
if tag in post["tags"]: |
|
filtered_posts.append(post) |
|
if len(filtered_posts) >= max_questions: |
|
break |
|
return filtered_posts |
|
|