Wasserstein-Fisher-Rao Gradient Flow

Fisher-Rao

Suppose we have two probability measures, , and we would like to find the distance between them. While one naturally might consider the Wasserstein metric, which (by Brenier-Benamou) is given by

this is well-studied and has well-known restrictions. An alternative here would be the Fisher-Rao distance,

Note that this is distance is constrained by a reaction PDE instead of a continuity equation. We can see that

where . Since we know , then the total mass can't possibly change (consider discretizing this via Euler and noting that it won't change for any step-size , so it converges as ). One can see that the (adjusted) Fisher-Rao induces a geodesic of the form

but this is irrelevant to the topic at hand.

Wasserstein-Fisher-Rao

The PDE constraining the Wasserstein-Fisher-Rao metric will intuitively be

i.e., we use velocity field and "reaction function" to govern the dynamics of . For a given functional , we would like to find a "Wasserstein-Fisher-Rao" gradient flow. In particular, we would like to find a velocity field and reaction function which determine a s.t. minimizes using constant-speed geodesics in the WFR geometry. To do this, we first must define a "Wasserstein-Fisher-Rao gradient". Generally one defines such gradients by writing as for some governing the PDE (TODO: elaborate?). Observe that

Then, if I have a measure , I would end up with something like

See, e.g., YWR23

KL functional

At first, we may be interested in the KL-divergence function often used in gradient flows, which (for a given distribution ) has pertinent information

but if we keep as a discrete measure, we will not be able to ever estimate . It is possible to use a kernelized version of in the vein of SVGD, but I'm not sold. Alternatively, I could use score matching. For the record, this gives

I don't think I need to center the weight derivative? Anyway, if is truly a discrete distribution, then this is undefined. If you kernelize it.... I'm not entirely sure, but it'd give something I guess? On the other hand, perhaps you could match the score (or score ratio). See Polynomial Score Matching, which would need to be modified for accounting for , which weights the true .

MMD functional?

Suppose we now have

where is the norm in the RKHS induced by the kernel . One can find that

where refers to the gradient in the second argument. Suppose that . Then,
. More generally, if for positive and symmetric distance, then you get which is generally pretty easy to compute. This gives the update equation as

Why is this nice?

What's nice here is that this is "forgetful" of our initial statue , assuming the convexity and all that (probably not the case, but still). The choice of is one of preconditioning, it's not an end-all-be-all choice in theory. So far in our transport quadrature, our choice of reference basically entirely constrains the points to be the pushforward of the reference quadrature under the map with no change to the weights, where this isn't necessarily the case.

Alternative "static" formulation

Take our Hilbert space to be polynomials up to degree with inner product

so our kernel becomes the Christoffel-Darboux kernel of

with the norm. I.e., this kernel projects onto the polynomials , which are orthogonal w.r.t. . Without loss of generality, assume . Obviously, for any function , we see

which means the embedding of an arbitrary function into our RKHS is exactly the orthogonal projection of . Our maximum mean discrepancy becomes

Assume we want to match points and weights , i.e., discrete measure to some target . We see

^3be7e2
with . Then, we get that

and somewhat less obviously, suppose we abuse notation to consider a diagonal matrix with diagonal entry equal to ; matrix calculus dictates

which indicates the trace is operating on a matrix with only one nonzero column ; therefore, is the th entry in the th column for and zero otherwise. This corresponds to

Here we maintain control of three things: . We know that induces the orthogonal polynomial family , and this gets normalized (with respect to ), so we don't get to choose which scaling for the family. Notably, for fixed , this becomes a quadratic loss in , so we can use a linear solve to find the solution. For the case of finding as well, it seems straightforward enough to attempt to just choose and optimize over these values. Unfortunately, even in one dimension, the Gaussian quadrature being unique does not mean that it will be easy to find in this setting. In higher dimensions, this will be very difficult without something of a greedy procedure.

Returning to dynamic

We now assume data from and want to estimate a quadrature via discrete as and we have reference distribution . Taking as our "preconditioner", we estimate for from data (perhaps), then we have (for quadrature points)

One can compare and contrast this with the minimization problem given in ^021c59. While they are very similar, note that the geometry of the static problem by on the points, and the WFR places the scaling of on the weights. This is a slight change, but can make remarkable differences due to imposing different geometries.

One node case

Suppose we have node and target measure , and we use kernel . Then, recall the form of WFR-MMD-Quad

In this squared-exponential kernel case, we know that and for any . Therefore, we reduce the equations to

We know that, if we set the RHS to zero, we might get a steady-state solution (i.e., a fixed-point of the ODE) such that and for all . Therefore, assuming strictly, we must have

The expressions on the right are necessitated by the form of the squared-exponential kernel.

Dirac case

If for some , then consider and . Therefore, we know that is a valid steady-state solution.

Empirical distribution case

Suppose for Lebesgue measure . Then, we consider the empirical distribution for given samples . Then, we note that , which is creating a KDE using the samples and evaluating it at . Similarly, the RHS of the expression for is estimating the moment of from its KDE. If , then we get that independent of , and therefore I'm pretty sure that, as , regardless of . Similarly, I think that approaches the expectation of .

I'm not sure of either of these things, though.

Time-dependent Dirac case

We return to the original ODE when Then,

Particle Mirror Descent