dapper.da_methods.ensemble

The EnKF and other ensemble-based methods.

   1"""The EnKF and other ensemble-based methods."""
   2
   3import numpy as np
   4import scipy.linalg as sla
   5from numpy import diag, eye, sqrt, zeros
   6
   7import dapper.tools.multiproc as multiproc
   8from dapper.stats import center, inflate_ens, mean0
   9from dapper.tools.linalg import mldiv, mrdiv, pad0, svd0, svdi, tinv, tsvd
  10from dapper.tools.matrices import funm_psd, genOG_1
  11from dapper.tools.progressbar import progbar
  12from dapper.tools.randvars import GaussRV
  13from dapper.tools.seeding import rng
  14
  15from . import da_method
  16
  17
  18@da_method
  19class ens_method:
  20    """Declare default ensemble arguments."""
  21
  22    infl: float = 1.0
  23    rot: bool = False
  24    fnoise_treatm: str = "Stoch"
  25
  26
  27@ens_method
  28class EnKF:
  29    """The ensemble Kalman filter.
  30
  31    Refs: `bib.evensen2009ensemble`.
  32    """
  33
  34    upd_a: str
  35    N: int
  36
  37    def assimilate(self, HMM, xx, yy):
  38        # Init
  39        E = HMM.X0.sample(self.N)
  40        self.stats.assess(0, E=E)
  41
  42        # Cycle
  43        for k, ko, t, dt in progbar(HMM.tseq.ticker):
  44            E = HMM.Dyn(E, t - dt, dt)
  45            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
  46
  47            # Analysis update
  48            if ko is not None:
  49                self.stats.assess(k, ko, "f", E=E)
  50                E = EnKF_analysis(
  51                    E,
  52                    HMM.Obs(ko)(E),
  53                    HMM.Obs(ko).noise,
  54                    yy[ko],
  55                    self.upd_a,
  56                    self.stats,
  57                    ko,
  58                )
  59                E = post_process(E, self.infl, self.rot)
  60
  61            self.stats.assess(k, ko, E=E)
  62
  63
  64def EnKF_analysis(E, Eo, hnoise, y, upd_a, stats=None, ko=None):
  65    """Perform the EnKF analysis update.
  66
  67    This implementation includes several flavours and forms,
  68    specified by `upd_a`.
  69
  70    Main references: `bib.sakov2008deterministic`,
  71    `bib.sakov2008implications`, `bib.hoteit2015mitigating`
  72    """
  73    R = hnoise.C  # Obs noise cov
  74    N, Nx = E.shape  # Dimensionality
  75    N1 = N - 1  # Ens size - 1
  76
  77    mu = np.mean(E, 0)  # Ens mean
  78    A = E - mu  # Ens anomalies
  79
  80    xo = np.mean(Eo, 0)  # Obs ens mean
  81    Y = Eo - xo  # Obs ens anomalies
  82    dy = y - xo  # Mean "innovation"
  83
  84    if "PertObs" in upd_a:
  85        # Uses classic, perturbed observations (Burgers'98)
  86        C = Y.T @ Y + R.full * N1
  87        D = mean0(hnoise.sample(N))
  88        YC = mrdiv(Y, C)
  89        KG = A.T @ YC
  90        HK = Y.T @ YC
  91        dE = (KG @ (y - D - Eo).T).T
  92        E = E + dE
  93
  94    elif "Sqrt" in upd_a:
  95        # Uses a symmetric square root (ETKF)
  96        # to deterministically transform the ensemble.
  97
  98        # The various versions below differ only numerically.
  99        # EVD is default, but for large N use SVD version.
 100        if upd_a == "Sqrt" and N > Nx:
 101            upd_a = "Sqrt svd"
 102
 103        if "explicit" in upd_a:
 104            # Not recommended due to numerical costs and instability.
 105            # Implementation using inv (in ens space)
 106            Pw = sla.inv(Y @ R.inv @ Y.T + N1 * eye(N))
 107            T = sla.sqrtm(Pw) * sqrt(N1)
 108            HK = R.inv @ Y.T @ Pw @ Y
 109            # KG = R.inv @ Y.T @ Pw @ A
 110        elif "svd" in upd_a:
 111            # Implementation using svd of Y R^{-1/2}.
 112            V, s, _ = svd0(Y @ R.sym_sqrt_inv.T)
 113            d = pad0(s**2, N) + N1
 114            Pw = (V * d ** (-1.0)) @ V.T
 115            T = (V * d ** (-0.5)) @ V.T * sqrt(N1)
 116            # docs/snippets/trHK.jpg
 117            trHK = np.sum((s**2 + N1) ** (-1.0) * s**2)
 118        elif "sS" in upd_a:
 119            # Same as 'svd', but with slightly different notation
 120            # (sometimes used by Sakov) using the normalization sqrt(N1).
 121            S = Y @ R.sym_sqrt_inv.T / sqrt(N1)
 122            V, s, _ = svd0(S)
 123            d = pad0(s**2, N) + 1
 124            Pw = (V * d ** (-1.0)) @ V.T / N1  # = G/(N1)
 125            T = (V * d ** (-0.5)) @ V.T
 126            # docs/snippets/trHK.jpg
 127            trHK = np.sum((s**2 + 1) ** (-1.0) * s**2)
 128        else:  # 'eig' in upd_a:
 129            # Implementation using eig. val. decomp.
 130            d, V = sla.eigh(Y @ R.inv @ Y.T + N1 * eye(N))
 131            T = V @ diag(d ** (-0.5)) @ V.T * sqrt(N1)
 132            Pw = V @ diag(d ** (-1.0)) @ V.T
 133            HK = R.inv @ Y.T @ (V @ diag(d ** (-1)) @ V.T) @ Y
 134        w = dy @ R.inv @ Y.T @ Pw
 135        E = mu + w @ A + T @ A
 136
 137    elif "Serial" in upd_a:
 138        # Observations assimilated one-at-a-time:
 139        inds = serial_inds(upd_a, y, R, A)
 140        #  Requires de-correlation:
 141        dy = dy @ R.sym_sqrt_inv.T
 142        Y = Y @ R.sym_sqrt_inv.T
 143        # Enhancement in the nonlinear case:
 144        # re-compute Y each scalar obs assim.
 145        # But: little benefit, model costly (?),
 146        # updates cannot be accumulated on S and T.
 147
 148        if any(x in upd_a for x in ["Stoch", "ESOPS", "Var1"]):
 149            # More details: Misc/Serial_ESOPS.py.
 150            for i, j in enumerate(inds):
 151                # Perturbation creation
 152                if "ESOPS" in upd_a:
 153                    # "2nd-O exact perturbation sampling"
 154                    if i == 0:
 155                        # Init -- increase nullspace by 1
 156                        V, s, UT = svd0(A)
 157                        s[N - 2 :] = 0
 158                        A = svdi(V, s, UT)
 159                        v = V[:, N - 2]
 160                    else:
 161                        # Orthogonalize v wrt. the new A
 162                        #
 163                        # v = Zj - Yj (from paper) requires Y==HX.
 164                        # Instead: mult` should be c*ones(Nx) so we can
 165                        # project v into ker(A) such that v@A is null.
 166                        mult = (v @ A) / (Yj @ A)  # noqa
 167                        v = v - mult[0] * Yj  # noqa
 168                        v /= sqrt(v @ v)
 169                    Zj = v * sqrt(N1)  # Standardized perturbation along v
 170                    Zj *= np.sign(rng.standard_normal() - 0.5)  # Random sign
 171                else:
 172                    # The usual stochastic perturbations.
 173                    Zj = mean0(rng.standard_normal(N))  # Un-coloured noise
 174                    if "Var1" in upd_a:
 175                        Zj *= sqrt(N / (Zj @ Zj))
 176
 177                # Select j-th obs
 178                Yj = Y[:, j]  # [j] obs anomalies
 179                dyj = dy[j]  # [j] innov mean
 180                DYj = Zj - Yj  # [j] innov anomalies
 181                DYj = DYj[:, None]  # Make 2d vertical
 182
 183                # Kalman gain computation
 184                C = Yj @ Yj + N1  # Total obs cov
 185                KGx = Yj @ A / C  # KG to update state
 186                KGy = Yj @ Y / C  # KG to update obs
 187
 188                # Updates
 189                A += DYj * KGx
 190                mu += dyj * KGx
 191                Y += DYj * KGy
 192                dy -= dyj * KGy
 193            E = mu + A
 194        else:
 195            # "Potter scheme", "EnSRF"
 196            # - EAKF's two-stage "update-regress" form yields
 197            #   the same *ensemble* as this.
 198            # - The form below may be derived as "serial ETKF",
 199            #   but does not yield the same
 200            #   ensemble as 'Sqrt' (which processes obs as a batch)
 201            #   -- only the same mean/cov.
 202            T = eye(N)
 203            for j in inds:
 204                Yj = Y[:, j]
 205                C = Yj @ Yj + N1
 206                Tj = np.outer(Yj, Yj / (C + sqrt(N1 * C)))
 207                T -= Tj @ T
 208                Y -= Tj @ Y
 209            w = dy @ Y.T @ T / N1
 210            E = mu + w @ A + T @ A
 211
 212    elif "DEnKF" == upd_a:
 213        # Uses "Deterministic EnKF" (sakov'08)
 214        C = Y.T @ Y + R.full * N1
 215        YC = mrdiv(Y, C)
 216        KG = A.T @ YC
 217        HK = Y.T @ YC
 218        E = E + KG @ dy - 0.5 * (KG @ Y.T).T
 219
 220    else:
 221        raise KeyError("No analysis update method found: '" + upd_a + "'.")
 222
 223    # Diagnostic: relative influence of observations
 224    if stats is not None:
 225        if "trHK" in locals():
 226            stats.trHK[ko] = trHK / hnoise.M
 227        elif "HK" in locals():
 228            stats.trHK[ko] = HK.trace() / hnoise.M
 229
 230    return E
 231
 232
 233def post_process(E, infl, rot):
 234    """Inflate, Rotate.
 235
 236    To avoid recomputing/recombining anomalies,
 237    this should have been inside `EnKF_analysis`
 238
 239    But it is kept as a separate function
 240
 241    - for readability;
 242    - to avoid inflating/rotationg smoothed states (for the `EnKS`).
 243    """
 244    do_infl = infl != 1.0 and infl != "-N"
 245
 246    if do_infl or rot:
 247        A, mu = center(E)
 248        N, Nx = E.shape
 249        T = eye(N)
 250
 251        if do_infl:
 252            T = infl * T
 253
 254        if rot:
 255            T = genOG_1(N, rot) @ T
 256
 257        E = mu + T @ A
 258    return E
 259
 260
 261def add_noise(E, dt, noise, method):
 262    """Treatment of additive noise for ensembles.
 263
 264    Refs: `bib.raanes2014ext`
 265    """
 266    if noise.C == 0:
 267        return E
 268
 269    N, Nx = E.shape
 270    A, mu = center(E)
 271    Q12 = noise.C.Left
 272    Q = noise.C.full
 273
 274    def sqrt_core():
 275        T = np.nan  # cause error if used
 276        Qa12 = np.nan  # cause error if used
 277        A2 = A.copy()  # Instead of using (the implicitly nonlocal) A,
 278        # which changes A outside as well. NB: This is a bug in Datum!
 279        if N <= Nx:
 280            Ainv = tinv(A2.T)
 281            Qa12 = Ainv @ Q12
 282            T = funm_psd(eye(N) + dt * (N - 1) * (Qa12 @ Qa12.T), sqrt)
 283            A2 = T @ A2
 284        else:  # "Left-multiplying" form
 285            P = A2.T @ A2 / (N - 1)
 286            L = funm_psd(eye(Nx) + dt * mrdiv(Q, P), sqrt)
 287            A2 = A2 @ L.T
 288        E = mu + A2
 289        return E, T, Qa12
 290
 291    if method == "Stoch":
 292        # In-place addition works (also) for empty [] noise sample.
 293        E += sqrt(dt) * noise.sample(N)
 294
 295    elif method == "none":
 296        pass
 297
 298    elif method == "Mult-1":
 299        varE = np.var(E, axis=0, ddof=1).sum()
 300        ratio = (varE + dt * diag(Q).sum()) / varE
 301        E = mu + sqrt(ratio) * A
 302        E = svdi(*tsvd(E, 0.999))  # Explained in Datum
 303
 304    elif method == "Mult-M":
 305        varE = np.var(E, axis=0)
 306        ratios = sqrt((varE + dt * diag(Q)) / varE)
 307        E = mu + A * ratios
 308        E = svdi(*tsvd(E, 0.999))  # Explained in Datum
 309
 310    elif method == "Sqrt-Core":
 311        E = sqrt_core()[0]
 312
 313    elif method == "Sqrt-Mult-1":
 314        varE0 = np.var(E, axis=0, ddof=1).sum()
 315        varE2 = varE0 + dt * diag(Q).sum()
 316        E, _, Qa12 = sqrt_core()
 317        if N <= Nx:
 318            A, mu = center(E)
 319            varE1 = np.var(E, axis=0, ddof=1).sum()
 320            ratio = varE2 / varE1
 321            E = mu + sqrt(ratio) * A
 322            E = svdi(*tsvd(E, 0.999))  # Explained in Datum
 323
 324    elif method == "Sqrt-Add-Z":
 325        E, _, Qa12 = sqrt_core()
 326        if N <= Nx:
 327            Z = Q12 - A.T @ Qa12
 328            E += sqrt(dt) * (Z @ rng.standard_normal((Z.shape[1], N))).T
 329
 330    elif method == "Sqrt-Dep":
 331        E, T, Qa12 = sqrt_core()
 332        if N <= Nx:
 333            # Q_hat12: reuse svd for both inversion and projection.
 334            Q_hat12 = A.T @ Qa12
 335            U, s, VT = tsvd(Q_hat12, 0.99)
 336            Q_hat12_inv = (VT.T * s ** (-1.0)) @ U.T
 337            Q_hat12_proj = VT.T @ VT
 338            rQ = Q12.shape[1]
 339            # Calc D_til
 340            Z = Q12 - Q_hat12
 341            D_hat = A.T @ (T - eye(N))
 342            Xi_hat = Q_hat12_inv @ D_hat
 343            Xi_til = (eye(rQ) - Q_hat12_proj) @ rng.standard_normal((rQ, N))
 344            D_til = Z @ (Xi_hat + sqrt(dt) * Xi_til)
 345            E += D_til.T
 346
 347    else:
 348        raise KeyError("No such method")
 349
 350    return E
 351
 352
 353@ens_method
 354class EnKS:
 355    """The ensemble Kalman smoother.
 356
 357    Refs: `bib.evensen2009ensemble`
 358
 359    The only difference to the EnKF
 360    is the management of the lag and the reshapings.
 361    """
 362
 363    upd_a: str
 364    N: int
 365    Lag: int
 366
 367    # Reshapings used in smoothers to go to/from
 368    # 3D arrays, where the 0th axis is the Lag index.
 369    def reshape_to(self, E):
 370        K, N, Nx = E.shape
 371        return E.transpose([1, 0, 2]).reshape((N, K * Nx))
 372
 373    def reshape_fr(self, E, Nx):
 374        N, Km = E.shape
 375        K = Km // Nx
 376        return E.reshape((N, K, Nx)).transpose([1, 0, 2])
 377
 378    def assimilate(self, HMM, xx, yy):
 379        # Inefficient version, storing full time series ensemble.
 380        # See iEnKS for a "rolling" version.
 381        E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M))
 382        E[0] = HMM.X0.sample(self.N)
 383
 384        for k, ko, t, dt in progbar(HMM.tseq.ticker):
 385            E[k] = HMM.Dyn(E[k - 1], t - dt, dt)
 386            E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm)
 387
 388            if ko is not None:
 389                self.stats.assess(k, ko, "f", E=E[k])
 390
 391                Eo = HMM.Obs(ko)(E[k])
 392                y = yy[ko]
 393
 394                # Inds within Lag
 395                kk = range(max(0, k - self.Lag * HMM.tseq.dko), k + 1)
 396
 397                EE = E[kk]
 398
 399                EE = self.reshape_to(EE)
 400                EE = EnKF_analysis(
 401                    EE, Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko
 402                )
 403                E[kk] = self.reshape_fr(EE, HMM.Dyn.M)
 404                E[k] = post_process(E[k], self.infl, self.rot)
 405                self.stats.assess(k, ko, "a", E=E[k])
 406
 407        for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"):
 408            self.stats.assess(k, ko, "u", E=E[k])
 409            if ko is not None:
 410                self.stats.assess(k, ko, "s", E=E[k])
 411
 412
 413@ens_method
 414class EnRTS:
 415    """EnRTS (Rauch-Tung-Striebel) smoother.
 416
 417    Refs: `bib.raanes2016thesis`
 418    """
 419
 420    upd_a: str
 421    N: int
 422    DeCorr: float
 423
 424    def assimilate(self, HMM, xx, yy):
 425        E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M))
 426        Ef = E.copy()
 427        E[0] = HMM.X0.sample(self.N)
 428
 429        # Forward pass
 430        for k, ko, t, dt in progbar(HMM.tseq.ticker):
 431            E[k] = HMM.Dyn(E[k - 1], t - dt, dt)
 432            E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm)
 433            Ef[k] = E[k]
 434
 435            if ko is not None:
 436                self.stats.assess(k, ko, "f", E=E[k])
 437                Eo = HMM.Obs(ko)(E[k])
 438                y = yy[ko]
 439                E[k] = EnKF_analysis(
 440                    E[k], Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko
 441                )
 442                E[k] = post_process(E[k], self.infl, self.rot)
 443                self.stats.assess(k, ko, "a", E=E[k])
 444
 445        # Backward pass
 446        for k in progbar(range(HMM.tseq.K)[::-1]):
 447            A = center(E[k])[0]
 448            Af = center(Ef[k + 1])[0]
 449
 450            J = tinv(Af) @ A
 451            J *= self.DeCorr
 452
 453            E[k] += (E[k + 1] - Ef[k + 1]) @ J
 454
 455        for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"):
 456            self.stats.assess(k, ko, "u", E=E[k])
 457            if ko is not None:
 458                self.stats.assess(k, ko, "s", E=E[k])
 459
 460
 461def serial_inds(upd_a, y, cvR, A):
 462    """Get the indices used for serial updating.
 463
 464    - Default: random ordering
 465    - if "mono" in `upd_a`: `1, 2, ..., len(y)`
 466    - if "sorted" in `upd_a`: sort by variance
 467    """
 468    if "mono" in upd_a:
 469        # Not robust?
 470        inds = np.arange(len(y))
 471    elif "sorted" in upd_a:
 472        N = len(A)
 473        dC = cvR.diag
 474        if np.all(dC == dC[0]):
 475            # Sort y by P
 476            dC = np.sum(A * A, 0) / (N - 1)
 477        inds = np.argsort(dC)
 478    else:  # Default: random ordering
 479        inds = rng.permutation(len(y))
 480    return inds
 481
 482
 483@ens_method
 484class SL_EAKF:
 485    """Serial, covariance-localized EAKF.
 486
 487    Refs: `bib.karspeck2007experimental`.
 488
 489    In contrast with LETKF, this iterates over the observations rather
 490    than over the state (batches).
 491
 492    Used without localization, this should be equivalent (full ensemble equality)
 493    to the `EnKF` with `upd_a='Serial'`.
 494    """
 495
 496    N: int
 497    loc_rad: float
 498    taper: str = "GC"
 499    ordr: str = "rand"
 500
 501    def assimilate(self, HMM, xx, yy):
 502        N1 = self.N - 1
 503
 504        E = HMM.X0.sample(self.N)
 505        self.stats.assess(0, E=E)
 506
 507        for k, ko, t, dt in progbar(HMM.tseq.ticker):
 508            E = HMM.Dyn(E, t - dt, dt)
 509            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
 510
 511            if ko is not None:
 512                self.stats.assess(k, ko, "f", E=E)
 513                Obs = HMM.Obs(ko)
 514                R = Obs.noise
 515                y = yy[ko]
 516                inds = serial_inds(self.ordr, y, R, center(E)[0])
 517                Rm12 = Obs.noise.C.sym_sqrt_inv
 518
 519                state_taperer = Obs.localizer(self.loc_rad, "y2x", self.taper)
 520                for j in inds:
 521                    # Prep:
 522                    # ------------------------------------------------------
 523                    Eo = Obs(E)
 524                    xo = np.mean(Eo, 0)
 525                    Y = Eo - xo
 526                    mu = np.mean(E, 0)
 527                    A = E - mu
 528                    # Update j-th component of observed ensemble:
 529                    # ------------------------------------------------------
 530                    Y_j = Rm12[j, :] @ Y.T
 531                    dy_j = Rm12[j, :] @ (y - xo)
 532                    # Prior var * N1:
 533                    sig2_j = Y_j @ Y_j
 534                    if sig2_j < 1e-9:
 535                        continue
 536                    # Update (below, we drop the locality subscript: _j)
 537                    sig2_u = 1 / (1 / sig2_j + 1 / N1)  # Postr. var * N1
 538                    alpha = (N1 / (N1 + sig2_j)) ** (0.5)  # Update contraction factor
 539                    dy2 = sig2_u * dy_j / N1  # Mean update
 540                    Y2 = alpha * Y_j  # Anomaly update
 541                    # Update state (regress update from obs space, using localization)
 542                    # ------------------------------------------------------
 543                    ii, tapering = state_taperer(j)
 544                    # ii, tapering = ..., 1  # cancel localization
 545                    if len(ii) == 0:
 546                        continue
 547                    Xi = A[:, ii] * tapering
 548                    Regression = Xi.T @ Y_j / np.sum(Y_j**2)
 549                    mu[ii] += Regression * dy2
 550                    A[:, ii] += np.outer(Y2 - Y_j, Regression)
 551                    E = mu + A
 552
 553                E = post_process(E, self.infl, self.rot)
 554
 555            self.stats.assess(k, ko, E=E)
 556
 557
 558def local_analyses(E, Eo, R, y, state_batches, obs_taperer, mp=map, xN=None, g=0):
 559    """Perform local analysis update for the LETKF."""
 560
 561    def local_analysis(ii):
 562        """Perform analysis, for state index batch `ii`."""
 563        # Locate local domain
 564        oBatch, tapering = obs_taperer(ii)
 565        Eii = E[:, ii]
 566
 567        # No update
 568        if len(oBatch) == 0:
 569            return Eii, 1
 570
 571        # Localize
 572        Yl = Y[:, oBatch]
 573        dyl = dy[oBatch]
 574        tpr = sqrt(tapering)
 575
 576        # Adaptive inflation estimation.
 577        # NB: Localisation is not 100% compatible with the EnKF-N, since
 578        # - After localisation there is much less need for inflation.
 579        # - Tapered values (Y, dy) are too neat
 580        #   (the EnKF-N expects a normal amount of sampling error).
 581        # One fix is to tune xN (maybe set it to 2 or 3). Thanks to adaptivity,
 582        # this should still be easier than tuning the inflation factor.
 583        infl1 = 1 if xN is None else sqrt(N1 / effective_N(Yl, dyl, xN, g))
 584        Eii, Yl = inflate_ens(Eii, infl1), Yl * infl1
 585        # Since R^{-1/2} was already applied (necesry for effective_N), now use R=Id.
 586        # TODO 4: the cost of re-init this R might not always be insignificant.
 587        R = GaussRV(C=1, M=len(dyl))
 588
 589        # Update
 590        Eii = EnKF_analysis(Eii, Yl * tpr, R, dyl * tpr, "Sqrt")
 591
 592        return Eii, infl1
 593
 594    # Prepare analysis
 595    N1 = len(E) - 1
 596    Y, xo = center(Eo)
 597    # Transform obs space
 598    Y = Y @ R.sym_sqrt_inv.T
 599    dy = (y - xo) @ R.sym_sqrt_inv.T
 600
 601    # Run
 602    result = mp(local_analysis, state_batches)
 603
 604    # Assign
 605    E_batches, infl1 = zip(*result)
 606    # TODO: this overwrites E, possibly unbeknownst to caller
 607    for ii, Eii in zip(state_batches, E_batches):
 608        E[:, ii] = Eii
 609
 610    return E, dict(ad_inf=sqrt(np.mean(np.array(infl1) ** 2)))
 611
 612
 613@ens_method
 614class LETKF:
 615    """Same as EnKF (Sqrt), but with localization.
 616
 617    Refs: `bib.hunt2007efficient`.
 618
 619    NB: Multiproc. yields slow-down for `dapper.mods.Lorenz96`,
 620    even with `batch_size=(1,)`. But for `dapper.mods.QG`
 621    (`batch_size=(2,2)` or less) it is quicker.
 622
 623    NB: If `len(ii)` is small, analysis may be slowed-down with '-N' infl.
 624    """
 625
 626    N: int
 627    loc_rad: float
 628    taper: str = "GC"
 629    xN: float = None
 630    g: int = 0
 631    mp: bool = False
 632
 633    def assimilate(self, HMM, xx, yy):
 634        E = HMM.X0.sample(self.N)
 635        self.stats.assess(0, E=E)
 636        self.stats.new_series("ad_inf", 1, HMM.tseq.Ko + 1)
 637
 638        with multiproc.Pool(self.mp) as pool:
 639            for k, ko, t, dt in progbar(HMM.tseq.ticker):
 640                E = HMM.Dyn(E, t - dt, dt)
 641                E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
 642
 643                if ko is not None:
 644                    self.stats.assess(k, ko, "f", E=E)
 645                    Obs = HMM.Obs(ko)
 646                    batch, taper = Obs.localizer(self.loc_rad, "x2y", self.taper)
 647                    E, stats = local_analyses(
 648                        E,
 649                        Obs(E),
 650                        Obs.noise.C,
 651                        yy[ko],
 652                        batch,
 653                        taper,
 654                        pool.map,
 655                        self.xN,
 656                        self.g,
 657                    )
 658                    self.stats.write(stats, k, ko, "a")
 659                    E = post_process(E, self.infl, self.rot)
 660
 661                self.stats.assess(k, ko, E=E)
 662
 663
 664def effective_N(YR, dyR, xN, g):
 665    """Effective ensemble size N.
 666
 667    As measured by the finite-size EnKF-N
 668    """
 669    N, Ny = YR.shape
 670    N1 = N - 1
 671
 672    V, s, UT = svd0(YR)
 673    du = UT @ dyR
 674
 675    eN, cL = hyperprior_coeffs(s, N, xN, g)
 676
 677    def pad_rk(arr):
 678        return pad0(arr, min(N, Ny))
 679
 680    def dgn_rk(l1):
 681        return pad_rk((l1 * s) ** 2) + N1
 682
 683    # Make dual cost function (in terms of l1)
 684    def J(l1):
 685        val = np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2)
 686        return val
 687
 688    # Derivatives (not required with minimize_scalar):
 689    def Jp(l1):
 690        val = (
 691            -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2)
 692            + -2 * eN / l1**3
 693            + 2 * cL / l1
 694        )
 695        return val
 696
 697    def Jpp(l1):
 698        val = (
 699            8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3)
 700            + 6 * eN / l1**4
 701            + -2 * cL / l1**2
 702        )
 703        return val
 704
 705    # Find inflation factor (optimize)
 706    l1 = Newton_m(Jp, Jpp, 1.0)
 707    # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0)
 708    # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2), tol=1e-4).x
 709
 710    za = N1 / l1**2
 711    return za
 712
 713
 714# Notes on optimizers for the 'dual' EnKF-N:
 715# ----------------------------------------
 716#  Using minimize_scalar:
 717#  - Doesn't take dJdx. Advantage: only need J
 718#  - method='bounded' not necessary and slower than 'brent'.
 719#  - bracket not necessary either...
 720#  Using multivariate minimization: fmin_cg, fmin_bfgs, fmin_ncg
 721#  - these also accept dJdx. But only fmin_bfgs approaches
 722#    the speed of the scalar minimizers.
 723#  Using scalar root-finders:
 724#  - brenth(dJ1, LowB, 1e2,     xtol=1e-6) # Same speed as minimization
 725#  - newton(dJ1,1.0, fprime=dJ2, tol=1e-6) # No improvement
 726#  - newton(dJ1,1.0, fprime=dJ2, tol=1e-6, fprime2=dJ3) # No improvement
 727#  - Newton_m(dJ1,dJ2, 1.0) # Significantly faster. Also slightly better CV?
 728# => Despite inconvienience of defining analytic derivatives,
 729#    Newton_m seems like the best option.
 730#  - In extreme (or just non-linear Obs.mod) cases,
 731#    the EnKF-N cost function may have multiple minima.
 732#    Then: should use more robust optimizer!
 733#
 734# For 'primal'
 735# ----------------------------------------
 736# Similarly, Newton_m seems like the best option,
 737# although alternatives are provided (commented out).
 738#
 739def Newton_m(
 740    fun, deriv, x0, is_inverted=False, conf=1.0, xtol=1e-4, ytol=1e-7, itermax=10**2
 741):
 742    """Find root of `fun`.
 743
 744    This is a simple (and pretty fast) implementation of Newton's method.
 745    """
 746    itr = 0
 747    dx = np.inf
 748    Jx = fun(x0)
 749
 750    def norm(x):
 751        return sqrt(np.sum(x**2))
 752
 753    while ytol < norm(Jx) and xtol < norm(dx) and itr < itermax:
 754        Dx = deriv(x0)
 755        if is_inverted:
 756            dx = Dx @ Jx
 757        elif isinstance(Dx, float):
 758            dx = Jx / Dx
 759        else:
 760            dx = mldiv(Dx, Jx)
 761        dx *= conf
 762        x0 -= dx
 763        Jx = fun(x0)
 764        itr += 1
 765    return x0
 766
 767
 768def hyperprior_coeffs(s, N, xN=1, g=0):
 769    r"""Set EnKF-N inflation hyperparams.
 770
 771    The EnKF-N prior may be specified by the constants:
 772
 773    - `eN`: Effect of unknown mean
 774    - `cL`: Coeff in front of log term
 775
 776    These are trivial constants in the original EnKF-N,
 777    but are further adjusted (corrected and tuned) for the following reasons.
 778
 779    - Reason 1: mode correction.
 780      These parameters bridge the Jeffreys (`xN=1`) and Dirac (`xN=Inf`) hyperpriors
 781      for the prior covariance, B, as discussed in `bib.bocquet2015expanding`.
 782      Indeed, mode correction becomes necessary when $$ R \rightarrow \infty $$
 783      because then there should be no ensemble update (and also no inflation!).
 784      More specifically, the mode of `l1`'s should be adjusted towards 1
 785      as a function of $$ I - K H $$ ("prior's weight").
 786      PS: why do we leave the prior mode below 1 at all?
 787      Because it sets up "tension" (negative feedback) in the inflation cycle:
 788      the prior pulls downwards, while the likelihood tends to pull upwards.
 789
 790    - Reason 2: Boosting the inflation prior's certainty from N to xN*N.
 791      The aim is to take advantage of the fact that the ensemble may not
 792      have quite as much sampling error as a fully stochastic sample,
 793      as illustrated in section 2.1 of `bib.raanes2019adaptive`.
 794
 795    - Its damping effect is similar to work done by J. Anderson.
 796
 797    The tuning is controlled by:
 798
 799    - `xN=1`: is fully agnostic, i.e. assumes the ensemble is generated
 800      from a highly chaotic or stochastic model.
 801    - `xN>1`: increases the certainty of the hyper-prior,
 802      which is appropriate for more linear and deterministic systems.
 803    - `xN<1`: yields a more (than 'fully') agnostic hyper-prior,
 804      as if N were smaller than it truly is.
 805    - `xN<=0` is not meaningful.
 806    """
 807    N1 = N - 1
 808
 809    eN = (N + 1) / N
 810    cL = (N + g) / N1
 811
 812    # Mode correction (almost) as in eqn 36 of `bib.bocquet2015expanding`
 813    prior_mode = eN / cL  # Mode of l1 (before correction)
 814    diagonal = pad0(s**2, N) + N1  # diag of Y@R.inv@Y + N1*I
 815    #                                           (Hessian of J)
 816    I_KH = np.mean(diagonal ** (-1)) * N1  # ≈ 1/(1 + HBH/R)
 817    # I_KH      = 1/(1 + (s**2).sum()/N1)     # Scalar alternative: use tr(HBH/R).
 818    mc = sqrt(prior_mode**I_KH)  # Correction coeff
 819
 820    # Apply correction
 821    eN /= mc
 822    cL *= mc
 823
 824    # Boost by xN
 825    eN *= xN
 826    cL *= xN
 827
 828    return eN, cL
 829
 830
 831def zeta_a(eN, cL, w):
 832    """EnKF-N inflation estimation via w.
 833
 834    Returns `zeta_a = (N-1)/pre-inflation^2`.
 835
 836    Using this inside an iterative minimization as in the
 837    `dapper.da_methods.variational.iEnKS` effectively blends
 838    the distinction between the primal and dual EnKF-N.
 839    """
 840    N = len(w)
 841    N1 = N - 1
 842    za = N1 * cL / (eN + w @ w)
 843    return za
 844
 845
 846@ens_method
 847class EnKF_N:
 848    """Finite-size EnKF (EnKF-N).
 849
 850    Refs: `bib.bocquet2011ensemble`, `bib.bocquet2015expanding`
 851
 852    This implementation is pedagogical, prioritizing the "dual" form.
 853    In consequence, the efficiency of the "primal" form suffers a bit.
 854    The primal form is included for completeness and to demonstrate equivalence.
 855    In `dapper.da_methods.variational.iEnKS`, however,
 856    the primal form is preferred because it
 857    already does optimization for w (as treatment for nonlinear models).
 858
 859    `infl` should be unnecessary (assuming no model error, or that Q is correct).
 860
 861    `Hess`: use non-approx Hessian for ensemble transform matrix?
 862
 863    `g` is the nullity of A (state anomalies's), ie. g=max(1,N-Nx),
 864    compensating for the redundancy in the space of w.
 865    But we have made it an input argument instead, with default 0,
 866    because mode-finding (of p(x) via the dual) completely ignores this redundancy,
 867    and the mode gets (undesireably) modified by g.
 868
 869    `xN` allows tuning the hyper-prior for the inflation.
 870    Usually, I just try setting it to 1 (default), or 2.
 871    Further description in hyperprior_coeffs().
 872    """
 873
 874    N: int
 875    dual: bool = False
 876    Hess: bool = False
 877    xN: float = 1.0
 878    g: int = 0
 879
 880    def assimilate(self, HMM, xx, yy):
 881        N, N1 = self.N, self.N - 1
 882
 883        # Init
 884        E = HMM.X0.sample(N)
 885        self.stats.assess(0, E=E)
 886
 887        # Cycle
 888        for k, ko, t, dt in progbar(HMM.tseq.ticker):
 889            # Forecast
 890            E = HMM.Dyn(E, t - dt, dt)
 891            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
 892
 893            # Analysis
 894            if ko is not None:
 895                self.stats.assess(k, ko, "f", E=E)
 896                Eo = HMM.Obs(ko)(E)
 897                y = yy[ko]
 898
 899                mu = np.mean(E, 0)
 900                A = E - mu
 901
 902                xo = np.mean(Eo, 0)
 903                Y = Eo - xo
 904                dy = y - xo
 905
 906                R = HMM.Obs(ko).noise.C
 907                V, s, UT = svd0(Y @ R.sym_sqrt_inv.T)
 908                du = UT @ (dy @ R.sym_sqrt_inv.T)
 909
 910                def dgn_N(l1):
 911                    return pad0((l1 * s) ** 2, N) + N1
 912
 913                # Adjust hyper-prior
 914                # xN_ = noise_level(self.xN, self.stats, HMM.tseq, N1, ko, A,
 915                #                   locals().get('A_old', None))
 916                eN, cL = hyperprior_coeffs(s, N, self.xN, self.g)
 917
 918                if self.dual:
 919                    # Make dual cost function (in terms of l1)
 920                    def pad_rk(arr):
 921                        return pad0(arr, min(N, len(y)))
 922
 923                    def dgn_rk(l1):
 924                        return pad_rk((l1 * s) ** 2) + N1
 925
 926                    def J(l1):
 927                        val = (
 928                            np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2)
 929                        )
 930                        return val
 931
 932                    # Derivatives (not required with minimize_scalar):
 933                    def Jp(l1):
 934                        val = (
 935                            -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2)
 936                            + -2 * eN / l1**3
 937                            + 2 * cL / l1
 938                        )
 939                        return val
 940
 941                    def Jpp(l1):
 942                        val = (
 943                            8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3)
 944                            + 6 * eN / l1**4
 945                            + -2 * cL / l1**2
 946                        )
 947                        return val
 948
 949                    # Find inflation factor (optimize)
 950                    l1 = Newton_m(Jp, Jpp, 1.0)
 951                    # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0)
 952                    # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2),
 953                    #                      tol=1e-4).x
 954
 955                else:
 956                    # Primal form, in a fully linearized version.
 957                    def za(w):
 958                        return zeta_a(eN, cL, w)
 959
 960                    def J(w):
 961                        return 0.5 * np.sum(
 962                            ((dy - w @ Y) @ R.sym_sqrt_inv.T) ** 2
 963                        ) + 0.5 * N1 * cL * np.log(eN + w @ w)
 964
 965                    # Derivatives (not required with fmin_bfgs):
 966                    def Jp(w):
 967                        return -Y @ R.inv @ (dy - w @ Y) + w * za(w)
 968
 969                    # Jpp   = lambda w:  Y@R.inv@Y.T + \
 970                    #     za(w)*(eye(N) - 2*np.outer(w,w)/(eN + w@w))
 971                    # Approx: no radial-angular cross-deriv:
 972                    # Jpp   = lambda w:  Y@R.inv@Y.T + za(w)*eye(N)
 973
 974                    def nvrs(w):
 975                        # inverse of Jpp-approx
 976                        return (V * (pad0(s**2, N) + za(w)) ** -1.0) @ V.T
 977
 978                    # Find w (optimize)
 979                    wa = Newton_m(Jp, nvrs, zeros(N), is_inverted=True)
 980                    # wa   = Newton_m(Jp,Jpp ,zeros(N))
 981                    # wa   = fmin_bfgs(J,zeros(N),Jp,disp=0)
 982                    l1 = sqrt(N1 / za(wa))
 983
 984                # Uncomment to revert to ETKF
 985                # l1 = 1.0
 986
 987                # Explicitly inflate prior
 988                # => formulae look different from `bib.bocquet2015expanding`.
 989                A *= l1
 990                Y *= l1
 991
 992                # Compute sqrt update
 993                Pw = (V * dgn_N(l1) ** (-1.0)) @ V.T
 994                w = dy @ R.inv @ Y.T @ Pw
 995                # For the anomalies:
 996                if not self.Hess:
 997                    # Regular ETKF (i.e. sym sqrt) update (with inflation)
 998                    T = (V * dgn_N(l1) ** (-0.5)) @ V.T * sqrt(N1)
 999                    # = (Y@R.inv@Y.T/N1 + eye(N))**(-0.5)
