RobertoBarrosoLuque
Add reranking
75361de
import os
import yaml
from openai import OpenAI
from dotenv import load_dotenv
from typing import List, Dict
from pathlib import Path
import requests
from src.config import EMBEDDING_MODEL, LLM_MODEL, RERANKER_MODEL
load_dotenv()
_FILE_PATH = Path(__file__).parents[2]
RERANK_URL = "https://api.fireworks.ai/inference/v1/rerank"
INFERENCE_URL = "https://api.fireworks.ai/inference/v1"
def load_prompt_library():
"""Load prompts from YAML configuration."""
with open(_FILE_PATH / "configs" / "prompt_library.yaml", "r") as f:
return yaml.safe_load(f)
def create_client() -> OpenAI:
"""
Create client for FW inference
"""
api_key = os.getenv("FIREWORKS_API_KEY")
assert api_key is not None, "FIREWORKS_API_KEY not found in environment variables"
return OpenAI(
api_key=api_key,
base_url=INFERENCE_URL,
)
CLIENT = create_client()
PROMPT_LIBRARY = load_prompt_library()
def get_embedding(text: str) -> List[float]:
"""
Get embedding for a given text using Fireworks AI embedding model.
Args:
text: Input text to embed
Returns:
List of float values representing the embedding vector
"""
response = CLIENT.embeddings.create(model=EMBEDDING_MODEL, input=text)
return response.data[0].embedding
def expand_query(query: str) -> str:
"""
Expand a search query using LLM with few-shot prompting.
Takes a user's search query and expands it with relevant terms, synonyms,
and related concepts to improve search recall and relevance.
Args:
query: Original search query
Returns:
Expanded query string with additional relevant terms
"""
system_prompt = PROMPT_LIBRARY["query_expansion"]["system_prompt"]
response = CLIENT.chat.completions.create(
model=LLM_MODEL,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
],
temperature=0.3,
max_tokens=100,
reasoning_effort="none",
)
expanded = response.choices[0].message.content.strip()
return expanded
def rerank_results(query: str, results: List[Dict], top_n: int = 5) -> List[Dict]:
"""
Rerank search results using Fireworks AI reranker model.
Takes search results and reranks them based on relevance to the query
using a specialized reranking model that considers cross-attention between
query and documents.
Args:
query: Original search query
results: List of dictionaries containing product information and scores
top_n: Number of top results to return after reranking (default: 5)
Returns:
List of dictionaries containing reranked product information with updated scores
"""
# Prepare documents as text for reranker (product name + description)
documents = [f"{r['product_name']}. {r['description']}" for r in results]
payload = {
"model": RERANKER_MODEL,
"query": query,
"documents": documents,
"top_n": top_n,
"return_documents": False,
}
headers = {
"Authorization": f"Bearer {os.getenv('FIREWORKS_API_KEY')}",
"Content-Type": "application/json",
}
response = requests.post(RERANK_URL, json=payload, headers=headers)
rerank_data = response.json()
# Map reranked results back to original product data
reranked_results = []
for item in rerank_data.get("data", []):
idx = item["index"]
reranked_results.append({**results[idx], "score": item["relevance_score"]})
return reranked_results