File size: 3,737 Bytes
5ac1ec2
 
0731ede
 
dfa643b
5ac1ec2
 
 
dfa643b
 
 
 
 
 
 
 
 
a5b3518
5ac1ec2
 
 
 
 
 
 
 
 
8c0edc5
5ac1ec2
 
 
 
 
 
 
0731ede
5ac1ec2
 
 
 
 
e06da36
 
 
 
 
5ac1ec2
e06da36
5ac1ec2
dfa643b
5ac1ec2
 
 
 
 
 
 
e06da36
 
 
 
 
 
 
 
5ac1ec2
dfa643b
 
e06da36
dfa643b
 
 
e06da36
dfa643b
 
 
 
 
 
 
 
 
 
5ac1ec2
dfa643b
e06da36
dfa643b
 
 
 
 
 
5ac1ec2
dfa643b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Path: src/nl2sql/hf_engine.py
# This module defines the HuggingFace-based engine for generating SQL queries from natural language questions.
import os
from huggingface_hub import InferenceClient
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.language_models.llms import LLM
from typing import Any, List, Optional

# Model Registry: Add several model to be tested
MODEL_REGISTRY = {
    "defog/sqlcoder-7b-2": "text",
    "Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai": "chat",
    "Qwen/Qwen2.5-Coder-32B-Instruct:featherless-ai": "chat",
    "defog/llama-3-sqlcoder-8b:featherless-ai": "chat"
    #"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B:featherless-ai": "chat"
}

DEFAULT_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai"

# Custom LangChain wrapper for HuggingFace Inference API
class HFChatWrapper(LLM):
    """
    Custom LLM wrapper for HuggingFace Inference API to maintain compatibility with LangChain's LLM interface.
    """
    client: Any
    model_id: str

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        completion = self.client.chat.completions.create(
            model = self.model_id,
            messages = [
                {"role": "user", "content": prompt}
            ],
            temperature = 0.0,
            max_tokens = 512
        )
        return completion.choices[0].message.content
    
    @property
    def _llm_type(self) -> str:
        return "huggingface_inference_client"

def get_models() -> List[str]:
    """Utility to return all model IDs available in the MODEL_REGISTRY."""
    return list(MODEL_REGISTRY.keys())

# Initialize the HuggingFace endpoint using the InferenceClient
def get_llm(model_id: str = DEFAULT_MODEL_ID):
    """
    Automatically detects the model type and returns the correct LangChain interface.
    Initializes the HuggingFace InferenceClient and returns an LLM instance for generating SQL queries.
    """
    # Load HuggingFace API token from environment variable
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HuggingFace API token not found!")
    
    # Determine the model type based on the MODEL_REGISTRY
    active_model = model_id if model_id else DEFAULT_MODEL_ID

    if active_model not in MODEL_REGISTRY:
        print(f"Warning: Model '{active_model}' not found in MODEL_REGISTRY. Defaulting to 'chat' type.")

    model_type = MODEL_REGISTRY.get(active_model, "chat")
    print(f"Initializing HuggingFace InferenceClient with model: {active_model}")

    if model_type == "chat":
        client = InferenceClient(api_key=hf_token)
        return HFChatWrapper(client=client, model_id=active_model)
    elif model_type == "text":
        # Route to standard Text Generation API
        return HuggingFaceEndpoint(
            repo_id=active_model,
            task="text-generation",
            max_new_tokens=512,
            temperature=0.0,
            huggingfacehub_api_token=hf_token,
            do_sample=False,
            return_full_text=False
        )
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    # Initialize the HuggingFace InferenceClient
    #client = InferenceClient(api_key=hf_token)
    #llm = HFChatWrapper(client=client, model_id=active_model)

    #return llm

if __name__=="__main__":
    from dotenv import load_dotenv
    load_dotenv()

    try:
        test_llm = get_llm()
        print("Model loaded successfully! Running a quick ping...")
        response = test_llm.invoke("write a single SQL statement to count all rows in a table name 'Employee'.")
        print(f"\nResponse:\n{response}")
    except Exception as e:
        print(f"Error during LLM initialization: {e}")