1from concurrent import futures
2
3import bioimage_cpp as bic
4import numpy as np
5
6from .blockwise_mc_impl import _relabel_from_zero, _remap_edges_sum_costs
7
8
9def find_inner_lifted_edges(lifted_uv_ids, node_list):
10 """@private
11 """
12 lifted_indices = np.arange(len(lifted_uv_ids), dtype="uint64")
13 # find overlap of node_list with u-edges
14 inner_us = np.isin(lifted_uv_ids[:, 0], node_list)
15 inner_indices = lifted_indices[inner_us]
16 inner_uvs = lifted_uv_ids[inner_us]
17 # find overlap of node_list with v-edges
18 inner_vs = np.isin(inner_uvs[:, 1], node_list)
19 return inner_indices[inner_vs]
20
21
22def solve_subproblems(graph, costs, lifted_uv_ids, lifted_costs,
23 segmentation, solver, blocking, halo, n_threads):
24 """@private
25 """
26
27 uv_ids = graph.uv_ids()
28
29 # solve sub-problem from one block
30 def solve_subproblem(block_id):
31
32 # extract nodes from this block
33 block = blocking.get_block(block_id) if halo is None else\
34 blocking.get_block_with_halo(block_id, list(halo)).outer_block
35 bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
36 node_ids = np.unique(segmentation[bb]).astype("uint64")
37
38 # get the sub-graph corresponding to the nodes
39 inner_edges, outer_edges = graph.extract_subgraph_from_nodes(node_ids)
40 sub_uvs = uv_ids[inner_edges]
41
42 # relabel the sub-nodes and associated uv-ids for more efficient processing
43 nodes_relabeled, max_id, mapping = _relabel_from_zero(node_ids)
44 sub_uvs = bic.utils.take_dict(mapping, np.ascontiguousarray(sub_uvs, dtype="uint64"))
45 n_local_nodes = max_id + 1
46 sub_graph = bic.graph.UndirectedGraph.from_edges(n_local_nodes, sub_uvs)
47
48 sub_costs = costs[inner_edges]
49 assert len(sub_costs) == sub_graph.number_of_edges
50
51 # get the inner lifted edges and costs
52 inner_lifted_edges = find_inner_lifted_edges(lifted_uv_ids, node_ids)
53 sub_lifted_uvs = bic.utils.take_dict(
54 mapping, np.ascontiguousarray(lifted_uv_ids[inner_lifted_edges], dtype="uint64"),
55 )
56 sub_lifted_costs = lifted_costs[inner_lifted_edges]
57
58 # solve multicut for the sub-graph
59 sub_result = solver(sub_graph, sub_costs, sub_lifted_uvs, sub_lifted_costs)
60 assert len(sub_result) == len(node_ids), "%i, %i" % (len(sub_result), len(node_ids))
61
62 sub_edgeresult = sub_result[sub_uvs[:, 0]] != sub_result[sub_uvs[:, 1]]
63 assert len(sub_edgeresult) == len(inner_edges)
64 cut_edge_ids = inner_edges[sub_edgeresult]
65 cut_edge_ids = np.concatenate([cut_edge_ids, outer_edges])
66 return cut_edge_ids
67
68 with futures.ThreadPoolExecutor(n_threads) as tp:
69 tasks = [tp.submit(solve_subproblem, block_id)
70 for block_id in range(blocking.number_of_blocks)]
71 results = [t.result() for t in tasks]
72
73 # merge the edge results to get all merge edges
74 cut_edges = np.zeros(graph.number_of_edges, dtype="uint16")
75 for res in results:
76 cut_edges[res] += 1
77 return cut_edges == 0
78
79
80def update_edges(uv_ids, costs, labels, n_threads):
81 """@private
82 """
83 return _remap_edges_sum_costs(uv_ids, labels, costs)
84
85
86def reduce_problem(graph, costs, lifted_uv_ids, lifted_costs, merge_edges, n_threads):
87 """@private
88 """
89 # merge node pairs with ufd
90 n_nodes = graph.number_of_nodes
91 nodes = np.arange(n_nodes, dtype="uint64")
92 uv_ids = np.asarray(graph.uv_ids(), dtype="uint64")
93 ufd = bic.utils.UnionFind(n_nodes)
94 ufd.merge(np.ascontiguousarray(uv_ids[merge_edges], dtype="uint64"))
95
96 # get then new node labels
97 new_labels = ufd.find(nodes)
98
99 # merge the edges and costs
100 new_uv_ids, new_costs = update_edges(uv_ids, costs, new_labels, n_threads)
101 new_lifted_uvs, new_lifted_costs = update_edges(
102 np.asarray(lifted_uv_ids, dtype="uint64"), lifted_costs, new_labels, n_threads,
103 )
104 # build the new graph
105 n_new_nodes = int(new_uv_ids.max()) + 1
106 new_graph = bic.graph.UndirectedGraph.from_edges(
107 n_new_nodes, np.ascontiguousarray(new_uv_ids, dtype="uint64"),
108 )
109
110 return new_graph, new_costs, new_lifted_uvs, new_lifted_costs, new_labels
111
112
113def hierarchy_level(graph, costs, lifted_uv_ids, lifted_costs,
114 labels, segmentation, blocking,
115 internal_solver, n_threads, halo):
116 """@private
117 """
118 merge_edges = solve_subproblems(graph, costs, lifted_uv_ids, lifted_costs,
119 segmentation, internal_solver,
120 blocking, halo, n_threads)
121 graph, costs, lifted_uv_ids, lifted_costs, new_labels = reduce_problem(graph, costs,
122 lifted_uv_ids, lifted_costs,
123 merge_edges, n_threads)
124
125 if labels is None:
126 labels = new_labels
127 else:
128 labels = new_labels[labels]
129
130 return graph, costs, lifted_uv_ids, lifted_costs, labels
131
132
133def blockwise_lmc_impl(graph, costs, lifted_uv_ids, lifted_costs,
134 segmentation, internal_solver,
135 block_shape, n_threads, n_levels=1, halo=None):
136 """@private
137 """
138 shape = segmentation.shape
139 graph_, costs_ = graph, costs
140 lifted_uv_ids_, lifted_costs_ = lifted_uv_ids, lifted_costs
141 block_shape_ = block_shape
142 labels = None
143
144 for level in range(n_levels):
145 blocking = bic.utils.Blocking(
146 roi_begin=[0] * len(shape),
147 roi_end=list(shape),
148 block_shape=list(block_shape_),
149 )
150 graph_, costs_, lifted_uv_ids_, lifted_costs_, labels = hierarchy_level(graph_, costs_,
151 lifted_uv_ids_, lifted_costs_,
152 labels, segmentation, blocking,
153 internal_solver, n_threads, halo)
154 block_shape_ = [bs * 2 for bs in block_shape]
155
156 # solve the final reduced problem
157 final_labels = internal_solver(graph_, costs_, lifted_uv_ids_, lifted_costs_)
158 # bring reduced problem back to the initial graph
159 labels = final_labels[labels]
160 return labels