made1570 commited on
Commit
3ea1454
·
verified ·
1 Parent(s): 7afcb61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -8,6 +8,7 @@ base_model_name = "unsloth/gemma-3-12b-it-unsloth-bnb-4bit"
8
  adapter_name = "adarsh3601/my_gemma3_pt"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
 
11
  base_model = AutoModelForCausalLM.from_pretrained(
12
  base_model_name,
13
  device_map={"": device},
@@ -15,6 +16,7 @@ base_model = AutoModelForCausalLM.from_pretrained(
15
  load_in_4bit=True
16
  )
17
 
 
18
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
19
  model = PeftModel.from_pretrained(base_model, adapter_name)
20
  model.to(device)
@@ -22,7 +24,14 @@ model.to(device)
22
  # Chat function
23
  def chat(message):
24
  inputs = tokenizer(message, return_tensors="pt")
25
- inputs = {k: v.to(device).half() for k, v in inputs.items()}
 
 
 
 
 
 
 
26
  outputs = model.generate(**inputs, max_new_tokens=150, do_sample=True)
27
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  return response
 
8
  adapter_name = "adarsh3601/my_gemma3_pt"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # Load base model
12
  base_model = AutoModelForCausalLM.from_pretrained(
13
  base_model_name,
14
  device_map={"": device},
 
16
  load_in_4bit=True
17
  )
18
 
19
+ # Load tokenizer and adapter
20
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
21
  model = PeftModel.from_pretrained(base_model, adapter_name)
22
  model.to(device)
 
24
  # Chat function
25
  def chat(message):
26
  inputs = tokenizer(message, return_tensors="pt")
27
+
28
+ # Move tensors to the correct device and convert only float tensors to half
29
+ for k in inputs:
30
+ if inputs[k].dtype == torch.float32:
31
+ inputs[k] = inputs[k].to(device).half()
32
+ else:
33
+ inputs[k] = inputs[k].to(device)
34
+
35
  outputs = model.generate(**inputs, max_new_tokens=150, do_sample=True)
36
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
  return response