Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,18 +14,34 @@ dtype = torch.bfloat16
|
|
14 |
|
15 |
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
@spaces.GPU
|
18 |
def moderate(user_input, assistant_response):
|
19 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
|
20 |
-
model = AutoModelForCausalLM.from_pretrained(
|
21 |
-
model_id,
|
22 |
-
torch_dtype=dtype,
|
23 |
-
device_map="auto",
|
24 |
-
quantization_config=quantization_config,
|
25 |
-
token=huggingface_token,
|
26 |
-
low_cpu_mem_usage=True
|
27 |
-
)
|
28 |
-
|
29 |
chat = [
|
30 |
{"role": "user", "content": user_input},
|
31 |
{"role": "assistant", "content": assistant_response},
|
@@ -35,30 +51,14 @@ def moderate(user_input, assistant_response):
|
|
35 |
with torch.no_grad():
|
36 |
output = model.generate(
|
37 |
input_ids=input_ids,
|
38 |
-
max_new_tokens=
|
39 |
pad_token_id=tokenizer.eos_token_id,
|
40 |
-
do_sample=False
|
41 |
)
|
42 |
|
43 |
-
|
44 |
-
result =
|
45 |
-
|
46 |
-
|
47 |
-
if not lines:
|
48 |
-
return "Error", "Empty output", "No valid output from model"
|
49 |
-
|
50 |
-
first_line = lines[0]
|
51 |
-
if first_line == 'safe':
|
52 |
-
safety_status = "Safe"
|
53 |
-
violated_categories = "None"
|
54 |
-
elif first_line == 'unsafe':
|
55 |
-
safety_status = "Unsafe"
|
56 |
-
violated_categories = lines[1] if len(lines) > 1 else "Unspecified"
|
57 |
-
else:
|
58 |
-
safety_status = "Error"
|
59 |
-
violated_categories = f"Invalid output: {first_line}"
|
60 |
-
|
61 |
-
return safety_status, violated_categories, result
|
62 |
|
63 |
iface = gr.Interface(
|
64 |
fn=moderate,
|
|
|
14 |
|
15 |
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
16 |
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(
|
19 |
+
model_id,
|
20 |
+
torch_dtype=dtype,
|
21 |
+
device_map="auto",
|
22 |
+
quantization_config=quantization_config,
|
23 |
+
token=huggingface_token,
|
24 |
+
low_cpu_mem_usage=True
|
25 |
+
)
|
26 |
+
|
27 |
+
def parse_llama_guard_output(result):
|
28 |
+
lines = [line.strip().lower() for line in result.split('\n') if line.strip()]
|
29 |
+
|
30 |
+
if not lines:
|
31 |
+
return "Error", "No valid output", result
|
32 |
+
|
33 |
+
safety_status = next((line for line in lines if line in ['safe', 'unsafe']), None)
|
34 |
+
|
35 |
+
if safety_status == 'safe':
|
36 |
+
return "Safe", "None", result
|
37 |
+
elif safety_status == 'unsafe':
|
38 |
+
violated_categories = next((lines[i+1] for i, line in enumerate(lines) if line == 'unsafe' and i+1 < len(lines)), "Unspecified")
|
39 |
+
return "Unsafe", violated_categories, result
|
40 |
+
else:
|
41 |
+
return "Error", f"Invalid output: {safety_status}", result
|
42 |
+
|
43 |
@spaces.GPU
|
44 |
def moderate(user_input, assistant_response):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
chat = [
|
46 |
{"role": "user", "content": user_input},
|
47 |
{"role": "assistant", "content": assistant_response},
|
|
|
51 |
with torch.no_grad():
|
52 |
output = model.generate(
|
53 |
input_ids=input_ids,
|
54 |
+
max_new_tokens=100,
|
55 |
pad_token_id=tokenizer.eos_token_id,
|
|
|
56 |
)
|
57 |
|
58 |
+
prompt_len = input_ids.shape[-1]
|
59 |
+
result = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
|
60 |
+
|
61 |
+
return parse_llama_guard_output(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
iface = gr.Interface(
|
64 |
fn=moderate,
|