Pr123 commited on
Commit
dd17305
1 Parent(s): 8cbf8ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -7,11 +7,12 @@ import gradio as gr
7
  peft_model_id = "Pr123/TinyLlama-EA-Chat"
8
 
9
  # Load Model with PEFT adapter
 
10
  model = AutoPeftModelForCausalLM.from_pretrained(
11
  peft_model_id,
12
- device_map="auto",
13
- torch_dtype=torch.float16
14
- )
15
  tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
16
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=500)
17
 
 
7
  peft_model_id = "Pr123/TinyLlama-EA-Chat"
8
 
9
  # Load Model with PEFT adapter
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model = AutoPeftModelForCausalLM.from_pretrained(
12
  peft_model_id,
13
+ torch_dtype=torch.float16 # Keeping half precision
14
+ ).to(device)
15
+
16
  tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
17
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=500)
18