Integrating interacting particles and gradient flows for clustering and quantization

By: Daniel Sharp

Recently, I finished a paper with Ayoub Belhadji and Youssef Marouk entitled "Weighted quantization using MMD: From mean field to mean shift via gradient flows", found here: https://arxiv.org/abs/2502.10600. Here, I have no interest in presenting the derivations of the mathematics---please see the paper if you are interested in those technical details---rather I find it somewhat illuminating to show graphically the process. Further, there are many deep computational questions that MATLAB can help appropriately answer, which do not warrant their own paper. I am getting ahead of myself! First, we ask the obvious question...

What are quantization and clustering?

While these are not necessarily interchangeable, I will consider them the same problem for the time being because they both use nearly identical tools. In brief, clustering answers the question: How do we best represent a large dataset with a substantially smaller one? First, let's create a two-dimensional dataset...
rng(1)
N_centroid = 20;
N_data = 5000;
state_dim = 2;
data_norm = randn(N_data, state_dim);
data_smile = [data_norm(:,1) -data_norm(:,2) + data_norm(:,1).^2];
data_left_eye = 0.25 * randn(N_data, state_dim) * diag([1.0, 1.75]) + [-1.2 9.5];
data_right_eye = 0.25 * randn(N_data, state_dim) * diag([1.0, 1.75]) + [1.2 9.5];
data_all = [data_smile; data_left_eye; data_right_eye];
data = data_all(randperm(3*N_data, N_data), :);
data = (data - mean(data, 1)) ./ std(data, 1);
set(0,'Defaultlinelinewidth',3)
scatter(data(:, 1), data(:, 2), 'filled', MarkerFaceAlpha=0.25);
As a side-note, this beautiful distribution (which I will whimsically refer to as the "Joker") can also be found in NumPy here: https://gist.github.com/dannys4/817394000cb75eeef22aefcd2a5645c2 .
The most common quantization and clustering algorithm, then, is Lloyd's k-means algorithm---as a rough procedure, this is going to be the following procedure for N data points in dimension d:
Initialize centroids Y_0 with shape (K, d)
For i = 0....
For j = 1,...,N % Scatter
Create label L(j) for closest centroid Y_i(L(j)) to data X(j)
 
For ell = 1,...,K % Gather
Set Y_{i+1}(ell) as the average of points labelled ell, i.e., average of X(j) s.t. L(j) == ell
Luckily enough, there's a simple kmeans implementation in MATLAB that we'll use.
[idx, C] = kmeans(data, N_centroid);
% Visualize the cluster centroids
scatter(data(:, 1), data(:, 2), 'filled', MarkerFaceAlpha=0.25);
hold on;
scatter(C(:, 1), C(:, 2), 100, 'r', 'filled', 'MarkerEdgeColor', 'k');
h = voronoi(C(:, 1), C(:, 2));
h(2).Color = 'w';
hold off;
Here, we plotted the Voronoi diagram---this acts as the dividing lines between labels. Each centroid is the "average" of all the blue points lying in the (perhaps unbounded) polygon surrounding it. While these centroids may have unusual dynamics during k-means iterations, the key idea is:

The transient dynamics do not matter; what matters is the steady-state solution

With that, we first ask:

What's wrong with k-means?

This is the crux of the question we consider. At its best, though, k-means is actually quite good---hence its general usage as a workhorse in clustering. What you can notice, however, is its tendency to do a few undesirable things.
  1. Notice that one eye tends to end up with significantly more clusters than the other, despite their nearly-identical structure. This is a problem of initialization; in fact, the most widely-used adaptation of k-means is initializing using the "k-means++" algorithm, which is what MATLAB does by default. One can view this simply as a failure of repulsion between the centroids themselves.
  2. Note that, despite vanishing probability of being in the tails of the smile (a "banana" distribution, for those familiar), we still end up with particles out near the edges. While the problem of adverse initialization is relatively reasonable for most readers, this is much more unusual and requires some delving into the mathematics.
What's nice, though, is that our approach to fixing 2 requires us to fix 1. Without further ado, then....

What's up with the samples in vanishing regions?

