Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| from typing import Any, List, Dict, Union | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import login | |
| import os | |
| # Get currently avilable device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # SimilarityModel Config's | |
| class Config: | |
| """Configuration settings for the application.""" | |
| EMBEDDING_MODEL_ID = "google/embeddinggemma-300M" | |
| QUERY_PROMPT_NAME = "query" | |
| TOOL_PROMPT_NAME = "document" | |
| TOP_K = 3 | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| DEVICE = device | |
| # Encapsulated Similarity Model | |
| class SimilarityModel: | |
| """ | |
| A class for finding similar tools for given query using Sentence Transformer embeddings. | |
| """ | |
| def __init__(self, config: Config): | |
| self.config = config | |
| self._login_to_hf() | |
| self.model = self._load_model() | |
| self.tool_embeddings_cache = {} | |
| def _login_to_hf(self): | |
| """Logs into Hugging Face Hub if a token is provided.""" | |
| if self.config.HF_TOKEN: | |
| print("Logging into Hugging Face Hub...") | |
| login(token=self.config.HF_TOKEN) | |
| else: | |
| print("HF_TOKEN not found. Proceeding without login.") | |
| print("Note: This may fail if the model is gated.") | |
| def _load_model(self) -> SentenceTransformer: | |
| """Loads the Sentence Transformer model.""" | |
| print(f"Initializing embedding model: {self.config.EMBEDDING_MODEL_ID}...") | |
| try: | |
| return SentenceTransformer(self.config.EMBEDDING_MODEL_ID).to(self.config.DEVICE) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| def _validate_query_tools(self, query: Union[str, Any], tools_list: Union[List[Dict], Any]) -> Union[str, List[Dict]]: | |
| """ | |
| Validates the query and tools data to ensure formats. | |
| Args: | |
| query: The user query string. | |
| tools_list: JSON instance, list of dict where each dict represents a tool declaration. | |
| Returns: | |
| True If the query and tools data are valid, then returns tools_data as converted from JSON to list of dict. | |
| False string saying invalid query or tools data. | |
| """ | |
| is_valid_query = isinstance(query, str) and len(query.strip()) > 0 | |
| if not is_valid_query: | |
| return "Invalid query. It should be a non-empty string." | |
| # If tools_list are already in format of list of dict. | |
| is_already_valid_tools = isinstance(tools_list, list) and all(isinstance(d, dict) for d in tools_list) | |
| if is_already_valid_tools: | |
| return tools_list | |
| # If tools_list is string but it's list of dict, then json loads will parse | |
| try: | |
| tools_data = json.loads(tools_list) | |
| except json.JSONDecodeError: | |
| return "Invalid JSON format for tools data." | |
| is_valid_tools = isinstance(tools_data, list) and all(isinstance(d, dict) for d in tools_data) | |
| if not is_valid_tools: | |
| return "Invalid tools data. It should be a list of dictionaries." | |
| return tools_data | |
| def cache_tool_embeddings(self, tools_data: List[Dict], tools_cache_key: str, cache_tool: float = True)-> torch.Tensor: | |
| """ | |
| If already tools embeddings are cached returns. If not cached computes tools embeddings and caches. | |
| Args: | |
| tools_data: List of JSON like format, where each dict represents a tool declaration. | |
| tools_cache_key: Unique key for caching based on the tools data. | |
| cache_tool: Whether to cache the tools embeddings or not. | |
| """ | |
| if tools_cache_key in self.tool_embeddings_cache: | |
| tool_description_embeddings = self.tool_embeddings_cache[tools_cache_key] | |
| else: | |
| tool_descriptions = [tool["description"] for tool in tools_data] | |
| tool_description_embeddings = self.model.encode(tool_descriptions, normalize_embeddings=True, prompt_name= self.config.TOOL_PROMPT_NAME) | |
| if cache_tool: | |
| self.tool_embeddings_cache[tools_cache_key] = tool_description_embeddings | |
| return tool_description_embeddings | |
| def find_similar_tools(self, query: str, tools_list: list[dict], top_k: int, cache_tool_embs: bool= True)-> list[dict]: | |
| """ | |
| Finds the top_k most similar tools to a given query using Sentence Transformer embeddings. | |
| Args: | |
| query: The user query string. | |
| tools_list: JSON instance, list of dict where each dict represents a tool declaration. | |
| top_k: The number of top similar tools to return. | |
| cache_tool_embs: What to cache tools embs? Default is True. | |
| Returns: | |
| A string containing the names and descriptions of the top_k similar tools, formatted for clarity. | |
| """ | |
| # Validate: query and tools_list | |
| tools_data = self._validate_query_tools(query, tools_list) | |
| try: | |
| assert isinstance(tools_data, list) and all(isinstance(d, dict) for d in tools_data) | |
| except AssertionError: | |
| return tools_data, json.dumps([{"Error": tools_data}]) | |
| # Create a unique key for caching based on the tools data | |
| tools_cache_key = json.dumps(tools_data, sort_keys=True) | |
| # Compute tools embedding or get cached embeddings | |
| tool_description_embeddings = self.cache_tool_embeddings(tools_data, tools_cache_key, cache_tool = cache_tool_embs) | |
| # Everytime computing query embeddings, query is from user is always user's stochastic | |
| query_embedding = self.model.encode(query, normalize_embeddings=True, prompt_name= self.config.QUERY_PROMPT_NAME) | |
| # Similarity scores B/W user query and tools embeddings | |
| similarity_scores = self.model.similarity(query_embedding, tool_description_embeddings).cpu() | |
| # Ensure top_k does not exceed the number of available tools | |
| actual_top_k = min(top_k or self.config.TOP_K, len(tools_data)) | |
| top_tool_indices = similarity_scores.argsort().flatten()[-actual_top_k:] | |
| # Reverse the indices to get the most similar first | |
| top_tool_indices = top_tool_indices.tolist()[::-1] | |
| top_tools = [tools_data[int(i)] for i in top_tool_indices] | |
| # Format the output for the Gradio Textbox | |
| output_text = f"Top {actual_top_k} most similar tools:\n\n" | |
| for i, tool in enumerate(top_tools): | |
| output_text += f"{i+1}. Name: {tool['name']}\n" | |
| output_text += f" Description: {tool['description']}\n" | |
| if i < len(top_tools) - 1: | |
| output_text += "---\n" # Add a separator between tools | |
| if not top_tools: | |
| output_text = "No tools found." | |
| return output_text, json.dumps(top_tools) | |
| def create_ui(model: SimilarityModel): | |
| """Pretty UI with Gradio for user to interact with""" | |
| with gr.Blocks() as demo: | |
| gr.Interface( | |
| fn = model.find_similar_tools, | |
| inputs=[ | |
| gr.Textbox(label="Query"), | |
| gr.Textbox( | |
| lines=6, | |
| label="Define tool declaration here", | |
| info="Please enter a valid JSON string. For e.g, a list of dict's (name & desc π).", | |
| placeholder='''[ | |
| { | |
| "name": "get_current_weather", | |
| "description": "Get the current weather in a given location" | |
| } | |
| ]'''), | |
| gr.Number(label="Top K", value=3, precision=0), | |
| gr.Checkbox(label="Cache Tool Embeddings", value=True) | |
| ], | |
| outputs=[ | |
| gr.TextArea(label="Similar Tools (Name and Description)", lines = 5), | |
| gr.JSON(label= "Similar Tools JSON-format") | |
| ], | |
| title="Tool Similarity Finder using Embedding Gemma 300M", | |
| description="Enter a query and a list of tools to find the most similar tools based on embeddings." | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| similarity_model = SimilarityModel(config = Config()) | |
| demo = create_ui(similarity_model) | |
| demo.launch( | |
| mcp_server= True | |
| ) |