1000                else:
1001                    # Also include angular-radial co-dependence.
1002                    # Note: denominator not squared coz
1003                    # unlike `bib.bocquet2015expanding` we have inflated Y.
1004                    Hw = (
1005                        Y @ R.inv @ Y.T / N1
1006                        + eye(N)
1007                        - 2 * np.outer(w, w) / (eN + w @ w)
1008                    )
1009                    T = funm_psd(Hw, lambda x: x**-0.5)  # is there a sqrtm Woodbury?
1010
1011                E = mu + w @ A + T @ A
1012                E = post_process(E, self.infl, self.rot)
1013
1014                self.stats.infl[ko] = l1
1015                self.stats.trHK[ko] = (
1016                    ((l1 * s) ** 2 + N1) ** (-1.0) * s**2
1017                ).sum() / len(y)
1018
1019            self.stats.assess(k, ko, E=E)
@ens_method
class EnKF:
28@ens_method
29class EnKF:
30    """The ensemble Kalman filter.
31
32    Refs: `bib.evensen2009ensemble`.
33    """
34
35    upd_a: str
36    N: int
37
38    def assimilate(self, HMM, xx, yy):
39        # Init
40        E = HMM.X0.sample(self.N)
41        self.stats.assess(0, E=E)
42
43        # Cycle
44        for k, ko, t, dt in progbar(HMM.tseq.ticker):
45            E = HMM.Dyn(E, t - dt, dt)
46            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
47
48            # Analysis update
49            if ko is not None:
50                self.stats.assess(k, ko, "f", E=E)
51                E = EnKF_analysis(
52                    E,
53                    HMM.Obs(ko)(E),
54                    HMM.Obs(ko).noise,
55                    yy[ko],
56                    self.upd_a,
57                    self.stats,
58                    ko,
59                )
60                E = post_process(E, self.infl, self.rot)
61
62            self.stats.assess(k, ko, E=E)