Let's plot the same thing as above, but now consider the weights within each Voronoi cell defined as:
wts = reshape(arrayfun(@(ell) sum(idx == ell), 1:N_centroid), N_centroid, 1)/N_centroid;
% Visualize the cluster centroids
scatter(data(:, 1), data(:, 2), 'filled', MarkerFaceAlpha=0.25);
hold on;
scatter(C(:, 1), C(:, 2), 10*wts, 'r', 'filled', 'MarkerEdgeColor', 'k');
h = voronoi(C(:, 1), C(:, 2));
h(2).Color = 'w';
hold off;
We can really see now that the weights in the tail of the smile (the "scars" if you are one for Christian Nolan) are vanishingly small. Therefore, if we consider the centroids as the optimum of some function, we must have that it requires that Y represents all the data and not just the bulk of it. To spoil the fun, this is exactly what happens! Indeed, (Du, Farber, and Gunzberger 1999) proves that Lloyd's algorithm is a preconditioned gradient descent of the Wasserstein distance.

What's the Wasserstein?

While this is a familiar object in uncertainty quantification and probability, it may not be known to data scientists looking for more clustering algorithms! The 2-Wasserstein distance , a metric between probability measures, is defined as
Here, we define as a coupling between the two distributions such that, if μ and π are densities, we have and . In this data-driven setting, however, we have no densities; such an object is rather foreign to practitioners. Suppose, then, that μ is our centroids, represented by a weighted set of points . Then, the target dataset π is represented by unweighted . The true 2-Wasserstein distance can be calculated by what's known as the Hungarian algorithm.

What's k-means doing then?

While the Hungarian algorithm can compute the distance between these two datasets, it does not minimize this distance over Y and w; this is what we're truly interested in! First, one can find that, for a fixed , we get
Here, we use 𝕀 to denote the indicator function for set , which corresponds to the Voronoi cell for centroid ! This exactly corresponds, in the case of i.i.d. unweighted data , to the weights we calculated in the plot!
Therefore, we have reduced the problem to figuring out how we choose given . We see that Lloyd defines a functional
This says: given a set of K centroids Y, what is the best (squared) 2-Wasserstein I can get when optimizing over the weights of these points? While we know how to calculate the best weights using the Voronoi tesselation, it is actually nontrivial to calculate . Luckily, though, we are not interested in calculating , as long as we can minimize it! Du, Farber, and Gunzburger prove the following:
where is the exact k-means map! Here, then, we can rearrange terms and use as a preconditioner:
In other words, the map defining k-means is just a preconditioned gradient descent of the 2-Wasserstein between our measure defined by centroids Y and the target data . This seems, at first glance, rather peculiar---how can we calculate this map in significantly lower time complexity than computing the 2-Wasserstein itself? Something to consider, but we won't touch it. What is important, though, is that this explains the geometric behavior seen above! Intuitively, we thus have the following idea:

If k-means minimizes Wasserstein, it must be robust to the furthest points from the bulk of target π!

This idea comes exactly from noting the nested optimization problem when substituting in the definition of 2-Wasserstein:
At a high level---for any point X from the data distribution, the closest centroid Ycannot be "too far", or else the Wasserstein can significantly increase.

If not Wasserstein, then who/what?

In our paper, one method we take is to instead consider the maximum mean discrepancy (MMD). This is, unfortunately, more complicated to explain due to the less intuitive nature of kernelized geometry, but I'll do a cursory idea. Again, for more in-depth detail, consider reading our work!
The MMD, though, is instead defined using a reproducing kernel Hilbert space, or RKHS. These are often used in the context of data science---many a data scientist has referred to simple, possibly incorrect, RKHS properties as the "Kernel trick"---to embed high-dimensional spaces in simple geometries. Some interesting connections are support vector machines, Gaussian process regression, quasi-Monte Carlo methods [1] [2], or determinantal point processes (which Ayoub, my co-author, has phenomenal work on [3]). The MMD is often in machine learning defined as the following:
for a kernel κ mapping two vectors to a positive number that is maximized when the two arguments are the same. While our work considers a more general class, here consider , which is often called the RBF or Gaussian kernel. I prefer calling it the "squared exponential kernel" to ensure we avoid confusion with other distance-based kernels generated by other radial basis functions or with the fact that neither μ nor π will be Gaussian.
While it may not immediately seem clear why this is generally a measurement of how different μ and π are, a simple way to imagine it is to consider as the length-scale in κ, defined as , goes to zero. Then, we get
since both μ and π integrate to unity, where is the inner product. If you recall that is the measure of similarity between unit vectors (i.e., the cosine between two vectors), then it seems reasonable to suggest this limiting case of the MMD is measuring the difference between two functions that extend unit vectors (i.e., sum to unity).
If Lloyd defines a preconditioned map for 2-Wasserstein, our method Mean shift interacting particles (MSIP) does the same for the MMD: define
We end up recovering the MSIP map as the following:
where is a matrix-valued preconditioning function. Look familiar? Hopefully! Here, we define the MMD-optimal weights for a fixed as:
Here, is the kernel matrix and is the kernel mean embedding. If is the kernel-optimal weights of this form, then we define
The notation is intended to evoke the idea of "kernelized first moment" (in contrast to the "kernelized zeroth moment", ). Compare this with the zeroth and first moments over the voronoi cell as seen in k-means. Where the Voronoi cell must assign every data point X to some Voronoi cell, using the MMD allows us to soften such a requirement.

