1from concurrent import futures
2
3import numpy as np
4import nifty
5from vigra.analysis import relabelConsecutive
6
7
8def find_inner_lifted_edges(lifted_uv_ids, node_list):
9 """@private
10 """
11 lifted_indices = np.arange(len(lifted_uv_ids), dtype="uint64")
12 # find overlap of node_list with u-edges
13 inner_us = np.in1d(lifted_uv_ids[:, 0], node_list)
14 inner_indices = lifted_indices[inner_us]
15 inner_uvs = lifted_uv_ids[inner_us]
16 # find overlap of node_list with v-edges
17 inner_vs = np.in1d(inner_uvs[:, 1], node_list)
18 return inner_indices[inner_vs]
19
20
21def solve_subproblems(graph, costs, lifted_uv_ids, lifted_costs,
22 segmentation, solver, blocking, halo, n_threads):
23 """@private
24 """
25
26 uv_ids = graph.uvIds()
27
28 # solve sub-problem from one block
29 def solve_subproblem(block_id):
30
31 # extract nodes from this block
32 block = blocking.getBlock(block_id) if halo is None else\
33 blocking.getBlockWithHalo(block_id, halo).outerBlock
34 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
35 node_ids = np.unique(segmentation[bb])
36
37 # get the sub-graph corresponding to the nodes
38 inner_edges, outer_edges = graph.extractSubgraphFromNodes(node_ids)
39 sub_uvs = uv_ids[inner_edges]
40
41 # relabel the sub-nodes and associated uv-ids for more efficient processing
42 nodes_relabeled, max_id, mapping = relabelConsecutive(node_ids,
43 start_label=0,
44 keep_zeros=False)
45 sub_uvs = nifty.tools.takeDict(mapping, sub_uvs)
46 n_local_nodes = max_id + 1
47 sub_graph = nifty.graph.undirectedGraph(n_local_nodes)
48 sub_graph.insertEdges(sub_uvs)
49
50 sub_costs = costs[inner_edges]
51 assert len(sub_costs) == sub_graph.numberOfEdges
52
53 # get the inner lifted edges and costs
54 inner_lifted_edges = find_inner_lifted_edges(lifted_uv_ids, node_ids)
55 sub_lifted_uvs = nifty.tools.takeDict(mapping, lifted_uv_ids[inner_lifted_edges])
56 sub_lifted_costs = lifted_costs[inner_lifted_edges]
57
58 # solve multicut for the sub-graph
59 sub_result = solver(sub_graph, sub_costs, sub_lifted_uvs, sub_lifted_costs)
60 assert len(sub_result) == len(node_ids), "%i, %i" % (len(sub_result), len(node_ids))
61
62 sub_edgeresult = sub_result[sub_uvs[:, 0]] != sub_result[sub_uvs[:, 1]]
63 assert len(sub_edgeresult) == len(inner_edges)
64 cut_edge_ids = inner_edges[sub_edgeresult]
65 cut_edge_ids = np.concatenate([cut_edge_ids, outer_edges])
66 return cut_edge_ids
67
68 with futures.ThreadPoolExecutor(n_threads) as tp:
69 tasks = [tp.submit(solve_subproblem, block_id)
70 for block_id in range(blocking.numberOfBlocks)]
71 results = [t.result() for t in tasks]
72
73 # merge the edge results to get all merge edges
74 cut_edges = np.zeros(graph.numberOfEdges, dtype="uint16")
75 for res in results:
76 cut_edges[res] += 1
77 return cut_edges == 0
78
79
80def update_edges(uv_ids, costs, labels, n_threads):
81 """@private
82 """
83 edge_mapping = nifty.tools.EdgeMapping(uv_ids, labels, numberOfThreads=n_threads)
84 new_uv_ids = edge_mapping.newUvIds()
85 new_costs = edge_mapping.mapEdgeValues(costs, "sum", numberOfThreads=n_threads)
86 assert len(new_uv_ids) == len(new_costs)
87 return new_uv_ids, new_costs
88
89
90def reduce_problem(graph, costs, lifted_uv_ids, lifted_costs, merge_edges, n_threads):
91 """@private
92 """
93 # merge node pairs with ufd
94 nodes = np.arange(graph.numberOfNodes, dtype="uint64")
95 uv_ids = graph.uvIds()
96 ufd = nifty.ufd.ufd(graph.numberOfNodes)
97 ufd.merge(uv_ids[merge_edges])
98
99 # get then new node labels
100 new_labels = ufd.find(nodes)
101
102 # merge the edges and costs
103 new_uv_ids, new_costs = update_edges(uv_ids, costs, new_labels, n_threads)
104 new_lifted_uvs, new_lifted_costs = update_edges(lifted_uv_ids,
105 lifted_costs,
106 new_labels, n_threads)
107 # build the new graph
108 n_new_nodes = int(new_uv_ids.max()) + 1
109 new_graph = nifty.graph.undirectedGraph(n_new_nodes)
110 new_graph.insertEdges(new_uv_ids)
111
112 return new_graph, new_costs, new_lifted_costs, new_lifted_costs, new_labels
113
114
115def hierarchy_level(graph, costs, lifted_uv_ids, lifted_costs,
116 labels, segmentation, blocking,
117 internal_solver, n_threads, halo):
118 """@private
119 """
120 merge_edges = solve_subproblems(graph, costs, lifted_uv_ids, lifted_costs,
121 segmentation, internal_solver,
122 blocking, halo, n_threads)
123 graph, costs, lifted_uv_ids, lifted_costs, new_labels = reduce_problem(graph, costs,
124 lifted_uv_ids, lifted_costs,
125 merge_edges, n_threads)
126
127 if labels is None:
128 labels = new_labels
129 else:
130 labels = new_labels[labels]
131
132 return graph, costs, lifted_uv_ids, lifted_costs, labels
133
134
135def blockwise_lmc_impl(graph, costs, lifted_uv_ids, lifted_costs,
136 segmentation, internal_solver,
137 block_shape, n_threads, n_levels=1, halo=None):
138 """@private
139 """
140 shape = segmentation.shape
141 graph_, costs_ = graph, costs
142 lifted_uv_ids_, lifted_costs_ = lifted_uv_ids, lifted_costs
143 block_shape_ = block_shape
144 labels = None
145
146 for level in range(n_levels):
147 blocking = nifty.tools.blocking([0, 0, 0], shape, block_shape_)
148 graph_, costs_, lifted_uv_ids_, lifted_costs_, labels = hierarchy_level(graph_, costs_,
149 lifted_uv_ids_, lifted_costs_,
150 labels, segmentation, blocking,
151 internal_solver, n_threads, halo)
152 block_shape_ = [bs * 2 for bs in block_shape]
153
154 # solve the final reduced problem
155 final_labels = internal_solver(graph_, costs_)
156 # bring reduced problem back to the initial graph
157 labels = final_labels[labels]
158 return labels