Spaces:
Running
Running
| import os | |
| import yaml | |
| from openai import OpenAI | |
| from dotenv import load_dotenv | |
| from typing import List, Dict | |
| from pathlib import Path | |
| import requests | |
| from src.config import EMBEDDING_MODEL, LLM_MODEL, RERANKER_MODEL | |
| load_dotenv() | |
| _FILE_PATH = Path(__file__).parents[2] | |
| RERANK_URL = "https://api.fireworks.ai/inference/v1/rerank" | |
| INFERENCE_URL = "https://api.fireworks.ai/inference/v1" | |
| def load_prompt_library(): | |
| """Load prompts from YAML configuration.""" | |
| with open(_FILE_PATH / "configs" / "prompt_library.yaml", "r") as f: | |
| return yaml.safe_load(f) | |
| def create_client() -> OpenAI: | |
| """ | |
| Create client for FW inference | |
| """ | |
| api_key = os.getenv("FIREWORKS_API_KEY") | |
| assert api_key is not None, "FIREWORKS_API_KEY not found in environment variables" | |
| return OpenAI( | |
| api_key=api_key, | |
| base_url=INFERENCE_URL, | |
| ) | |
| CLIENT = create_client() | |
| PROMPT_LIBRARY = load_prompt_library() | |
| def get_embedding(text: str) -> List[float]: | |
| """ | |
| Get embedding for a given text using Fireworks AI embedding model. | |
| Args: | |
| text: Input text to embed | |
| Returns: | |
| List of float values representing the embedding vector | |
| """ | |
| response = CLIENT.embeddings.create(model=EMBEDDING_MODEL, input=text) | |
| return response.data[0].embedding | |
| def expand_query(query: str) -> str: | |
| """ | |
| Expand a search query using LLM with few-shot prompting. | |
| Takes a user's search query and expands it with relevant terms, synonyms, | |
| and related concepts to improve search recall and relevance. | |
| Args: | |
| query: Original search query | |
| Returns: | |
| Expanded query string with additional relevant terms | |
| """ | |
| system_prompt = PROMPT_LIBRARY["query_expansion"]["system_prompt"] | |
| response = CLIENT.chat.completions.create( | |
| model=LLM_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": query}, | |
| ], | |
| temperature=0.3, | |
| max_tokens=100, | |
| reasoning_effort="none", | |
| ) | |
| expanded = response.choices[0].message.content.strip() | |
| return expanded | |
| def rerank_results(query: str, results: List[Dict], top_n: int = 5) -> List[Dict]: | |
| """ | |
| Rerank search results using Fireworks AI reranker model. | |
| Takes search results and reranks them based on relevance to the query | |
| using a specialized reranking model that considers cross-attention between | |
| query and documents. | |
| Args: | |
| query: Original search query | |
| results: List of dictionaries containing product information and scores | |
| top_n: Number of top results to return after reranking (default: 5) | |
| Returns: | |
| List of dictionaries containing reranked product information with updated scores | |
| """ | |
| # Prepare documents as text for reranker (product name + description) | |
| documents = [f"{r['product_name']}. {r['description']}" for r in results] | |
| payload = { | |
| "model": RERANKER_MODEL, | |
| "query": query, | |
| "documents": documents, | |
| "top_n": top_n, | |
| "return_documents": False, | |
| } | |
| headers = { | |
| "Authorization": f"Bearer {os.getenv('FIREWORKS_API_KEY')}", | |
| "Content-Type": "application/json", | |
| } | |
| response = requests.post(RERANK_URL, json=payload, headers=headers) | |
| rerank_data = response.json() | |
| # Map reranked results back to original product data | |
| reranked_results = [] | |
| for item in rerank_data.get("data", []): | |
| idx = item["index"] | |
| reranked_results.append({**results[idx], "score": item["relevance_score"]}) | |
| return reranked_results | |