elf.segmentation.blockwise_lmc_impl

  1from concurrent import futures
  2
  3import numpy as np
  4import nifty
  5from vigra.analysis import relabelConsecutive
  6
  7
  8def find_inner_lifted_edges(lifted_uv_ids, node_list):
  9    """@private
 10    """
 11    lifted_indices = np.arange(len(lifted_uv_ids), dtype="uint64")
 12    # find overlap of node_list with u-edges
 13    inner_us = np.in1d(lifted_uv_ids[:, 0], node_list)
 14    inner_indices = lifted_indices[inner_us]
 15    inner_uvs = lifted_uv_ids[inner_us]
 16    # find overlap of node_list with v-edges
 17    inner_vs = np.in1d(inner_uvs[:, 1], node_list)
 18    return inner_indices[inner_vs]
 19
 20
 21def solve_subproblems(graph, costs, lifted_uv_ids, lifted_costs,
 22                      segmentation, solver, blocking, halo, n_threads):
 23    """@private
 24    """
 25
 26    uv_ids = graph.uvIds()
 27
 28    # solve sub-problem from one block
 29    def solve_subproblem(block_id):
 30
 31        # extract nodes from this block
 32        block = blocking.getBlock(block_id) if halo is None else\
 33            blocking.getBlockWithHalo(block_id, halo).outerBlock
 34        bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
 35        node_ids = np.unique(segmentation[bb])
 36
 37        # get the sub-graph corresponding to the nodes
 38        inner_edges, outer_edges = graph.extractSubgraphFromNodes(node_ids)
 39        sub_uvs = uv_ids[inner_edges]
 40
 41        # relabel the sub-nodes and associated uv-ids for more efficient processing
 42        nodes_relabeled, max_id, mapping = relabelConsecutive(node_ids,
 43                                                              start_label=0,
 44                                                              keep_zeros=False)
 45        sub_uvs = nifty.tools.takeDict(mapping, sub_uvs)
 46        n_local_nodes = max_id + 1
 47        sub_graph = nifty.graph.undirectedGraph(n_local_nodes)
 48        sub_graph.insertEdges(sub_uvs)
 49
 50        sub_costs = costs[inner_edges]
 51        assert len(sub_costs) == sub_graph.numberOfEdges
 52
 53        # get the inner lifted edges and costs
 54        inner_lifted_edges = find_inner_lifted_edges(lifted_uv_ids, node_ids)
 55        sub_lifted_uvs = nifty.tools.takeDict(mapping, lifted_uv_ids[inner_lifted_edges])
 56        sub_lifted_costs = lifted_costs[inner_lifted_edges]
 57
 58        # solve multicut for the sub-graph
 59        sub_result = solver(sub_graph, sub_costs, sub_lifted_uvs, sub_lifted_costs)
 60        assert len(sub_result) == len(node_ids), "%i, %i" % (len(sub_result), len(node_ids))
 61
 62        sub_edgeresult = sub_result[sub_uvs[:, 0]] != sub_result[sub_uvs[:, 1]]
 63        assert len(sub_edgeresult) == len(inner_edges)
 64        cut_edge_ids = inner_edges[sub_edgeresult]
 65        cut_edge_ids = np.concatenate([cut_edge_ids, outer_edges])
 66        return cut_edge_ids
 67
 68    with futures.ThreadPoolExecutor(n_threads) as tp:
 69        tasks = [tp.submit(solve_subproblem, block_id)
 70                 for block_id in range(blocking.numberOfBlocks)]
 71        results = [t.result() for t in tasks]
 72
 73    # merge the edge results to get all merge edges
 74    cut_edges = np.zeros(graph.numberOfEdges, dtype="uint16")
 75    for res in results:
 76        cut_edges[res] += 1
 77    return cut_edges == 0
 78
 79
 80def update_edges(uv_ids, costs, labels, n_threads):
 81    """@private
 82    """
 83    edge_mapping = nifty.tools.EdgeMapping(uv_ids, labels, numberOfThreads=n_threads)
 84    new_uv_ids = edge_mapping.newUvIds()
 85    new_costs = edge_mapping.mapEdgeValues(costs, "sum", numberOfThreads=n_threads)
 86    assert len(new_uv_ids) == len(new_costs)
 87    return new_uv_ids, new_costs
 88
 89
 90def reduce_problem(graph, costs, lifted_uv_ids, lifted_costs, merge_edges, n_threads):
 91    """@private
 92    """
 93    # merge node pairs with ufd
 94    nodes = np.arange(graph.numberOfNodes, dtype="uint64")
 95    uv_ids = graph.uvIds()
 96    ufd = nifty.ufd.ufd(graph.numberOfNodes)
 97    ufd.merge(uv_ids[merge_edges])
 98
 99    # get then new node labels
100    new_labels = ufd.find(nodes)
101
102    # merge the edges and costs
103    new_uv_ids, new_costs = update_edges(uv_ids, costs, new_labels, n_threads)
104    new_lifted_uvs, new_lifted_costs = update_edges(lifted_uv_ids,
105                                                    lifted_costs,
106                                                    new_labels, n_threads)
107    # build the new graph
108    n_new_nodes = int(new_uv_ids.max()) + 1
109    new_graph = nifty.graph.undirectedGraph(n_new_nodes)
110    new_graph.insertEdges(new_uv_ids)
111
112    return new_graph, new_costs, new_lifted_costs, new_lifted_costs, new_labels
113
114
115def hierarchy_level(graph, costs, lifted_uv_ids, lifted_costs,
116                    labels, segmentation, blocking,
117                    internal_solver, n_threads, halo):
118    """@private
119    """
120    merge_edges = solve_subproblems(graph, costs, lifted_uv_ids, lifted_costs,
121                                    segmentation, internal_solver,
122                                    blocking, halo, n_threads)
123    graph, costs, lifted_uv_ids, lifted_costs, new_labels = reduce_problem(graph, costs,
124                                                                           lifted_uv_ids, lifted_costs,
125                                                                           merge_edges, n_threads)
126
127    if labels is None:
128        labels = new_labels
129    else:
130        labels = new_labels[labels]
131
132    return graph, costs, lifted_uv_ids, lifted_costs, labels
133
134
135def blockwise_lmc_impl(graph, costs, lifted_uv_ids, lifted_costs,
136                       segmentation, internal_solver,
137                       block_shape, n_threads, n_levels=1, halo=None):
138    """@private
139    """
140    shape = segmentation.shape
141    graph_, costs_ = graph, costs
142    lifted_uv_ids_, lifted_costs_ = lifted_uv_ids, lifted_costs
143    block_shape_ = block_shape
144    labels = None
145
146    for level in range(n_levels):
147        blocking = nifty.tools.blocking([0, 0, 0], shape, block_shape_)
148        graph_, costs_, lifted_uv_ids_, lifted_costs_, labels = hierarchy_level(graph_, costs_,
149                                                                                lifted_uv_ids_, lifted_costs_,
150                                                                                labels, segmentation, blocking,
151                                                                                internal_solver, n_threads, halo)
152        block_shape_ = [bs * 2 for bs in block_shape]
153
154    # solve the final reduced problem
155    final_labels = internal_solver(graph_, costs_)
156    # bring reduced problem back to the initial graph
157    labels = final_labels[labels]
158    return labels