schroneko commited on
Commit
46358a2
·
verified ·
1 Parent(s): ca0aa0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -1,8 +1,13 @@
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import gradio as gr
4
  import spaces
5
 
 
 
 
 
6
  model_id = "meta-llama/Llama-Guard-3-8B-INT8"
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  dtype = torch.bfloat16
@@ -10,12 +15,13 @@ dtype = torch.bfloat16
10
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
11
 
12
  def load_model():
13
- tokenizer = AutoTokenizer.from_pretrained(model_id)
14
  model = AutoModelForCausalLM.from_pretrained(
15
  model_id,
16
  torch_dtype=dtype,
17
  device_map="auto",
18
  quantization_config=quantization_config,
 
19
  low_cpu_mem_usage=True
20
  )
21
  return tokenizer, model
@@ -39,7 +45,6 @@ def moderate(user_input, assistant_response):
39
  )
40
 
41
  result = tokenizer.decode(output[0], skip_special_tokens=True)
42
-
43
  result = result.split(assistant_response)[-1].strip()
44
 
45
  is_safe = "safe" in result.lower()
 
1
+ import os
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  import gradio as gr
5
  import spaces
6
 
7
+ huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
8
+ if not huggingface_token:
9
+ raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
10
+
11
  model_id = "meta-llama/Llama-Guard-3-8B-INT8"
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  dtype = torch.bfloat16
 
15
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
16
 
17
  def load_model():
18
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
  torch_dtype=dtype,
22
  device_map="auto",
23
  quantization_config=quantization_config,
24
+ token=huggingface_token,
25
  low_cpu_mem_usage=True
26
  )
27
  return tokenizer, model
 
45
  )
46
 
47
  result = tokenizer.decode(output[0], skip_special_tokens=True)
 
48
  result = result.split(assistant_response)[-1].strip()
49
 
50
  is_safe = "safe" in result.lower()