File size: 2,970 Bytes
0861ec7
 
0eeaf9e
fad1ac4
0861ec7
 
0eeaf9e
0861ec7
 
 
820aa6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0861ec7
0eeaf9e
820aa6d
0eeaf9e
0861ec7
0eeaf9e
 
820aa6d
0eeaf9e
820aa6d
 
 
 
 
 
0eeaf9e
820aa6d
 
0eeaf9e
 
 
 
 
 
 
 
 
 
 
 
0861ec7
820aa6d
0861ec7
0eeaf9e
 
 
0861ec7
0eeaf9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
820aa6d
6153fbc
820aa6d
0eeaf9e
 
820aa6d
 
4121aea
820aa6d
0eeaf9e
 
fad1ac4
0eeaf9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import faiss
import numpy as np
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from typing import List

app = FastAPI()

FIELDS = (
    "full_name",
    "description",
    "default_branch",
    "open_issues",
    "stargazers_count",
    "forks_count",
    "watchers_count",
    "license",
    "size",
    "fork",
    "updated_at",
    "has_build_zig",
    "has_build_zig_zon",
    "created_at",
)

print("Loading sentence transformer model (all-MiniLM-L6-v2)...")
model = SentenceTransformer("all-MiniLM-L6-v2")
print("Model loaded successfully.")

def load_and_index_dataset(name: str, include_readme: bool = False):
    print(f"Loading dataset '{name}'...")
    dataset = load_dataset(name)["train"]
    
    repo_texts = [
        " ".join(str(x.get(field, "")) for field in FIELDS) +
        (" " + x.get("readme_content", "") if include_readme else "") +
        " " + " ".join(x.get("topics", []))
        for x in dataset
    ]
    
    if not include_readme:
        dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
        
    print(f"Creating embeddings for {len(repo_texts)} documents in '{name}'...")
    repo_embeddings = model.encode(repo_texts, show_progress_bar=True)
    
    print(f"Building FAISS index for '{name}'...")
    embedding_dim = repo_embeddings.shape[1]
    
    index = faiss.IndexFlatL2(embedding_dim)
    index.add(np.array(repo_embeddings, dtype=np.float32))
    
    print(f"'{name}' dataset indexed with {index.ntotal} vectors.")
    return index, list(dataset)

indices = {}

for key, readme_flag in {"packages": True, "programs": True}.items():
    index, data = load_and_index_dataset(f"zigistry/{key}", include_readme=readme_flag)
    indices[key] = (index, data)

def perform_search(query: str, dataset_key: str, k: int):
    index, dataset = indices[dataset_key]
    
    query_embedding = model.encode([query])
    query_embedding = np.array(query_embedding, dtype=np.float32)

    distances, idxs = index.search(query_embedding, k)

    results = []
    for dist, idx in zip(distances[0], idxs[0]):
        if idx == -1:
            continue
            
        item = dataset[int(idx)].copy()
        item['relevance_score'] = 1.0 - (dist / 2.0)
        results.append(item)
        
    return results

@app.get("/searchPackages/")
def search_packages(q: str, k: int = 10):
    results = perform_search(query=q, dataset_key="packages", k=k)
    headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
    return JSONResponse(content=results, headers=headers)

@app.get("/searchPrograms/")
def search_programs(q: str, k: int = 10):
    results = perform_search(query=q, dataset_key="programs", k=k)
    headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
    return JSONResponse(content=results, headers=headers)