The ensemble Kalman filter.

Refs: bib.evensen2009ensemble.

EnKF( upd_a: str, N: int, infl: float = 1.0, rot: bool = False, fnoise_treatm: str = 'Stoch')
upd_a: str
N: int
def assimilate(self, HMM, xx, yy):
38    def assimilate(self, HMM, xx, yy):
39        # Init
40        E = HMM.X0.sample(self.N)
41        self.stats.assess(0, E=E)
42
43        # Cycle
44        for k, ko, t, dt in progbar(HMM.tseq.ticker):
45            E = HMM.Dyn(E, t - dt, dt)
46            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
47
48            # Analysis update
49            if ko is not None:
50                self.stats.assess(k, ko, "f", E=E)
51                E = EnKF_analysis(
52                    E,
53                    HMM.Obs(ko)(E),
54                    HMM.Obs(ko).noise,
55                    yy[ko],
56                    self.upd_a,
57                    self.stats,
58                    ko,
59                )
60                E = post_process(E, self.infl, self.rot)
61
62            self.stats.assess(k, ko, E=E)
infl: float = 1.0
rot: bool = False
fnoise_treatm: str = 'Stoch'
def stat(self, name, value):
138        def stat(self, name, value):
139            dapper.stats.register_stat(self.stats, name, value)
da_method = 'EnKF'
def EnKF_analysis(E, Eo, hnoise, y, upd_a, stats=None, ko=None):
 65def EnKF_analysis(E, Eo, hnoise, y, upd_a, stats=None, ko=None):
 66    """Perform the EnKF analysis update.
 67
 68    This implementation includes several flavours and forms,
 69    specified by `upd_a`.
 70
 71    Main references: `bib.sakov2008deterministic`,
 72    `bib.sakov2008implications`, `bib.hoteit2015mitigating`
 73    """
 74    R = hnoise.C  # Obs noise cov
 75    N, Nx = E.shape  # Dimensionality
 76    N1 = N - 1  # Ens size - 1
 77
 78    mu = np.mean(E, 0)  # Ens mean
 79    A = E - mu  # Ens anomalies
 80
 81    xo = np.mean(Eo, 0)  # Obs ens mean
 82    Y = Eo - xo  # Obs ens anomalies
 83    dy = y - xo  # Mean "innovation"
 84
 85    if "PertObs" in upd_a:
 86        # Uses classic, perturbed observations (Burgers'98)
 87        C = Y.T @ Y + R.full * N1
 88        D = mean0(hnoise.sample(N))
 89        YC = mrdiv(Y, C)
 90        KG = A.T @ YC
 91        HK = Y.T @ YC
 92        dE = (KG @ (y - D - Eo).T).T
 93        E = E + dE
 94
 95    elif "Sqrt" in upd_a:
 96        # Uses a symmetric square root (ETKF)
 97        # to deterministically transform the ensemble.
 98
 99        # The various versions below differ only numerically.
