NeverMore0123 commited on
Commit
e103991
·
1 Parent(s): 2c4bb7c

update presets and fix detectron2

Browse files
base_world_generation_pipeline.py CHANGED
@@ -22,7 +22,7 @@ import numpy as np
22
  import torch
23
 
24
  from .t5_text_encoder import CosmosT5TextEncoder
25
- from .guardrail_common_presets import guardrail_common_presets as guardrail_presets
26
 
27
 
28
  class BaseWorldGenerationPipeline(ABC):
 
22
  import torch
23
 
24
  from .t5_text_encoder import CosmosT5TextEncoder
25
+ from .guardrail_common_presets import presets as guardrail_presets
26
 
27
 
28
  class BaseWorldGenerationPipeline(ABC):
guardrail_common_presets.py CHANGED
@@ -24,54 +24,56 @@ from .guardrail_face_blur_filter import RetinaFaceFilter
24
  from .guardrail_video_content_safety_filter import VideoContentSafetyFilter
25
  from .log import log
26
 
27
-
28
- def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
29
- """Create the text guardrail runner."""
30
- blocklist_checkpoint_dir = os.path.join(checkpoint_dir, "blocklist")
31
- aegis_checkpoint_dir = os.path.join(checkpoint_dir, "aegis")
32
- return GuardrailRunner(safety_models=[Blocklist(blocklist_checkpoint_dir), Aegis(aegis_checkpoint_dir)])
33
-
34
-
35
- def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
36
- """Create the video guardrail runner."""
37
- video_filter_checkpoint_dir = os.path.join(checkpoint_dir, "video_content_safety_filter")
38
- retinaface_checkpoint_path = os.path.join(checkpoint_dir, "face_blur_filter/Resnet50_Final.pth")
39
- return GuardrailRunner(
40
- safety_models=[VideoContentSafetyFilter(video_filter_checkpoint_dir)],
41
- postprocessors=[RetinaFaceFilter(retinaface_checkpoint_path)],
42
- )
43
-
44
-
45
- def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool:
46
- """Run the text guardrail on the prompt, checking for content safety.
47
-
48
- Args:
49
- prompt: The text prompt.
50
- guardrail_runner: The text guardrail runner.
51
-
52
- Returns:
53
- bool: Whether the prompt is safe.
54
- """
55
- is_safe, message = guardrail_runner.run_safety_check(prompt)
56
- if not is_safe:
57
- log.critical(f"GUARDRAIL BLOCKED: {message}")
58
- return is_safe
59
-
60
-
61
- def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None:
62
- """Run the video guardrail on the frames, checking for content safety and applying face blur.
63
-
64
- Args:
65
- frames: The frames of the generated video.
66
- guardrail_runner: The video guardrail runner.
67
-
68
- Returns:
69
- The processed frames if safe, otherwise None.
70
- """
71
- is_safe, message = guardrail_runner.run_safety_check(frames)
72
- if not is_safe:
73
- log.critical(f"GUARDRAIL BLOCKED: {message}")
74
- return None
75
-
76
- frames = guardrail_runner.postprocess(frames)
77
- return frames
 
 
 
24
  from .guardrail_video_content_safety_filter import VideoContentSafetyFilter
25
  from .log import log
26
 
27
+ class presets():
28
+
29
+ @staticmethod
30
+ def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
31
+ """Create the text guardrail runner."""
32
+ blocklist_checkpoint_dir = os.path.join(checkpoint_dir, "blocklist")
33
+ aegis_checkpoint_dir = os.path.join(checkpoint_dir, "aegis")
34
+ return GuardrailRunner(safety_models=[Blocklist(blocklist_checkpoint_dir), Aegis(aegis_checkpoint_dir)])
35
+
36
+ @staticmethod
37
+ def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
38
+ """Create the video guardrail runner."""
39
+ video_filter_checkpoint_dir = os.path.join(checkpoint_dir, "video_content_safety_filter")
40
+ retinaface_checkpoint_path = os.path.join(checkpoint_dir, "face_blur_filter/Resnet50_Final.pth")
41
+ return GuardrailRunner(
42
+ safety_models=[VideoContentSafetyFilter(video_filter_checkpoint_dir)],
43
+ postprocessors=[RetinaFaceFilter(retinaface_checkpoint_path)],
44
+ )
45
+
46
+ @staticmethod
47
+ def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool:
48
+ """Run the text guardrail on the prompt, checking for content safety.
49
+
50
+ Args:
51
+ prompt: The text prompt.
52
+ guardrail_runner: The text guardrail runner.
53
+
54
+ Returns:
55
+ bool: Whether the prompt is safe.
56
+ """
57
+ is_safe, message = guardrail_runner.run_safety_check(prompt)
58
+ if not is_safe:
59
+ log.critical(f"GUARDRAIL BLOCKED: {message}")
60
+ return is_safe
61
+
62
+ @staticmethod
63
+ def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None:
64
+ """Run the video guardrail on the frames, checking for content safety and applying face blur.
65
+
66
+ Args:
67
+ frames: The frames of the generated video.
68
+ guardrail_runner: The video guardrail runner.
69
+
70
+ Returns:
71
+ The processed frames if safe, otherwise None.
72
+ """
73
+ is_safe, message = guardrail_runner.run_safety_check(frames)
74
+ if not is_safe:
75
+ log.critical(f"GUARDRAIL BLOCKED: {message}")
76
+ return None
77
+
78
+ frames = guardrail_runner.postprocess(frames)
79
+ return frames
lazy.py CHANGED
@@ -78,7 +78,7 @@ class LazyCall:
78
 
79
  Examples:
80
  ::
81
- from detectron2.config import instantiate, LazyCall
82
 
83
  layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
84
  layer_cfg.out_channels = 64 # can edit it afterwards
 
78
 
79
  Examples:
80
  ::
81
+ # from detectron2.config import instantiate, LazyCall
82
 
83
  layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
84
  layer_cfg.out_channels = 64 # can edit it afterwards