File size: 2,835 Bytes
2202e15
a69843f
d776322
a69843f
89efbe0
2202e15
 
89efbe0
 
 
 
2202e15
 
89efbe0
 
 
 
2202e15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89efbe0
2202e15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89efbe0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Updated import path
#from llama_index.llms import HuggingFaceInferenceAPI
#from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
import torch
# If that doesn't work, try:
# from llama_index.llms.huggingface import HuggingFaceLLM

def setup_llm(model_name: str = "microsoft/phi-3-mini-4k-instruct", 
              device: str = None,
              context_window: int = 4096,
              max_new_tokens: int = 512):
    """Set up the language model for the CSV chatbot."""
    # Determine device
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Try the updated class
    try:
        # First attempt with new API
        from llama_index.llms.huggingface import HuggingFaceLLM
        
        # Configure model with appropriate parameters for HF Spaces
        model_kwargs = {
            "trust_remote_code": True,
            "torch_dtype": torch.float16,
        }
        
        if device == "cuda":
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16
            )
            model_kwargs["quantization_config"] = quantization_config
        
        # Initialize LLM
        llm = HuggingFaceLLM(
            model_name=model_name,
            tokenizer_name=model_name,
            context_window=context_window,
            max_new_tokens=max_new_tokens,
            generate_kwargs={"temperature": 0.7, "top_p": 0.95},
            device_map=device,
            tokenizer_kwargs={"trust_remote_code": True},
            model_kwargs=model_kwargs,
            # Cache the model to avoid reloading
            cache_folder="./model_cache"
        )
        
    except (ImportError, AttributeError):
        # Fallback to other API options
        try:
            from llama_index.llms import HuggingFaceInferenceAPI
            
            llm = HuggingFaceInferenceAPI(
                model_name=model_name,
                tokenizer_name=model_name,
                context_window=context_window,
                max_new_tokens=max_new_tokens,
                generate_kwargs={"temperature": 0.7, "top_p": 0.95}
            )
        except:
            # Last resort - try the base LLM class
            from llama_index.llms.base import LLM
            from llama_index.llms.huggingface import HuggingFaceInference
            
            llm = HuggingFaceInference(
                model_name=model_name,
                tokenizer_name=model_name,
                context_window=context_window,
                max_new_tokens=max_new_tokens,
                generate_kwargs={"temperature": 0.7, "top_p": 0.95}
            )
    
    return llm