rootxhacker commited on
Commit
2e03541
1 Parent(s): a04d31e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -4,23 +4,23 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import gradio as gr
5
  import spaces
6
 
7
-
8
-
9
- @spaces.GPU(duration=200)
10
- def get_completion(query, model, tokenizer):
11
- peft_model_id = "rootxhacker/CodeAstra-7B"
12
- config = PeftConfig.from_pretrained(peft_model_id)
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- model = AutoModelForCausalLM.from_pretrained(
15
  config.base_model_name_or_path,
16
  return_dict=True,
17
  load_in_4bit=True,
18
  device_map="auto" # This will automatically handle device placement
19
  )
20
 
21
- tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
22
 
23
- model = PeftModel.from_pretrained(model, peft_model_id)
 
 
 
 
24
  inputs = tokenizer(query, return_tensors="pt").to(device) # Move inputs to the same device as the model
25
  outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
26
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
4
  import gradio as gr
5
  import spaces
6
 
7
+ peft_model_id = "rootxhacker/CodeAstra-7B"
8
+ config = PeftConfig.from_pretrained(peft_model_id)
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
11
  config.base_model_name_or_path,
12
  return_dict=True,
13
  load_in_4bit=True,
14
  device_map="auto" # This will automatically handle device placement
15
  )
16
 
17
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
18
 
19
+ model = PeftModel.from_pretrained(model, peft_model_id)
20
+
21
+
22
+ @spaces.GPU(duration=200)
23
+ def get_completion(query, model, tokenizer):
24
  inputs = tokenizer(query, return_tensors="pt").to(device) # Move inputs to the same device as the model
25
  outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
26
  return tokenizer.decode(outputs[0], skip_special_tokens=True)