tags:
- transport
- parameterization
Rectified Linear Parameterization
Currently, we use functions of the following form to represent a transport map component: given a multi-index set
where
pros:
Alternative to this, we have the separable linear parameterization, where
where
pros:
What we need for a KR map is to ensure
However, what really happens is that the integrated rectifer "monotonizes" in the
with positive coefficients
The major difference is that we're now adding terms dependent on both
Consider a fixed
I think so! It doesn't look like the formulation depends on the diagonal components not being dependent on
Addendum: see Optimizing Linear Maps (new draft)
InputDerivative: gradient with respect to input
set f = 0, grad = [0,0,...,0]
for each midx and wrt each input (-1 for forward eval):
if the midx == 0:
if wrt == -1
add coeff(0) to f
end if
continue
end if
initialize termVal=1, hasDeriv=false, wrtVal = 0, wrtDeriv = 0
for each nonzero midx input dimension j up to J=nonzero_length(midx)-1:
if dimension[j] == wrt (if this is the input we're differentiating wrt):
wrtVal = value of psi[wrt] degree midx[wrt]
wrtDeriv = deriv of psi[wrt] degree midx[wrt]
else:
termVal *= value of psi[j] degree midx[j]
end if
end for
// Want to revert to loop body if not rectifying, only add code if we are
d = last nonzero dimension of midx, i.e., dimension[nonzero_length(midx)]
if d == wrt // If we differentiate wrt last nonconstant input
wrtDeriv = deriv of psi[d] degree midx[d]
if constexpr(rectify): // If there's rectification
if d == dim: // Derivative wrt last input
termVal = rectify(termVal) * wrtDeriv
else: // no midx entry on last input
wrtVal = value of psi[d] degree midx[d]
termVal = deriv rectify(termVal*wrtVal)*termVal*wrtDeriv
else constexpr: // Reduces to loop body
termVal *= wrtDeriv
else: // If we've "already" differentiated
if constexpr(rectify):
lastVal = value psi[d] degree midx[d]
rectVal = d == dim ? termVal*wrtVal : termVal*lastVal*wrtVal
if wrt == -1:
diagVal = d == dim ? lastVal : 1
termVal = rectify(rectVal)*diagVal
else:
// d_x f(p(x)q(y))g(z) = f'(p(x)q(y))p'(x)q(y)g(z)
// we know deriv is not of g, so if d != dim, q takes place of g as lastVal
termVal = rectify(rectVal)*termVal*wrtDeriv*lastVal
end if
else // Reduce to loop body
lastVal = value of psi[d] degree midx[d] // already diff'd, don't need lastDeriv
if w == -1:
termVal = termVal*lastVal
else:
termVal = termVal*wrtDeriv*lastVal
end if
end if
end if
if wrt == -1
f += termVal
else
grad(wrt) += termVal
end if
end for