schroneko commited on
Commit
0a17bfe
1 Parent(s): 50ffd30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -31
app.py CHANGED
@@ -14,18 +14,34 @@ dtype = torch.bfloat16
14
 
15
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @spaces.GPU
18
  def moderate(user_input, assistant_response):
19
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
20
- model = AutoModelForCausalLM.from_pretrained(
21
- model_id,
22
- torch_dtype=dtype,
23
- device_map="auto",
24
- quantization_config=quantization_config,
25
- token=huggingface_token,
26
- low_cpu_mem_usage=True
27
- )
28
-
29
  chat = [
30
  {"role": "user", "content": user_input},
31
  {"role": "assistant", "content": assistant_response},
@@ -35,30 +51,14 @@ def moderate(user_input, assistant_response):
35
  with torch.no_grad():
36
  output = model.generate(
37
  input_ids=input_ids,
38
- max_new_tokens=200,
39
  pad_token_id=tokenizer.eos_token_id,
40
- do_sample=False
41
  )
42
 
43
- result = tokenizer.decode(output[0], skip_special_tokens=True)
44
- result = result.split(assistant_response)[-1].strip()
45
-
46
- lines = [line.strip().lower() for line in result.split('\n') if line.strip()]
47
- if not lines:
48
- return "Error", "Empty output", "No valid output from model"
49
-
50
- first_line = lines[0]
51
- if first_line == 'safe':
52
- safety_status = "Safe"
53
- violated_categories = "None"
54
- elif first_line == 'unsafe':
55
- safety_status = "Unsafe"
56
- violated_categories = lines[1] if len(lines) > 1 else "Unspecified"
57
- else:
58
- safety_status = "Error"
59
- violated_categories = f"Invalid output: {first_line}"
60
-
61
- return safety_status, violated_categories, result
62
 
63
  iface = gr.Interface(
64
  fn=moderate,
 
14
 
15
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
16
 
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_id,
20
+ torch_dtype=dtype,
21
+ device_map="auto",
22
+ quantization_config=quantization_config,
23
+ token=huggingface_token,
24
+ low_cpu_mem_usage=True
25
+ )
26
+
27
+ def parse_llama_guard_output(result):
28
+ lines = [line.strip().lower() for line in result.split('\n') if line.strip()]
29
+
30
+ if not lines:
31
+ return "Error", "No valid output", result
32
+
33
+ safety_status = next((line for line in lines if line in ['safe', 'unsafe']), None)
34
+
35
+ if safety_status == 'safe':
36
+ return "Safe", "None", result
37
+ elif safety_status == 'unsafe':
38
+ violated_categories = next((lines[i+1] for i, line in enumerate(lines) if line == 'unsafe' and i+1 < len(lines)), "Unspecified")
39
+ return "Unsafe", violated_categories, result
40
+ else:
41
+ return "Error", f"Invalid output: {safety_status}", result
42
+
43
  @spaces.GPU
44
  def moderate(user_input, assistant_response):
 
 
 
 
 
 
 
 
 
 
45
  chat = [
46
  {"role": "user", "content": user_input},
47
  {"role": "assistant", "content": assistant_response},
 
51
  with torch.no_grad():
52
  output = model.generate(
53
  input_ids=input_ids,
54
+ max_new_tokens=100,
55
  pad_token_id=tokenizer.eos_token_id,
 
56
  )
57
 
58
+ prompt_len = input_ids.shape[-1]
59
+ result = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
60
+
61
+ return parse_llama_guard_output(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  iface = gr.Interface(
64
  fn=moderate,