The implementation problem

Now that we understand where this method comes from, we can implement it. Contrasting with k-means, we instead implement a flow of the function : . We subtract out to ensure we are invariant to the currect position (i.e., note that , so it is akin to performing gradient descent over a preconditioned space in continuous time).
%% Create ODE Right-hand side
function [K, diff_XY] = kernel_matrix(X, Y, sigma)
X_a = reshape(X, [], 1, size(X,2));
Y_a = reshape(Y, 1, [], size(Y,2));
diff_XY = X_a - Y_a;
K = squeeze(exp( -0.5 * sum( (diff_XY/sigma).^2, 3)));
end
 
function f = msip_ode_rhs_nomass(y, NY, data, sigma, nugget)
% Use nugget to prevent ill conditioned matrices.
Y = reshape(y, NY, []); % Convert from vector to matrix (N_centroids, dim)
K_YY = kernel_matrix(Y, Y, sigma);
K_YX = kernel_matrix(Y, data, sigma);
v_0 = sum(K_YX, 2);
w = (K_YY + nugget * eye(NY)) \ v_0;
K_YY_w = K_YY * diag(w) + nugget * eye(NY);
v_1 = K_YX * data;
f = reshape((K_YY_w \ v_1) - Y, [], 1); % Convert back from matrix to vector
end
We first use tried-and-true ode45 to integrate this possibly-stiff system.
Y0 = data(randperm(N_data, N_centroid), :);
sigma = 0.5;
nugget = 1e-5;
final_time = 1000;
ode_fn = @(t,y) msip_ode_rhs_nomass(y, N_centroid, data, sigma, nugget);
[t,y_msip_ode] = ode45(ode_fn, [0.0, final_time], Y0(:));
Similar to the k-means results, we plot here:
% wts = reshape(arrayfun(@(ell) sum(idx == ell), 1:N_centroid), N_centroid, 1)/N_centroid;
% Visualize the cluster centroids
y_msip_final = reshape(y_msip_ode(end,:), N_centroid, []);
k_matrix_final = kernel_matrix(y_msip_final, y_msip_final, sigma);
wts_msip = (k_matrix_final + nugget * eye(N_centroid)) \ mean(kernel_matrix(y_msip_final, data, sigma), 2);
sz_msip = 500 * wts_msip / max(abs(wts_msip));
clf
scatter(data(:, 1), data(:, 2), 'filled', MarkerFaceAlpha=0.25);
hold on;
scatter(y_msip_final(:, 1), y_msip_final(:, 2), sz_msip, 'r', 'filled', 'MarkerEdgeColor', 'k');
hold off;
Note here that the resulting centroids have significantly more even weights between the three clusters, the centroids are significantly more even between the three clusters, and there are fewer centroids lingering in the tails of the distribution.

Can we precondition better?

Note that, here, we implement dynamics of the form:
For those more familiar with numerical analysis, this is precisely an index-one semi-linear differential algebraic equation with non-constant mass matrix (say that three times fast!). More specifically, we create the system of differential-algebraic equations as follows:
In this way, we completely remove the inversion of any matrix and could solely use matrix multiplication in theory. Here, is representing the Kronecker product. While we don't need to actually form the (generally sparse) matrix with block structure, we do this in the following code since we are only experimenting here.

Why would we wish to do this?

