001 /*--------------------------------------------------------------------------+
002 $Id: MaxWeightMatching.java 26283 2010-02-18 11:18:57Z juergens $
003 | |
004 | Copyright 2005-2010 Technische Universitaet Muenchen |
005 | |
006 | Licensed under the Apache License, Version 2.0 (the "License"); |
007 | you may not use this file except in compliance with the License. |
008 | You may obtain a copy of the License at |
009 | |
010 | http://www.apache.org/licenses/LICENSE-2.0 |
011 | |
012 | Unless required by applicable law or agreed to in writing, software |
013 | distributed under the License is distributed on an "AS IS" BASIS, |
014 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
015 | See the License for the specific language governing permissions and |
016 | limitations under the License. |
017 +--------------------------------------------------------------------------*/
018 package edu.tum.cs.commons.algo;
019
020 import java.util.Arrays;
021 import java.util.List;
022
023 import edu.tum.cs.commons.collections.PairList;
024
025 /**
026 * A class for calculating maximum weighted matching using an augmenting path
027 * algorithm running in O(n^3*m), where n is the size of the smaller node set
028 * and m the size of the larger one. In practice the running time is much less.
029 * <p>
030 * This class is not thread save!
031 *
032 * @author hummelb
033 * @author $Author: juergens $
034 * @version $Rev: 26283 $
035 * @levd.rating GREEN Hash: 2069DC784F078E4503328061B520BBB1
036 *
037 * @param <N1>
038 * The first node type
039 * @param <N2>
040 * The second node type
041 */
042 public class MaxWeightMatching<N1, N2> {
043
044 /**
045 * Flag indicating whether we are running in swapped mode. Swapped mode is
046 * needed as our algorithm requires the second set of nodes not to be
047 * smaller than the first set. If this is not the case, we just swap these
048 * sets, but we need this flag to adjust some parts of the code.
049 */
050 private boolean swapped;
051
052 /** Size of the first (or second if {@link #swapped}) node set. */
053 private int size1;
054
055 /** Size of the second (or first if {@link #swapped}) node set. */
056 private int size2;
057
058 /** The first node set. */
059 private List<N1> nodes1;
060
061 /** The second node set. */
062 private List<N2> nodes2;
063
064 /** The provider for the weights (i.e. weight matrix). */
065 private IWeightProvider<N1, N2> weightProvider;
066
067 /**
068 * This array stores for each node of the second set the index of the node
069 * from the first set, it is matched to (or -1 if is not in matching). If
070 * {@link #swapped}, first and second set change meaning.
071 */
072 private int[] mate = new int[16];
073
074 /**
075 * This is used while searching shortest path and stores the node index we
076 * came from.
077 */
078 private int[] from = new int[16];
079
080 /**
081 * This is used while searching shortest path and stores the distance (i.e.
082 * weight sum) to this node.
083 */
084 private double[] dist = new double[16];
085
086 /**
087 * Calculate the weighted bipartite matching.
088 *
089 * @param matching
090 * if this is non <code>null</code>, the matching (i.e. the pairs of nodes
091 * matched onto each other) will be put into it.
092 *
093 * @return the weight of the matching.
094 */
095 public double calculateMatching(List<N1> nodes1, List<N2> nodes2,
096 IWeightProvider<N1, N2> weightProvider, PairList<N1, N2> matching) {
097
098 if (matching != null) {
099 matching.clear();
100 }
101
102 if (nodes1.isEmpty() || nodes2.isEmpty()) {
103 return 0;
104 }
105
106 init(nodes1, nodes2, weightProvider);
107 prepareInternalArrays();
108
109 for (int i = 0; i < size1; ++i) {
110 augmentFrom(i);
111 }
112
113 double res = 0;
114 for (int i = 0; i < size2; ++i) {
115 if (mate[i] >= 0) {
116 if (matching != null) {
117 if (swapped) {
118 matching.add(nodes1.get(i), nodes2.get(mate[i]));
119 } else {
120 matching.add(nodes1.get(mate[i]), nodes2.get(i));
121 }
122 }
123 res += getWeight(mate[i], i);
124 }
125 }
126 return res;
127 }
128
129 /**
130 * Initializes the data structures from the parameters to the
131 * {@link #calculateMatching(List, List, edu.tum.cs.commons.algo.MaxWeightMatching.IWeightProvider, PairList)}
132 * method.
133 */
134 private void init(List<N1> nodes1, List<N2> nodes2,
135 IWeightProvider<N1, N2> weightProvider) {
136 if (nodes1.size() <= nodes2.size()) {
137 size1 = nodes1.size();
138 size2 = nodes2.size();
139 swapped = false;
140 } else {
141 size1 = nodes2.size();
142 size2 = nodes1.size();
143 swapped = true;
144 }
145 this.nodes1 = nodes1;
146 this.nodes2 = nodes2;
147 this.weightProvider = weightProvider;
148 }
149
150 /** Make sure all internal arrays are large enough. */
151 private void prepareInternalArrays() {
152 if (size2 > mate.length) {
153 int newSize = mate.length;
154 while (newSize < size2) {
155 newSize *= 2;
156 }
157 mate = new int[newSize];
158 from = new int[newSize];
159 dist = new double[newSize];
160 }
161
162 Arrays.fill(mate, 0, size2, -1);
163 }
164
165 /**
166 * Calculate shortest augmenting path and augment along it starting from the
167 * given node (index).
168 */
169 private void augmentFrom(int u) {
170 for (int i = 0; i < size2; ++i) {
171 from[i] = -1;
172 dist[i] = getWeight(u, i);
173 }
174 bellmanFord();
175 int target = findBestUnmatchedTarget();
176 augmentAlongPath(u, target);
177 }
178
179 /** Calculate the shortest path using Bellman-Ford algorithm. */
180 private void bellmanFord() {
181 boolean changed = true;
182 while (changed) {
183 changed = false;
184 for (int i = 0; i < size2; ++i) {
185 if (mate[i] < 0) {
186 continue;
187 }
188 double w = getWeight(mate[i], i);
189 for (int j = 0; j < size2; ++j) {
190 if (i == j) {
191 continue;
192 }
193 double newDist = dist[i] - w + getWeight(mate[i], j);
194 if (newDist - 1e-15 > dist[j]) {
195 dist[j] = newDist;
196 from[j] = i;
197 changed = true;
198 }
199 }
200 }
201 }
202 }
203
204 /** Find the best target which is not yet in the matching. */
205 private int findBestUnmatchedTarget() {
206 int target = -1;
207 for (int i = 0; i < size2; ++i) {
208 if (mate[i] < 0) {
209 if (target < 0 || dist[i] > dist[target]) {
210 target = i;
211 }
212 }
213 }
214 return target;
215 }
216
217 /** Augment along the given path to the target by adjusting the mate array. */
218 private void augmentAlongPath(int u, int target) {
219 while (from[target] >= 0) {
220 mate[target] = mate[from[target]];
221 target = from[target];
222 }
223 mate[target] = u;
224 }
225
226 /**
227 * Returns the weight between two nodes (=indices) handling swapping
228 * transparently.
229 */
230 private double getWeight(int i1, int i2) {
231 if (swapped) {
232 return weightProvider.getConnectionWeight(nodes1.get(i2), nodes2
233 .get(i1));
234 }
235 return weightProvider.getConnectionWeight(nodes1.get(i1), nodes2
236 .get(i2));
237 }
238
239 /** A class providing the weight for a connection between two nodes. */
240 public interface IWeightProvider<N1, N2> {
241
242 /** Returns the weight of the connection between both nodes. */
243 double getConnectionWeight(N1 node1, N2 node2);
244 }
245 }