Spaces:
Runtime error
Runtime error
import gzip | |
import json | |
import numpy as np | |
import streamlit as st | |
import torch | |
import tqdm | |
from sentence_transformers import SentenceTransformer | |
def load_model(model_name, model_dict): | |
assert model_name in model_dict.keys() | |
# Lazy downloading | |
model_ids = model_dict[model_name] | |
if type(model_ids) == str: | |
output = SentenceTransformer(model_ids) | |
elif hasattr(model_ids, '__iter__'): | |
output = [SentenceTransformer(name) for name in model_ids] | |
return output | |
def load_embeddings(): | |
# embedding pre-generated | |
corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-distilbert-emb.csv', max_rows=10000)) | |
return corpus_emb.float() | |
def filter_questions(tag, max_questions=10000): | |
posts = [] | |
max_posts = 6e6 | |
with gzip.open("./data/stackoverflow-titles.jsonl.gz", "rt") as fIn: | |
for line in tqdm.auto.tqdm(fIn, total=max_posts, desc="Load data"): | |
posts.append(json.loads(line)) | |
if len(posts) >= max_posts: | |
break | |
filtered_posts = [] | |
for post in posts: | |
if tag in post["tags"]: | |
filtered_posts.append(post) | |
if len(filtered_posts) >= max_questions: | |
break | |
return filtered_posts | |