100        # EVD is default, but for large N use SVD version.
101        if upd_a == "Sqrt" and N > Nx:
102            upd_a = "Sqrt svd"
103
104        if "explicit" in upd_a:
105            # Not recommended due to numerical costs and instability.
106            # Implementation using inv (in ens space)
107            Pw = sla.inv(Y @ R.inv @ Y.T + N1 * eye(N))
108            T = sla.sqrtm(Pw) * sqrt(N1)
109            HK = R.inv @ Y.T @ Pw @ Y
110            # KG = R.inv @ Y.T @ Pw @ A
111        elif "svd" in upd_a:
112            # Implementation using svd of Y R^{-1/2}.
113            V, s, _ = svd0(Y @ R.sym_sqrt_inv.T)
114            d = pad0(s**2, N) + N1
115            Pw = (V * d ** (-1.0)) @ V.T
116            T = (V * d ** (-0.5)) @ V.T * sqrt(N1)
117            # docs/snippets/trHK.jpg
118            trHK = np.sum((s**2 + N1) ** (-1.0) * s**2)
119        elif "sS" in upd_a:
120            # Same as 'svd', but with slightly different notation
121            # (sometimes used by Sakov) using the normalization sqrt(N1).
122            S = Y @ R.sym_sqrt_inv.T / sqrt(N1)
123            V, s, _ = svd0(S)
124            d = pad0(s**2, N) + 1
125            Pw = (V * d ** (-1.0)) @ V.T / N1  # = G/(N1)
126            T = (V * d ** (-0.5)) @ V.T
127            # docs/snippets/trHK.jpg
128            trHK = np.sum((s**2 + 1) ** (-1.0) * s**2)
129        else:  # 'eig' in upd_a:
130            # Implementation using eig. val. decomp.
131            d, V = sla.eigh(Y @ R.inv @ Y.T + N1 * eye(N))
132            T = V @ diag(d ** (-0.5)) @ V.T * sqrt(N1)
133            Pw = V @ diag(d ** (-1.0)) @ V.T
134            HK = R.inv @ Y.T @ (V @ diag(d ** (-1)) @ V.T) @ Y
135        w = dy @ R.inv @ Y.T @ Pw
136        E = mu + w @ A + T @ A
137
138    elif "Serial" in upd_a:
139        # Observations assimilated one-at-a-time:
140        inds = serial_inds(upd_a, y, R, A)
141        #  Requires de-correlation:
142        dy = dy @ R.sym_sqrt_inv.T
143        Y = Y @ R.sym_sqrt_inv.T
144        # Enhancement in the nonlinear case:
145        # re-compute Y each scalar obs assim.
146        # But: little benefit, model costly (?),
147        # updates cannot be accumulated on S and T.
148
149        if any(x in upd_a for x in ["Stoch", "ESOPS", "Var1"]):
150            # More details: Misc/Serial_ESOPS.py.
151            for i, j in enumerate(inds):
152                # Perturbation creation
153                if "ESOPS" in upd_a:
154                    # "2nd-O exact perturbation sampling"
155                    if i == 0:
156                        # Init -- increase nullspace by 1
157                        V, s, UT = svd0(A)
158                        s[N - 2 :] = 0
159                        A = svdi(V, s, UT)
160                        v = V[:, N - 2]
161                    else:
162                        # Orthogonalize v wrt. the new A
163                        #
164                        # v = Zj - Yj (from paper) requires Y==HX.
165                        # Instead: mult` should be c*ones(Nx) so we can
166                        # project v into ker(A) such that v@A is null.
167                        mult = (v @ A) / (Yj @ A)  # noqa
168                        v = v - mult[0] * Yj  # noqa
169                        v /= sqrt(v @ v)
170                    Zj = v * sqrt(N1)  # Standardized perturbation along v
171                    Zj *= np.sign(rng.standard_normal() - 0.5)  # Random sign
172                else:
173                    # The usual stochastic perturbations.
174                    Zj = mean0(rng.standard_normal(N))  # Un-coloured noise
175                    if "Var1" in upd_a:
176                        Zj *= sqrt(N / (Zj @ Zj))
177
178                # Select j-th obs
179                Yj = Y[:, j]  # [j] obs anomalies
180                dyj = dy[j]  # [j] innov mean
181                DYj = Zj - Yj  # [j] innov anomalies
182                DYj = DYj[:, None]  # Make 2d vertical
183
184                # Kalman gain computation
185                C = Yj @ Yj + N1  # Total obs cov
186                KGx = Yj @ A / C  # KG to update state
187                KGy = Yj @ Y / C  # KG to update obs
188
189                # Updates
190                A += DYj * KGx
191                mu += dyj * KGx
192                Y += DYj * KGy
193                dy -= dyj * KGy
194            E = mu + A
195        else:
196            # "Potter scheme", "EnSRF"
197            # - EAKF's two-stage "update-regress" form yields
198            #   the same *ensemble* as this.
199            # - The form below may be derived as "serial ETKF",
200            #   but does not yield the same
201            #   ensemble as 'Sqrt' (which processes obs as a batch)
202            #   -- only the same mean/cov.
203            T = eye(N)
204            for j in inds:
205                Yj = Y[:, j]
206                C = Yj @ Yj + N1
207                Tj = np.outer(Yj, Yj / (C + sqrt(N1 * C)))
208                T -= Tj @ T
209                Y -= Tj @ Y
210            w = dy @ Y.T @ T / N1
211            E = mu + w @ A + T @ A
212
213    elif "DEnKF" == upd_a:
214        # Uses "Deterministic EnKF" (sakov'08)
215        C = Y.T @ Y + R.full * N1
216        YC = mrdiv(Y, C)
217        KG = A.T @ YC
218        HK = Y.T @ YC
219        E = E + KG @ dy - 0.5 * (KG @ Y.T).T
220
221    else:
222        raise KeyError("No analysis update method found: '" + upd_a + "'.")
223
224    # Diagnostic: relative influence of observations
225    if stats is not None:
226        if "trHK" in locals():
227            stats.trHK[ko] = trHK / hnoise.M
228        elif "HK" in locals():
229            stats.trHK[ko] = HK.trace() / hnoise.M
230
231    return E

