File size: 3,770 Bytes
0861ec7 cf9a0c1 fad1ac4 0861ec7 cf9a0c1 0861ec7 820aa6d 0861ec7 0eeaf9e 820aa6d 0eeaf9e 0861ec7 cf9a0c1 0861ec7 0eeaf9e cf9a0c1 6153fbc 820aa6d cf9a0c1 4121aea 820aa6d cf9a0c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import faiss
import numpy as np
from fastapi import FastAPI, Query, HTTPException
from fastapi.responses import JSONResponse
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple
app = FastAPI()
FIELDS = (
"full_name",
"description",
"default_branch",
"open_issues",
"stargazers_count",
"forks_count",
"watchers_count",
"license",
"size",
"fork",
"updated_at",
"has_build_zig",
"has_build_zig_zon",
"created_at",
)
print("Loading sentence transformer model (all-MiniLM-L6-v2)...")
model = SentenceTransformer("all-MiniLM-L6-v2")
print("Model loaded successfully.")
def load_and_index_dataset(name: str, include_readme: bool = False) -> Tuple[faiss.IndexFlatL2, List[Dict]]:
try:
print(f"Loading dataset '{name}'...")
dataset = load_dataset(name)["train"]
repo_texts = [
" ".join(str(x.get(field, "")) for field in FIELDS) +
(" " + x.get("readme_content", "") if include_readme else "") +
" " + " ".join(x.get("topics", []))
for x in dataset
]
if not include_readme:
dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
print(f"Creating embeddings for {len(repo_texts)} documents in '{name}'...")
repo_embeddings = model.encode(repo_texts, show_progress_bar=True)
embedding_dim = repo_embeddings.shape[1]
index = faiss.IndexFlatL2(embedding_dim)
index.add(np.array(repo_embeddings, dtype=np.float32))
print(f"'{name}' dataset indexed with {index.ntotal} vectors.")
return index, list(dataset)
except Exception as e:
print(f"Error loading dataset '{name}': {e}")
raise RuntimeError(f"Dataset loading/indexing failed: {name}")
indices: Dict[str, Tuple[faiss.IndexFlatL2, List[Dict]]] = {}
for key, readme_flag in {"packages": True, "programs": True}.items():
try:
index, data = load_and_index_dataset(f"zigistry/{key}", include_readme=readme_flag)
indices[key] = (index, data)
except Exception as e:
print(f"Failed to prepare index for {key}: {e}")
indices[key] = (None, [])
def perform_search(query: str, dataset_key: str, k: int) -> List[Dict]:
index, dataset = indices.get(dataset_key, (None, []))
if not index:
raise HTTPException(status_code=500, detail=f"Index not available for {dataset_key}")
try:
query_embedding = model.encode([query])
distances, idxs = index.search(np.array(query_embedding, dtype=np.float32), k)
results = []
for dist, idx in zip(distances[0], idxs[0]):
if idx == -1:
continue
item = dataset[int(idx)].copy()
item["relevance_score"] = float(1.0 - dist / 2.0)
results.append(item)
return results
except Exception as e:
print(f"Error during search: {e}")
raise HTTPException(status_code=500, detail="Search failed")
@app.get("/searchPackages/")
def search_packages(q: str = Query(...), k: int = Query(10)) -> JSONResponse:
if not q:
raise HTTPException(status_code=400, detail="Query parameter 'q' is required.")
results = perform_search(q, "packages", k)
return JSONResponse(content=results, headers={"Access-Control-Allow-Origin": "*"})
@app.get("/searchPrograms/")
def search_programs(q: str = Query(...), k: int = Query(10)) -> JSONResponse:
if not q:
raise HTTPException(status_code=400, detail="Query parameter 'q' is required.")
results = perform_search(q, "programs", k)
return JSONResponse(content=results, headers={"Access-Control-Allow-Origin": "*"})
|