Create wrapper.py
Browse files- wrapper.py +54 -0
wrapper.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Dict
|
| 3 |
+
import shutil
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import glob
|
| 6 |
+
import tempfile
|
| 7 |
+
|
| 8 |
+
from sorghum_pipeline.pipeline import SorghumPipeline
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True) -> Dict[str, str]:
|
| 12 |
+
"""
|
| 13 |
+
Run sorghum pipeline on a single image (no instance segmentation).
|
| 14 |
+
Returns dict[label -> image_path] for gallery display.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
work = Path(work_dir)
|
| 18 |
+
work.mkdir(parents=True, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
# Copy input to work dir
|
| 21 |
+
input_copy = work / Path(input_image_path).name
|
| 22 |
+
shutil.copy(input_image_path, input_copy)
|
| 23 |
+
|
| 24 |
+
# Initialize pipeline with config
|
| 25 |
+
# adjust this if you have a YAML config file (e.g., "configs/demo.yaml")
|
| 26 |
+
pipeline = SorghumPipeline(
|
| 27 |
+
config_path=str(Path("sorghum_pipeline/config.py")),
|
| 28 |
+
enable_occlusion_handling=False,
|
| 29 |
+
enable_instance_integration=False
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Run the pipeline (single image, no frames, no SAM2Long)
|
| 33 |
+
results = pipeline.run(
|
| 34 |
+
load_all_frames=False,
|
| 35 |
+
segmentation_only=False,
|
| 36 |
+
run_instance_segmentation=False,
|
| 37 |
+
features_frame_only=None
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Collect outputs
|
| 41 |
+
outputs: Dict[str, str] = {}
|
| 42 |
+
|
| 43 |
+
# Save original for reference
|
| 44 |
+
original = work / "original.png"
|
| 45 |
+
Image.open(input_copy).convert("RGB").save(original)
|
| 46 |
+
outputs["Original"] = str(original)
|
| 47 |
+
|
| 48 |
+
# Gather all PNG files created by OutputManager
|
| 49 |
+
for f in glob.glob(str(work / "**/*.png"), recursive=True):
|
| 50 |
+
name = Path(f).stem
|
| 51 |
+
if name.lower() not in outputs: # avoid duplicate "Original"
|
| 52 |
+
outputs[name] = f
|
| 53 |
+
|
| 54 |
+
return outputs
|