DAEs are often used to enforce specific invariances in dynamics; for example, in dynamics, the first DAE many people see is that of the pendulum, which must conserve energy and momentum during a simulation. Simplectic integrators are often used to preserve structure in simulating molecular systems for the same reason. It stands to reason that one might be curious how such an integrator behaves here.
function M = msip_dae_mass_matrix(u, sigma, N, nugget)
Y = reshape(u(1:end-N), N, []);
w = u(end-N+1:end);
K_YY = kernel_matrix(Y, Y, sigma);
K_YY_w = K_YY * diag(w) + nugget * eye(N); % Again, add for invertibility
M = blkdiag(kron(eye(size(Y,2)),K_YY_w), zeros(N));
end
 
function f = msip_data_dae_rhs(u, NY, data, sigma, nugget)
Y = reshape(u(1:end-NY), NY, []);
w = u(end-NY+1:end);
K_YY = kernel_matrix(Y, Y, sigma);
K_YX = kernel_matrix(Y, data, sigma);
v_0 = mean(K_YX, 2);
v_1 = (K_YX * data) / size(data, 1);
diff_Y = reshape(v_1 - (K_YY*diag(w) + nugget * eye(NY))*Y, [], 1);
diff_w = v_0-(K_YY + nugget * eye(NY))*w; % Add for invertibility
f = [diff_Y; diff_w];
end
Of course, for semi-implicit solvers, we often want the Jacobian of f, the RHS of the ODE. I'm providing this at the end, because it's not an enlightening calculation. We're going to just check the gradients using the checkGradients function from the optimization toolbox.
u_final = [y_msip_final(:);wts_msip(:)];
checkGradients(@(u) msip_data_dae_rhs_and_jac(u, N_centroid, data, sigma, nugget), u_final, Display='on');
____________________________________________________________ Objective function derivatives: Maximum relative difference between supplied and finite-difference derivatives = 8.22546e-08. checkGradients successfully passed. ____________________________________________________________
Now, we implement the DAE:
w0 = (kernel_matrix(Y0, Y0, sigma) + nugget * eye(N_centroid)) \ mean(kernel_matrix(Y0, data, sigma), 2);
u0 = [Y0(:); w0(:)];
f_rhs = @(t,u) msip_data_dae_rhs(u, N_centroid, data, sigma, nugget);
f_rhs_jac = @(t,u) msip_data_dae_rhs_jac(u, N_centroid, data, sigma, nugget);
M_0 = msip_dae_mass_matrix(u0, sigma, N_centroid, nugget);
u0_rhs = f_rhs(0.,u0);
u0_prime = [M_0(1:end-N_centroid, 1:end-N_centroid) \ u0_rhs(1:end-N_centroid); u0(end-N_centroid+1:end)];
M_pattern = msip_dae_mass_matrix(ones(numel(u0),1), sigma, N_centroid, nugget);
opts = odeset(...
Mass=@(t,y) msip_dae_mass_matrix(y, sigma, N_centroid, nugget),...
MStateDependence='strong',...
MassSingular='yes',...
InitialSlope=u0_prime,...
Jacobian=f_rhs_jac...
);
And then we solve...
[t, y] = ode15s(f_rhs, [0. final_time], u0, opts);
Warning: Failure at t=5.033265e-04. Unable to meet integration tolerances without reducing the step size below the smallest value allowed (1.734723e-18) at time t.
... and we fail?

As Ayoub asked me: Why fail?

Consider a configuration of points Ywhere for a fixed vector v and small perturbation for each point. Note, then, that we have a rather problematic kernel matrix :
y_prtrb = 1e-3 * randn(N_centroid, state_dim);
K_prtrb = kernel_matrix(y_prtrb, y_prtrb, sigma);
fprintf("Condition number: %e\n", cond(K_prtrb));
Condition number: 2.167394e+18
This is not good! It isn't very important, however, that we actually satisfy the algebraic constraint exactly. This can be suggested by the main motivation of this work:

The transient dynamics do not matter; what matters is the steady-state solution

Satisfying any constraint is only important in the steady-state limit. Otherwise, it is (colloquially) only encouraged for quicker convergence, but transient instability introduced by the kernel matrix is rather unimportant---adaptive step sizing trying to resolve stiffness is important for general transient dynamics, but is not vital for our use-cases.

Conclusions