Perform the EnKF analysis update.

This implementation includes several flavours and forms, specified by upd_a.

Main references: bib.sakov2008deterministic, bib.sakov2008implications, bib.hoteit2015mitigating

def post_process(E, infl, rot):
234def post_process(E, infl, rot):
235    """Inflate, Rotate.
236
237    To avoid recomputing/recombining anomalies,
238    this should have been inside `EnKF_analysis`
239
240    But it is kept as a separate function
241
242    - for readability;
243    - to avoid inflating/rotationg smoothed states (for the `EnKS`).
244    """
245    do_infl = infl != 1.0 and infl != "-N"
246
247    if do_infl or rot:
248        A, mu = center(E)
249        N, Nx = E.shape
250        T = eye(N)
251
252        if do_infl:
253            T = infl * T
254
255        if rot:
256            T = genOG_1(N, rot) @ T
257
258        E = mu + T @ A
259    return E

Inflate, Rotate.

To avoid recomputing/recombining anomalies, this should have been inside EnKF_analysis

But it is kept as a separate function

  • for readability;
  • to avoid inflating/rotationg smoothed states (for the EnKS).
def add_noise(E, dt, noise, method):
262def add_noise(E, dt, noise, method):
263    """Treatment of additive noise for ensembles.
264
265    Refs: `bib.raanes2014ext`
266    """
267    if noise.C == 0:
268        return E
269
270    N, Nx = E.shape
271    A, mu = center(E)
272    Q12 = noise.C.Left
273    Q = noise.C.full
274
275    def sqrt_core():
276        T = np.nan  # cause error if used
277        Qa12 = np.nan  # cause error if used
278        A2 = A.copy()  # Instead of using (the implicitly nonlocal) A,
279        # which changes A outside as well. NB: This is a bug in Datum!
280        if N <= Nx:
281            Ainv = tinv(A2.T)
282            Qa12 = Ainv @ Q12
283            T = funm_psd(eye(N) + dt * (N - 1) * (Qa12 @ Qa12.T), sqrt)
284            A2 = T @ A2
285        else:  # "Left-multiplying" form
286            P = A2.T @ A2 / (N - 1)
287            L = funm_psd(eye(Nx) + dt * mrdiv(Q, P), sqrt)
288            A2 = A2 @ L.T
289        E = mu + A2
290        return E, T, Qa12
291
292    if method == "Stoch":
293        # In-place addition works (also) for empty [] noise sample.
294        E += sqrt(dt) * noise.sample(N)
295
296    elif method == "none":
297        pass
298
299    elif method == "Mult-1":
300        varE = np.var(E, axis=0, ddof=1).sum()
301        ratio = (varE + dt * diag(Q).sum()) / varE
302        E = mu + sqrt(ratio) * A
303        E = svdi(*tsvd(E, 0.999))  # Explained in Datum
304
305    elif method == "Mult-M":
306        varE = np.var(E, axis=0)
307        ratios = sqrt((varE + dt * diag(Q)) / varE)
308        E = mu + A * ratios
309        E = svdi(*tsvd(E, 0.999))  # Explained in Datum
310
311    elif method == "Sqrt-Core":
312        E = sqrt_core()[0]
313
314    elif method == "Sqrt-Mult-1":
315        varE0 = np.var(E, axis=0, ddof=1).sum()
316        varE2 = varE0 + dt * diag(Q).sum()
317        E, _, Qa12 = sqrt_core()
318        if N <= Nx:
319            A, mu = center(E)
320            varE1 = np.var(E, axis=0, ddof=1).sum()
321            ratio = varE2 / varE1
322            E = mu + sqrt(ratio) * A
323            E = svdi(*tsvd(E, 0.999))  # Explained in Datum
324
325    elif method == "Sqrt-Add-Z":
326        E, _, Qa12 = sqrt_core()
327        if N <= Nx:
328            Z = Q12 - A.T @ Qa12
329            E += sqrt(dt) * (Z @ rng.standard_normal((Z.shape[1], N))).T
330
331    elif method == "Sqrt-Dep":
332        E, T, Qa12 = sqrt_core()
333        if N <= Nx:
334            # Q_hat12: reuse svd for both inversion and projection.
335            Q_hat12 = A.T @ Qa12
336            U, s, VT = tsvd(Q_hat12, 0.99)
337            Q_hat12_inv = (VT.T * s ** (-1.0)) @ U.T
338            Q_hat12_proj = VT.T @ VT
339            rQ = Q12.shape[1]
340            # Calc D_til
341            Z = Q12 - Q_hat12
342            D_hat = A.T @ (T - eye(N))
343            Xi_hat = Q_hat12_inv @ D_hat
344            Xi_til = (eye(rQ) - Q_hat12_proj) @ rng.standard_normal((rQ, N))
345            D_til = Z @ (Xi_hat + sqrt(dt) * Xi_til)
346            E += D_til.T
347
348    else:
349        raise KeyError("No such method")
350
351    return E

Treatment of additive noise for ensembles.

Refs: bib.raanes2014ext

@ens_method
class EnKS:
354@ens_method
355class EnKS:
356    """The ensemble Kalman smoother.
357
358    Refs: `bib.evensen2009ensemble`
359
360    The only difference to the EnKF
361    is the management of the lag and the reshapings.
362    """
363
364    upd_a: str
365    N: int
366    Lag: int
367
368    # Reshapings used in smoothers to go to/from
369    # 3D arrays, where the 0th axis is the Lag index.
370    def reshape_to(self, E):
371        K, N, Nx = E.shape
372        return E.transpose([1, 0, 2]).reshape((N, K * Nx))
373
374    def reshape_fr(self, E, Nx):
375        N, Km = E.shape
376        K = Km // Nx
377        return E.reshape((N, K, Nx)).transpose([1, 0, 2])
378
379    def assimilate(self, HMM, xx, yy):
380        # Inefficient version, storing full time series ensemble.
381        # See iEnKS for a "rolling" version.
382        E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M))
383        E[0] = HMM.X0.sample(self.N)
384
385        for k, ko, t, dt in progbar(HMM.tseq.ticker):
386            E[k] = HMM.Dyn(E[k - 1], t - dt, dt)
387            E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm)
388
389            if ko is not None:
390                self.stats.assess(k, ko, "f", E=E[k])
391
392                Eo = HMM.Obs(ko)(E[k])
393                y = yy[ko]
394
395                # Inds within Lag
396                kk = range(max(0, k - self.Lag * HMM.tseq.dko), k + 1)
397
398                EE = E[kk]
399
400                EE = self.reshape_to(EE)
401                EE = EnKF_analysis(
402                    EE, Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko
403                )
404                E[kk] = self.reshape_fr(EE, HMM.Dyn.M)
405                E[k] = post_process(E[k], self.infl, self.rot)
406                self.stats.assess(k, ko, "a", E=E[k])
407
408        for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"):
409            self.stats.assess(k, ko, "u", E=E[k])
410            if ko is not None:
411                self.stats.assess(k, ko, "s", E=E[k])

The ensemble Kalman smoother.

Refs: bib.evensen2009ensemble

The only difference to the EnKF is the management of the lag and the reshapings.

EnKS( upd_a: str, N: int, Lag: int, infl: float = 1.0, rot: bool = False, fnoise_treatm: str = 'Stoch')
upd_a: str
N: int
Lag: int
def reshape_to(self, E):
370    def reshape_to(self, E):
371        K, N, Nx = E.shape
372        return E.transpose([1, 0, 2]).reshape((N, K * Nx))
def reshape_fr(self, E, Nx):
374    def reshape_fr(self, E, Nx):
375        N, Km = E.shape
376        K = Km // Nx
377        return E.reshape((N, K, Nx)).transpose([1, 0, 2])
def assimilate(self, HMM, xx, yy):
379    def assimilate(self, HMM, xx, yy):
380        # Inefficient version, storing full time series ensemble.
381        # See iEnKS for a "rolling" version.
382        E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M))
383        E[0] = HMM.X0.sample(self.N)
384
385        for k, ko, t, dt in progbar(HMM.tseq.ticker):
386            E[k] = HMM.Dyn(E[k - 1], t - dt, dt)
387            E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm)
388
389            if ko is not None:
390                self.stats.assess(k, ko, "f", E=E[k])
391
392                Eo = HMM.Obs(ko)(E[k])
393                y = yy[ko]
394
395                # Inds within Lag
396                kk = range(max(0, k - self.Lag * HMM.tseq.dko), k + 1)
397
398                EE = E[kk]
399
400                EE = self.reshape_to(EE)
401                EE = EnKF_analysis(
402                    EE, Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko
403                )
404                E[kk] = self.reshape_fr(EE, HMM.Dyn.M)
405                E[k] = post_process(E[k], self.infl, self.rot)
406                self.stats.assess(k, ko, "a", E=E[k])
407
408        for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"):
409            self.stats.assess(k, ko, "u", E=E[k])
410            if ko is not None:
411                self.stats.assess(k, ko, "s", E=E[k])
infl: float = 1.0
rot: bool = False
fnoise_treatm: str = 'Stoch'
def stat(self, name, value):
138        def stat(self, name, value):
139            dapper.stats.register_stat(self.stats, name, value)
da_method = 'EnKS'
@ens_method
class EnRTS:
414@ens_method
415class EnRTS:
416    """EnRTS (Rauch-Tung-Striebel) smoother.
417
418    Refs: `bib.raanes2016thesis`
419    """
420
421    upd_a: str
422    N: int
423    DeCorr: float
424
425    def assimilate(self, HMM, xx, yy):
426        E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M))
427        Ef = E.copy()
428        E[0] = HMM.X0.sample(self.N)
429
430        # Forward pass
431        for k, ko, t, dt in progbar(HMM.tseq.ticker):
432            E[k] = HMM.Dyn(E[k - 1], t - dt, dt)
433            E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm)
434            Ef[k] = E[k]
435
436            if ko is not None:
437                self.stats.assess(k, ko, "f", E=E[k])
438                Eo = HMM.Obs(ko)(E[k])
439                y = yy[ko]
440                E[k] = EnKF_analysis(
441                    E[k], Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko
442                )
443                E[k] = post_process(E[k], self.infl, self.rot)
444                self.stats.assess(k, ko, "a", E=E[k])
445
446        # Backward pass
447        for k in progbar(range(HMM.tseq.K)[::-1]):
448            A = center(E[k])[0]
449            Af = center(Ef[k + 1])[0]
450
451            J = tinv(Af) @ A
452            J *= self.DeCorr
453
454            E[k] += (E[k + 1] - Ef[k + 1]) @ J
455
456        for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"):
457            self.stats.assess(k, ko, "u", E=E[k])
458            if ko is not None:
459                self.stats.assess(k, ko, "s", E=E[k])

