File size: 2,510 Bytes
8a17d7c
 
48a9cfb
8a17d7c
 
48a9cfb
 
 
 
 
 
 
 
8a17d7c
48a9cfb
 
8a17d7c
 
 
 
 
 
48a9cfb
 
 
8a17d7c
48a9cfb
8a17d7c
48a9cfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a17d7c
 
 
 
48a9cfb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Dict
import os

def get_model():
    model_id = "mistralai/Mistral-7B-Instruct-v0.2"
    
    # Force CUDA to be the default device
    if torch.cuda.is_available():
        torch.set_default_device('cuda')
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    # Load model with explicit device placement
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )
    
    # Explicitly move model to GPU
    if torch.cuda.is_available():
        model = model.cuda()
    
    return model, tokenizer

# Initialize model and tokenizer
model, tokenizer = get_model()

def generate(text: str, params: Dict) -> Dict:
    try:
        # Ensure we're using CUDA
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        
        # Tokenize with explicit device placement
        inputs = tokenizer(text, return_tensors="pt", padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Print debug info
        print(f"Input device: {inputs['input_ids'].device}")
        print(f"Model device: {next(model.parameters()).device}")
        
        # Generate with explicit device placement
        with torch.cuda.device(device):
            outputs = model.generate(
                **inputs,
                max_new_tokens=params.get('max_new_tokens', 500),
                temperature=params.get('temperature', 0.7),
                top_p=params.get('top_p', 0.95),
                top_k=params.get('top_k', 50),
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"generated_text": response}
        
    except Exception as e:
        print(f"Error in generation: {str(e)}")
        # Print device information for debugging
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"Current CUDA device: {torch.cuda.current_device()}")
            print(f"Device count: {torch.cuda.device_count()}")
        raise e

def inference(inputs: Dict) -> Dict:
    prompt = inputs.get("inputs", "")
    params = inputs.get("parameters", {})
    return generate(prompt, params)