
The EnKF and other ensemble-based methods.

   1"""The EnKF and other ensemble-based methods."""
   3import numpy as np
   4import scipy.linalg as sla
   5from numpy import diag, eye, sqrt, zeros
   7import dapper.tools.multiproc as multiproc
   8from dapper.stats import center, inflate_ens, mean0
   9from dapper.tools.linalg import mldiv, mrdiv, pad0, svd0, svdi, tinv, tsvd
  10from dapper.tools.matrices import funm_psd, genOG_1
  11from dapper.tools.progressbar import progbar
  12from dapper.tools.randvars import GaussRV
  13from dapper.tools.seeding import rng
  15from . import da_method
  19class ens_method:
  20    """Declare default ensemble arguments."""
  22    infl: float = 1.0
  23    rot: bool = False
  24    fnoise_treatm: str = "Stoch"
  28class EnKF:
  29    """The ensemble Kalman filter.
  31    Refs: `bib.evensen2009ensemble`.
  32    """
  34    upd_a: str
  35    N: int
  37    def assimilate(self, HMM, xx, yy):
  38        # Init
  39        E = HMM.X0.sample(self.N)
  40        self.stats.assess(0, E=E)
  42        # Cycle
  43        for k, ko, t, dt in progbar(HMM.tseq.ticker):
  44            E = HMM.Dyn(E, t - dt, dt)
  45            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
  47            # Analysis update
  48            if ko is not None:
  49                self.stats.assess(k, ko, "f", E=E)
  50                E = EnKF_analysis(
  51                    E,
  52                    HMM.Obs(ko)(E),
  53                    HMM.Obs(ko).noise,
  54                    yy[ko],
  55                    self.upd_a,
  56                    self.stats,
  57                    ko,
  58                )
  59                E = post_process(E, self.infl, self.rot)
  61            self.stats.assess(k, ko, E=E)
  64def EnKF_analysis(E, Eo, hnoise, y, upd_a, stats=None, ko=None):
  65    """Perform the EnKF analysis update.
  67    This implementation includes several flavours and forms,
  68    specified by `upd_a`.
  70    Main references: `bib.sakov2008deterministic`,
  71    `bib.sakov2008implications`, `bib.hoteit2015mitigating`
  72    """
  73    R = hnoise.C  # Obs noise cov
  74    N, Nx = E.shape  # Dimensionality
  75    N1 = N - 1  # Ens size - 1
  77    mu = np.mean(E, 0)  # Ens mean
  78    A = E - mu  # Ens anomalies
  80    xo = np.mean(Eo, 0)  # Obs ens mean
  81    Y = Eo - xo  # Obs ens anomalies
  82    dy = y - xo  # Mean "innovation"
  84    if "PertObs" in upd_a:
  85        # Uses classic, perturbed observations (Burgers'98)
  86        C = Y.T @ Y + R.full * N1
  87        D = mean0(hnoise.sample(N))
  88        YC = mrdiv(Y, C)
  89        KG = A.T @ YC
  90        HK = Y.T @ YC
  91        dE = (KG @ (y - D - Eo).T).T
  92        E = E + dE
  94    elif "Sqrt" in upd_a:
  95        # Uses a symmetric square root (ETKF)
  96        # to deterministically transform the ensemble.
  98        # The various versions below differ only numerically.
  99        # EVD is default, but for large N use SVD version.
 100        if upd_a == "Sqrt" and N > Nx:
 101            upd_a = "Sqrt svd"
 103        if "explicit" in upd_a:
 104            # Not recommended due to numerical costs and instability.
 105            # Implementation using inv (in ens space)
 106            Pw = sla.inv(Y @ R.inv @ Y.T + N1 * eye(N))
 107            T = sla.sqrtm(Pw) * sqrt(N1)
 108            HK = R.inv @ Y.T @ Pw @ Y
 109            # KG = R.inv @ Y.T @ Pw @ A
 110        elif "svd" in upd_a:
 111            # Implementation using svd of Y R^{-1/2}.
 112            V, s, _ = svd0(Y @ R.sym_sqrt_inv.T)
 113            d = pad0(s**2, N) + N1
 114            Pw = (V * d ** (-1.0)) @ V.T
 115            T = (V * d ** (-0.5)) @ V.T * sqrt(N1)
 116            # docs/snippets/trHK.jpg
 117            trHK = np.sum((s**2 + N1) ** (-1.0) * s**2)
 118        elif "sS" in upd_a:
 119            # Same as 'svd', but with slightly different notation
 120            # (sometimes used by Sakov) using the normalization sqrt(N1).
 121            S = Y @ R.sym_sqrt_inv.T / sqrt(N1)
 122            V, s, _ = svd0(S)
 123            d = pad0(s**2, N) + 1
 124            Pw = (V * d ** (-1.0)) @ V.T / N1  # = G/(N1)
 125            T = (V * d ** (-0.5)) @ V.T
 126            # docs/snippets/trHK.jpg
 127            trHK = np.sum((s**2 + 1) ** (-1.0) * s**2)
 128        else:  # 'eig' in upd_a:
 129            # Implementation using eig. val. decomp.
 130            d, V = sla.eigh(Y @ R.inv @ Y.T + N1 * eye(N))
 131            T = V @ diag(d ** (-0.5)) @ V.T * sqrt(N1)
 132            Pw = V @ diag(d ** (-1.0)) @ V.T
 133            HK = R.inv @ Y.T @ (V @ diag(d ** (-1)) @ V.T) @ Y
 134        w = dy @ R.inv @ Y.T @ Pw
 135        E = mu + w @ A + T @ A
 137    elif "Serial" in upd_a:
 138        # Observations assimilated one-at-a-time:
 139        inds = serial_inds(upd_a, y, R, A)
 140        #  Requires de-correlation:
 141        dy = dy @ R.sym_sqrt_inv.T
 142        Y = Y @ R.sym_sqrt_inv.T
 143        # Enhancement in the nonlinear case:
 144        # re-compute Y each scalar obs assim.
 145        # But: little benefit, model costly (?),
 146        # updates cannot be accumulated on S and T.
 148        if any(x in upd_a for x in ["Stoch", "ESOPS", "Var1"]):
 149            # More details: Misc/Serial_ESOPS.py.
 150            for i, j in enumerate(inds):
 151                # Perturbation creation
 152                if "ESOPS" in upd_a:
 153                    # "2nd-O exact perturbation sampling"
 154                    if i == 0:
 155                        # Init -- increase nullspace by 1
 156                        V, s, UT = svd0(A)
 157                        s[N - 2 :] = 0
 158                        A = svdi(V, s, UT)
 159                        v = V[:, N - 2]
 160                    else:
 161                        # Orthogonalize v wrt. the new A
 162                        #
 163                        # v = Zj - Yj (from paper) requires Y==HX.
 164                        # Instead: mult` should be c*ones(Nx) so we can
 165                        # project v into ker(A) such that v@A is null.
 166                        mult = (v @ A) / (Yj @ A)  # noqa
 167                        v = v - mult[0] * Yj  # noqa
 168                        v /= sqrt(v @ v)
 169                    Zj = v * sqrt(N1)  # Standardized perturbation along v
 170                    Zj *= np.sign(rng.standard_normal() - 0.5)  # Random sign
 171                else:
 172                    # The usual stochastic perturbations.
 173                    Zj = mean0(rng.standard_normal(N))  # Un-coloured noise
 174                    if "Var1" in upd_a:
 175                        Zj *= sqrt(N / (Zj @ Zj))
 177                # Select j-th obs
 178                Yj = Y[:, j]  # [j] obs anomalies
 179                dyj = dy[j]  # [j] innov mean
 180                DYj = Zj - Yj  # [j] innov anomalies
 181                DYj = DYj[:, None]  # Make 2d vertical
 183                # Kalman gain computation
 184                C = Yj @ Yj + N1  # Total obs cov
 185                KGx = Yj @ A / C  # KG to update state
 186                KGy = Yj @ Y / C  # KG to update obs
 188                # Updates
 189                A += DYj * KGx
 190                mu += dyj * KGx
 191                Y += DYj * KGy
 192                dy -= dyj * KGy
 193            E = mu + A
 194        else:
 195            # "Potter scheme", "EnSRF"
 196            # - EAKF's two-stage "update-regress" form yields
 197            #   the same *ensemble* as this.
 198            # - The form below may be derived as "serial ETKF",
 199            #   but does not yield the same
 200            #   ensemble as 'Sqrt' (which processes obs as a batch)
 201            #   -- only the same mean/cov.
 202            T = eye(N)
 203            for j in inds:
 204                Yj = Y[:, j]
 205                C = Yj @ Yj + N1
 206                Tj = np.outer(Yj, Yj / (C + sqrt(N1 * C)))
 207                T -= Tj @ T
 208                Y -= Tj @ Y
 209            w = dy @ Y.T @ T / N1
 210            E = mu + w @ A + T @ A
 212    elif "DEnKF" == upd_a:
 213        # Uses "Deterministic EnKF" (sakov'08)
 214        C = Y.T @ Y + R.full * N1
 215        YC = mrdiv(Y, C)
 216        KG = A.T @ YC
 217        HK = Y.T @ YC
 218        E = E + KG @ dy - 0.5 * (KG @ Y.T).T
 220    else:
 221        raise KeyError("No analysis update method found: '" + upd_a + "'.")
 223    # Diagnostic: relative influence of observations
 224    if stats is not None:
 225        if "trHK" in locals():
 226            stats.trHK[ko] = trHK / hnoise.M
 227        elif "HK" in locals():
 228            stats.trHK[ko] = HK.trace() / hnoise.M
 230    return E
 233def post_process(E, infl, rot):
 234    """Inflate, Rotate.
 236    To avoid recomputing/recombining anomalies,
 237    this should have been inside `EnKF_analysis`
 239    But it is kept as a separate function
 241    - for readability;
 242    - to avoid inflating/rotationg smoothed states (for the `EnKS`).
 243    """
 244    do_infl = infl != 1.0 and infl != "-N"
 246    if do_infl or rot:
 247        A, mu = center(E)
 248        N, Nx = E.shape
 249        T = eye(N)
 251        if do_infl:
 252            T = infl * T
 254        if rot:
 255            T = genOG_1(N, rot) @ T
 257        E = mu + T @ A
 258    return E
 261def add_noise(E, dt, noise, method):
 262    """Treatment of additive noise for ensembles.
 264    Refs: `bib.raanes2014ext`
 265    """
 266    if noise.C == 0:
 267        return E
 269    N, Nx = E.shape
 270    A, mu = center(E)
 271    Q12 = noise.C.Left
 272    Q = noise.C.full
 274    def sqrt_core():
 275        T = np.nan  # cause error if used
 276        Qa12 = np.nan  # cause error if used
 277        A2 = A.copy()  # Instead of using (the implicitly nonlocal) A,
 278        # which changes A outside as well. NB: This is a bug in Datum!
 279        if N <= Nx:
 280            Ainv = tinv(A2.T)
 281            Qa12 = Ainv @ Q12
 282            T = funm_psd(eye(N) + dt * (N - 1) * (Qa12 @ Qa12.T), sqrt)
 283            A2 = T @ A2
 284        else:  # "Left-multiplying" form
 285            P = A2.T @ A2 / (N - 1)
 286            L = funm_psd(eye(Nx) + dt * mrdiv(Q, P), sqrt)
 287            A2 = A2 @ L.T
 288        E = mu + A2
 289        return E, T, Qa12
 291    if method == "Stoch":
 292        # In-place addition works (also) for empty [] noise sample.
 293        E += sqrt(dt) * noise.sample(N)
 295    elif method == "none":
 296        pass
 298    elif method == "Mult-1":
 299        varE = np.var(E, axis=0, ddof=1).sum()
 300        ratio = (varE + dt * diag(Q).sum()) / varE
 301        E = mu + sqrt(ratio) * A
 302        E = svdi(*tsvd(E, 0.999))  # Explained in Datum
 304    elif method == "Mult-M":
 305        varE = np.var(E, axis=0)
 306        ratios = sqrt((varE + dt * diag(Q)) / varE)
 307        E = mu + A * ratios
 308        E = svdi(*tsvd(E, 0.999))  # Explained in Datum
 310    elif method == "Sqrt-Core":
 311        E = sqrt_core()[0]
 313    elif method == "Sqrt-Mult-1":
 314        varE0 = np.var(E, axis=0, ddof=1).sum()
 315        varE2 = varE0 + dt * diag(Q).sum()
 316        E, _, Qa12 = sqrt_core()
 317        if N <= Nx:
 318            A, mu = center(E)
 319            varE1 = np.var(E, axis=0, ddof=1).sum()
 320            ratio = varE2 / varE1
 321            E = mu + sqrt(ratio) * A
 322            E = svdi(*tsvd(E, 0.999))  # Explained in Datum
 324    elif method == "Sqrt-Add-Z":
 325        E, _, Qa12 = sqrt_core()
 326        if N <= Nx:
 327            Z = Q12 - A.T @ Qa12
 328            E += sqrt(dt) * (Z @ rng.standard_normal((Z.shape[1], N))).T
 330    elif method == "Sqrt-Dep":
 331        E, T, Qa12 = sqrt_core()
 332        if N <= Nx:
 333            # Q_hat12: reuse svd for both inversion and projection.
 334            Q_hat12 = A.T @ Qa12
 335            U, s, VT = tsvd(Q_hat12, 0.99)
 336            Q_hat12_inv = (VT.T * s ** (-1.0)) @ U.T
 337            Q_hat12_proj = VT.T @ VT
 338            rQ = Q12.shape[1]
 339            # Calc D_til
 340            Z = Q12 - Q_hat12
 341            D_hat = A.T @ (T - eye(N))
 342            Xi_hat = Q_hat12_inv @ D_hat
 343            Xi_til = (eye(rQ) - Q_hat12_proj) @ rng.standard_normal((rQ, N))
 344            D_til = Z @ (Xi_hat + sqrt(dt) * Xi_til)
 345            E += D_til.T
 347    else:
 348        raise KeyError("No such method")
 350    return E
 354class EnKS:
 355    """The ensemble Kalman smoother.
 357    Refs: `bib.evensen2009ensemble`
 359    The only difference to the EnKF
 360    is the management of the lag and the reshapings.
 361    """
 363    upd_a: str
 364    N: int
 365    Lag: int
 367    # Reshapings used in smoothers to go to/from
 368    # 3D arrays, where the 0th axis is the Lag index.
 369    def reshape_to(self, E):
 370        K, N, Nx = E.shape
 371        return E.transpose([1, 0, 2]).reshape((N, K * Nx))
 373    def reshape_fr(self, E, Nx):
 374        N, Km = E.shape
 375        K = Km // Nx
 376        return E.reshape((N, K, Nx)).transpose([1, 0, 2])
 378    def assimilate(self, HMM, xx, yy):
 379        # Inefficient version, storing full time series ensemble.
 380        # See iEnKS for a "rolling" version.
 381        E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M))
 382        E[0] = HMM.X0.sample(self.N)
 384        for k, ko, t, dt in progbar(HMM.tseq.ticker):
 385            E[k] = HMM.Dyn(E[k - 1], t - dt, dt)
 386            E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm)
 388            if ko is not None:
 389                self.stats.assess(k, ko, "f", E=E[k])
 391                Eo = HMM.Obs(ko)(E[k])
 392                y = yy[ko]
 394                # Inds within Lag
 395                kk = range(max(0, k - self.Lag * HMM.tseq.dko), k + 1)
 397                EE = E[kk]
 399                EE = self.reshape_to(EE)
 400                EE = EnKF_analysis(
 401                    EE, Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko
 402                )
 403                E[kk] = self.reshape_fr(EE, HMM.Dyn.M)
 404                E[k] = post_process(E[k], self.infl, self.rot)
 405                self.stats.assess(k, ko, "a", E=E[k])
 407        for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"):
 408            self.stats.assess(k, ko, "u", E=E[k])
 409            if ko is not None:
 410                self.stats.assess(k, ko, "s", E=E[k])
 414class EnRTS:
 415    """EnRTS (Rauch-Tung-Striebel) smoother.
 417    Refs: `bib.raanes2016thesis`
 418    """
 420    upd_a: str
 421    N: int
 422    DeCorr: float
 424    def assimilate(self, HMM, xx, yy):
 425        E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M))
 426        Ef = E.copy()
 427        E[0] = HMM.X0.sample(self.N)
 429        # Forward pass
 430        for k, ko, t, dt in progbar(HMM.tseq.ticker):
 431            E[k] = HMM.Dyn(E[k - 1], t - dt, dt)
 432            E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm)
 433            Ef[k] = E[k]
 435            if ko is not None:
 436                self.stats.assess(k, ko, "f", E=E[k])
 437                Eo = HMM.Obs(ko)(E[k])
 438                y = yy[ko]
 439                E[k] = EnKF_analysis(
 440                    E[k], Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko
 441                )
 442                E[k] = post_process(E[k], self.infl, self.rot)
 443                self.stats.assess(k, ko, "a", E=E[k])
 445        # Backward pass
 446        for k in progbar(range(HMM.tseq.K)[::-1]):
 447            A = center(E[k])[0]
 448            Af = center(Ef[k + 1])[0]
 450            J = tinv(Af) @ A
 451            J *= self.DeCorr
 453            E[k] += (E[k + 1] - Ef[k + 1]) @ J
 455        for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"):
 456            self.stats.assess(k, ko, "u", E=E[k])
 457            if ko is not None:
 458                self.stats.assess(k, ko, "s", E=E[k])
 461def serial_inds(upd_a, y, cvR, A):
 462    """Get the indices used for serial updating.
 464    - Default: random ordering
 465    - if "mono" in `upd_a`: `1, 2, ..., len(y)`
 466    - if "sorted" in `upd_a`: sort by variance
 467    """
 468    if "mono" in upd_a:
 469        # Not robust?
 470        inds = np.arange(len(y))
 471    elif "sorted" in upd_a:
 472        N = len(A)
 473        dC = cvR.diag
 474        if np.all(dC == dC[0]):
 475            # Sort y by P
 476            dC = np.sum(A * A, 0) / (N - 1)
 477        inds = np.argsort(dC)
 478    else:  # Default: random ordering
 479        inds = rng.permutation(len(y))
 480    return inds
 484class SL_EAKF:
 485    """Serial, covariance-localized EAKF.
 487    Refs: `bib.karspeck2007experimental`.
 489    In contrast with LETKF, this iterates over the observations rather
 490    than over the state (batches).
 492    Used without localization, this should be equivalent (full ensemble equality)
 493    to the `EnKF` with `upd_a='Serial'`.
 494    """
 496    N: int
 497    loc_rad: float
 498    taper: str = "GC"
 499    ordr: str = "rand"
 501    def assimilate(self, HMM, xx, yy):
 502        N1 = self.N - 1
 504        E = HMM.X0.sample(self.N)
 505        self.stats.assess(0, E=E)
 507        for k, ko, t, dt in progbar(HMM.tseq.ticker):
 508            E = HMM.Dyn(E, t - dt, dt)
 509            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
 511            if ko is not None:
 512                self.stats.assess(k, ko, "f", E=E)
 513                Obs = HMM.Obs(ko)
 514                R = Obs.noise
 515                y = yy[ko]
 516                inds = serial_inds(self.ordr, y, R, center(E)[0])
 517                Rm12 = Obs.noise.C.sym_sqrt_inv
 519                state_taperer = Obs.localizer(self.loc_rad, "y2x", self.taper)
 520                for j in inds:
 521                    # Prep:
 522                    # ------------------------------------------------------
 523                    Eo = Obs(E)
 524                    xo = np.mean(Eo, 0)
 525                    Y = Eo - xo
 526                    mu = np.mean(E, 0)
 527                    A = E - mu
 528                    # Update j-th component of observed ensemble:
 529                    # ------------------------------------------------------
 530                    Y_j = Rm12[j, :] @ Y.T
 531                    dy_j = Rm12[j, :] @ (y - xo)
 532                    # Prior var * N1:
 533                    sig2_j = Y_j @ Y_j
 534                    if sig2_j < 1e-9:
 535                        continue
 536                    # Update (below, we drop the locality subscript: _j)
 537                    sig2_u = 1 / (1 / sig2_j + 1 / N1)  # Postr. var * N1
 538                    alpha = (N1 / (N1 + sig2_j)) ** (0.5)  # Update contraction factor
 539                    dy2 = sig2_u * dy_j / N1  # Mean update
 540                    Y2 = alpha * Y_j  # Anomaly update
 541                    # Update state (regress update from obs space, using localization)
 542                    # ------------------------------------------------------
 543                    ii, tapering = state_taperer(j)
 544                    # ii, tapering = ..., 1  # cancel localization
 545                    if len(ii) == 0:
 546                        continue
 547                    Xi = A[:, ii] * tapering
 548                    Regression = Xi.T @ Y_j / np.sum(Y_j**2)
 549                    mu[ii] += Regression * dy2
 550                    A[:, ii] += np.outer(Y2 - Y_j, Regression)
 551                    E = mu + A
 553                E = post_process(E, self.infl, self.rot)
 555            self.stats.assess(k, ko, E=E)
 558def local_analyses(E, Eo, R, y, state_batches, obs_taperer, mp=map, xN=None, g=0):
 559    """Perform local analysis update for the LETKF."""
 561    def local_analysis(ii):
 562        """Perform analysis, for state index batch `ii`."""
 563        # Locate local domain
 564        oBatch, tapering = obs_taperer(ii)
 565        Eii = E[:, ii]
 567        # No update
 568        if len(oBatch) == 0:
 569            return Eii, 1
 571        # Localize
 572        Yl = Y[:, oBatch]
 573        dyl = dy[oBatch]
 574        tpr = sqrt(tapering)
 576        # Adaptive inflation estimation.
 577        # NB: Localisation is not 100% compatible with the EnKF-N, since
 578        # - After localisation there is much less need for inflation.
 579        # - Tapered values (Y, dy) are too neat
 580        #   (the EnKF-N expects a normal amount of sampling error).
 581        # One fix is to tune xN (maybe set it to 2 or 3). Thanks to adaptivity,
 582        # this should still be easier than tuning the inflation factor.
 583        infl1 = 1 if xN is None else sqrt(N1 / effective_N(Yl, dyl, xN, g))
 584        Eii, Yl = inflate_ens(Eii, infl1), Yl * infl1
 585        # Since R^{-1/2} was already applied (necesry for effective_N), now use R=Id.
 586        # TODO 4: the cost of re-init this R might not always be insignificant.
 587        R = GaussRV(C=1, M=len(dyl))
 589        # Update
 590        Eii = EnKF_analysis(Eii, Yl * tpr, R, dyl * tpr, "Sqrt")
 592        return Eii, infl1
 594    # Prepare analysis
 595    N1 = len(E) - 1
 596    Y, xo = center(Eo)
 597    # Transform obs space
 598    Y = Y @ R.sym_sqrt_inv.T
 599    dy = (y - xo) @ R.sym_sqrt_inv.T
 601    # Run
 602    result = mp(local_analysis, state_batches)
 604    # Assign
 605    E_batches, infl1 = zip(*result)
 606    # TODO: this overwrites E, possibly unbeknownst to caller
 607    for ii, Eii in zip(state_batches, E_batches):
 608        E[:, ii] = Eii
 610    return E, dict(ad_inf=sqrt(np.mean(np.array(infl1) ** 2)))
 614class LETKF:
 615    """Same as EnKF (Sqrt), but with localization.
 617    Refs: `bib.hunt2007efficient`.
 619    NB: Multiproc. yields slow-down for `dapper.mods.Lorenz96`,
 620    even with `batch_size=(1,)`. But for `dapper.mods.QG`
 621    (`batch_size=(2,2)` or less) it is quicker.
 623    NB: If `len(ii)` is small, analysis may be slowed-down with '-N' infl.
 624    """
 626    N: int
 627    loc_rad: float
 628    taper: str = "GC"
 629    xN: float = None
 630    g: int = 0
 631    mp: bool = False
 633    def assimilate(self, HMM, xx, yy):
 634        E = HMM.X0.sample(self.N)
 635        self.stats.assess(0, E=E)
 636        self.stats.new_series("ad_inf", 1, HMM.tseq.Ko + 1)
 638        with multiproc.Pool(self.mp) as pool:
 639            for k, ko, t, dt in progbar(HMM.tseq.ticker):
 640                E = HMM.Dyn(E, t - dt, dt)
 641                E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
 643                if ko is not None:
 644                    self.stats.assess(k, ko, "f", E=E)
 645                    Obs = HMM.Obs(ko)
 646                    batch, taper = Obs.localizer(self.loc_rad, "x2y", self.taper)
 647                    E, stats = local_analyses(
 648                        E,
 649                        Obs(E),
 650                        Obs.noise.C,
 651                        yy[ko],
 652                        batch,
 653                        taper,
 654                        pool.map,
 655                        self.xN,
 656                        self.g,
 657                    )
 658                    self.stats.write(stats, k, ko, "a")
 659                    E = post_process(E, self.infl, self.rot)
 661                self.stats.assess(k, ko, E=E)
 664def effective_N(YR, dyR, xN, g):
 665    """Effective ensemble size N.
 667    As measured by the finite-size EnKF-N
 668    """
 669    N, Ny = YR.shape
 670    N1 = N - 1
 672    V, s, UT = svd0(YR)
 673    du = UT @ dyR
 675    eN, cL = hyperprior_coeffs(s, N, xN, g)
 677    def pad_rk(arr):
 678        return pad0(arr, min(N, Ny))
 680    def dgn_rk(l1):
 681        return pad_rk((l1 * s) ** 2) + N1
 683    # Make dual cost function (in terms of l1)
 684    def J(l1):
 685        val = np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2)
 686        return val
 688    # Derivatives (not required with minimize_scalar):
 689    def Jp(l1):
 690        val = (
 691            -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2)
 692            + -2 * eN / l1**3
 693            + 2 * cL / l1
 694        )
 695        return val
 697    def Jpp(l1):
 698        val = (
 699            8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3)
 700            + 6 * eN / l1**4
 701            + -2 * cL / l1**2
 702        )
 703        return val
 705    # Find inflation factor (optimize)
 706    l1 = Newton_m(Jp, Jpp, 1.0)
 707    # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0)
 708    # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2), tol=1e-4).x
 710    za = N1 / l1**2
 711    return za
 714# Notes on optimizers for the 'dual' EnKF-N:
 715# ----------------------------------------
 716#  Using minimize_scalar:
 717#  - Doesn't take dJdx. Advantage: only need J
 718#  - method='bounded' not necessary and slower than 'brent'.
 719#  - bracket not necessary either...
 720#  Using multivariate minimization: fmin_cg, fmin_bfgs, fmin_ncg
 721#  - these also accept dJdx. But only fmin_bfgs approaches
 722#    the speed of the scalar minimizers.
 723#  Using scalar root-finders:
 724#  - brenth(dJ1, LowB, 1e2,     xtol=1e-6) # Same speed as minimization
 725#  - newton(dJ1,1.0, fprime=dJ2, tol=1e-6) # No improvement
 726#  - newton(dJ1,1.0, fprime=dJ2, tol=1e-6, fprime2=dJ3) # No improvement
 727#  - Newton_m(dJ1,dJ2, 1.0) # Significantly faster. Also slightly better CV?
 728# => Despite inconvienience of defining analytic derivatives,
 729#    Newton_m seems like the best option.
 730#  - In extreme (or just non-linear Obs.mod) cases,
 731#    the EnKF-N cost function may have multiple minima.
 732#    Then: should use more robust optimizer!
 734# For 'primal'
 735# ----------------------------------------
 736# Similarly, Newton_m seems like the best option,
 737# although alternatives are provided (commented out).
 739def Newton_m(
 740    fun, deriv, x0, is_inverted=False, conf=1.0, xtol=1e-4, ytol=1e-7, itermax=10**2
 742    """Find root of `fun`.
 744    This is a simple (and pretty fast) implementation of Newton's method.
 745    """
 746    itr = 0
 747    dx = np.inf
 748    Jx = fun(x0)
 750    def norm(x):
 751        return sqrt(np.sum(x**2))
 753    while ytol < norm(Jx) and xtol < norm(dx) and itr < itermax:
 754        Dx = deriv(x0)
 755        if is_inverted:
 756            dx = Dx @ Jx
 757        elif isinstance(Dx, float):
 758            dx = Jx / Dx
 759        else:
 760            dx = mldiv(Dx, Jx)
 761        dx *= conf
 762        x0 -= dx
 763        Jx = fun(x0)
 764        itr += 1
 765    return x0
 768def hyperprior_coeffs(s, N, xN=1, g=0):
 769    r"""Set EnKF-N inflation hyperparams.
 771    The EnKF-N prior may be specified by the constants:
 773    - `eN`: Effect of unknown mean
 774    - `cL`: Coeff in front of log term
 776    These are trivial constants in the original EnKF-N,
 777    but are further adjusted (corrected and tuned) for the following reasons.
 779    - Reason 1: mode correction.
 780      These parameters bridge the Jeffreys (`xN=1`) and Dirac (`xN=Inf`) hyperpriors
 781      for the prior covariance, B, as discussed in `bib.bocquet2015expanding`.
 782      Indeed, mode correction becomes necessary when $$ R \rightarrow \infty $$
 783      because then there should be no ensemble update (and also no inflation!).
 784      More specifically, the mode of `l1`'s should be adjusted towards 1
 785      as a function of $$ I - K H $$ ("prior's weight").
 786      PS: why do we leave the prior mode below 1 at all?
 787      Because it sets up "tension" (negative feedback) in the inflation cycle:
 788      the prior pulls downwards, while the likelihood tends to pull upwards.
 790    - Reason 2: Boosting the inflation prior's certainty from N to xN*N.
 791      The aim is to take advantage of the fact that the ensemble may not
 792      have quite as much sampling error as a fully stochastic sample,
 793      as illustrated in section 2.1 of `bib.raanes2019adaptive`.
 795    - Its damping effect is similar to work done by J. Anderson.
 797    The tuning is controlled by:
 799    - `xN=1`: is fully agnostic, i.e. assumes the ensemble is generated
 800      from a highly chaotic or stochastic model.
 801    - `xN>1`: increases the certainty of the hyper-prior,
 802      which is appropriate for more linear and deterministic systems.
 803    - `xN<1`: yields a more (than 'fully') agnostic hyper-prior,
 804      as if N were smaller than it truly is.
 805    - `xN<=0` is not meaningful.
 806    """
 807    N1 = N - 1
 809    eN = (N + 1) / N
 810    cL = (N + g) / N1
 812    # Mode correction (almost) as in eqn 36 of `bib.bocquet2015expanding`
 813    prior_mode = eN / cL  # Mode of l1 (before correction)
 814    diagonal = pad0(s**2, N) + N1  # diag of Y@R.inv@Y + N1*I
 815    #                                           (Hessian of J)
 816    I_KH = np.mean(diagonal ** (-1)) * N1  # ≈ 1/(1 + HBH/R)
 817    # I_KH      = 1/(1 + (s**2).sum()/N1)     # Scalar alternative: use tr(HBH/R).
 818    mc = sqrt(prior_mode**I_KH)  # Correction coeff
 820    # Apply correction
 821    eN /= mc
 822    cL *= mc
 824    # Boost by xN
 825    eN *= xN
 826    cL *= xN
 828    return eN, cL
 831def zeta_a(eN, cL, w):
 832    """EnKF-N inflation estimation via w.
 834    Returns `zeta_a = (N-1)/pre-inflation^2`.
 836    Using this inside an iterative minimization as in the
 837    `dapper.da_methods.variational.iEnKS` effectively blends
 838    the distinction between the primal and dual EnKF-N.
 839    """
 840    N = len(w)
 841    N1 = N - 1
 842    za = N1 * cL / (eN + w @ w)
 843    return za
 847class EnKF_N:
 848    """Finite-size EnKF (EnKF-N).
 850    Refs: `bib.bocquet2011ensemble`, `bib.bocquet2015expanding`
 852    This implementation is pedagogical, prioritizing the "dual" form.
 853    In consequence, the efficiency of the "primal" form suffers a bit.
 854    The primal form is included for completeness and to demonstrate equivalence.
 855    In `dapper.da_methods.variational.iEnKS`, however,
 856    the primal form is preferred because it
 857    already does optimization for w (as treatment for nonlinear models).
 859    `infl` should be unnecessary (assuming no model error, or that Q is correct).
 861    `Hess`: use non-approx Hessian for ensemble transform matrix?
 863    `g` is the nullity of A (state anomalies's), ie. g=max(1,N-Nx),
 864    compensating for the redundancy in the space of w.
 865    But we have made it an input argument instead, with default 0,
 866    because mode-finding (of p(x) via the dual) completely ignores this redundancy,
 867    and the mode gets (undesireably) modified by g.
 869    `xN` allows tuning the hyper-prior for the inflation.
 870    Usually, I just try setting it to 1 (default), or 2.
 871    Further description in hyperprior_coeffs().
 872    """
 874    N: int
 875    dual: bool = False
 876    Hess: bool = False
 877    xN: float = 1.0
 878    g: int = 0
 880    def assimilate(self, HMM, xx, yy):
 881        N, N1 = self.N, self.N - 1
 883        # Init
 884        E = HMM.X0.sample(N)
 885        self.stats.assess(0, E=E)
 887        # Cycle
 888        for k, ko, t, dt in progbar(HMM.tseq.ticker):
 889            # Forecast
 890            E = HMM.Dyn(E, t - dt, dt)
 891            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
 893            # Analysis
 894            if ko is not None:
 895                self.stats.assess(k, ko, "f", E=E)
 896                Eo = HMM.Obs(ko)(E)
 897                y = yy[ko]
 899                mu = np.mean(E, 0)
 900                A = E - mu
 902                xo = np.mean(Eo, 0)
 903                Y = Eo - xo
 904                dy = y - xo
 906                R = HMM.Obs(ko).noise.C
 907                V, s, UT = svd0(Y @ R.sym_sqrt_inv.T)
 908                du = UT @ (dy @ R.sym_sqrt_inv.T)
 910                def dgn_N(l1):
 911                    return pad0((l1 * s) ** 2, N) + N1
 913                # Adjust hyper-prior
 914                # xN_ = noise_level(self.xN, self.stats, HMM.tseq, N1, ko, A,
 915                #                   locals().get('A_old', None))
 916                eN, cL = hyperprior_coeffs(s, N, self.xN, self.g)
 918                if self.dual:
 919                    # Make dual cost function (in terms of l1)
 920                    def pad_rk(arr):
 921                        return pad0(arr, min(N, len(y)))
 923                    def dgn_rk(l1):
 924                        return pad_rk((l1 * s) ** 2) + N1
 926                    def J(l1):
 927                        val = (
 928                            np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2)
 929                        )
 930                        return val
 932                    # Derivatives (not required with minimize_scalar):
 933                    def Jp(l1):
 934                        val = (
 935                            -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2)
 936                            + -2 * eN / l1**3
 937                            + 2 * cL / l1
 938                        )
 939                        return val
 941                    def Jpp(l1):
 942                        val = (
 943                            8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3)
 944                            + 6 * eN / l1**4
 945                            + -2 * cL / l1**2
 946                        )
 947                        return val
 949                    # Find inflation factor (optimize)
 950                    l1 = Newton_m(Jp, Jpp, 1.0)
 951                    # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0)
 952                    # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2),
 953                    #                      tol=1e-4).x
 955                else:
 956                    # Primal form, in a fully linearized version.
 957                    def za(w):
 958                        return zeta_a(eN, cL, w)
 960                    def J(w):
 961                        return 0.5 * np.sum(
 962                            ((dy - w @ Y) @ R.sym_sqrt_inv.T) ** 2
 963                        ) + 0.5 * N1 * cL * np.log(eN + w @ w)
 965                    # Derivatives (not required with fmin_bfgs):
 966                    def Jp(w):
 967                        return -Y @ R.inv @ (dy - w @ Y) + w * za(w)
 969                    # Jpp   = lambda w:  Y@R.inv@Y.T + \
 970                    #     za(w)*(eye(N) - 2*np.outer(w,w)/(eN + w@w))
 971                    # Approx: no radial-angular cross-deriv:
 972                    # Jpp   = lambda w:  Y@R.inv@Y.T + za(w)*eye(N)
 974                    def nvrs(w):
 975                        # inverse of Jpp-approx
 976                        return (V * (pad0(s**2, N) + za(w)) ** -1.0) @ V.T
 978                    # Find w (optimize)
 979                    wa = Newton_m(Jp, nvrs, zeros(N), is_inverted=True)
 980                    # wa   = Newton_m(Jp,Jpp ,zeros(N))
 981                    # wa   = fmin_bfgs(J,zeros(N),Jp,disp=0)
 982                    l1 = sqrt(N1 / za(wa))
 984                # Uncomment to revert to ETKF
 985                # l1 = 1.0
 987                # Explicitly inflate prior
 988                # => formulae look different from `bib.bocquet2015expanding`.
 989                A *= l1
 990                Y *= l1
 992                # Compute sqrt update
 993                Pw = (V * dgn_N(l1) ** (-1.0)) @ V.T
 994                w = dy @ R.inv @ Y.T @ Pw
 995                # For the anomalies:
 996                if not self.Hess:
 997                    # Regular ETKF (i.e. sym sqrt) update (with inflation)
 998                    T = (V * dgn_N(l1) ** (-0.5)) @ V.T * sqrt(N1)
 999                    # = (Y@R.inv@Y.T/N1 + eye(N))**(-0.5)
1000                else:
1001                    # Also include angular-radial co-dependence.
1002                    # Note: denominator not squared coz
1003                    # unlike `bib.bocquet2015expanding` we have inflated Y.
1004                    Hw = (
1005                        Y @ R.inv @ Y.T / N1
1006                        + eye(N)
1007                        - 2 * np.outer(w, w) / (eN + w @ w)
1008                    )
1009                    T = funm_psd(Hw, lambda x: x**-0.5)  # is there a sqrtm Woodbury?
1011                E = mu + w @ A + T @ A
1012                E = post_process(E, self.infl, self.rot)
1014                self.stats.infl[ko] = l1
1015                self.stats.trHK[ko] = (
1016                    ((l1 * s) ** 2 + N1) ** (-1.0) * s**2
1017                ).sum() / len(y)
1019            self.stats.assess(k, ko, E=E)