EnRTS (Rauch-Tung-Striebel) smoother.

Refs: bib.raanes2016thesis

EnRTS( upd_a: str, N: int, DeCorr: float, infl: float = 1.0, rot: bool = False, fnoise_treatm: str = 'Stoch')
upd_a: str
N: int
DeCorr: float
def assimilate(self, HMM, xx, yy):
425    def assimilate(self, HMM, xx, yy):
426        E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M))
427        Ef = E.copy()
428        E[0] = HMM.X0.sample(self.N)
429
430        # Forward pass
431        for k, ko, t, dt in progbar(HMM.tseq.ticker):
432            E[k] = HMM.Dyn(E[k - 1], t - dt, dt)
433            E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm)
434            Ef[k] = E[k]
435
436            if ko is not None:
437                self.stats.assess(k, ko, "f", E=E[k])
438                Eo = HMM.Obs(ko)(E[k])
439                y = yy[ko]
440                E[k] = EnKF_analysis(
441                    E[k], Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko
442                )
443                E[k] = post_process(E[k], self.infl, self.rot)
444                self.stats.assess(k, ko, "a", E=E[k])
445
446        # Backward pass
447        for k in progbar(range(HMM.tseq.K)[::-1]):
448            A = center(E[k])[0]
449            Af = center(Ef[k + 1])[0]
450
451            J = tinv(Af) @ A
452            J *= self.DeCorr
453
454            E[k] += (E[k + 1] - Ef[k + 1]) @ J
455
456        for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"):
457            self.stats.assess(k, ko, "u", E=E[k])
458            if ko is not None:
459                self.stats.assess(k, ko, "s", E=E[k])
infl: float = 1.0
rot: bool = False
fnoise_treatm: str = 'Stoch'
def stat(self, name, value):
138        def stat(self, name, value):
139            dapper.stats.register_stat(self.stats, name, value)
da_method = 'EnRTS'
def serial_inds(upd_a, y, cvR, A):
462def serial_inds(upd_a, y, cvR, A):
463    """Get the indices used for serial updating.
464
465    - Default: random ordering
466    - if "mono" in `upd_a`: `1, 2, ..., len(y)`
467    - if "sorted" in `upd_a`: sort by variance
468    """
469    if "mono" in upd_a:
470        # Not robust?
471        inds = np.arange(len(y))
472    elif "sorted" in upd_a:
473        N = len(A)
474        dC = cvR.diag
475        if np.all(dC == dC[0]):
476            # Sort y by P
477            dC = np.sum(A * A, 0) / (N - 1)
478        inds = np.argsort(dC)
479    else:  # Default: random ordering
480        inds = rng.permutation(len(y))
481    return inds

Get the indices used for serial updating.

  • Default: random ordering
  • if "mono" in upd_a: 1, 2, ..., len(y)
  • if "sorted" in upd_a: sort by variance
@ens_method
class SL_EAKF:
484@ens_method
485class SL_EAKF:
486    """Serial, covariance-localized EAKF.
487
488    Refs: `bib.karspeck2007experimental`.
489
490    In contrast with LETKF, this iterates over the observations rather
491    than over the state (batches).
492
493    Used without localization, this should be equivalent (full ensemble equality)
494    to the `EnKF` with `upd_a='Serial'`.
495    """
496
497    N: int
498    loc_rad: float
499    taper: str = "GC"
500    ordr: str = "rand"
501
502    def assimilate(self, HMM, xx, yy):
503        N1 = self.N - 1
504
505        E = HMM.X0.sample(self.N)
506        self.stats.assess(0, E=E)
507
508        for k, ko, t, dt in progbar(HMM.tseq.ticker):
509            E = HMM.Dyn(E, t - dt, dt)
510            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
511
512            if ko is not None:
513                self.stats.assess(k, ko, "f", E=E)
514                Obs = HMM.Obs(ko)
515                R = Obs.noise
516                y = yy[ko]
517                inds = serial_inds(self.ordr, y, R, center(E)[0])
518                Rm12 = Obs.noise.C.sym_sqrt_inv
519
520                state_taperer = Obs.localizer(self.loc_rad, "y2x", self.taper)
521                for j in inds:
522                    # Prep:
523                    # ------------------------------------------------------
524                    Eo = Obs(E)
525                    xo = np.mean(Eo, 0)
526                    Y = Eo - xo
527                    mu = np.mean(E, 0)
528                    A = E - mu
529                    # Update j-th component of observed ensemble:
530                    # ------------------------------------------------------
531                    Y_j = Rm12[j, :] @ Y.T
532                    dy_j = Rm12[j, :] @ (y - xo)
533                    # Prior var * N1:
534                    sig2_j = Y_j @ Y_j
535                    if sig2_j < 1e-9:
536                        continue
537                    # Update (below, we drop the locality subscript: _j)
538                    sig2_u = 1 / (1 / sig2_j + 1 / N1)  # Postr. var * N1
539                    alpha = (N1 / (N1 + sig2_j)) ** (0.5)  # Update contraction factor
540                    dy2 = sig2_u * dy_j / N1  # Mean update
541                    Y2 = alpha * Y_j  # Anomaly update
542                    # Update state (regress update from obs space, using localization)
543                    # ------------------------------------------------------
544                    ii, tapering = state_taperer(j)
545                    # ii, tapering = ..., 1  # cancel localization
546                    if len(ii) == 0:
547                        continue
548                    Xi = A[:, ii] * tapering
549                    Regression = Xi.T @ Y_j / np.sum(Y_j**2)
550                    mu[ii] += Regression * dy2
551                    A[:, ii] += np.outer(Y2 - Y_j, Regression)
552                    E = mu + A
553
554                E = post_process(E, self.infl, self.rot)
555
556            self.stats.assess(k, ko, E=E)

Serial, covariance-localized EAKF.

Refs: bib.karspeck2007experimental.

In contrast with LETKF, this iterates over the observations rather than over the state (batches).

Used without localization, this should be equivalent (full ensemble equality) to the EnKF with upd_a='Serial'.

SL_EAKF( N: int, loc_rad: float, taper: str = 'GC', ordr: str = 'rand', infl: float = 1.0, rot: bool = False, fnoise_treatm: str = 'Stoch')
N: int
loc_rad: float
taper: str = 'GC'
ordr: str = 'rand'
def assimilate(self, HMM, xx, yy):
502    def assimilate(self, HMM, xx, yy):
503        N1 = self.N - 1
504
505        E = HMM.X0.sample(self.N)
506        self.stats.assess(0, E=E)
507
508        for k, ko, t, dt in progbar(HMM.tseq.ticker):
509            E = HMM.Dyn(E, t - dt, dt)
510            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
511
512            if ko is not None:
513                self.stats.assess(k, ko, "f", E=E)
514                Obs = HMM.Obs(ko)
515                R = Obs.noise
516                y = yy[ko]
517                inds = serial_inds(self.ordr, y, R, center(E)[0])
518                Rm12 = Obs.noise.C.sym_sqrt_inv
519
520                state_taperer = Obs.localizer(self.loc_rad, "y2x", self.taper)
521                for j in inds:
522                    # Prep:
523                    # ------------------------------------------------------
524                    Eo = Obs(E)
525                    xo = np.mean(Eo, 0)
526                    Y = Eo - xo
527                    mu = np.mean(E, 0)
528                    A = E - mu
529                    # Update j-th component of observed ensemble:
530                    # ------------------------------------------------------
531                    Y_j = Rm12[j, :] @ Y.T
532                    dy_j = Rm12[j, :] @ (y - xo)
533                    # Prior var * N1:
534                    sig2_j = Y_j @ Y_j
535                    if sig2_j < 1e-9:
536                        continue
537                    # Update (below, we drop the locality subscript: _j)
538                    sig2_u = 1 / (1 / sig2_j + 1 / N1)  # Postr. var * N1
539                    alpha = (N1 / (N1 + sig2_j)) ** (0.5)  # Update contraction factor
540                    dy2 = sig2_u * dy_j / N1  # Mean update
541                    Y2 = alpha * Y_j  # Anomaly update
542                    # Update state (regress update from obs space, using localization)
543                    # ------------------------------------------------------
544                    ii, tapering = state_taperer(j)
545                    # ii, tapering = ..., 1  # cancel localization
546                    if len(ii) == 0:
547                        continue
548                    Xi = A[:, ii] * tapering
549                    Regression = Xi.T @ Y_j / np.sum(Y_j**2)
550                    mu[ii] += Regression * dy2
551                    A[:, ii] += np.outer(Y2 - Y_j, Regression)
552                    E = mu + A
553
554                E = post_process(E, self.infl, self.rot)
555
556            self.stats.assess(k, ko, E=E)
infl: float = 1.0
rot: bool = False
fnoise_treatm: str = 'Stoch'
def stat(self, name, value):
138        def stat(self, name, value):
139            dapper.stats.register_stat(self.stats, name, value)
da_method = 'SL_EAKF'
def local_analyses( E, Eo, R, y, state_batches, obs_taperer, mp=<class 'map'>, xN=None, g=0):
559def local_analyses(E, Eo, R, y, state_batches, obs_taperer, mp=map, xN=None, g=0):
560    """Perform local analysis update for the LETKF."""
561
562    def local_analysis(ii):
563        """Perform analysis, for state index batch `ii`."""
564        # Locate local domain
565        oBatch, tapering = obs_taperer(ii)
566        Eii = E[:, ii]
567
568        # No update
569        if len(oBatch) == 0:
570            return Eii, 1
571
572        # Localize
573        Yl = Y[:, oBatch]
574        dyl = dy[oBatch]
575        tpr = sqrt(tapering)
576
577        # Adaptive inflation estimation.
578        # NB: Localisation is not 100% compatible with the EnKF-N, since
579        # - After localisation there is much less need for inflation.
580        # - Tapered values (Y, dy) are too neat
581        #   (the EnKF-N expects a normal amount of sampling error).
582        # One fix is to tune xN (maybe set it to 2 or 3). Thanks to adaptivity,
583        # this should still be easier than tuning the inflation factor.
584        infl1 = 1 if xN is None else sqrt(N1 / effective_N(Yl, dyl, xN, g))
585        Eii, Yl = inflate_ens(Eii, infl1), Yl * infl1
586        # Since R^{-1/2} was already applied (necesry for effective_N), now use R=Id.
587        # TODO 4: the cost of re-init this R might not always be insignificant.
588        R = GaussRV(C=1, M=len(dyl))
589
590        # Update
591        Eii = EnKF_analysis(Eii, Yl * tpr, R, dyl * tpr, "Sqrt")
592
593        return Eii, infl1
594
595    # Prepare analysis
596    N1 = len(E) - 1
597    Y, xo = center(Eo)
598    # Transform obs space
599    Y = Y @ R.sym_sqrt_inv.T
600    dy = (y - xo) @ R.sym_sqrt_inv.T
601
602    # Run
603    result = mp(local_analysis, state_batches)
604
605    # Assign
606    E_batches, infl1 = zip(*result)
607    # TODO: this overwrites E, possibly unbeknownst to caller
608    for ii, Eii in zip(state_batches, E_batches):
609        E[:, ii] = Eii
610
611    return E, dict(ad_inf=sqrt(np.mean(np.array(infl1) ** 2)))

Perform local analysis update for the LETKF.

@ens_method
class LETKF:
614@ens_method
615class LETKF:
616    """Same as EnKF (Sqrt), but with localization.
617
618    Refs: `bib.hunt2007efficient`.
619
620    NB: Multiproc. yields slow-down for `dapper.mods.Lorenz96`,
621    even with `batch_size=(1,)`. But for `dapper.mods.QG`
622    (`batch_size=(2,2)` or less) it is quicker.
623
624    NB: If `len(ii)` is small, analysis may be slowed-down with '-N' infl.
625    """
626
627    N: int
628    loc_rad: float
629    taper: str = "GC"
630    xN: float = None
631    g: int = 0
632    mp: bool = False
633
634    def assimilate(self, HMM, xx, yy):
635        E = HMM.X0.sample(self.N)
636        self.stats.assess(0, E=E)
637        self.stats.new_series("ad_inf", 1, HMM.tseq.Ko + 1)
638
639        with multiproc.Pool(self.mp) as pool:
640            for k, ko, t, dt in progbar(HMM.tseq.ticker):
641                E = HMM.Dyn(E, t - dt, dt)
642                E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
643
644                if ko is not None:
645                    self.stats.assess(k, ko, "f", E=E)
646                    Obs = HMM.Obs(ko)
647                    batch, taper = Obs.localizer(self.loc_rad, "x2y", self.taper)
648                    E, stats = local_analyses(
649                        E,
650                        Obs(E),
651                        Obs.noise.C,
652                        yy[ko],
653                        batch,
654                        taper,
655                        pool.map,
656                        self.xN,
657                        self.g,
658                    )
659                    self.stats.write(stats, k, ko, "a")
660                    E = post_process(E, self.infl, self.rot)
661
662                self.stats.assess(k, ko, E=E)

