File size: 2,510 Bytes
8a17d7c 48a9cfb 8a17d7c 48a9cfb 8a17d7c 48a9cfb 8a17d7c 48a9cfb 8a17d7c 48a9cfb 8a17d7c 48a9cfb 8a17d7c 48a9cfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Dict
import os
def get_model():
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
# Force CUDA to be the default device
if torch.cuda.is_available():
torch.set_default_device('cuda')
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load model with explicit device placement
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# Explicitly move model to GPU
if torch.cuda.is_available():
model = model.cuda()
return model, tokenizer
# Initialize model and tokenizer
model, tokenizer = get_model()
def generate(text: str, params: Dict) -> Dict:
try:
# Ensure we're using CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Tokenize with explicit device placement
inputs = tokenizer(text, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Print debug info
print(f"Input device: {inputs['input_ids'].device}")
print(f"Model device: {next(model.parameters()).device}")
# Generate with explicit device placement
with torch.cuda.device(device):
outputs = model.generate(
**inputs,
max_new_tokens=params.get('max_new_tokens', 500),
temperature=params.get('temperature', 0.7),
top_p=params.get('top_p', 0.95),
top_k=params.get('top_k', 50),
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": response}
except Exception as e:
print(f"Error in generation: {str(e)}")
# Print device information for debugging
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"Current CUDA device: {torch.cuda.current_device()}")
print(f"Device count: {torch.cuda.device_count()}")
raise e
def inference(inputs: Dict) -> Dict:
prompt = inputs.get("inputs", "")
params = inputs.get("parameters", {})
return generate(prompt, params) |