This is not intended as some kind of dig at adaptive, variable-order DAE solving---there are great solvers for this in MATLAB, as well as Julia (DifferentialEquations) and C++ (Sundials, Chrono, etc.). I think the point here is that, in interacting particle systems, people often concentrate on satisfying transient dynamics. This is either by Euler fixed-size timestepping, taking mean-field particle limits, or using a consistent discretization of a flow-based PDE (see, e.g., Wasserstein or Hellenger-Kantorovich gradient flows). Instead, the lesson here is that we should seek dynamics that get correct steady-state solutions. Solving transient dynamics for an autonomous system exactly? It is often neither desired nor crucial.

APPENDIX: Jacobian of DAE RHS, f

function jac_f = msip_data_dae_rhs_jac(u, NY, data, sigma, nugget)
Y = reshape(u(1:end-NY), NY, []);
N_data = size(data,1);
dim = size(Y, 2);
w = reshape(u(end-NY+1:end), [], 1);
[K_YY, diff_YY] = kernel_matrix(Y, Y, sigma);
K_YX = kernel_matrix(Y, data, sigma);
v_0 = mean(K_YX, 2);
v_1 = (K_YX * data) / N_data;
v_2_diag = tensorprod(K_YX, reshape(data, [], dim, 1) .* reshape(data, [], 1, dim), 2, 1);
v_2_diag = v_2_diag / N_data;
v_0_jac_Y = zeros(NY, NY, dim);
v_1_jac_Y = zeros(NY, dim, NY, dim);
Kw_jac_Y = zeros(NY, NY, dim);
K_diagw_Y_jac_Y = zeros(NY, dim, NY, dim);
diff_Y_jac_w_tens = zeros(NY, dim, NY);
 
for j = 1:NY
v_0_jac_Y(j, j, :) = v_1(j,:) - (Y(j, :) * v_0(j));
v_1_jac_Y(j, :, j, :) = reshape(v_2_diag(j, :, :), dim, dim) - (reshape(v_1(j, :), dim, 1) * Y(j, :));
diff_Y_jac_w_tens(:, :, j) = - (K_YY(:, j) .* Y(j, :));
for i = 1:NY
delta_term = sigma^2 * K_YY(j, i) * w(i) * eye(dim);
if i == j
diff_YY_slice = reshape(diff_YY(:, i, :), NY, dim);
K_diagw_Y_divY = (K_YY(i, :) .* w') * diff_YY_slice;
Kw_jac_Y(j, i, :) = K_diagw_Y_divY;
sum_term = Y.' * ((K_YY(i, :)' .* w) .* diff_YY_slice) + nugget * sigma^2 * eye(dim);
K_diagw_Y_jac_Y(j, :, i, :) = reshape(sum_term + delta_term, 1, dim, 1, dim);
else
K_diagw_Y_divY = w(i) * K_YY(i, j) * (Y(j, :) - Y(i, :));
Kw_jac_Y(j, i, :) = K_diagw_Y_divY;
sum_term = reshape(Y(i, :), dim, 1) * (w(i) * K_YY(i, j) * (Y(j, :) - Y(i, :)));
K_diagw_Y_jac_Y(j, :, i, :) = reshape(sum_term + delta_term, 1, dim, 1, dim);
end
end
end
diff_Y_jac_Y_tens = v_1_jac_Y - K_diagw_Y_jac_Y;
diff_w_jac_Y_tens = v_0_jac_Y - Kw_jac_Y;
diff_Y_jac_Y = reshape(diff_Y_jac_Y_tens, dim * NY, dim * NY) / (sigma * sigma);
diff_w_jac_Y = reshape(diff_w_jac_Y_tens, NY, dim * NY) / (sigma * sigma);
diff_Y_jac_w = reshape(diff_Y_jac_w_tens, dim * NY, NY);
diff_w_jac_w = -(K_YY + nugget * eye(NY));
jac_f = [diff_Y_jac_Y diff_Y_jac_w; diff_w_jac_Y diff_w_jac_w];
end
 
function [fval, fgrad] = msip_data_dae_rhs_and_jac(u, NY, data, sigma, nugget)
fval = msip_data_dae_rhs(u, NY, data, sigma, nugget);
fgrad = msip_data_dae_rhs_jac(u, NY, data, sigma, nugget);
end