Same as EnKF (Sqrt), but with localization.

Refs: bib.hunt2007efficient.

NB: Multiproc. yields slow-down for dapper.mods.Lorenz96, even with batch_size=(1,). But for dapper.mods.QG (batch_size=(2,2) or less) it is quicker.

NB: If len(ii) is small, analysis may be slowed-down with '-N' infl.

LETKF( N: int, loc_rad: float, taper: str = 'GC', xN: float = None, g: int = 0, mp: bool = False, infl: float = 1.0, rot: bool = False, fnoise_treatm: str = 'Stoch')
N: int
loc_rad: float
taper: str = 'GC'
xN: float = None
g: int = 0
mp: bool = False
def assimilate(self, HMM, xx, yy):
634    def assimilate(self, HMM, xx, yy):
635        E = HMM.X0.sample(self.N)
636        self.stats.assess(0, E=E)
637        self.stats.new_series("ad_inf", 1, HMM.tseq.Ko + 1)
638
639        with multiproc.Pool(self.mp) as pool:
640            for k, ko, t, dt in progbar(HMM.tseq.ticker):
641                E = HMM.Dyn(E, t - dt, dt)
642                E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
643
644                if ko is not None:
645                    self.stats.assess(k, ko, "f", E=E)
646                    Obs = HMM.Obs(ko)
647                    batch, taper = Obs.localizer(self.loc_rad, "x2y", self.taper)
648                    E, stats = local_analyses(
649                        E,
650                        Obs(E),
651                        Obs.noise.C,
652                        yy[ko],
653                        batch,
654                        taper,
655                        pool.map,
656                        self.xN,
657                        self.g,
658                    )
659                    self.stats.write(stats, k, ko, "a")
660                    E = post_process(E, self.infl, self.rot)
661
662                self.stats.assess(k, ko, E=E)
infl: float = 1.0
rot: bool = False
fnoise_treatm: str = 'Stoch'
def stat(self, name, value):
138        def stat(self, name, value):
139            dapper.stats.register_stat(self.stats, name, value)
da_method = 'LETKF'
def effective_N(YR, dyR, xN, g):
665def effective_N(YR, dyR, xN, g):
666    """Effective ensemble size N.
667
668    As measured by the finite-size EnKF-N
669    """
670    N, Ny = YR.shape
671    N1 = N - 1
672
673    V, s, UT = svd0(YR)
674    du = UT @ dyR
675
676    eN, cL = hyperprior_coeffs(s, N, xN, g)
677
678    def pad_rk(arr):
679        return pad0(arr, min(N, Ny))
680
681    def dgn_rk(l1):
682        return pad_rk((l1 * s) ** 2) + N1
683
684    # Make dual cost function (in terms of l1)
685    def J(l1):
686        val = np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2)
687        return val
688
689    # Derivatives (not required with minimize_scalar):
690    def Jp(l1):
691        val = (
692            -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2)
693            + -2 * eN / l1**3
694            + 2 * cL / l1
695        )
696        return val
697
698    def Jpp(l1):
699        val = (
700            8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3)
701            + 6 * eN / l1**4
702            + -2 * cL / l1**2
703        )
704        return val
705
706    # Find inflation factor (optimize)
707    l1 = Newton_m(Jp, Jpp, 1.0)
708    # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0)
709    # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2), tol=1e-4).x
710
711    za = N1 / l1**2
712    return za

Effective ensemble size N.

As measured by the finite-size EnKF-N

def Newton_m( fun, deriv, x0, is_inverted=False, conf=1.0, xtol=0.0001, ytol=1e-07, itermax=100):
740def Newton_m(
741    fun, deriv, x0, is_inverted=False, conf=1.0, xtol=1e-4, ytol=1e-7, itermax=10**2
742):
743    """Find root of `fun`.
744
745    This is a simple (and pretty fast) implementation of Newton's method.
746    """
747    itr = 0
748    dx = np.inf
749    Jx = fun(x0)
750
751    def norm(x):
752        return sqrt(np.sum(x**2))
753
754    while ytol < norm(Jx) and xtol < norm(dx) and itr < itermax:
755        Dx = deriv(x0)
756        if is_inverted:
757            dx = Dx @ Jx
758        elif isinstance(Dx, float):
759            dx = Jx / Dx
760        else:
761            dx = mldiv(Dx, Jx)
762        dx *= conf
763        x0 -= dx
764        Jx = fun(x0)
765        itr += 1
766    return x0

Find root of fun.

This is a simple (and pretty fast) implementation of Newton's method.

def hyperprior_coeffs(s, N, xN=1, g=0):
769def hyperprior_coeffs(s, N, xN=1, g=0):
770    r"""Set EnKF-N inflation hyperparams.
771
772    The EnKF-N prior may be specified by the constants:
773
774    - `eN`: Effect of unknown mean
775    - `cL`: Coeff in front of log term
776
777    These are trivial constants in the original EnKF-N,
778    but are further adjusted (corrected and tuned) for the following reasons.
779
780    - Reason 1: mode correction.
781      These parameters bridge the Jeffreys (`xN=1`) and Dirac (`xN=Inf`) hyperpriors
782      for the prior covariance, B, as discussed in `bib.bocquet2015expanding`.
783      Indeed, mode correction becomes necessary when $$ R \rightarrow \infty $$
784      because then there should be no ensemble update (and also no inflation!).
785      More specifically, the mode of `l1`'s should be adjusted towards 1
786      as a function of $$ I - K H $$ ("prior's weight").
787      PS: why do we leave the prior mode below 1 at all?
788      Because it sets up "tension" (negative feedback) in the inflation cycle:
789      the prior pulls downwards, while the likelihood tends to pull upwards.
790
791    - Reason 2: Boosting the inflation prior's certainty from N to xN*N.
792      The aim is to take advantage of the fact that the ensemble may not
793      have quite as much sampling error as a fully stochastic sample,
794      as illustrated in section 2.1 of `bib.raanes2019adaptive`.
795
796    - Its damping effect is similar to work done by J. Anderson.
797
798    The tuning is controlled by:
799
800    - `xN=1`: is fully agnostic, i.e. assumes the ensemble is generated
801      from a highly chaotic or stochastic model.
802    - `xN>1`: increases the certainty of the hyper-prior,
803      which is appropriate for more linear and deterministic systems.
804    - `xN<1`: yields a more (than 'fully') agnostic hyper-prior,
805      as if N were smaller than it truly is.
806    - `xN<=0` is not meaningful.
807    """
808    N1 = N - 1
809
810    eN = (N + 1) / N
811    cL = (N + g) / N1
812
813    # Mode correction (almost) as in eqn 36 of `bib.bocquet2015expanding`
814    prior_mode = eN / cL  # Mode of l1 (before correction)
815    diagonal = pad0(s**2, N) + N1  # diag of Y@R.inv@Y + N1*I
816    #                                           (Hessian of J)
817    I_KH = np.mean(diagonal ** (-1)) * N1  # ≈ 1/(1 + HBH/R)
818    # I_KH      = 1/(1 + (s**2).sum()/N1)     # Scalar alternative: use tr(HBH/R).
819    mc = sqrt(prior_mode**I_KH)  # Correction coeff
820
821    # Apply correction
822    eN /= mc
823    cL *= mc
824
825    # Boost by xN
826    eN *= xN
827    cL *= xN
828
829    return eN, cL

Set EnKF-N inflation hyperparams.

The EnKF-N prior may be specified by the constants:

  • eN: Effect of unknown mean
  • cL: Coeff in front of log term

These are trivial constants in the original EnKF-N, but are further adjusted (corrected and tuned) for the following reasons.

  • Reason 1: mode correction. These parameters bridge the Jeffreys (xN=1) and Dirac (xN=Inf) hyperpriors for the prior covariance, B, as discussed in bib.bocquet2015expanding. Indeed, mode correction becomes necessary when $$ R \rightarrow \infty $$ because then there should be no ensemble update (and also no inflation!). More specifically, the mode of l1's should be adjusted towards 1 as a function of $$ I - K H $$ ("prior's weight"). PS: why do we leave the prior mode below 1 at all? Because it sets up "tension" (negative feedback) in the inflation cycle: the prior pulls downwards, while the likelihood tends to pull upwards.

  • Reason 2: Boosting the inflation prior's certainty from N to xN*N. The aim is to take advantage of the fact that the ensemble may not have quite as much sampling error as a fully stochastic sample, as illustrated in section 2.1 of bib.raanes2019adaptive.

  • Its damping effect is similar to work done by J. Anderson.

The tuning is controlled by:

  • xN=1: is fully agnostic, i.e. assumes the ensemble is generated from a highly chaotic or stochastic model.
  • xN>1: increases the certainty of the hyper-prior, which is appropriate for more linear and deterministic systems.
  • xN<1: yields a more (than 'fully') agnostic hyper-prior, as if N were smaller than it truly is.
  • xN<=0 is not meaningful.
def zeta_a(eN, cL, w):
832def zeta_a(eN, cL, w):
833    """EnKF-N inflation estimation via w.
834
835    Returns `zeta_a = (N-1)/pre-inflation^2`.
836
837    Using this inside an iterative minimization as in the
838    `dapper.da_methods.variational.iEnKS` effectively blends
839    the distinction between the primal and dual EnKF-N.
840    """
841    N = len(w)
842    N1 = N - 1
843    za = N1 * cL / (eN + w @ w)
844    return za

EnKF-N inflation estimation via w.

Returns zeta_a = (N-1)/pre-inflation^2.

Using this inside an iterative minimization as in the dapper.da_methods.variational.iEnKS effectively blends the distinction between the primal and dual EnKF-N.

