File size: 2,574 Bytes
69591a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import streamlit as st
from dnafiber.inference import infer
from dnafiber.postprocess.core import refine_segmentation
import numpy as np
from dnafiber.deployment import _get_model
import torch


@st.cache_data
def ui_inference(_model, _image, _device, postprocess=True, id=None):
    return ui_inference_cacheless(
        _model, _image, _device, postprocess=postprocess, id=id
    )


@st.cache_resource
def get_model(model_name):
    model = _get_model(
        device="cuda" if torch.cuda.is_available() else "cpu",
        revision=model_name,
    )
    return model


def ui_inference_cacheless(_model, _image, _device, postprocess=True, id=None):
    """

    A cacheless version of the ui_inference function.

    This function does not use caching and is intended for use in scenarios where caching is not desired.

    """
    h, w = _image.shape[:2]
    with st.spinner("Sliding window segmentation in progress..."):
        if isinstance(_model, list):
            output = None
            for model in _model:
                if isinstance(model, str):
                    model = get_model(model)
                with st.spinner(text="Segmenting with model: {}".format(model)):
                    if output is None:
                        output = infer(
                            model,
                            image=_image,
                            device=_device,
                            scale=st.session_state.get("pixel_size", 0.13),
                            only_probabilities=True,
                        ).cpu()
                    else:
                        output = (
                            output
                            + infer(
                                model,
                                image=_image,
                                device=_device,
                                scale=st.session_state.get("pixel_size", 0.13),
                                only_probabilities=True,
                            ).cpu()
                        )
            output = (output / len(_model)).argmax(1).squeeze().numpy()
        else:
            output = infer(
                _model,
                image=_image,
                device=_device,
                scale=st.session_state.get("pixel_size", 0.13),
            )
    output = output.astype(np.uint8)
    if postprocess:
        with st.spinner("Post-processing segmentation..."):
            output = refine_segmentation(output, fix_junctions=postprocess)
    return output