elf.segmentation.blockwise_lmc_impl

  1from concurrent import futures
  2
  3import bioimage_cpp as bic
  4import numpy as np
  5
  6from .blockwise_mc_impl import _relabel_from_zero, _remap_edges_sum_costs
  7
  8
  9def find_inner_lifted_edges(lifted_uv_ids, node_list):
 10    """@private
 11    """
 12    lifted_indices = np.arange(len(lifted_uv_ids), dtype="uint64")
 13    # find overlap of node_list with u-edges
 14    inner_us = np.isin(lifted_uv_ids[:, 0], node_list)
 15    inner_indices = lifted_indices[inner_us]
 16    inner_uvs = lifted_uv_ids[inner_us]
 17    # find overlap of node_list with v-edges
 18    inner_vs = np.isin(inner_uvs[:, 1], node_list)
 19    return inner_indices[inner_vs]
 20
 21
 22def solve_subproblems(graph, costs, lifted_uv_ids, lifted_costs,
 23                      segmentation, solver, blocking, halo, n_threads):
 24    """@private
 25    """
 26
 27    uv_ids = graph.uv_ids()
 28
 29    # solve sub-problem from one block
 30    def solve_subproblem(block_id):
 31
 32        # extract nodes from this block
 33        block = blocking.get_block(block_id) if halo is None else\
 34            blocking.get_block_with_halo(block_id, list(halo)).outer_block
 35        bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
 36        node_ids = np.unique(segmentation[bb]).astype("uint64")
 37
 38        # get the sub-graph corresponding to the nodes
 39        inner_edges, outer_edges = graph.extract_subgraph_from_nodes(node_ids)
 40        sub_uvs = uv_ids[inner_edges]
 41
 42        # relabel the sub-nodes and associated uv-ids for more efficient processing
 43        nodes_relabeled, max_id, mapping = _relabel_from_zero(node_ids)
 44        sub_uvs = bic.utils.take_dict(mapping, np.ascontiguousarray(sub_uvs, dtype="uint64"))
 45        n_local_nodes = max_id + 1
 46        sub_graph = bic.graph.UndirectedGraph.from_edges(n_local_nodes, sub_uvs)
 47
 48        sub_costs = costs[inner_edges]
 49        assert len(sub_costs) == sub_graph.number_of_edges
 50
 51        # get the inner lifted edges and costs
 52        inner_lifted_edges = find_inner_lifted_edges(lifted_uv_ids, node_ids)
 53        sub_lifted_uvs = bic.utils.take_dict(
 54            mapping, np.ascontiguousarray(lifted_uv_ids[inner_lifted_edges], dtype="uint64"),
 55        )
 56        sub_lifted_costs = lifted_costs[inner_lifted_edges]
 57
 58        # solve multicut for the sub-graph
 59        sub_result = solver(sub_graph, sub_costs, sub_lifted_uvs, sub_lifted_costs)
 60        assert len(sub_result) == len(node_ids), "%i, %i" % (len(sub_result), len(node_ids))
 61
 62        sub_edgeresult = sub_result[sub_uvs[:, 0]] != sub_result[sub_uvs[:, 1]]
 63        assert len(sub_edgeresult) == len(inner_edges)
 64        cut_edge_ids = inner_edges[sub_edgeresult]
 65        cut_edge_ids = np.concatenate([cut_edge_ids, outer_edges])
 66        return cut_edge_ids
 67
 68    with futures.ThreadPoolExecutor(n_threads) as tp:
 69        tasks = [tp.submit(solve_subproblem, block_id)
 70                 for block_id in range(blocking.number_of_blocks)]
 71        results = [t.result() for t in tasks]
 72
 73    # merge the edge results to get all merge edges
 74    cut_edges = np.zeros(graph.number_of_edges, dtype="uint16")
 75    for res in results:
 76        cut_edges[res] += 1
 77    return cut_edges == 0
 78
 79
 80def update_edges(uv_ids, costs, labels, n_threads):
 81    """@private
 82    """
 83    return _remap_edges_sum_costs(uv_ids, labels, costs)
 84
 85
 86def reduce_problem(graph, costs, lifted_uv_ids, lifted_costs, merge_edges, n_threads):
 87    """@private
 88    """
 89    # merge node pairs with ufd
 90    n_nodes = graph.number_of_nodes
 91    nodes = np.arange(n_nodes, dtype="uint64")
 92    uv_ids = np.asarray(graph.uv_ids(), dtype="uint64")
 93    ufd = bic.utils.UnionFind(n_nodes)
 94    ufd.merge(np.ascontiguousarray(uv_ids[merge_edges], dtype="uint64"))
 95
 96    # get then new node labels
 97    new_labels = ufd.find(nodes)
 98
 99    # merge the edges and costs
100    new_uv_ids, new_costs = update_edges(uv_ids, costs, new_labels, n_threads)
101    new_lifted_uvs, new_lifted_costs = update_edges(
102        np.asarray(lifted_uv_ids, dtype="uint64"), lifted_costs, new_labels, n_threads,
103    )
104    # build the new graph
105    n_new_nodes = int(new_uv_ids.max()) + 1
106    new_graph = bic.graph.UndirectedGraph.from_edges(
107        n_new_nodes, np.ascontiguousarray(new_uv_ids, dtype="uint64"),
108    )
109
110    return new_graph, new_costs, new_lifted_uvs, new_lifted_costs, new_labels
111
112
113def hierarchy_level(graph, costs, lifted_uv_ids, lifted_costs,
114                    labels, segmentation, blocking,
115                    internal_solver, n_threads, halo):
116    """@private
117    """
118    merge_edges = solve_subproblems(graph, costs, lifted_uv_ids, lifted_costs,
119                                    segmentation, internal_solver,
120                                    blocking, halo, n_threads)
121    graph, costs, lifted_uv_ids, lifted_costs, new_labels = reduce_problem(graph, costs,
122                                                                           lifted_uv_ids, lifted_costs,
123                                                                           merge_edges, n_threads)
124
125    if labels is None:
126        labels = new_labels
127    else:
128        labels = new_labels[labels]
129
130    return graph, costs, lifted_uv_ids, lifted_costs, labels
131
132
133def blockwise_lmc_impl(graph, costs, lifted_uv_ids, lifted_costs,
134                       segmentation, internal_solver,
135                       block_shape, n_threads, n_levels=1, halo=None):
136    """@private
137    """
138    shape = segmentation.shape
139    graph_, costs_ = graph, costs
140    lifted_uv_ids_, lifted_costs_ = lifted_uv_ids, lifted_costs
141    block_shape_ = block_shape
142    labels = None
143
144    for level in range(n_levels):
145        blocking = bic.utils.Blocking(
146            roi_begin=[0] * len(shape),
147            roi_end=list(shape),
148            block_shape=list(block_shape_),
149        )
150        graph_, costs_, lifted_uv_ids_, lifted_costs_, labels = hierarchy_level(graph_, costs_,
151                                                                                lifted_uv_ids_, lifted_costs_,
152                                                                                labels, segmentation, blocking,
153                                                                                internal_solver, n_threads, halo)
154        block_shape_ = [bs * 2 for bs in block_shape]
155
156    # solve the final reduced problem
157    final_labels = internal_solver(graph_, costs_, lifted_uv_ids_, lifted_costs_)
158    # bring reduced problem back to the initial graph
159    labels = final_labels[labels]
160    return labels