@ens_method
class EnKF_N:
 847@ens_method
 848class EnKF_N:
 849    """Finite-size EnKF (EnKF-N).
 850
 851    Refs: `bib.bocquet2011ensemble`, `bib.bocquet2015expanding`
 852
 853    This implementation is pedagogical, prioritizing the "dual" form.
 854    In consequence, the efficiency of the "primal" form suffers a bit.
 855    The primal form is included for completeness and to demonstrate equivalence.
 856    In `dapper.da_methods.variational.iEnKS`, however,
 857    the primal form is preferred because it
 858    already does optimization for w (as treatment for nonlinear models).
 859
 860    `infl` should be unnecessary (assuming no model error, or that Q is correct).
 861
 862    `Hess`: use non-approx Hessian for ensemble transform matrix?
 863
 864    `g` is the nullity of A (state anomalies's), ie. g=max(1,N-Nx),
 865    compensating for the redundancy in the space of w.
 866    But we have made it an input argument instead, with default 0,
 867    because mode-finding (of p(x) via the dual) completely ignores this redundancy,
 868    and the mode gets (undesireably) modified by g.
 869
 870    `xN` allows tuning the hyper-prior for the inflation.
 871    Usually, I just try setting it to 1 (default), or 2.
 872    Further description in hyperprior_coeffs().
 873    """
 874
 875    N: int
 876    dual: bool = False
 877    Hess: bool = False
 878    xN: float = 1.0
 879    g: int = 0
 880
 881    def assimilate(self, HMM, xx, yy):
 882        N, N1 = self.N, self.N - 1
 883
 884        # Init
 885        E = HMM.X0.sample(N)
 886        self.stats.assess(0, E=E)
 887
 888        # Cycle
 889        for k, ko, t, dt in progbar(HMM.tseq.ticker):
 890            # Forecast
 891            E = HMM.Dyn(E, t - dt, dt)
 892            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
 893
 894            # Analysis
 895            if ko is not None:
 896                self.stats.assess(k, ko, "f", E=E)
 897                Eo = HMM.Obs(ko)(E)
 898                y = yy[ko]
 899
 900                mu = np.mean(E, 0)
 901                A = E - mu
 902
 903                xo = np.mean(Eo, 0)
 904                Y = Eo - xo
 905                dy = y - xo
 906
 907                R = HMM.Obs(ko).noise.C
 908                V, s, UT = svd0(Y @ R.sym_sqrt_inv.T)
 909                du = UT @ (dy @ R.sym_sqrt_inv.T)
 910
 911                def dgn_N(l1):
 912                    return pad0((l1 * s) ** 2, N) + N1
 913
 914                # Adjust hyper-prior
 915                # xN_ = noise_level(self.xN, self.stats, HMM.tseq, N1, ko, A,
 916                #                   locals().get('A_old', None))
 917                eN, cL = hyperprior_coeffs(s, N, self.xN, self.g)
 918
 919                if self.dual:
 920                    # Make dual cost function (in terms of l1)
 921                    def pad_rk(arr):
 922                        return pad0(arr, min(N, len(y)))
 923
 924                    def dgn_rk(l1):
 925                        return pad_rk((l1 * s) ** 2) + N1
 926
 927                    def J(l1):
 928                        val = (
 929                            np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2)
 930                        )
 931                        return val
 932
 933                    # Derivatives (not required with minimize_scalar):
 934                    def Jp(l1):
 935                        val = (
 936                            -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2)
 937                            + -2 * eN / l1**3
 938                            + 2 * cL / l1
 939                        )
 940                        return val
 941
 942                    def Jpp(l1):
 943                        val = (
 944                            8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3)
 945                            + 6 * eN / l1**4
 946                            + -2 * cL / l1**2
 947                        )
 948                        return val
 949
 950                    # Find inflation factor (optimize)
 951                    l1 = Newton_m(Jp, Jpp, 1.0)
 952                    # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0)
 953                    # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2),
 954                    #                      tol=1e-4).x
 955
 956                else:
 957                    # Primal form, in a fully linearized version.
 958                    def za(w):
 959                        return zeta_a(eN, cL, w)
 960
 961                    def J(w):
 962                        return 0.5 * np.sum(
 963                            ((dy - w @ Y) @ R.sym_sqrt_inv.T) ** 2
 964                        ) + 0.5 * N1 * cL * np.log(eN + w @ w)
 965
 966                    # Derivatives (not required with fmin_bfgs):
 967                    def Jp(w):
 968                        return -Y @ R.inv @ (dy - w @ Y) + w * za(w)
 969
 970                    # Jpp   = lambda w:  Y@R.inv@Y.T + \
 971                    #     za(w)*(eye(N) - 2*np.outer(w,w)/(eN + w@w))
 972                    # Approx: no radial-angular cross-deriv:
 973                    # Jpp   = lambda w:  Y@R.inv@Y.T + za(w)*eye(N)
 974
 975                    def nvrs(w):
 976                        # inverse of Jpp-approx
 977                        return (V * (pad0(s**2, N) + za(w)) ** -1.0) @ V.T
 978
 979                    # Find w (optimize)
 980                    wa = Newton_m(Jp, nvrs, zeros(N), is_inverted=True)
 981                    # wa   = Newton_m(Jp,Jpp ,zeros(N))
 982                    # wa   = fmin_bfgs(J,zeros(N),Jp,disp=0)
 983                    l1 = sqrt(N1 / za(wa))
 984
 985                # Uncomment to revert to ETKF
 986                # l1 = 1.0
 987
 988                # Explicitly inflate prior
 989                # => formulae look different from `bib.bocquet2015expanding`.
 990                A *= l1
 991                Y *= l1
 992
 993                # Compute sqrt update
 994                Pw = (V * dgn_N(l1) ** (-1.0)) @ V.T
 995                w = dy @ R.inv @ Y.T @ Pw
 996                # For the anomalies:
 997                if not self.Hess:
 998                    # Regular ETKF (i.e. sym sqrt) update (with inflation)
 999                    T = (V * dgn_N(l1) ** (-0.5)) @ V.T * sqrt(N1)
1000                    # = (Y@R.inv@Y.T/N1 + eye(N))**(-0.5)
1001                else:
1002                    # Also include angular-radial co-dependence.
1003                    # Note: denominator not squared coz
1004                    # unlike `bib.bocquet2015expanding` we have inflated Y.
1005                    Hw = (
1006                        Y @ R.inv @ Y.T / N1
1007                        + eye(N)
1008                        - 2 * np.outer(w, w) / (eN + w @ w)
1009                    )
1010                    T = funm_psd(Hw, lambda x: x**-0.5)  # is there a sqrtm Woodbury?
1011
1012                E = mu + w @ A + T @ A
1013                E = post_process(E, self.infl, self.rot)
1014
1015                self.stats.infl[ko] = l1
1016                self.stats.trHK[ko] = (
1017                    ((l1 * s) ** 2 + N1) ** (-1.0) * s**2
1018                ).sum() / len(y)
1019
1020            self.stats.assess(k, ko, E=E)

Finite-size EnKF (EnKF-N).

Refs: bib.bocquet2011ensemble, bib.bocquet2015expanding

This implementation is pedagogical, prioritizing the "dual" form. In consequence, the efficiency of the "primal" form suffers a bit. The primal form is included for completeness and to demonstrate equivalence. In dapper.da_methods.variational.iEnKS, however, the primal form is preferred because it already does optimization for w (as treatment for nonlinear models).

infl should be unnecessary (assuming no model error, or that Q is correct).

Hess: use non-approx Hessian for ensemble transform matrix?

g is the nullity of A (state anomalies's), ie. g=max(1,N-Nx), compensating for the redundancy in the space of w. But we have made it an input argument instead, with default 0, because mode-finding (of p(x) via the dual) completely ignores this redundancy, and the mode gets (undesireably) modified by g.

xN allows tuning the hyper-prior for the inflation. Usually, I just try setting it to 1 (default), or 2. Further description in hyperprior_coeffs().

EnKF_N( N: int, dual: bool = False, Hess: bool = False, xN: float = 1.0, g: int = 0, infl: float = 1.0, rot: bool = False, fnoise_treatm: str = 'Stoch')
N: int
dual: bool = False
Hess: bool = False
xN: float = 1.0
g: int = 0
def assimilate(self, HMM, xx, yy):
 881    def assimilate(self, HMM, xx, yy):
 882        N, N1 = self.N, self.N - 1
 883
 884        # Init
 885        E = HMM.X0.sample(N)
 886        self.stats.assess(0, E=E)
 887
 888        # Cycle
 889        for k, ko, t, dt in progbar(HMM.tseq.ticker):
 890            # Forecast
 891            E = HMM.Dyn(E, t - dt, dt)
 892            E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm)
 893
 894            # Analysis
 895            if ko is not None:
 896                self.stats.assess(k, ko, "f", E=E)
 897                Eo = HMM.Obs(ko)(E)
 898                y = yy[ko]
 899
 900                mu = np.mean(E, 0)
 901                A = E - mu
 902
 903                xo = np.mean(Eo, 0)
 904                Y = Eo - xo
 905                dy = y - xo
 906
 907                R = HMM.Obs(ko).noise.C
 908                V, s, UT = svd0(Y @ R.sym_sqrt_inv.T)
 909                du = UT @ (dy @ R.sym_sqrt_inv.T)
 910
 911                def dgn_N(l1):
 912                    return pad0((l1 * s) ** 2, N) + N1
 913
 914                # Adjust hyper-prior
 915                # xN_ = noise_level(self.xN, self.stats, HMM.tseq, N1, ko, A,
 916                #                   locals().get('A_old', None))
 917                eN, cL = hyperprior_coeffs(s, N, self.xN, self.g)
 918
 919                if self.dual:
 920                    # Make dual cost function (in terms of l1)
 921                    def pad_rk(arr):
 922                        return pad0(arr, min(N, len(y)))
 923
 924                    def dgn_rk(l1):
 925                        return pad_rk((l1 * s) ** 2) + N1
 926
 927                    def J(l1):
 928                        val = (
 929                            np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2)
 930                        )
 931                        return val
 932
 933                    # Derivatives (not required with minimize_scalar):
 934                    def Jp(l1):
 935                        val = (
 936                            -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2)
 937                            + -2 * eN / l1**3
 938                            + 2 * cL / l1
 939                        )
 940                        return val
 941
 942                    def Jpp(l1):
 943                        val = (
 944                            8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3)
 945                            + 6 * eN / l1**4
 946                            + -2 * cL / l1**2
 947                        )
 948                        return val
 949
 950                    # Find inflation factor (optimize)
 951                    l1 = Newton_m(Jp, Jpp, 1.0)
 952                    # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0)
 953                    # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2),
 954                    #                      tol=1e-4).x
 955
 956                else:
 957                    # Primal form, in a fully linearized version.
 958                    def za(w):
 959                        return zeta_a(eN, cL, w)
 960
 961                    def J(w):
 962                        return 0.5 * np.sum(
 963                            ((dy - w @ Y) @ R.sym_sqrt_inv.T) ** 2
 964                        ) + 0.5 * N1 * cL * np.log(eN + w @ w)
 965
 966                    # Derivatives (not required with fmin_bfgs):
 967                    def Jp(w):
 968                        return -Y @ R.inv @ (dy - w @ Y) + w * za(w)
 969
 970                    # Jpp   = lambda w:  Y@R.inv@Y.T + \
 971                    #     za(w)*(eye(N) - 2*np.outer(w,w)/(eN + w@w))
 972                    # Approx: no radial-angular cross-deriv:
 973                    # Jpp   = lambda w:  Y@R.inv@Y.T + za(w)*eye(N)
 974
 975                    def nvrs(w):
 976                        # inverse of Jpp-approx
 977                        return (V * (pad0(s**2, N) + za(w)) ** -1.0) @ V.T
 978
 979                    # Find w (optimize)
 980                    wa = Newton_m(Jp, nvrs, zeros(N), is_inverted=True)
 981                    # wa   = Newton_m(Jp,Jpp ,zeros(N))
 982                    # wa   = fmin_bfgs(J,zeros(N),Jp,disp=0)
 983                    l1 = sqrt(N1 / za(wa))
 984
 985                # Uncomment to revert to ETKF
 986                # l1 = 1.0
 987
 988                # Explicitly inflate prior
 989                # => formulae look different from `bib.bocquet2015expanding`.
 990                A *= l1
 991                Y *= l1
 992
 993                # Compute sqrt update
 994                Pw = (V * dgn_N(l1) ** (-1.0)) @ V.T
 995                w = dy @ R.inv @ Y.T @ Pw
 996                # For the anomalies:
 997                if not self.Hess:
 998                    # Regular ETKF (i.e. sym sqrt) update (with inflation)
 999                    T = (V * dgn_N(l1) ** (-0.5)) @ V.T * sqrt(N1)
1000                    # = (Y@R.inv@Y.T/N1 + eye(N))**(-0.5)
1001                else:
1002                    # Also include angular-radial co-dependence.
1003                    # Note: denominator not squared coz
1004                    # unlike `bib.bocquet2015expanding` we have inflated Y.
1005                    Hw = (
1006                        Y @ R.inv @ Y.T / N1
1007                        + eye(N)
1008                        - 2 * np.outer(w, w) / (eN + w @ w)
1009                    )
1010                    T = funm_psd(Hw, lambda x: x**-0.5)  # is there a sqrtm Woodbury?
1011
1012                E = mu + w @ A + T @ A
1013                E = post_process(E, self.infl, self.rot)
1014
1015                self.stats.infl[ko] = l1
1016                self.stats.trHK[ko] = (
1017                    ((l1 * s) ** 2 + N1) ** (-1.0) * s**2
1018                ).sum() / len(y)
1019
1020            self.stats.assess(k, ko, E=E)
infl: float = 1.0
rot: bool = False
fnoise_treatm: str = 'Stoch'
def stat(self, name, value):
138        def stat(self, name, value):
139            dapper.stats.register_stat(self.stats, name, value)
da_method = 'EnKF_N'