|
import os |
|
import base64 |
|
import requests |
|
import numpy as np |
|
import faiss |
|
import re |
|
import logging |
|
from pathlib import Path |
|
|
|
|
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
from langchain_groq import ChatGroq |
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
try: |
|
from rank_bm25 import BM25Okapi |
|
except ImportError: |
|
BM25Okapi = None |
|
|
|
|
|
|
|
|
|
|
|
GITHUB_API_KEY = os.getenv("GITHUB_API_KEY") |
|
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
|
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") |
|
|
|
CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
|
|
|
session = requests.Session() |
|
session.headers.update({ |
|
"Authorization": f"token {GITHUB_API_KEY}", |
|
"Accept": "application/vnd.github.v3+json" |
|
}) |
|
|
|
|
|
|
|
|
|
llm = ChatGroq( |
|
model="deepseek-r1-distill-llama-70b", |
|
temperature=0.3, |
|
max_tokens=512, |
|
max_retries=3, |
|
api_key=GROQ_API_KEY |
|
) |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", |
|
"""You are a GitHub search optimization expert. |
|
|
|
Your job is to: |
|
1. Read a user's query about tools, research, or tasks. |
|
2. Detect if the query mentions a specific programming language other than Python (for example, JavaScript or JS). If so, record that language as the target language. |
|
3. Think iteratively and generate your internal chain-of-thought enclosed in <think> ... </think> tags. |
|
4. After your internal reasoning, output up to five GitHub-style search tags or library names that maximize repository discovery. |
|
Use as many tags as necessary based on the query's complexity, but never more than five. |
|
5. If you detected a non-Python target language, append an additional tag at the end in the format target-[language] (e.g., target-javascript). |
|
If no specific language is mentioned, do not include any target tag. |
|
|
|
Output Format: |
|
tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]] |
|
|
|
Rules: |
|
- Use lowercase and hyphenated keywords (e.g., image-augmentation, chain-of-thought). |
|
- Use terms commonly found in GitHub repo names, topics, or descriptions. |
|
- Avoid generic terms like "python", "ai", "tool", "project". |
|
- Do NOT use full phrases or vague words like "no-code", "framework", or "approach". |
|
- Prefer real tools, popular methods, or dataset names when mentioned. |
|
- If your output does not strictly match the required format, correct it after your internal reasoning. |
|
- Choose high-signal keywords to ensure the search yields the most relevant GitHub repositories. |
|
|
|
Excellent Examples: |
|
|
|
Input: "No code tool to augment image and annotation" |
|
Output: image-augmentation:albumentations |
|
|
|
Input: "Repos around chain of thought prompting mainly for finetuned models" |
|
Output: chain-of-thought:finetuned-llm |
|
|
|
Input: "Find repositories implementing data augmentation pipelines in JavaScript" |
|
Output: data-augmentation:target-javascript |
|
|
|
Output must be ONLY the search tags separated by colons. Do not include any extra text, bullet points, or explanations. |
|
"""), |
|
("human", "{query}") |
|
]) |
|
chain = prompt | llm |
|
|
|
def valid_tags(tags: str) -> bool: |
|
pattern = r'^[a-z0-9-]+(?::[a-z0-9-]+){1,5}$' |
|
return re.match(pattern, tags) is not None |
|
|
|
def parse_search_tags(response: str) -> str: |
|
|
|
cleaned = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL) |
|
pattern = r'([a-z0-9-]+(?::[a-z0-9-]+){1,5})' |
|
match = re.search(pattern, cleaned) |
|
if match: |
|
return match.group(1).strip() |
|
return cleaned.strip() |
|
|
|
def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str: |
|
print(f"\n [iterative_convert_to_search_tags] Input Query: {query}") |
|
refined_query = query |
|
tags_output = "" |
|
for iteration in range(max_iterations): |
|
print(f"\n Iteration {iteration+1}") |
|
response = chain.invoke({"query": refined_query}) |
|
full_output = response.content.strip() |
|
tags_output = parse_search_tags(full_output) |
|
print(f"Output Tags: {tags_output}") |
|
if valid_tags(tags_output): |
|
print("Valid tags format detected.") |
|
return tags_output |
|
else: |
|
print(" Invalid tags format. Requesting refinement...") |
|
refined_query = f"{query}\nPlease refine your answer so that the output strictly matches the format: tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]." |
|
print("Final output (may be invalid):", tags_output) |
|
return tags_output |
|
|
|
|
|
|
|
|
|
def fetch_readme_content(repo_full_name: str) -> str: |
|
readme_url = f"https://api.github.com/repos/{repo_full_name}/readme" |
|
response = session.get(readme_url) |
|
if response.status_code == 200: |
|
readme_data = response.json() |
|
try: |
|
return base64.b64decode(readme_data.get('content', '')).decode('utf-8', errors='replace') |
|
except Exception: |
|
return "" |
|
return "" |
|
|
|
def fetch_markdown_contents(repo_full_name: str) -> str: |
|
url = f"https://api.github.com/repos/{repo_full_name}/contents" |
|
response = session.get(url) |
|
contents = "" |
|
if response.status_code == 200: |
|
items = response.json() |
|
for item in items: |
|
if item.get("type") == "file" and item.get("name", "").lower().endswith(".md"): |
|
file_url = item.get("download_url") |
|
if file_url: |
|
file_resp = requests.get(file_url) |
|
if file_resp.status_code == 200: |
|
contents += "\n" + file_resp.text |
|
return contents |
|
|
|
def fetch_all_markdown(repo_full_name: str) -> str: |
|
readme = fetch_readme_content(repo_full_name) |
|
other_md = fetch_markdown_contents(repo_full_name) |
|
return readme + "\n" + other_md |
|
|
|
def fetch_github_repositories(query: str, max_results: int = 10) -> list: |
|
url = "https://api.github.com/search/repositories" |
|
params = { |
|
"q": query, |
|
"per_page": max_results |
|
} |
|
response = session.get(url, params=params) |
|
if response.status_code != 200: |
|
print(f"Error {response.status_code}: {response.json().get('message')}") |
|
return [] |
|
repo_list = [] |
|
for repo in response.json().get('items', []): |
|
repo_link = repo.get('html_url') |
|
description = repo.get('description') or "" |
|
combined_markdown = fetch_all_markdown(repo.get('full_name')) |
|
combined_text = (description + "\n" + combined_markdown).strip() |
|
repo_list.append({ |
|
"title": repo.get('name', 'No title available'), |
|
"link": repo_link, |
|
"combined_text": combined_text |
|
}) |
|
return repo_list |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
model = SentenceTransformer('all-mpnet-base-v2', device='cpu') |
|
except Exception as e: |
|
print("Error initializing GPU for SentenceTransformer; falling back to CPU:", e) |
|
model = SentenceTransformer('all-mpnet-base-v2', device='cpu') |
|
|
|
def robust_min_max_norm(scores: np.ndarray) -> np.ndarray: |
|
min_val = scores.min() |
|
max_val = scores.max() |
|
if max_val - min_val < 1e-10: |
|
return np.ones_like(scores) |
|
return (scores - min_val) / (max_val - min_val) |
|
|
|
|
|
|
|
|
|
def cross_encoder_rerank_candidates(candidates: list, query: str, model_name: str, top_n: int = 10) -> list: |
|
try: |
|
cross_encoder = CrossEncoder(model_name, device='cpu') |
|
except Exception as e: |
|
print("Error initializing CrossEncoder on GPU; falling back to CPU:", e) |
|
cross_encoder = CrossEncoder(model_name, device='cpu') |
|
|
|
CHUNK_SIZE = 2000 |
|
MAX_DOC_LENGTH = 5000 |
|
MIN_DOC_LENGTH = 200 |
|
|
|
def split_text(text: str, chunk_size: int = CHUNK_SIZE) -> list: |
|
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] |
|
|
|
for candidate in candidates: |
|
doc = candidate.get("combined_text", "") |
|
if len(doc) > MAX_DOC_LENGTH: |
|
doc = doc[:MAX_DOC_LENGTH] |
|
try: |
|
if len(doc) < MIN_DOC_LENGTH: |
|
score = cross_encoder.predict([[query, doc]]) |
|
if hasattr(score, '__len__') and len(score) == 1: |
|
candidate["cross_encoder_score"] = float(score[0]) |
|
else: |
|
candidate["cross_encoder_score"] = float(score) |
|
else: |
|
chunks = split_text(doc) |
|
pairs = [[query, chunk] for chunk in chunks] |
|
scores = cross_encoder.predict(pairs) |
|
scores = np.array(scores) |
|
max_score = float(np.max(scores)) if scores.size > 0 else 0.0 |
|
avg_score = float(np.mean(scores)) if scores.size > 0 else 0.0 |
|
candidate["cross_encoder_score"] = 0.5 * max_score + 0.5 * avg_score |
|
except Exception as e: |
|
logging.debug(f"[cross-encoder] Error scoring candidate {candidate.get('link', 'unknown')}: {e}") |
|
candidate["cross_encoder_score"] = 0.0 |
|
|
|
all_scores = [candidate["cross_encoder_score"] for candidate in candidates] |
|
if all_scores: |
|
min_score = min(all_scores) |
|
if min_score < 0: |
|
for candidate in candidates: |
|
candidate["cross_encoder_score"] += -min_score |
|
|
|
return candidates |
|
|
|
|
|
|
|
|
|
def run_repository_ranking(query: str, num_results: int = 10) -> str: |
|
logging.info("[DeepGit] Step 1: Generate search tags from the query.") |
|
search_tags = iterative_convert_to_search_tags(query) |
|
tag_list = [tag.strip() for tag in search_tags.split(":") if tag.strip()] |
|
|
|
|
|
logging.info("[DeepGit] Step 2: Handle target language extraction.") |
|
if any(tag.startswith("target-") for tag in tag_list): |
|
target_tag = next(tag for tag in tag_list if tag.startswith("target-")) |
|
lang_query = f"language:{target_tag.replace('target-', '')}" |
|
tag_list = [tag for tag in tag_list if not tag.startswith("target-")] |
|
else: |
|
lang_query = "language:python" |
|
|
|
|
|
logging.info("[DeepGit] Step 3: Build advanced search qualifiers and fetch repositories.") |
|
advanced_qualifier = "in:name,description,readme" |
|
all_repositories = [] |
|
|
|
for tag in tag_list: |
|
github_query = f"{tag} {advanced_qualifier} {lang_query}" |
|
logging.info(f"[DeepGit] GitHub Query: {github_query}") |
|
repos = fetch_github_repositories(github_query, max_results=15) |
|
all_repositories.extend(repos) |
|
|
|
combined_query = " OR ".join(tag_list) |
|
combined_query = f"({combined_query}) {advanced_qualifier} {lang_query}" |
|
logging.info(f"[DeepGit] Combined GitHub Query: {combined_query}") |
|
repos = fetch_github_repositories(combined_query, max_results=15) |
|
all_repositories.extend(repos) |
|
|
|
unique_repositories = {} |
|
for repo in all_repositories: |
|
if repo["link"] not in unique_repositories: |
|
unique_repositories[repo["link"]] = repo |
|
else: |
|
existing_text = unique_repositories[repo["link"]]["combined_text"] |
|
unique_repositories[repo["link"]]["combined_text"] = existing_text + "\n" + repo["combined_text"] |
|
repositories = list(unique_repositories.values()) |
|
|
|
if not repositories: |
|
return "No repositories found for your query." |
|
|
|
|
|
logging.info("[DeepGit] Step 4: Prepare documents for dense retrieval.") |
|
docs = [repo.get("combined_text", "") for repo in repositories] |
|
|
|
|
|
logging.info("[DeepGit] Step 5: Compute dense embeddings and scores.") |
|
doc_embeddings = model.encode(docs, convert_to_numpy=True, show_progress_bar=True, batch_size=16) |
|
if doc_embeddings.ndim == 1: |
|
doc_embeddings = doc_embeddings.reshape(1, -1) |
|
norms = np.linalg.norm(doc_embeddings, axis=1, keepdims=True) |
|
norm_doc_embeddings = doc_embeddings / (norms + 1e-10) |
|
|
|
query_embedding = model.encode(query, convert_to_numpy=True) |
|
if query_embedding.ndim == 1: |
|
query_embedding = query_embedding.reshape(1, -1) |
|
norm_query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-10) |
|
|
|
dim = norm_doc_embeddings.shape[1] |
|
index = faiss.IndexFlatIP(dim) |
|
index.add(norm_doc_embeddings) |
|
k = norm_doc_embeddings.shape[0] |
|
D, I = index.search(norm_query_embedding, k) |
|
dense_scores = D.squeeze() |
|
norm_dense_scores = robust_min_max_norm(dense_scores) |
|
|
|
|
|
logging.info("[DeepGit] Step 6: Compute BM25 scores.") |
|
if BM25Okapi is not None: |
|
tokenized_docs = [re.findall(r'\w+', doc.lower()) for doc in docs] |
|
bm25 = BM25Okapi(tokenized_docs) |
|
query_tokens = re.findall(r'\w+', query.lower()) |
|
bm25_scores = np.array(bm25.get_scores(query_tokens)) |
|
norm_bm25_scores = robust_min_max_norm(bm25_scores) |
|
else: |
|
norm_bm25_scores = np.zeros_like(norm_dense_scores) |
|
|
|
|
|
logging.info("[DeepGit] Step 7: Combine dense and BM25 scores.") |
|
alpha = 0.8 |
|
combined_scores = alpha * norm_dense_scores + (1 - alpha) * norm_bm25_scores |
|
for idx, repo in enumerate(repositories): |
|
repo["combined_score"] = float(combined_scores[idx]) |
|
|
|
|
|
logging.info("[DeepGit] Step 8: Initial ranking by combined score.") |
|
ranked_repositories = sorted(repositories, key=lambda x: x.get("combined_score", 0), reverse=True) |
|
|
|
|
|
logging.info("[DeepGit] Step 9: Cross-encoder re-ranking.") |
|
top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories |
|
cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=len(top_candidates)) |
|
|
|
|
|
logging.info("[DeepGit] Step 10: Final scoring and output formatting.") |
|
w1 = 0.7 |
|
w2 = 0.3 |
|
for candidate in top_candidates: |
|
candidate["final_score"] = w1 * candidate.get("combined_score", 0) + w2 * candidate.get("cross_encoder_score", 0) |
|
|
|
final_ranked = sorted(top_candidates, key=lambda x: x.get("final_score", 0), reverse=True)[:num_results] |
|
|
|
|
|
output = "\n=== Ranked Repositories ===\n" |
|
for rank, repo in enumerate(final_ranked, 1): |
|
output += f"Final Rank: {rank}\n" |
|
output += f"Title: {repo['title']}\n" |
|
output += f"Link: {repo['link']}\n" |
|
output += f"Combined Score: {repo.get('combined_score', 0) * 100:.2f}%\n" |
|
output += f"Cross-Encoder Score: {repo.get('cross_encoder_score', 0) * 100:.2f}%\n" |
|
output += f"Final Score: {repo.get('final_score', 0) * 100:.2f}%\n" |
|
snippet = repo['combined_text'][:300].replace('\n', ' ') |
|
output += f"Snippet: {snippet}...\n" |
|
output += '-' * 80 + "\n" |
|
output += "\n=== End of Results ===" |
|
return output |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
test_query = "Chain of thought prompting for reasoning models" |
|
result = run_repository_ranking(test_query) |
|
print(result) |
|
|