|
|
from pathlib import Path |
|
|
from typing import Dict, Callable, Optional, Generator, Any |
|
|
import shutil |
|
|
from PIL import Image |
|
|
import glob |
|
|
import os |
|
|
|
|
|
from sorghum_pipeline.pipeline import SorghumPipeline |
|
|
from sorghum_pipeline.config import Config, Paths |
|
|
|
|
|
|
|
|
def run_pipeline_on_image(input_image_path: str, work_dir: str, save_artifacts: bool = True, |
|
|
progress_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None, |
|
|
single_plant_mode: bool = False) -> Generator[Dict[str, str], None, None]: |
|
|
""" |
|
|
Run sorghum pipeline on a single image (no instance segmentation). |
|
|
Yields dict[label -> image_path] progressively for gallery display. |
|
|
|
|
|
Args: |
|
|
input_image_path: Path to input image |
|
|
work_dir: Working directory for outputs |
|
|
save_artifacts: Whether to save artifacts |
|
|
progress_callback: Optional callback(stage_name, data) called after each pipeline stage |
|
|
|
|
|
Yields: |
|
|
Dictionary of output paths progressively as they become available |
|
|
""" |
|
|
|
|
|
work = Path(work_dir) |
|
|
work.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
input_path = Path(input_image_path) |
|
|
|
|
|
|
|
|
os.environ['MINIMAL_DEMO'] = '1' |
|
|
os.environ['FAST_OUTPUT'] = '1' |
|
|
|
|
|
|
|
|
cfg = Config() |
|
|
cfg.paths = Paths( |
|
|
input_folder=str(work), |
|
|
output_folder=str(work), |
|
|
boundingbox_dir=str(work) |
|
|
) |
|
|
pipeline = SorghumPipeline(config=cfg, single_plant_mode=single_plant_mode) |
|
|
|
|
|
|
|
|
for stage_result in pipeline.run_with_progress(single_image_path=str(input_path), progress_callback=progress_callback, single_plant_mode=single_plant_mode): |
|
|
|
|
|
outputs = _collect_outputs(work, stage_result.get('plants', {})) |
|
|
yield outputs |
|
|
|
|
|
|
|
|
results = stage_result |
|
|
|
|
|
|
|
|
def _collect_outputs(work: Path, plants: Dict[str, Any]) -> Dict[str, str]: |
|
|
"""Collect all available outputs from work directory and plants data.""" |
|
|
outputs: Dict[str, str] = {} |
|
|
|
|
|
try: |
|
|
|
|
|
for sub in ['results', 'Vegetation_indices_images', 'texture_output']: |
|
|
p = work / sub |
|
|
if p.exists(): |
|
|
files = sorted([str(x.name) for x in p.iterdir() if x.is_file()]) |
|
|
print(f"Artifacts in {sub}: {files}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
wanted = [ |
|
|
work / 'Vegetation_indices_images/ndvi.png', |
|
|
work / 'Vegetation_indices_images/gndvi.png', |
|
|
work / 'Vegetation_indices_images/savi.png', |
|
|
] |
|
|
labels = [ |
|
|
'NDVI', 'GNDVI', 'SAVI', |
|
|
] |
|
|
for label, path in zip(labels, wanted): |
|
|
if path.exists(): |
|
|
outputs[label] = str(path) |
|
|
|
|
|
|
|
|
overlay_path = work / 'results/overlay.png' |
|
|
mask_path = work / 'results/mask.png' |
|
|
composite_path = work / 'results/composite.png' |
|
|
input_img_path = work / 'results/input_image.png' |
|
|
if overlay_path.exists(): |
|
|
outputs['Overlay'] = str(overlay_path) |
|
|
if mask_path.exists(): |
|
|
outputs['Mask'] = str(mask_path) |
|
|
if composite_path.exists(): |
|
|
outputs['Composite'] = str(composite_path) |
|
|
if input_img_path.exists(): |
|
|
outputs['InputImage'] = str(input_img_path) |
|
|
|
|
|
|
|
|
try: |
|
|
if plants: |
|
|
_, pdata = next(iter(plants.items())) |
|
|
veg = pdata.get('vegetation_indices', {}) |
|
|
stats_lines = [] |
|
|
for name in ['NDVI', 'GNDVI', 'SAVI']: |
|
|
entry = veg.get(name, {}) |
|
|
st = entry.get('statistics', {}) if isinstance(entry, dict) else {} |
|
|
if st: |
|
|
stats_lines.append(f"{name}: mean={st.get('mean', 0):.3f}, std={st.get('std', 0):.3f}") |
|
|
|
|
|
morph = pdata.get('morphology_features', {}) if isinstance(pdata, dict) else {} |
|
|
traits = morph.get('traits', {}) if isinstance(morph, dict) else {} |
|
|
|
|
|
|
|
|
plant_heights = traits.get('plant_heights', {}) |
|
|
num_plants = traits.get('num_plants', 0) |
|
|
|
|
|
|
|
|
if num_plants > 0 and isinstance(plant_heights, dict): |
|
|
if num_plants == 1 or len(plant_heights) == 1: |
|
|
|
|
|
height_cm = list(plant_heights.values())[0] |
|
|
stats_lines.append(f"Number of plants: 1") |
|
|
stats_lines.append(f"Plant height: {height_cm:.2f} cm") |
|
|
else: |
|
|
|
|
|
stats_lines.append(f"Number of plants: {num_plants}") |
|
|
sorted_plants = sorted(plant_heights.items(), key=lambda x: int(x[0].split('_')[1])) |
|
|
for plant_name, height_cm in sorted_plants: |
|
|
plant_num = plant_name.split('_')[1] |
|
|
stats_lines.append(f" Plant {plant_num}: {height_cm:.2f} cm") |
|
|
else: |
|
|
|
|
|
height_cm = traits.get('plant_height_cm') |
|
|
if isinstance(height_cm, (int, float)) and height_cm > 0: |
|
|
stats_lines.append(f"Number of plants: 1") |
|
|
stats_lines.append(f"Plant height: {height_cm:.2f} cm") |
|
|
if stats_lines: |
|
|
outputs['StatsText'] = "\n".join(stats_lines) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return outputs |
|
|
|