Spaces:
Running
Running
File size: 3,737 Bytes
5ac1ec2 0731ede dfa643b 5ac1ec2 dfa643b a5b3518 5ac1ec2 8c0edc5 5ac1ec2 0731ede 5ac1ec2 e06da36 5ac1ec2 e06da36 5ac1ec2 dfa643b 5ac1ec2 e06da36 5ac1ec2 dfa643b e06da36 dfa643b e06da36 dfa643b 5ac1ec2 dfa643b e06da36 dfa643b 5ac1ec2 dfa643b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | # Path: src/nl2sql/hf_engine.py
# This module defines the HuggingFace-based engine for generating SQL queries from natural language questions.
import os
from huggingface_hub import InferenceClient
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.language_models.llms import LLM
from typing import Any, List, Optional
# Model Registry: Add several model to be tested
MODEL_REGISTRY = {
"defog/sqlcoder-7b-2": "text",
"Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai": "chat",
"Qwen/Qwen2.5-Coder-32B-Instruct:featherless-ai": "chat",
"defog/llama-3-sqlcoder-8b:featherless-ai": "chat"
#"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B:featherless-ai": "chat"
}
DEFAULT_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai"
# Custom LangChain wrapper for HuggingFace Inference API
class HFChatWrapper(LLM):
"""
Custom LLM wrapper for HuggingFace Inference API to maintain compatibility with LangChain's LLM interface.
"""
client: Any
model_id: str
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
completion = self.client.chat.completions.create(
model = self.model_id,
messages = [
{"role": "user", "content": prompt}
],
temperature = 0.0,
max_tokens = 512
)
return completion.choices[0].message.content
@property
def _llm_type(self) -> str:
return "huggingface_inference_client"
def get_models() -> List[str]:
"""Utility to return all model IDs available in the MODEL_REGISTRY."""
return list(MODEL_REGISTRY.keys())
# Initialize the HuggingFace endpoint using the InferenceClient
def get_llm(model_id: str = DEFAULT_MODEL_ID):
"""
Automatically detects the model type and returns the correct LangChain interface.
Initializes the HuggingFace InferenceClient and returns an LLM instance for generating SQL queries.
"""
# Load HuggingFace API token from environment variable
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HuggingFace API token not found!")
# Determine the model type based on the MODEL_REGISTRY
active_model = model_id if model_id else DEFAULT_MODEL_ID
if active_model not in MODEL_REGISTRY:
print(f"Warning: Model '{active_model}' not found in MODEL_REGISTRY. Defaulting to 'chat' type.")
model_type = MODEL_REGISTRY.get(active_model, "chat")
print(f"Initializing HuggingFace InferenceClient with model: {active_model}")
if model_type == "chat":
client = InferenceClient(api_key=hf_token)
return HFChatWrapper(client=client, model_id=active_model)
elif model_type == "text":
# Route to standard Text Generation API
return HuggingFaceEndpoint(
repo_id=active_model,
task="text-generation",
max_new_tokens=512,
temperature=0.0,
huggingfacehub_api_token=hf_token,
do_sample=False,
return_full_text=False
)
else:
raise ValueError(f"Unknown model type: {model_type}")
# Initialize the HuggingFace InferenceClient
#client = InferenceClient(api_key=hf_token)
#llm = HFChatWrapper(client=client, model_id=active_model)
#return llm
if __name__=="__main__":
from dotenv import load_dotenv
load_dotenv()
try:
test_llm = get_llm()
print("Model loaded successfully! Running a quick ping...")
response = test_llm.invoke("write a single SQL statement to count all rows in a table name 'Employee'.")
print(f"\nResponse:\n{response}")
except Exception as e:
print(f"Error during LLM initialization: {e}") |