Jayashree Sridhar commited on
Commit
aeee3e3
·
1 Parent(s): f8a9066

fallback to mistral

Browse files
Files changed (1) hide show
  1. models/mistral_model.py +6 -5
models/mistral_model.py CHANGED
@@ -24,22 +24,23 @@ class MistralModel:
24
 
25
  def _initialize_model(self):
26
  """Initialize Mistral model with optimizations"""
27
- print("Loading TinyGPT2Model model...")
28
 
29
- model_id = "sshleifer/tiny-gpt2"
30
 
31
  # Load tokenizer
32
- MistralModel._tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACE_TOKEN,use_fast=False)
33
 
34
  # Load model with optimizations
35
  MistralModel._model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
  token=HUGGINGFACE_TOKEN,
38
- torch_dtype=torch.float32,
 
39
  load_in_8bit=True # Use 8-bit quantization for memory efficiency
40
  )
41
 
42
- print("TinyGPT2Model loaded successfully!")
43
 
44
  def generate(
45
  self,
 
24
 
25
  def _initialize_model(self):
26
  """Initialize Mistral model with optimizations"""
27
+ print("Loading Mistral model...")
28
 
29
+ model_id = "mistralai/Mistral-7B-Instruct-v0.2"
30
 
31
  # Load tokenizer
32
+ MistralModel._tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACE_TOKEN)
33
 
34
  # Load model with optimizations
35
  MistralModel._model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
  token=HUGGINGFACE_TOKEN,
38
+ torch_dtype=torch.float16,
39
+ device_map="auto",
40
  load_in_8bit=True # Use 8-bit quantization for memory efficiency
41
  )
42
 
43
+ print("Mistral model loaded successfully!")
44
 
45
  def generate(
46
  self,