Rectified Linear Parameterization

Integrated rectifier parameterization

Currently, we use functions of the following form to represent a transport map component: given a multi-index set , we see

where is a positive monotone function (e.g., or softplus).
pros:

  • stellar accuracy when you get the parameters right
  • has good way of expressing dependence between diagonal variable and off-diagonal variable
    problems:
  • Nonlinear and nonconvex in the parameters
  • Requires significant effort to evaluate (each evaluation requires quadrature, which itself requires several evaluations of each multivariate polynomial, which requires several evaluations of each polynomial)
  • Similar to above, a serious chore to invert

Separable linear parameterization

Alternative to this, we have the separable linear parameterization, where

where are monotone functions of some kind, and .
pros:

  • Really fast. No quadrature
  • Given s, you can find in closed form
    cons:
  • as described in the draft with Max, is that these do not capture cross-terms. I.e., any correlation between and cannot be described by this model. This almost surely creates bias in virtually all realistic situations (at least without particular care).

Idea: cross-terms in linear parameterization

What we need for a KR map is to ensure for all . The idea of the integrated rectifier functions is easy to see from this part: just ensure the derivative is positive and then integrate that! This gives properties that are analytically very easy to evaluate.

However, what really happens is that the integrated rectifer "monotonizes" in the direction (a cartoon would be to make something that looks like a bunch of sigmoids), and then just "positivizes" in the directions. However, if we already have a monotone function in the direction, then, we only need to "positivize" the components. In this case, we get something of the following form:

with positive coefficients .

What's happening here?

The major difference is that we're now adding terms dependent on both in the diagonal derivative of the map. However, this only requires the addition of one evaluation of (compared with the quadrature, which requires many new evaluations of the polynomials).

What properties do we get?

Consider a fixed . Then, we can sweep all of the -terms into constants and just consider that we approximate the function using a log-sum. This is going to look really similar to basically approximating the diagonal-derivative in log-space using the cross-terms (indeed, if we instead chose to just separate the diagonal from rectified off-diagonal terms and then take the product there, we'd get that we are approximating the log determinant using polynomials). On the other hand, we still keep the linearity though.

Do we keep the closed for solution for ?

I think so! It doesn't look like the formulation depends on the diagonal components not being dependent on , rather just that the offdiagonal components are not dependent on !
Addendum: see Optimizing Linear Maps (new draft)

Implementation

InputDerivative: gradient with respect to input
set f = 0, grad = [0,0,...,0]
for each midx and wrt each input (-1 for forward eval):
	if the midx == 0:
		if wrt == -1
			add coeff(0) to f
		end if
		continue
	end if
	
	initialize termVal=1, hasDeriv=false, wrtVal = 0, wrtDeriv = 0
	
	for each nonzero midx input dimension j up to J=nonzero_length(midx)-1:
		if dimension[j] == wrt (if this is the input we're differentiating wrt):
			wrtVal = value of psi[wrt] degree midx[wrt]
			wrtDeriv = deriv of psi[wrt] degree midx[wrt]
		else:
			termVal *= value of psi[j] degree midx[j]
		end if
	end for
	// Want to revert to loop body if not rectifying, only add code if we are
	d = last nonzero dimension of midx, i.e., dimension[nonzero_length(midx)]
	if d == wrt // If we differentiate wrt last nonconstant input
		wrtDeriv = deriv of psi[d] degree midx[d]
		if constexpr(rectify): // If there's rectification
			if d == dim: // Derivative wrt last input
				termVal = rectify(termVal) * wrtDeriv
			else: // no midx entry on last input
				wrtVal = value of psi[d] degree midx[d]
				termVal = deriv rectify(termVal*wrtVal)*termVal*wrtDeriv
		else constexpr: // Reduces to loop body
			termVal *= wrtDeriv
	else: // If we've "already" differentiated
		if constexpr(rectify):
			lastVal = value psi[d] degree midx[d]
			rectVal = d == dim ? termVal*wrtVal : termVal*lastVal*wrtVal
			if wrt == -1:
				diagVal = d == dim ? lastVal : 1
				termVal = rectify(rectVal)*diagVal
			else:
				// d_x f(p(x)q(y))g(z) = f'(p(x)q(y))p'(x)q(y)g(z)
				// we know deriv is not of g, so if d != dim, q takes place of g as lastVal
				termVal = rectify(rectVal)*termVal*wrtDeriv*lastVal
			end if
		else // Reduce to loop body
			lastVal = value of psi[d] degree midx[d] // already diff'd, don't need lastDeriv
			if w == -1:
				termVal = termVal*lastVal
			else:
				termVal = termVal*wrtDeriv*lastVal
			end if
		end if
	end if
	if wrt == -1
		f += termVal
	else
		grad(wrt) += termVal
	end if
end for