dapper.da_methods.ensemble
The EnKF and other ensemble-based methods.
1"""The EnKF and other ensemble-based methods.""" 2 3import numpy as np 4import scipy.linalg as sla 5from numpy import diag, eye, sqrt, zeros 6 7import dapper.tools.multiproc as multiproc 8from dapper.stats import center, inflate_ens, mean0 9from dapper.tools.linalg import mldiv, mrdiv, pad0, svd0, svdi, tinv, tsvd 10from dapper.tools.matrices import funm_psd, genOG_1 11from dapper.tools.progressbar import progbar 12from dapper.tools.randvars import GaussRV 13from dapper.tools.seeding import rng 14 15from . import da_method 16 17 18@da_method 19class ens_method: 20 """Declare default ensemble arguments.""" 21 22 infl: float = 1.0 23 rot: bool = False 24 fnoise_treatm: str = "Stoch" 25 26 27@ens_method 28class EnKF: 29 """The ensemble Kalman filter. 30 31 Refs: `bib.evensen2009ensemble`. 32 """ 33 34 upd_a: str 35 N: int 36 37 def assimilate(self, HMM, xx, yy): 38 # Init 39 E = HMM.X0.sample(self.N) 40 self.stats.assess(0, E=E) 41 42 # Cycle 43 for k, ko, t, dt in progbar(HMM.tseq.ticker): 44 E = HMM.Dyn(E, t - dt, dt) 45 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 46 47 # Analysis update 48 if ko is not None: 49 self.stats.assess(k, ko, "f", E=E) 50 E = EnKF_analysis( 51 E, 52 HMM.Obs(ko)(E), 53 HMM.Obs(ko).noise, 54 yy[ko], 55 self.upd_a, 56 self.stats, 57 ko, 58 ) 59 E = post_process(E, self.infl, self.rot) 60 61 self.stats.assess(k, ko, E=E) 62 63 64def EnKF_analysis(E, Eo, hnoise, y, upd_a, stats=None, ko=None): 65 """Perform the EnKF analysis update. 66 67 This implementation includes several flavours and forms, 68 specified by `upd_a`. 69 70 Main references: `bib.sakov2008deterministic`, 71 `bib.sakov2008implications`, `bib.hoteit2015mitigating` 72 """ 73 R = hnoise.C # Obs noise cov 74 N, Nx = E.shape # Dimensionality 75 N1 = N - 1 # Ens size - 1 76 77 mu = np.mean(E, 0) # Ens mean 78 A = E - mu # Ens anomalies 79 80 xo = np.mean(Eo, 0) # Obs ens mean 81 Y = Eo - xo # Obs ens anomalies 82 dy = y - xo # Mean "innovation" 83 84 if "PertObs" in upd_a: 85 # Uses classic, perturbed observations (Burgers'98) 86 C = Y.T @ Y + R.full * N1 87 D = mean0(hnoise.sample(N)) 88 YC = mrdiv(Y, C) 89 KG = A.T @ YC 90 HK = Y.T @ YC 91 dE = (KG @ (y - D - Eo).T).T 92 E = E + dE 93 94 elif "Sqrt" in upd_a: 95 # Uses a symmetric square root (ETKF) 96 # to deterministically transform the ensemble. 97 98 # The various versions below differ only numerically. 99 # EVD is default, but for large N use SVD version. 100 if upd_a == "Sqrt" and N > Nx: 101 upd_a = "Sqrt svd" 102 103 if "explicit" in upd_a: 104 # Not recommended due to numerical costs and instability. 105 # Implementation using inv (in ens space) 106 Pw = sla.inv(Y @ R.inv @ Y.T + N1 * eye(N)) 107 T = sla.sqrtm(Pw) * sqrt(N1) 108 HK = R.inv @ Y.T @ Pw @ Y 109 # KG = R.inv @ Y.T @ Pw @ A 110 elif "svd" in upd_a: 111 # Implementation using svd of Y R^{-1/2}. 112 V, s, _ = svd0(Y @ R.sym_sqrt_inv.T) 113 d = pad0(s**2, N) + N1 114 Pw = (V * d ** (-1.0)) @ V.T 115 T = (V * d ** (-0.5)) @ V.T * sqrt(N1) 116 # docs/snippets/trHK.jpg 117 trHK = np.sum((s**2 + N1) ** (-1.0) * s**2) 118 elif "sS" in upd_a: 119 # Same as 'svd', but with slightly different notation 120 # (sometimes used by Sakov) using the normalization sqrt(N1). 121 S = Y @ R.sym_sqrt_inv.T / sqrt(N1) 122 V, s, _ = svd0(S) 123 d = pad0(s**2, N) + 1 124 Pw = (V * d ** (-1.0)) @ V.T / N1 # = G/(N1) 125 T = (V * d ** (-0.5)) @ V.T 126 # docs/snippets/trHK.jpg 127 trHK = np.sum((s**2 + 1) ** (-1.0) * s**2) 128 else: # 'eig' in upd_a: 129 # Implementation using eig. val. decomp. 130 d, V = sla.eigh(Y @ R.inv @ Y.T + N1 * eye(N)) 131 T = V @ diag(d ** (-0.5)) @ V.T * sqrt(N1) 132 Pw = V @ diag(d ** (-1.0)) @ V.T 133 HK = R.inv @ Y.T @ (V @ diag(d ** (-1)) @ V.T) @ Y 134 w = dy @ R.inv @ Y.T @ Pw 135 E = mu + w @ A + T @ A 136 137 elif "Serial" in upd_a: 138 # Observations assimilated one-at-a-time: 139 inds = serial_inds(upd_a, y, R, A) 140 # Requires de-correlation: 141 dy = dy @ R.sym_sqrt_inv.T 142 Y = Y @ R.sym_sqrt_inv.T 143 # Enhancement in the nonlinear case: 144 # re-compute Y each scalar obs assim. 145 # But: little benefit, model costly (?), 146 # updates cannot be accumulated on S and T. 147 148 if any(x in upd_a for x in ["Stoch", "ESOPS", "Var1"]): 149 # More details: Misc/Serial_ESOPS.py. 150 for i, j in enumerate(inds): 151 # Perturbation creation 152 if "ESOPS" in upd_a: 153 # "2nd-O exact perturbation sampling" 154 if i == 0: 155 # Init -- increase nullspace by 1 156 V, s, UT = svd0(A) 157 s[N - 2 :] = 0 158 A = svdi(V, s, UT) 159 v = V[:, N - 2] 160 else: 161 # Orthogonalize v wrt. the new A 162 # 163 # v = Zj - Yj (from paper) requires Y==HX. 164 # Instead: mult` should be c*ones(Nx) so we can 165 # project v into ker(A) such that v@A is null. 166 mult = (v @ A) / (Yj @ A) # noqa 167 v = v - mult[0] * Yj # noqa 168 v /= sqrt(v @ v) 169 Zj = v * sqrt(N1) # Standardized perturbation along v 170 Zj *= np.sign(rng.standard_normal() - 0.5) # Random sign 171 else: 172 # The usual stochastic perturbations. 173 Zj = mean0(rng.standard_normal(N)) # Un-coloured noise 174 if "Var1" in upd_a: 175 Zj *= sqrt(N / (Zj @ Zj)) 176 177 # Select j-th obs 178 Yj = Y[:, j] # [j] obs anomalies 179 dyj = dy[j] # [j] innov mean 180 DYj = Zj - Yj # [j] innov anomalies 181 DYj = DYj[:, None] # Make 2d vertical 182 183 # Kalman gain computation 184 C = Yj @ Yj + N1 # Total obs cov 185 KGx = Yj @ A / C # KG to update state 186 KGy = Yj @ Y / C # KG to update obs 187 188 # Updates 189 A += DYj * KGx 190 mu += dyj * KGx 191 Y += DYj * KGy 192 dy -= dyj * KGy 193 E = mu + A 194 else: 195 # "Potter scheme", "EnSRF" 196 # - EAKF's two-stage "update-regress" form yields 197 # the same *ensemble* as this. 198 # - The form below may be derived as "serial ETKF", 199 # but does not yield the same 200 # ensemble as 'Sqrt' (which processes obs as a batch) 201 # -- only the same mean/cov. 202 T = eye(N) 203 for j in inds: 204 Yj = Y[:, j] 205 C = Yj @ Yj + N1 206 Tj = np.outer(Yj, Yj / (C + sqrt(N1 * C))) 207 T -= Tj @ T 208 Y -= Tj @ Y 209 w = dy @ Y.T @ T / N1 210 E = mu + w @ A + T @ A 211 212 elif "DEnKF" == upd_a: 213 # Uses "Deterministic EnKF" (sakov'08) 214 C = Y.T @ Y + R.full * N1 215 YC = mrdiv(Y, C) 216 KG = A.T @ YC 217 HK = Y.T @ YC 218 E = E + KG @ dy - 0.5 * (KG @ Y.T).T 219 220 else: 221 raise KeyError("No analysis update method found: '" + upd_a + "'.") 222 223 # Diagnostic: relative influence of observations 224 if stats is not None: 225 if "trHK" in locals(): 226 stats.trHK[ko] = trHK / hnoise.M 227 elif "HK" in locals(): 228 stats.trHK[ko] = HK.trace() / hnoise.M 229 230 return E 231 232 233def post_process(E, infl, rot): 234 """Inflate, Rotate. 235 236 To avoid recomputing/recombining anomalies, 237 this should have been inside `EnKF_analysis` 238 239 But it is kept as a separate function 240 241 - for readability; 242 - to avoid inflating/rotationg smoothed states (for the `EnKS`). 243 """ 244 do_infl = infl != 1.0 and infl != "-N" 245 246 if do_infl or rot: 247 A, mu = center(E) 248 N, Nx = E.shape 249 T = eye(N) 250 251 if do_infl: 252 T = infl * T 253 254 if rot: 255 T = genOG_1(N, rot) @ T 256 257 E = mu + T @ A 258 return E 259 260 261def add_noise(E, dt, noise, method): 262 """Treatment of additive noise for ensembles. 263 264 Refs: `bib.raanes2014ext` 265 """ 266 if noise.C == 0: 267 return E 268 269 N, Nx = E.shape 270 A, mu = center(E) 271 Q12 = noise.C.Left 272 Q = noise.C.full 273 274 def sqrt_core(): 275 T = np.nan # cause error if used 276 Qa12 = np.nan # cause error if used 277 A2 = A.copy() # Instead of using (the implicitly nonlocal) A, 278 # which changes A outside as well. NB: This is a bug in Datum! 279 if N <= Nx: 280 Ainv = tinv(A2.T) 281 Qa12 = Ainv @ Q12 282 T = funm_psd(eye(N) + dt * (N - 1) * (Qa12 @ Qa12.T), sqrt) 283 A2 = T @ A2 284 else: # "Left-multiplying" form 285 P = A2.T @ A2 / (N - 1) 286 L = funm_psd(eye(Nx) + dt * mrdiv(Q, P), sqrt) 287 A2 = A2 @ L.T 288 E = mu + A2 289 return E, T, Qa12 290 291 if method == "Stoch": 292 # In-place addition works (also) for empty [] noise sample. 293 E += sqrt(dt) * noise.sample(N) 294 295 elif method == "none": 296 pass 297 298 elif method == "Mult-1": 299 varE = np.var(E, axis=0, ddof=1).sum() 300 ratio = (varE + dt * diag(Q).sum()) / varE 301 E = mu + sqrt(ratio) * A 302 E = svdi(*tsvd(E, 0.999)) # Explained in Datum 303 304 elif method == "Mult-M": 305 varE = np.var(E, axis=0) 306 ratios = sqrt((varE + dt * diag(Q)) / varE) 307 E = mu + A * ratios 308 E = svdi(*tsvd(E, 0.999)) # Explained in Datum 309 310 elif method == "Sqrt-Core": 311 E = sqrt_core()[0] 312 313 elif method == "Sqrt-Mult-1": 314 varE0 = np.var(E, axis=0, ddof=1).sum() 315 varE2 = varE0 + dt * diag(Q).sum() 316 E, _, Qa12 = sqrt_core() 317 if N <= Nx: 318 A, mu = center(E) 319 varE1 = np.var(E, axis=0, ddof=1).sum() 320 ratio = varE2 / varE1 321 E = mu + sqrt(ratio) * A 322 E = svdi(*tsvd(E, 0.999)) # Explained in Datum 323 324 elif method == "Sqrt-Add-Z": 325 E, _, Qa12 = sqrt_core() 326 if N <= Nx: 327 Z = Q12 - A.T @ Qa12 328 E += sqrt(dt) * (Z @ rng.standard_normal((Z.shape[1], N))).T 329 330 elif method == "Sqrt-Dep": 331 E, T, Qa12 = sqrt_core() 332 if N <= Nx: 333 # Q_hat12: reuse svd for both inversion and projection. 334 Q_hat12 = A.T @ Qa12 335 U, s, VT = tsvd(Q_hat12, 0.99) 336 Q_hat12_inv = (VT.T * s ** (-1.0)) @ U.T 337 Q_hat12_proj = VT.T @ VT 338 rQ = Q12.shape[1] 339 # Calc D_til 340 Z = Q12 - Q_hat12 341 D_hat = A.T @ (T - eye(N)) 342 Xi_hat = Q_hat12_inv @ D_hat 343 Xi_til = (eye(rQ) - Q_hat12_proj) @ rng.standard_normal((rQ, N)) 344 D_til = Z @ (Xi_hat + sqrt(dt) * Xi_til) 345 E += D_til.T 346 347 else: 348 raise KeyError("No such method") 349 350 return E 351 352 353@ens_method 354class EnKS: 355 """The ensemble Kalman smoother. 356 357 Refs: `bib.evensen2009ensemble` 358 359 The only difference to the EnKF 360 is the management of the lag and the reshapings. 361 """ 362 363 upd_a: str 364 N: int 365 Lag: int 366 367 # Reshapings used in smoothers to go to/from 368 # 3D arrays, where the 0th axis is the Lag index. 369 def reshape_to(self, E): 370 K, N, Nx = E.shape 371 return E.transpose([1, 0, 2]).reshape((N, K * Nx)) 372 373 def reshape_fr(self, E, Nx): 374 N, Km = E.shape 375 K = Km // Nx 376 return E.reshape((N, K, Nx)).transpose([1, 0, 2]) 377 378 def assimilate(self, HMM, xx, yy): 379 # Inefficient version, storing full time series ensemble. 380 # See iEnKS for a "rolling" version. 381 E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M)) 382 E[0] = HMM.X0.sample(self.N) 383 384 for k, ko, t, dt in progbar(HMM.tseq.ticker): 385 E[k] = HMM.Dyn(E[k - 1], t - dt, dt) 386 E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm) 387 388 if ko is not None: 389 self.stats.assess(k, ko, "f", E=E[k]) 390 391 Eo = HMM.Obs(ko)(E[k]) 392 y = yy[ko] 393 394 # Inds within Lag 395 kk = range(max(0, k - self.Lag * HMM.tseq.dko), k + 1) 396 397 EE = E[kk] 398 399 EE = self.reshape_to(EE) 400 EE = EnKF_analysis( 401 EE, Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko 402 ) 403 E[kk] = self.reshape_fr(EE, HMM.Dyn.M) 404 E[k] = post_process(E[k], self.infl, self.rot) 405 self.stats.assess(k, ko, "a", E=E[k]) 406 407 for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"): 408 self.stats.assess(k, ko, "u", E=E[k]) 409 if ko is not None: 410 self.stats.assess(k, ko, "s", E=E[k]) 411 412 413@ens_method 414class EnRTS: 415 """EnRTS (Rauch-Tung-Striebel) smoother. 416 417 Refs: `bib.raanes2016thesis` 418 """ 419 420 upd_a: str 421 N: int 422 DeCorr: float 423 424 def assimilate(self, HMM, xx, yy): 425 E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M)) 426 Ef = E.copy() 427 E[0] = HMM.X0.sample(self.N) 428 429 # Forward pass 430 for k, ko, t, dt in progbar(HMM.tseq.ticker): 431 E[k] = HMM.Dyn(E[k - 1], t - dt, dt) 432 E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm) 433 Ef[k] = E[k] 434 435 if ko is not None: 436 self.stats.assess(k, ko, "f", E=E[k]) 437 Eo = HMM.Obs(ko)(E[k]) 438 y = yy[ko] 439 E[k] = EnKF_analysis( 440 E[k], Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko 441 ) 442 E[k] = post_process(E[k], self.infl, self.rot) 443 self.stats.assess(k, ko, "a", E=E[k]) 444 445 # Backward pass 446 for k in progbar(range(HMM.tseq.K)[::-1]): 447 A = center(E[k])[0] 448 Af = center(Ef[k + 1])[0] 449 450 J = tinv(Af) @ A 451 J *= self.DeCorr 452 453 E[k] += (E[k + 1] - Ef[k + 1]) @ J 454 455 for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"): 456 self.stats.assess(k, ko, "u", E=E[k]) 457 if ko is not None: 458 self.stats.assess(k, ko, "s", E=E[k]) 459 460 461def serial_inds(upd_a, y, cvR, A): 462 """Get the indices used for serial updating. 463 464 - Default: random ordering 465 - if "mono" in `upd_a`: `1, 2, ..., len(y)` 466 - if "sorted" in `upd_a`: sort by variance 467 """ 468 if "mono" in upd_a: 469 # Not robust? 470 inds = np.arange(len(y)) 471 elif "sorted" in upd_a: 472 N = len(A) 473 dC = cvR.diag 474 if np.all(dC == dC[0]): 475 # Sort y by P 476 dC = np.sum(A * A, 0) / (N - 1) 477 inds = np.argsort(dC) 478 else: # Default: random ordering 479 inds = rng.permutation(len(y)) 480 return inds 481 482 483@ens_method 484class SL_EAKF: 485 """Serial, covariance-localized EAKF. 486 487 Refs: `bib.karspeck2007experimental`. 488 489 In contrast with LETKF, this iterates over the observations rather 490 than over the state (batches). 491 492 Used without localization, this should be equivalent (full ensemble equality) 493 to the `EnKF` with `upd_a='Serial'`. 494 """ 495 496 N: int 497 loc_rad: float 498 taper: str = "GC" 499 ordr: str = "rand" 500 501 def assimilate(self, HMM, xx, yy): 502 N1 = self.N - 1 503 504 E = HMM.X0.sample(self.N) 505 self.stats.assess(0, E=E) 506 507 for k, ko, t, dt in progbar(HMM.tseq.ticker): 508 E = HMM.Dyn(E, t - dt, dt) 509 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 510 511 if ko is not None: 512 self.stats.assess(k, ko, "f", E=E) 513 Obs = HMM.Obs(ko) 514 R = Obs.noise 515 y = yy[ko] 516 inds = serial_inds(self.ordr, y, R, center(E)[0]) 517 Rm12 = Obs.noise.C.sym_sqrt_inv 518 519 state_taperer = Obs.localizer(self.loc_rad, "y2x", self.taper) 520 for j in inds: 521 # Prep: 522 # ------------------------------------------------------ 523 Eo = Obs(E) 524 xo = np.mean(Eo, 0) 525 Y = Eo - xo 526 mu = np.mean(E, 0) 527 A = E - mu 528 # Update j-th component of observed ensemble: 529 # ------------------------------------------------------ 530 Y_j = Rm12[j, :] @ Y.T 531 dy_j = Rm12[j, :] @ (y - xo) 532 # Prior var * N1: 533 sig2_j = Y_j @ Y_j 534 if sig2_j < 1e-9: 535 continue 536 # Update (below, we drop the locality subscript: _j) 537 sig2_u = 1 / (1 / sig2_j + 1 / N1) # Postr. var * N1 538 alpha = (N1 / (N1 + sig2_j)) ** (0.5) # Update contraction factor 539 dy2 = sig2_u * dy_j / N1 # Mean update 540 Y2 = alpha * Y_j # Anomaly update 541 # Update state (regress update from obs space, using localization) 542 # ------------------------------------------------------ 543 ii, tapering = state_taperer(j) 544 # ii, tapering = ..., 1 # cancel localization 545 if len(ii) == 0: 546 continue 547 Xi = A[:, ii] * tapering 548 Regression = Xi.T @ Y_j / np.sum(Y_j**2) 549 mu[ii] += Regression * dy2 550 A[:, ii] += np.outer(Y2 - Y_j, Regression) 551 E = mu + A 552 553 E = post_process(E, self.infl, self.rot) 554 555 self.stats.assess(k, ko, E=E) 556 557 558def local_analyses(E, Eo, R, y, state_batches, obs_taperer, mp=map, xN=None, g=0): 559 """Perform local analysis update for the LETKF.""" 560 561 def local_analysis(ii): 562 """Perform analysis, for state index batch `ii`.""" 563 # Locate local domain 564 oBatch, tapering = obs_taperer(ii) 565 Eii = E[:, ii] 566 567 # No update 568 if len(oBatch) == 0: 569 return Eii, 1 570 571 # Localize 572 Yl = Y[:, oBatch] 573 dyl = dy[oBatch] 574 tpr = sqrt(tapering) 575 576 # Adaptive inflation estimation. 577 # NB: Localisation is not 100% compatible with the EnKF-N, since 578 # - After localisation there is much less need for inflation. 579 # - Tapered values (Y, dy) are too neat 580 # (the EnKF-N expects a normal amount of sampling error). 581 # One fix is to tune xN (maybe set it to 2 or 3). Thanks to adaptivity, 582 # this should still be easier than tuning the inflation factor. 583 infl1 = 1 if xN is None else sqrt(N1 / effective_N(Yl, dyl, xN, g)) 584 Eii, Yl = inflate_ens(Eii, infl1), Yl * infl1 585 # Since R^{-1/2} was already applied (necesry for effective_N), now use R=Id. 586 # TODO 4: the cost of re-init this R might not always be insignificant. 587 R = GaussRV(C=1, M=len(dyl)) 588 589 # Update 590 Eii = EnKF_analysis(Eii, Yl * tpr, R, dyl * tpr, "Sqrt") 591 592 return Eii, infl1 593 594 # Prepare analysis 595 N1 = len(E) - 1 596 Y, xo = center(Eo) 597 # Transform obs space 598 Y = Y @ R.sym_sqrt_inv.T 599 dy = (y - xo) @ R.sym_sqrt_inv.T 600 601 # Run 602 result = mp(local_analysis, state_batches) 603 604 # Assign 605 E_batches, infl1 = zip(*result) 606 # TODO: this overwrites E, possibly unbeknownst to caller 607 for ii, Eii in zip(state_batches, E_batches): 608 E[:, ii] = Eii 609 610 return E, dict(ad_inf=sqrt(np.mean(np.array(infl1) ** 2))) 611 612 613@ens_method 614class LETKF: 615 """Same as EnKF (Sqrt), but with localization. 616 617 Refs: `bib.hunt2007efficient`. 618 619 NB: Multiproc. yields slow-down for `dapper.mods.Lorenz96`, 620 even with `batch_size=(1,)`. But for `dapper.mods.QG` 621 (`batch_size=(2,2)` or less) it is quicker. 622 623 NB: If `len(ii)` is small, analysis may be slowed-down with '-N' infl. 624 """ 625 626 N: int 627 loc_rad: float 628 taper: str = "GC" 629 xN: float = None 630 g: int = 0 631 mp: bool = False 632 633 def assimilate(self, HMM, xx, yy): 634 E = HMM.X0.sample(self.N) 635 self.stats.assess(0, E=E) 636 self.stats.new_series("ad_inf", 1, HMM.tseq.Ko + 1) 637 638 with multiproc.Pool(self.mp) as pool: 639 for k, ko, t, dt in progbar(HMM.tseq.ticker): 640 E = HMM.Dyn(E, t - dt, dt) 641 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 642 643 if ko is not None: 644 self.stats.assess(k, ko, "f", E=E) 645 Obs = HMM.Obs(ko) 646 batch, taper = Obs.localizer(self.loc_rad, "x2y", self.taper) 647 E, stats = local_analyses( 648 E, 649 Obs(E), 650 Obs.noise.C, 651 yy[ko], 652 batch, 653 taper, 654 pool.map, 655 self.xN, 656 self.g, 657 ) 658 self.stats.write(stats, k, ko, "a") 659 E = post_process(E, self.infl, self.rot) 660 661 self.stats.assess(k, ko, E=E) 662 663 664def effective_N(YR, dyR, xN, g): 665 """Effective ensemble size N. 666 667 As measured by the finite-size EnKF-N 668 """ 669 N, Ny = YR.shape 670 N1 = N - 1 671 672 V, s, UT = svd0(YR) 673 du = UT @ dyR 674 675 eN, cL = hyperprior_coeffs(s, N, xN, g) 676 677 def pad_rk(arr): 678 return pad0(arr, min(N, Ny)) 679 680 def dgn_rk(l1): 681 return pad_rk((l1 * s) ** 2) + N1 682 683 # Make dual cost function (in terms of l1) 684 def J(l1): 685 val = np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2) 686 return val 687 688 # Derivatives (not required with minimize_scalar): 689 def Jp(l1): 690 val = ( 691 -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2) 692 + -2 * eN / l1**3 693 + 2 * cL / l1 694 ) 695 return val 696 697 def Jpp(l1): 698 val = ( 699 8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3) 700 + 6 * eN / l1**4 701 + -2 * cL / l1**2 702 ) 703 return val 704 705 # Find inflation factor (optimize) 706 l1 = Newton_m(Jp, Jpp, 1.0) 707 # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0) 708 # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2), tol=1e-4).x 709 710 za = N1 / l1**2 711 return za 712 713 714# Notes on optimizers for the 'dual' EnKF-N: 715# ---------------------------------------- 716# Using minimize_scalar: 717# - Doesn't take dJdx. Advantage: only need J 718# - method='bounded' not necessary and slower than 'brent'. 719# - bracket not necessary either... 720# Using multivariate minimization: fmin_cg, fmin_bfgs, fmin_ncg 721# - these also accept dJdx. But only fmin_bfgs approaches 722# the speed of the scalar minimizers. 723# Using scalar root-finders: 724# - brenth(dJ1, LowB, 1e2, xtol=1e-6) # Same speed as minimization 725# - newton(dJ1,1.0, fprime=dJ2, tol=1e-6) # No improvement 726# - newton(dJ1,1.0, fprime=dJ2, tol=1e-6, fprime2=dJ3) # No improvement 727# - Newton_m(dJ1,dJ2, 1.0) # Significantly faster. Also slightly better CV? 728# => Despite inconvienience of defining analytic derivatives, 729# Newton_m seems like the best option. 730# - In extreme (or just non-linear Obs.mod) cases, 731# the EnKF-N cost function may have multiple minima. 732# Then: should use more robust optimizer! 733# 734# For 'primal' 735# ---------------------------------------- 736# Similarly, Newton_m seems like the best option, 737# although alternatives are provided (commented out). 738# 739def Newton_m( 740 fun, deriv, x0, is_inverted=False, conf=1.0, xtol=1e-4, ytol=1e-7, itermax=10**2 741): 742 """Find root of `fun`. 743 744 This is a simple (and pretty fast) implementation of Newton's method. 745 """ 746 itr = 0 747 dx = np.inf 748 Jx = fun(x0) 749 750 def norm(x): 751 return sqrt(np.sum(x**2)) 752 753 while ytol < norm(Jx) and xtol < norm(dx) and itr < itermax: 754 Dx = deriv(x0) 755 if is_inverted: 756 dx = Dx @ Jx 757 elif isinstance(Dx, float): 758 dx = Jx / Dx 759 else: 760 dx = mldiv(Dx, Jx) 761 dx *= conf 762 x0 -= dx 763 Jx = fun(x0) 764 itr += 1 765 return x0 766 767 768def hyperprior_coeffs(s, N, xN=1, g=0): 769 r"""Set EnKF-N inflation hyperparams. 770 771 The EnKF-N prior may be specified by the constants: 772 773 - `eN`: Effect of unknown mean 774 - `cL`: Coeff in front of log term 775 776 These are trivial constants in the original EnKF-N, 777 but are further adjusted (corrected and tuned) for the following reasons. 778 779 - Reason 1: mode correction. 780 These parameters bridge the Jeffreys (`xN=1`) and Dirac (`xN=Inf`) hyperpriors 781 for the prior covariance, B, as discussed in `bib.bocquet2015expanding`. 782 Indeed, mode correction becomes necessary when $$ R \rightarrow \infty $$ 783 because then there should be no ensemble update (and also no inflation!). 784 More specifically, the mode of `l1`'s should be adjusted towards 1 785 as a function of $$ I - K H $$ ("prior's weight"). 786 PS: why do we leave the prior mode below 1 at all? 787 Because it sets up "tension" (negative feedback) in the inflation cycle: 788 the prior pulls downwards, while the likelihood tends to pull upwards. 789 790 - Reason 2: Boosting the inflation prior's certainty from N to xN*N. 791 The aim is to take advantage of the fact that the ensemble may not 792 have quite as much sampling error as a fully stochastic sample, 793 as illustrated in section 2.1 of `bib.raanes2019adaptive`. 794 795 - Its damping effect is similar to work done by J. Anderson. 796 797 The tuning is controlled by: 798 799 - `xN=1`: is fully agnostic, i.e. assumes the ensemble is generated 800 from a highly chaotic or stochastic model. 801 - `xN>1`: increases the certainty of the hyper-prior, 802 which is appropriate for more linear and deterministic systems. 803 - `xN<1`: yields a more (than 'fully') agnostic hyper-prior, 804 as if N were smaller than it truly is. 805 - `xN<=0` is not meaningful. 806 """ 807 N1 = N - 1 808 809 eN = (N + 1) / N 810 cL = (N + g) / N1 811 812 # Mode correction (almost) as in eqn 36 of `bib.bocquet2015expanding` 813 prior_mode = eN / cL # Mode of l1 (before correction) 814 diagonal = pad0(s**2, N) + N1 # diag of Y@R.inv@Y + N1*I 815 # (Hessian of J) 816 I_KH = np.mean(diagonal ** (-1)) * N1 # ≈ 1/(1 + HBH/R) 817 # I_KH = 1/(1 + (s**2).sum()/N1) # Scalar alternative: use tr(HBH/R). 818 mc = sqrt(prior_mode**I_KH) # Correction coeff 819 820 # Apply correction 821 eN /= mc 822 cL *= mc 823 824 # Boost by xN 825 eN *= xN 826 cL *= xN 827 828 return eN, cL 829 830 831def zeta_a(eN, cL, w): 832 """EnKF-N inflation estimation via w. 833 834 Returns `zeta_a = (N-1)/pre-inflation^2`. 835 836 Using this inside an iterative minimization as in the 837 `dapper.da_methods.variational.iEnKS` effectively blends 838 the distinction between the primal and dual EnKF-N. 839 """ 840 N = len(w) 841 N1 = N - 1 842 za = N1 * cL / (eN + w @ w) 843 return za 844 845 846@ens_method 847class EnKF_N: 848 """Finite-size EnKF (EnKF-N). 849 850 Refs: `bib.bocquet2011ensemble`, `bib.bocquet2015expanding` 851 852 This implementation is pedagogical, prioritizing the "dual" form. 853 In consequence, the efficiency of the "primal" form suffers a bit. 854 The primal form is included for completeness and to demonstrate equivalence. 855 In `dapper.da_methods.variational.iEnKS`, however, 856 the primal form is preferred because it 857 already does optimization for w (as treatment for nonlinear models). 858 859 `infl` should be unnecessary (assuming no model error, or that Q is correct). 860 861 `Hess`: use non-approx Hessian for ensemble transform matrix? 862 863 `g` is the nullity of A (state anomalies's), ie. g=max(1,N-Nx), 864 compensating for the redundancy in the space of w. 865 But we have made it an input argument instead, with default 0, 866 because mode-finding (of p(x) via the dual) completely ignores this redundancy, 867 and the mode gets (undesireably) modified by g. 868 869 `xN` allows tuning the hyper-prior for the inflation. 870 Usually, I just try setting it to 1 (default), or 2. 871 Further description in hyperprior_coeffs(). 872 """ 873 874 N: int 875 dual: bool = False 876 Hess: bool = False 877 xN: float = 1.0 878 g: int = 0 879 880 def assimilate(self, HMM, xx, yy): 881 N, N1 = self.N, self.N - 1 882 883 # Init 884 E = HMM.X0.sample(N) 885 self.stats.assess(0, E=E) 886 887 # Cycle 888 for k, ko, t, dt in progbar(HMM.tseq.ticker): 889 # Forecast 890 E = HMM.Dyn(E, t - dt, dt) 891 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 892 893 # Analysis 894 if ko is not None: 895 self.stats.assess(k, ko, "f", E=E) 896 Eo = HMM.Obs(ko)(E) 897 y = yy[ko] 898 899 mu = np.mean(E, 0) 900 A = E - mu 901 902 xo = np.mean(Eo, 0) 903 Y = Eo - xo 904 dy = y - xo 905 906 R = HMM.Obs(ko).noise.C 907 V, s, UT = svd0(Y @ R.sym_sqrt_inv.T) 908 du = UT @ (dy @ R.sym_sqrt_inv.T) 909 910 def dgn_N(l1): 911 return pad0((l1 * s) ** 2, N) + N1 912 913 # Adjust hyper-prior 914 # xN_ = noise_level(self.xN, self.stats, HMM.tseq, N1, ko, A, 915 # locals().get('A_old', None)) 916 eN, cL = hyperprior_coeffs(s, N, self.xN, self.g) 917 918 if self.dual: 919 # Make dual cost function (in terms of l1) 920 def pad_rk(arr): 921 return pad0(arr, min(N, len(y))) 922 923 def dgn_rk(l1): 924 return pad_rk((l1 * s) ** 2) + N1 925 926 def J(l1): 927 val = ( 928 np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2) 929 ) 930 return val 931 932 # Derivatives (not required with minimize_scalar): 933 def Jp(l1): 934 val = ( 935 -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2) 936 + -2 * eN / l1**3 937 + 2 * cL / l1 938 ) 939 return val 940 941 def Jpp(l1): 942 val = ( 943 8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3) 944 + 6 * eN / l1**4 945 + -2 * cL / l1**2 946 ) 947 return val 948 949 # Find inflation factor (optimize) 950 l1 = Newton_m(Jp, Jpp, 1.0) 951 # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0) 952 # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2), 953 # tol=1e-4).x 954 955 else: 956 # Primal form, in a fully linearized version. 957 def za(w): 958 return zeta_a(eN, cL, w) 959 960 def J(w): 961 return 0.5 * np.sum( 962 ((dy - w @ Y) @ R.sym_sqrt_inv.T) ** 2 963 ) + 0.5 * N1 * cL * np.log(eN + w @ w) 964 965 # Derivatives (not required with fmin_bfgs): 966 def Jp(w): 967 return -Y @ R.inv @ (dy - w @ Y) + w * za(w) 968 969 # Jpp = lambda w: Y@R.inv@Y.T + \ 970 # za(w)*(eye(N) - 2*np.outer(w,w)/(eN + w@w)) 971 # Approx: no radial-angular cross-deriv: 972 # Jpp = lambda w: Y@R.inv@Y.T + za(w)*eye(N) 973 974 def nvrs(w): 975 # inverse of Jpp-approx 976 return (V * (pad0(s**2, N) + za(w)) ** -1.0) @ V.T 977 978 # Find w (optimize) 979 wa = Newton_m(Jp, nvrs, zeros(N), is_inverted=True) 980 # wa = Newton_m(Jp,Jpp ,zeros(N)) 981 # wa = fmin_bfgs(J,zeros(N),Jp,disp=0) 982 l1 = sqrt(N1 / za(wa)) 983 984 # Uncomment to revert to ETKF 985 # l1 = 1.0 986 987 # Explicitly inflate prior 988 # => formulae look different from `bib.bocquet2015expanding`. 989 A *= l1 990 Y *= l1 991 992 # Compute sqrt update 993 Pw = (V * dgn_N(l1) ** (-1.0)) @ V.T 994 w = dy @ R.inv @ Y.T @ Pw 995 # For the anomalies: 996 if not self.Hess: 997 # Regular ETKF (i.e. sym sqrt) update (with inflation) 998 T = (V * dgn_N(l1) ** (-0.5)) @ V.T * sqrt(N1) 999 # = (Y@R.inv@Y.T/N1 + eye(N))**(-0.5) 1000 else: 1001 # Also include angular-radial co-dependence. 1002 # Note: denominator not squared coz 1003 # unlike `bib.bocquet2015expanding` we have inflated Y. 1004 Hw = ( 1005 Y @ R.inv @ Y.T / N1 1006 + eye(N) 1007 - 2 * np.outer(w, w) / (eN + w @ w) 1008 ) 1009 T = funm_psd(Hw, lambda x: x**-0.5) # is there a sqrtm Woodbury? 1010 1011 E = mu + w @ A + T @ A 1012 E = post_process(E, self.infl, self.rot) 1013 1014 self.stats.infl[ko] = l1 1015 self.stats.trHK[ko] = ( 1016 ((l1 * s) ** 2 + N1) ** (-1.0) * s**2 1017 ).sum() / len(y) 1018 1019 self.stats.assess(k, ko, E=E)
28@ens_method 29class EnKF: 30 """The ensemble Kalman filter. 31 32 Refs: `bib.evensen2009ensemble`. 33 """ 34 35 upd_a: str 36 N: int 37 38 def assimilate(self, HMM, xx, yy): 39 # Init 40 E = HMM.X0.sample(self.N) 41 self.stats.assess(0, E=E) 42 43 # Cycle 44 for k, ko, t, dt in progbar(HMM.tseq.ticker): 45 E = HMM.Dyn(E, t - dt, dt) 46 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 47 48 # Analysis update 49 if ko is not None: 50 self.stats.assess(k, ko, "f", E=E) 51 E = EnKF_analysis( 52 E, 53 HMM.Obs(ko)(E), 54 HMM.Obs(ko).noise, 55 yy[ko], 56 self.upd_a, 57 self.stats, 58 ko, 59 ) 60 E = post_process(E, self.infl, self.rot) 61 62 self.stats.assess(k, ko, E=E)
The ensemble Kalman filter.
Refs: bib.evensen2009ensemble
.
38 def assimilate(self, HMM, xx, yy): 39 # Init 40 E = HMM.X0.sample(self.N) 41 self.stats.assess(0, E=E) 42 43 # Cycle 44 for k, ko, t, dt in progbar(HMM.tseq.ticker): 45 E = HMM.Dyn(E, t - dt, dt) 46 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 47 48 # Analysis update 49 if ko is not None: 50 self.stats.assess(k, ko, "f", E=E) 51 E = EnKF_analysis( 52 E, 53 HMM.Obs(ko)(E), 54 HMM.Obs(ko).noise, 55 yy[ko], 56 self.upd_a, 57 self.stats, 58 ko, 59 ) 60 E = post_process(E, self.infl, self.rot) 61 62 self.stats.assess(k, ko, E=E)
65def EnKF_analysis(E, Eo, hnoise, y, upd_a, stats=None, ko=None): 66 """Perform the EnKF analysis update. 67 68 This implementation includes several flavours and forms, 69 specified by `upd_a`. 70 71 Main references: `bib.sakov2008deterministic`, 72 `bib.sakov2008implications`, `bib.hoteit2015mitigating` 73 """ 74 R = hnoise.C # Obs noise cov 75 N, Nx = E.shape # Dimensionality 76 N1 = N - 1 # Ens size - 1 77 78 mu = np.mean(E, 0) # Ens mean 79 A = E - mu # Ens anomalies 80 81 xo = np.mean(Eo, 0) # Obs ens mean 82 Y = Eo - xo # Obs ens anomalies 83 dy = y - xo # Mean "innovation" 84 85 if "PertObs" in upd_a: 86 # Uses classic, perturbed observations (Burgers'98) 87 C = Y.T @ Y + R.full * N1 88 D = mean0(hnoise.sample(N)) 89 YC = mrdiv(Y, C) 90 KG = A.T @ YC 91 HK = Y.T @ YC 92 dE = (KG @ (y - D - Eo).T).T 93 E = E + dE 94 95 elif "Sqrt" in upd_a: 96 # Uses a symmetric square root (ETKF) 97 # to deterministically transform the ensemble. 98 99 # The various versions below differ only numerically. 100 # EVD is default, but for large N use SVD version. 101 if upd_a == "Sqrt" and N > Nx: 102 upd_a = "Sqrt svd" 103 104 if "explicit" in upd_a: 105 # Not recommended due to numerical costs and instability. 106 # Implementation using inv (in ens space) 107 Pw = sla.inv(Y @ R.inv @ Y.T + N1 * eye(N)) 108 T = sla.sqrtm(Pw) * sqrt(N1) 109 HK = R.inv @ Y.T @ Pw @ Y 110 # KG = R.inv @ Y.T @ Pw @ A 111 elif "svd" in upd_a: 112 # Implementation using svd of Y R^{-1/2}. 113 V, s, _ = svd0(Y @ R.sym_sqrt_inv.T) 114 d = pad0(s**2, N) + N1 115 Pw = (V * d ** (-1.0)) @ V.T 116 T = (V * d ** (-0.5)) @ V.T * sqrt(N1) 117 # docs/snippets/trHK.jpg 118 trHK = np.sum((s**2 + N1) ** (-1.0) * s**2) 119 elif "sS" in upd_a: 120 # Same as 'svd', but with slightly different notation 121 # (sometimes used by Sakov) using the normalization sqrt(N1). 122 S = Y @ R.sym_sqrt_inv.T / sqrt(N1) 123 V, s, _ = svd0(S) 124 d = pad0(s**2, N) + 1 125 Pw = (V * d ** (-1.0)) @ V.T / N1 # = G/(N1) 126 T = (V * d ** (-0.5)) @ V.T 127 # docs/snippets/trHK.jpg 128 trHK = np.sum((s**2 + 1) ** (-1.0) * s**2) 129 else: # 'eig' in upd_a: 130 # Implementation using eig. val. decomp. 131 d, V = sla.eigh(Y @ R.inv @ Y.T + N1 * eye(N)) 132 T = V @ diag(d ** (-0.5)) @ V.T * sqrt(N1) 133 Pw = V @ diag(d ** (-1.0)) @ V.T 134 HK = R.inv @ Y.T @ (V @ diag(d ** (-1)) @ V.T) @ Y 135 w = dy @ R.inv @ Y.T @ Pw 136 E = mu + w @ A + T @ A 137 138 elif "Serial" in upd_a: 139 # Observations assimilated one-at-a-time: 140 inds = serial_inds(upd_a, y, R, A) 141 # Requires de-correlation: 142 dy = dy @ R.sym_sqrt_inv.T 143 Y = Y @ R.sym_sqrt_inv.T 144 # Enhancement in the nonlinear case: 145 # re-compute Y each scalar obs assim. 146 # But: little benefit, model costly (?), 147 # updates cannot be accumulated on S and T. 148 149 if any(x in upd_a for x in ["Stoch", "ESOPS", "Var1"]): 150 # More details: Misc/Serial_ESOPS.py. 151 for i, j in enumerate(inds): 152 # Perturbation creation 153 if "ESOPS" in upd_a: 154 # "2nd-O exact perturbation sampling" 155 if i == 0: 156 # Init -- increase nullspace by 1 157 V, s, UT = svd0(A) 158 s[N - 2 :] = 0 159 A = svdi(V, s, UT) 160 v = V[:, N - 2] 161 else: 162 # Orthogonalize v wrt. the new A 163 # 164 # v = Zj - Yj (from paper) requires Y==HX. 165 # Instead: mult` should be c*ones(Nx) so we can 166 # project v into ker(A) such that v@A is null. 167 mult = (v @ A) / (Yj @ A) # noqa 168 v = v - mult[0] * Yj # noqa 169 v /= sqrt(v @ v) 170 Zj = v * sqrt(N1) # Standardized perturbation along v 171 Zj *= np.sign(rng.standard_normal() - 0.5) # Random sign 172 else: 173 # The usual stochastic perturbations. 174 Zj = mean0(rng.standard_normal(N)) # Un-coloured noise 175 if "Var1" in upd_a: 176 Zj *= sqrt(N / (Zj @ Zj)) 177 178 # Select j-th obs 179 Yj = Y[:, j] # [j] obs anomalies 180 dyj = dy[j] # [j] innov mean 181 DYj = Zj - Yj # [j] innov anomalies 182 DYj = DYj[:, None] # Make 2d vertical 183 184 # Kalman gain computation 185 C = Yj @ Yj + N1 # Total obs cov 186 KGx = Yj @ A / C # KG to update state 187 KGy = Yj @ Y / C # KG to update obs 188 189 # Updates 190 A += DYj * KGx 191 mu += dyj * KGx 192 Y += DYj * KGy 193 dy -= dyj * KGy 194 E = mu + A 195 else: 196 # "Potter scheme", "EnSRF" 197 # - EAKF's two-stage "update-regress" form yields 198 # the same *ensemble* as this. 199 # - The form below may be derived as "serial ETKF", 200 # but does not yield the same 201 # ensemble as 'Sqrt' (which processes obs as a batch) 202 # -- only the same mean/cov. 203 T = eye(N) 204 for j in inds: 205 Yj = Y[:, j] 206 C = Yj @ Yj + N1 207 Tj = np.outer(Yj, Yj / (C + sqrt(N1 * C))) 208 T -= Tj @ T 209 Y -= Tj @ Y 210 w = dy @ Y.T @ T / N1 211 E = mu + w @ A + T @ A 212 213 elif "DEnKF" == upd_a: 214 # Uses "Deterministic EnKF" (sakov'08) 215 C = Y.T @ Y + R.full * N1 216 YC = mrdiv(Y, C) 217 KG = A.T @ YC 218 HK = Y.T @ YC 219 E = E + KG @ dy - 0.5 * (KG @ Y.T).T 220 221 else: 222 raise KeyError("No analysis update method found: '" + upd_a + "'.") 223 224 # Diagnostic: relative influence of observations 225 if stats is not None: 226 if "trHK" in locals(): 227 stats.trHK[ko] = trHK / hnoise.M 228 elif "HK" in locals(): 229 stats.trHK[ko] = HK.trace() / hnoise.M 230 231 return E
Perform the EnKF analysis update.
This implementation includes several flavours and forms,
specified by upd_a
.
Main references: bib.sakov2008deterministic
,
bib.sakov2008implications
, bib.hoteit2015mitigating
234def post_process(E, infl, rot): 235 """Inflate, Rotate. 236 237 To avoid recomputing/recombining anomalies, 238 this should have been inside `EnKF_analysis` 239 240 But it is kept as a separate function 241 242 - for readability; 243 - to avoid inflating/rotationg smoothed states (for the `EnKS`). 244 """ 245 do_infl = infl != 1.0 and infl != "-N" 246 247 if do_infl or rot: 248 A, mu = center(E) 249 N, Nx = E.shape 250 T = eye(N) 251 252 if do_infl: 253 T = infl * T 254 255 if rot: 256 T = genOG_1(N, rot) @ T 257 258 E = mu + T @ A 259 return E
Inflate, Rotate.
To avoid recomputing/recombining anomalies,
this should have been inside EnKF_analysis
But it is kept as a separate function
- for readability;
- to avoid inflating/rotationg smoothed states (for the
EnKS
).
262def add_noise(E, dt, noise, method): 263 """Treatment of additive noise for ensembles. 264 265 Refs: `bib.raanes2014ext` 266 """ 267 if noise.C == 0: 268 return E 269 270 N, Nx = E.shape 271 A, mu = center(E) 272 Q12 = noise.C.Left 273 Q = noise.C.full 274 275 def sqrt_core(): 276 T = np.nan # cause error if used 277 Qa12 = np.nan # cause error if used 278 A2 = A.copy() # Instead of using (the implicitly nonlocal) A, 279 # which changes A outside as well. NB: This is a bug in Datum! 280 if N <= Nx: 281 Ainv = tinv(A2.T) 282 Qa12 = Ainv @ Q12 283 T = funm_psd(eye(N) + dt * (N - 1) * (Qa12 @ Qa12.T), sqrt) 284 A2 = T @ A2 285 else: # "Left-multiplying" form 286 P = A2.T @ A2 / (N - 1) 287 L = funm_psd(eye(Nx) + dt * mrdiv(Q, P), sqrt) 288 A2 = A2 @ L.T 289 E = mu + A2 290 return E, T, Qa12 291 292 if method == "Stoch": 293 # In-place addition works (also) for empty [] noise sample. 294 E += sqrt(dt) * noise.sample(N) 295 296 elif method == "none": 297 pass 298 299 elif method == "Mult-1": 300 varE = np.var(E, axis=0, ddof=1).sum() 301 ratio = (varE + dt * diag(Q).sum()) / varE 302 E = mu + sqrt(ratio) * A 303 E = svdi(*tsvd(E, 0.999)) # Explained in Datum 304 305 elif method == "Mult-M": 306 varE = np.var(E, axis=0) 307 ratios = sqrt((varE + dt * diag(Q)) / varE) 308 E = mu + A * ratios 309 E = svdi(*tsvd(E, 0.999)) # Explained in Datum 310 311 elif method == "Sqrt-Core": 312 E = sqrt_core()[0] 313 314 elif method == "Sqrt-Mult-1": 315 varE0 = np.var(E, axis=0, ddof=1).sum() 316 varE2 = varE0 + dt * diag(Q).sum() 317 E, _, Qa12 = sqrt_core() 318 if N <= Nx: 319 A, mu = center(E) 320 varE1 = np.var(E, axis=0, ddof=1).sum() 321 ratio = varE2 / varE1 322 E = mu + sqrt(ratio) * A 323 E = svdi(*tsvd(E, 0.999)) # Explained in Datum 324 325 elif method == "Sqrt-Add-Z": 326 E, _, Qa12 = sqrt_core() 327 if N <= Nx: 328 Z = Q12 - A.T @ Qa12 329 E += sqrt(dt) * (Z @ rng.standard_normal((Z.shape[1], N))).T 330 331 elif method == "Sqrt-Dep": 332 E, T, Qa12 = sqrt_core() 333 if N <= Nx: 334 # Q_hat12: reuse svd for both inversion and projection. 335 Q_hat12 = A.T @ Qa12 336 U, s, VT = tsvd(Q_hat12, 0.99) 337 Q_hat12_inv = (VT.T * s ** (-1.0)) @ U.T 338 Q_hat12_proj = VT.T @ VT 339 rQ = Q12.shape[1] 340 # Calc D_til 341 Z = Q12 - Q_hat12 342 D_hat = A.T @ (T - eye(N)) 343 Xi_hat = Q_hat12_inv @ D_hat 344 Xi_til = (eye(rQ) - Q_hat12_proj) @ rng.standard_normal((rQ, N)) 345 D_til = Z @ (Xi_hat + sqrt(dt) * Xi_til) 346 E += D_til.T 347 348 else: 349 raise KeyError("No such method") 350 351 return E
Treatment of additive noise for ensembles.
Refs: bib.raanes2014ext
354@ens_method 355class EnKS: 356 """The ensemble Kalman smoother. 357 358 Refs: `bib.evensen2009ensemble` 359 360 The only difference to the EnKF 361 is the management of the lag and the reshapings. 362 """ 363 364 upd_a: str 365 N: int 366 Lag: int 367 368 # Reshapings used in smoothers to go to/from 369 # 3D arrays, where the 0th axis is the Lag index. 370 def reshape_to(self, E): 371 K, N, Nx = E.shape 372 return E.transpose([1, 0, 2]).reshape((N, K * Nx)) 373 374 def reshape_fr(self, E, Nx): 375 N, Km = E.shape 376 K = Km // Nx 377 return E.reshape((N, K, Nx)).transpose([1, 0, 2]) 378 379 def assimilate(self, HMM, xx, yy): 380 # Inefficient version, storing full time series ensemble. 381 # See iEnKS for a "rolling" version. 382 E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M)) 383 E[0] = HMM.X0.sample(self.N) 384 385 for k, ko, t, dt in progbar(HMM.tseq.ticker): 386 E[k] = HMM.Dyn(E[k - 1], t - dt, dt) 387 E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm) 388 389 if ko is not None: 390 self.stats.assess(k, ko, "f", E=E[k]) 391 392 Eo = HMM.Obs(ko)(E[k]) 393 y = yy[ko] 394 395 # Inds within Lag 396 kk = range(max(0, k - self.Lag * HMM.tseq.dko), k + 1) 397 398 EE = E[kk] 399 400 EE = self.reshape_to(EE) 401 EE = EnKF_analysis( 402 EE, Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko 403 ) 404 E[kk] = self.reshape_fr(EE, HMM.Dyn.M) 405 E[k] = post_process(E[k], self.infl, self.rot) 406 self.stats.assess(k, ko, "a", E=E[k]) 407 408 for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"): 409 self.stats.assess(k, ko, "u", E=E[k]) 410 if ko is not None: 411 self.stats.assess(k, ko, "s", E=E[k])
The ensemble Kalman smoother.
Refs: bib.evensen2009ensemble
The only difference to the EnKF is the management of the lag and the reshapings.
379 def assimilate(self, HMM, xx, yy): 380 # Inefficient version, storing full time series ensemble. 381 # See iEnKS for a "rolling" version. 382 E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M)) 383 E[0] = HMM.X0.sample(self.N) 384 385 for k, ko, t, dt in progbar(HMM.tseq.ticker): 386 E[k] = HMM.Dyn(E[k - 1], t - dt, dt) 387 E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm) 388 389 if ko is not None: 390 self.stats.assess(k, ko, "f", E=E[k]) 391 392 Eo = HMM.Obs(ko)(E[k]) 393 y = yy[ko] 394 395 # Inds within Lag 396 kk = range(max(0, k - self.Lag * HMM.tseq.dko), k + 1) 397 398 EE = E[kk] 399 400 EE = self.reshape_to(EE) 401 EE = EnKF_analysis( 402 EE, Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko 403 ) 404 E[kk] = self.reshape_fr(EE, HMM.Dyn.M) 405 E[k] = post_process(E[k], self.infl, self.rot) 406 self.stats.assess(k, ko, "a", E=E[k]) 407 408 for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"): 409 self.stats.assess(k, ko, "u", E=E[k]) 410 if ko is not None: 411 self.stats.assess(k, ko, "s", E=E[k])
414@ens_method 415class EnRTS: 416 """EnRTS (Rauch-Tung-Striebel) smoother. 417 418 Refs: `bib.raanes2016thesis` 419 """ 420 421 upd_a: str 422 N: int 423 DeCorr: float 424 425 def assimilate(self, HMM, xx, yy): 426 E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M)) 427 Ef = E.copy() 428 E[0] = HMM.X0.sample(self.N) 429 430 # Forward pass 431 for k, ko, t, dt in progbar(HMM.tseq.ticker): 432 E[k] = HMM.Dyn(E[k - 1], t - dt, dt) 433 E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm) 434 Ef[k] = E[k] 435 436 if ko is not None: 437 self.stats.assess(k, ko, "f", E=E[k]) 438 Eo = HMM.Obs(ko)(E[k]) 439 y = yy[ko] 440 E[k] = EnKF_analysis( 441 E[k], Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko 442 ) 443 E[k] = post_process(E[k], self.infl, self.rot) 444 self.stats.assess(k, ko, "a", E=E[k]) 445 446 # Backward pass 447 for k in progbar(range(HMM.tseq.K)[::-1]): 448 A = center(E[k])[0] 449 Af = center(Ef[k + 1])[0] 450 451 J = tinv(Af) @ A 452 J *= self.DeCorr 453 454 E[k] += (E[k + 1] - Ef[k + 1]) @ J 455 456 for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"): 457 self.stats.assess(k, ko, "u", E=E[k]) 458 if ko is not None: 459 self.stats.assess(k, ko, "s", E=E[k])
EnRTS (Rauch-Tung-Striebel) smoother.
Refs: bib.raanes2016thesis
425 def assimilate(self, HMM, xx, yy): 426 E = zeros((HMM.tseq.K + 1, self.N, HMM.Dyn.M)) 427 Ef = E.copy() 428 E[0] = HMM.X0.sample(self.N) 429 430 # Forward pass 431 for k, ko, t, dt in progbar(HMM.tseq.ticker): 432 E[k] = HMM.Dyn(E[k - 1], t - dt, dt) 433 E[k] = add_noise(E[k], dt, HMM.Dyn.noise, self.fnoise_treatm) 434 Ef[k] = E[k] 435 436 if ko is not None: 437 self.stats.assess(k, ko, "f", E=E[k]) 438 Eo = HMM.Obs(ko)(E[k]) 439 y = yy[ko] 440 E[k] = EnKF_analysis( 441 E[k], Eo, HMM.Obs(ko).noise, y, self.upd_a, self.stats, ko 442 ) 443 E[k] = post_process(E[k], self.infl, self.rot) 444 self.stats.assess(k, ko, "a", E=E[k]) 445 446 # Backward pass 447 for k in progbar(range(HMM.tseq.K)[::-1]): 448 A = center(E[k])[0] 449 Af = center(Ef[k + 1])[0] 450 451 J = tinv(Af) @ A 452 J *= self.DeCorr 453 454 E[k] += (E[k + 1] - Ef[k + 1]) @ J 455 456 for k, ko, _, _ in progbar(HMM.tseq.ticker, desc="Assessing"): 457 self.stats.assess(k, ko, "u", E=E[k]) 458 if ko is not None: 459 self.stats.assess(k, ko, "s", E=E[k])
462def serial_inds(upd_a, y, cvR, A): 463 """Get the indices used for serial updating. 464 465 - Default: random ordering 466 - if "mono" in `upd_a`: `1, 2, ..., len(y)` 467 - if "sorted" in `upd_a`: sort by variance 468 """ 469 if "mono" in upd_a: 470 # Not robust? 471 inds = np.arange(len(y)) 472 elif "sorted" in upd_a: 473 N = len(A) 474 dC = cvR.diag 475 if np.all(dC == dC[0]): 476 # Sort y by P 477 dC = np.sum(A * A, 0) / (N - 1) 478 inds = np.argsort(dC) 479 else: # Default: random ordering 480 inds = rng.permutation(len(y)) 481 return inds
Get the indices used for serial updating.
- Default: random ordering
- if "mono" in
upd_a
:1, 2, ..., len(y)
- if "sorted" in
upd_a
: sort by variance
484@ens_method 485class SL_EAKF: 486 """Serial, covariance-localized EAKF. 487 488 Refs: `bib.karspeck2007experimental`. 489 490 In contrast with LETKF, this iterates over the observations rather 491 than over the state (batches). 492 493 Used without localization, this should be equivalent (full ensemble equality) 494 to the `EnKF` with `upd_a='Serial'`. 495 """ 496 497 N: int 498 loc_rad: float 499 taper: str = "GC" 500 ordr: str = "rand" 501 502 def assimilate(self, HMM, xx, yy): 503 N1 = self.N - 1 504 505 E = HMM.X0.sample(self.N) 506 self.stats.assess(0, E=E) 507 508 for k, ko, t, dt in progbar(HMM.tseq.ticker): 509 E = HMM.Dyn(E, t - dt, dt) 510 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 511 512 if ko is not None: 513 self.stats.assess(k, ko, "f", E=E) 514 Obs = HMM.Obs(ko) 515 R = Obs.noise 516 y = yy[ko] 517 inds = serial_inds(self.ordr, y, R, center(E)[0]) 518 Rm12 = Obs.noise.C.sym_sqrt_inv 519 520 state_taperer = Obs.localizer(self.loc_rad, "y2x", self.taper) 521 for j in inds: 522 # Prep: 523 # ------------------------------------------------------ 524 Eo = Obs(E) 525 xo = np.mean(Eo, 0) 526 Y = Eo - xo 527 mu = np.mean(E, 0) 528 A = E - mu 529 # Update j-th component of observed ensemble: 530 # ------------------------------------------------------ 531 Y_j = Rm12[j, :] @ Y.T 532 dy_j = Rm12[j, :] @ (y - xo) 533 # Prior var * N1: 534 sig2_j = Y_j @ Y_j 535 if sig2_j < 1e-9: 536 continue 537 # Update (below, we drop the locality subscript: _j) 538 sig2_u = 1 / (1 / sig2_j + 1 / N1) # Postr. var * N1 539 alpha = (N1 / (N1 + sig2_j)) ** (0.5) # Update contraction factor 540 dy2 = sig2_u * dy_j / N1 # Mean update 541 Y2 = alpha * Y_j # Anomaly update 542 # Update state (regress update from obs space, using localization) 543 # ------------------------------------------------------ 544 ii, tapering = state_taperer(j) 545 # ii, tapering = ..., 1 # cancel localization 546 if len(ii) == 0: 547 continue 548 Xi = A[:, ii] * tapering 549 Regression = Xi.T @ Y_j / np.sum(Y_j**2) 550 mu[ii] += Regression * dy2 551 A[:, ii] += np.outer(Y2 - Y_j, Regression) 552 E = mu + A 553 554 E = post_process(E, self.infl, self.rot) 555 556 self.stats.assess(k, ko, E=E)
Serial, covariance-localized EAKF.
Refs: bib.karspeck2007experimental
.
In contrast with LETKF, this iterates over the observations rather than over the state (batches).
Used without localization, this should be equivalent (full ensemble equality)
to the EnKF
with upd_a='Serial'
.
502 def assimilate(self, HMM, xx, yy): 503 N1 = self.N - 1 504 505 E = HMM.X0.sample(self.N) 506 self.stats.assess(0, E=E) 507 508 for k, ko, t, dt in progbar(HMM.tseq.ticker): 509 E = HMM.Dyn(E, t - dt, dt) 510 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 511 512 if ko is not None: 513 self.stats.assess(k, ko, "f", E=E) 514 Obs = HMM.Obs(ko) 515 R = Obs.noise 516 y = yy[ko] 517 inds = serial_inds(self.ordr, y, R, center(E)[0]) 518 Rm12 = Obs.noise.C.sym_sqrt_inv 519 520 state_taperer = Obs.localizer(self.loc_rad, "y2x", self.taper) 521 for j in inds: 522 # Prep: 523 # ------------------------------------------------------ 524 Eo = Obs(E) 525 xo = np.mean(Eo, 0) 526 Y = Eo - xo 527 mu = np.mean(E, 0) 528 A = E - mu 529 # Update j-th component of observed ensemble: 530 # ------------------------------------------------------ 531 Y_j = Rm12[j, :] @ Y.T 532 dy_j = Rm12[j, :] @ (y - xo) 533 # Prior var * N1: 534 sig2_j = Y_j @ Y_j 535 if sig2_j < 1e-9: 536 continue 537 # Update (below, we drop the locality subscript: _j) 538 sig2_u = 1 / (1 / sig2_j + 1 / N1) # Postr. var * N1 539 alpha = (N1 / (N1 + sig2_j)) ** (0.5) # Update contraction factor 540 dy2 = sig2_u * dy_j / N1 # Mean update 541 Y2 = alpha * Y_j # Anomaly update 542 # Update state (regress update from obs space, using localization) 543 # ------------------------------------------------------ 544 ii, tapering = state_taperer(j) 545 # ii, tapering = ..., 1 # cancel localization 546 if len(ii) == 0: 547 continue 548 Xi = A[:, ii] * tapering 549 Regression = Xi.T @ Y_j / np.sum(Y_j**2) 550 mu[ii] += Regression * dy2 551 A[:, ii] += np.outer(Y2 - Y_j, Regression) 552 E = mu + A 553 554 E = post_process(E, self.infl, self.rot) 555 556 self.stats.assess(k, ko, E=E)
559def local_analyses(E, Eo, R, y, state_batches, obs_taperer, mp=map, xN=None, g=0): 560 """Perform local analysis update for the LETKF.""" 561 562 def local_analysis(ii): 563 """Perform analysis, for state index batch `ii`.""" 564 # Locate local domain 565 oBatch, tapering = obs_taperer(ii) 566 Eii = E[:, ii] 567 568 # No update 569 if len(oBatch) == 0: 570 return Eii, 1 571 572 # Localize 573 Yl = Y[:, oBatch] 574 dyl = dy[oBatch] 575 tpr = sqrt(tapering) 576 577 # Adaptive inflation estimation. 578 # NB: Localisation is not 100% compatible with the EnKF-N, since 579 # - After localisation there is much less need for inflation. 580 # - Tapered values (Y, dy) are too neat 581 # (the EnKF-N expects a normal amount of sampling error). 582 # One fix is to tune xN (maybe set it to 2 or 3). Thanks to adaptivity, 583 # this should still be easier than tuning the inflation factor. 584 infl1 = 1 if xN is None else sqrt(N1 / effective_N(Yl, dyl, xN, g)) 585 Eii, Yl = inflate_ens(Eii, infl1), Yl * infl1 586 # Since R^{-1/2} was already applied (necesry for effective_N), now use R=Id. 587 # TODO 4: the cost of re-init this R might not always be insignificant. 588 R = GaussRV(C=1, M=len(dyl)) 589 590 # Update 591 Eii = EnKF_analysis(Eii, Yl * tpr, R, dyl * tpr, "Sqrt") 592 593 return Eii, infl1 594 595 # Prepare analysis 596 N1 = len(E) - 1 597 Y, xo = center(Eo) 598 # Transform obs space 599 Y = Y @ R.sym_sqrt_inv.T 600 dy = (y - xo) @ R.sym_sqrt_inv.T 601 602 # Run 603 result = mp(local_analysis, state_batches) 604 605 # Assign 606 E_batches, infl1 = zip(*result) 607 # TODO: this overwrites E, possibly unbeknownst to caller 608 for ii, Eii in zip(state_batches, E_batches): 609 E[:, ii] = Eii 610 611 return E, dict(ad_inf=sqrt(np.mean(np.array(infl1) ** 2)))
Perform local analysis update for the LETKF.
614@ens_method 615class LETKF: 616 """Same as EnKF (Sqrt), but with localization. 617 618 Refs: `bib.hunt2007efficient`. 619 620 NB: Multiproc. yields slow-down for `dapper.mods.Lorenz96`, 621 even with `batch_size=(1,)`. But for `dapper.mods.QG` 622 (`batch_size=(2,2)` or less) it is quicker. 623 624 NB: If `len(ii)` is small, analysis may be slowed-down with '-N' infl. 625 """ 626 627 N: int 628 loc_rad: float 629 taper: str = "GC" 630 xN: float = None 631 g: int = 0 632 mp: bool = False 633 634 def assimilate(self, HMM, xx, yy): 635 E = HMM.X0.sample(self.N) 636 self.stats.assess(0, E=E) 637 self.stats.new_series("ad_inf", 1, HMM.tseq.Ko + 1) 638 639 with multiproc.Pool(self.mp) as pool: 640 for k, ko, t, dt in progbar(HMM.tseq.ticker): 641 E = HMM.Dyn(E, t - dt, dt) 642 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 643 644 if ko is not None: 645 self.stats.assess(k, ko, "f", E=E) 646 Obs = HMM.Obs(ko) 647 batch, taper = Obs.localizer(self.loc_rad, "x2y", self.taper) 648 E, stats = local_analyses( 649 E, 650 Obs(E), 651 Obs.noise.C, 652 yy[ko], 653 batch, 654 taper, 655 pool.map, 656 self.xN, 657 self.g, 658 ) 659 self.stats.write(stats, k, ko, "a") 660 E = post_process(E, self.infl, self.rot) 661 662 self.stats.assess(k, ko, E=E)
Same as EnKF (Sqrt), but with localization.
Refs: bib.hunt2007efficient
.
NB: Multiproc. yields slow-down for dapper.mods.Lorenz96
,
even with batch_size=(1,)
. But for dapper.mods.QG
(batch_size=(2,2)
or less) it is quicker.
NB: If len(ii)
is small, analysis may be slowed-down with '-N' infl.
634 def assimilate(self, HMM, xx, yy): 635 E = HMM.X0.sample(self.N) 636 self.stats.assess(0, E=E) 637 self.stats.new_series("ad_inf", 1, HMM.tseq.Ko + 1) 638 639 with multiproc.Pool(self.mp) as pool: 640 for k, ko, t, dt in progbar(HMM.tseq.ticker): 641 E = HMM.Dyn(E, t - dt, dt) 642 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 643 644 if ko is not None: 645 self.stats.assess(k, ko, "f", E=E) 646 Obs = HMM.Obs(ko) 647 batch, taper = Obs.localizer(self.loc_rad, "x2y", self.taper) 648 E, stats = local_analyses( 649 E, 650 Obs(E), 651 Obs.noise.C, 652 yy[ko], 653 batch, 654 taper, 655 pool.map, 656 self.xN, 657 self.g, 658 ) 659 self.stats.write(stats, k, ko, "a") 660 E = post_process(E, self.infl, self.rot) 661 662 self.stats.assess(k, ko, E=E)
665def effective_N(YR, dyR, xN, g): 666 """Effective ensemble size N. 667 668 As measured by the finite-size EnKF-N 669 """ 670 N, Ny = YR.shape 671 N1 = N - 1 672 673 V, s, UT = svd0(YR) 674 du = UT @ dyR 675 676 eN, cL = hyperprior_coeffs(s, N, xN, g) 677 678 def pad_rk(arr): 679 return pad0(arr, min(N, Ny)) 680 681 def dgn_rk(l1): 682 return pad_rk((l1 * s) ** 2) + N1 683 684 # Make dual cost function (in terms of l1) 685 def J(l1): 686 val = np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2) 687 return val 688 689 # Derivatives (not required with minimize_scalar): 690 def Jp(l1): 691 val = ( 692 -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2) 693 + -2 * eN / l1**3 694 + 2 * cL / l1 695 ) 696 return val 697 698 def Jpp(l1): 699 val = ( 700 8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3) 701 + 6 * eN / l1**4 702 + -2 * cL / l1**2 703 ) 704 return val 705 706 # Find inflation factor (optimize) 707 l1 = Newton_m(Jp, Jpp, 1.0) 708 # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0) 709 # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2), tol=1e-4).x 710 711 za = N1 / l1**2 712 return za
Effective ensemble size N.
As measured by the finite-size EnKF-N
740def Newton_m( 741 fun, deriv, x0, is_inverted=False, conf=1.0, xtol=1e-4, ytol=1e-7, itermax=10**2 742): 743 """Find root of `fun`. 744 745 This is a simple (and pretty fast) implementation of Newton's method. 746 """ 747 itr = 0 748 dx = np.inf 749 Jx = fun(x0) 750 751 def norm(x): 752 return sqrt(np.sum(x**2)) 753 754 while ytol < norm(Jx) and xtol < norm(dx) and itr < itermax: 755 Dx = deriv(x0) 756 if is_inverted: 757 dx = Dx @ Jx 758 elif isinstance(Dx, float): 759 dx = Jx / Dx 760 else: 761 dx = mldiv(Dx, Jx) 762 dx *= conf 763 x0 -= dx 764 Jx = fun(x0) 765 itr += 1 766 return x0
Find root of fun
.
This is a simple (and pretty fast) implementation of Newton's method.
769def hyperprior_coeffs(s, N, xN=1, g=0): 770 r"""Set EnKF-N inflation hyperparams. 771 772 The EnKF-N prior may be specified by the constants: 773 774 - `eN`: Effect of unknown mean 775 - `cL`: Coeff in front of log term 776 777 These are trivial constants in the original EnKF-N, 778 but are further adjusted (corrected and tuned) for the following reasons. 779 780 - Reason 1: mode correction. 781 These parameters bridge the Jeffreys (`xN=1`) and Dirac (`xN=Inf`) hyperpriors 782 for the prior covariance, B, as discussed in `bib.bocquet2015expanding`. 783 Indeed, mode correction becomes necessary when $$ R \rightarrow \infty $$ 784 because then there should be no ensemble update (and also no inflation!). 785 More specifically, the mode of `l1`'s should be adjusted towards 1 786 as a function of $$ I - K H $$ ("prior's weight"). 787 PS: why do we leave the prior mode below 1 at all? 788 Because it sets up "tension" (negative feedback) in the inflation cycle: 789 the prior pulls downwards, while the likelihood tends to pull upwards. 790 791 - Reason 2: Boosting the inflation prior's certainty from N to xN*N. 792 The aim is to take advantage of the fact that the ensemble may not 793 have quite as much sampling error as a fully stochastic sample, 794 as illustrated in section 2.1 of `bib.raanes2019adaptive`. 795 796 - Its damping effect is similar to work done by J. Anderson. 797 798 The tuning is controlled by: 799 800 - `xN=1`: is fully agnostic, i.e. assumes the ensemble is generated 801 from a highly chaotic or stochastic model. 802 - `xN>1`: increases the certainty of the hyper-prior, 803 which is appropriate for more linear and deterministic systems. 804 - `xN<1`: yields a more (than 'fully') agnostic hyper-prior, 805 as if N were smaller than it truly is. 806 - `xN<=0` is not meaningful. 807 """ 808 N1 = N - 1 809 810 eN = (N + 1) / N 811 cL = (N + g) / N1 812 813 # Mode correction (almost) as in eqn 36 of `bib.bocquet2015expanding` 814 prior_mode = eN / cL # Mode of l1 (before correction) 815 diagonal = pad0(s**2, N) + N1 # diag of Y@R.inv@Y + N1*I 816 # (Hessian of J) 817 I_KH = np.mean(diagonal ** (-1)) * N1 # ≈ 1/(1 + HBH/R) 818 # I_KH = 1/(1 + (s**2).sum()/N1) # Scalar alternative: use tr(HBH/R). 819 mc = sqrt(prior_mode**I_KH) # Correction coeff 820 821 # Apply correction 822 eN /= mc 823 cL *= mc 824 825 # Boost by xN 826 eN *= xN 827 cL *= xN 828 829 return eN, cL
Set EnKF-N inflation hyperparams.
The EnKF-N prior may be specified by the constants:
eN
: Effect of unknown meancL
: Coeff in front of log term
These are trivial constants in the original EnKF-N, but are further adjusted (corrected and tuned) for the following reasons.
Reason 1: mode correction. These parameters bridge the Jeffreys (
xN=1
) and Dirac (xN=Inf
) hyperpriors for the prior covariance, B, as discussed inbib.bocquet2015expanding
. Indeed, mode correction becomes necessary when $$ R \rightarrow \infty $$ because then there should be no ensemble update (and also no inflation!). More specifically, the mode ofl1
's should be adjusted towards 1 as a function of $$ I - K H $$ ("prior's weight"). PS: why do we leave the prior mode below 1 at all? Because it sets up "tension" (negative feedback) in the inflation cycle: the prior pulls downwards, while the likelihood tends to pull upwards.Reason 2: Boosting the inflation prior's certainty from N to xN*N. The aim is to take advantage of the fact that the ensemble may not have quite as much sampling error as a fully stochastic sample, as illustrated in section 2.1 of
bib.raanes2019adaptive
.Its damping effect is similar to work done by J. Anderson.
The tuning is controlled by:
xN=1
: is fully agnostic, i.e. assumes the ensemble is generated from a highly chaotic or stochastic model.xN>1
: increases the certainty of the hyper-prior, which is appropriate for more linear and deterministic systems.xN<1
: yields a more (than 'fully') agnostic hyper-prior, as if N were smaller than it truly is.xN<=0
is not meaningful.
832def zeta_a(eN, cL, w): 833 """EnKF-N inflation estimation via w. 834 835 Returns `zeta_a = (N-1)/pre-inflation^2`. 836 837 Using this inside an iterative minimization as in the 838 `dapper.da_methods.variational.iEnKS` effectively blends 839 the distinction between the primal and dual EnKF-N. 840 """ 841 N = len(w) 842 N1 = N - 1 843 za = N1 * cL / (eN + w @ w) 844 return za
EnKF-N inflation estimation via w.
Returns zeta_a = (N-1)/pre-inflation^2
.
Using this inside an iterative minimization as in the
dapper.da_methods.variational.iEnKS
effectively blends
the distinction between the primal and dual EnKF-N.
847@ens_method 848class EnKF_N: 849 """Finite-size EnKF (EnKF-N). 850 851 Refs: `bib.bocquet2011ensemble`, `bib.bocquet2015expanding` 852 853 This implementation is pedagogical, prioritizing the "dual" form. 854 In consequence, the efficiency of the "primal" form suffers a bit. 855 The primal form is included for completeness and to demonstrate equivalence. 856 In `dapper.da_methods.variational.iEnKS`, however, 857 the primal form is preferred because it 858 already does optimization for w (as treatment for nonlinear models). 859 860 `infl` should be unnecessary (assuming no model error, or that Q is correct). 861 862 `Hess`: use non-approx Hessian for ensemble transform matrix? 863 864 `g` is the nullity of A (state anomalies's), ie. g=max(1,N-Nx), 865 compensating for the redundancy in the space of w. 866 But we have made it an input argument instead, with default 0, 867 because mode-finding (of p(x) via the dual) completely ignores this redundancy, 868 and the mode gets (undesireably) modified by g. 869 870 `xN` allows tuning the hyper-prior for the inflation. 871 Usually, I just try setting it to 1 (default), or 2. 872 Further description in hyperprior_coeffs(). 873 """ 874 875 N: int 876 dual: bool = False 877 Hess: bool = False 878 xN: float = 1.0 879 g: int = 0 880 881 def assimilate(self, HMM, xx, yy): 882 N, N1 = self.N, self.N - 1 883 884 # Init 885 E = HMM.X0.sample(N) 886 self.stats.assess(0, E=E) 887 888 # Cycle 889 for k, ko, t, dt in progbar(HMM.tseq.ticker): 890 # Forecast 891 E = HMM.Dyn(E, t - dt, dt) 892 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 893 894 # Analysis 895 if ko is not None: 896 self.stats.assess(k, ko, "f", E=E) 897 Eo = HMM.Obs(ko)(E) 898 y = yy[ko] 899 900 mu = np.mean(E, 0) 901 A = E - mu 902 903 xo = np.mean(Eo, 0) 904 Y = Eo - xo 905 dy = y - xo 906 907 R = HMM.Obs(ko).noise.C 908 V, s, UT = svd0(Y @ R.sym_sqrt_inv.T) 909 du = UT @ (dy @ R.sym_sqrt_inv.T) 910 911 def dgn_N(l1): 912 return pad0((l1 * s) ** 2, N) + N1 913 914 # Adjust hyper-prior 915 # xN_ = noise_level(self.xN, self.stats, HMM.tseq, N1, ko, A, 916 # locals().get('A_old', None)) 917 eN, cL = hyperprior_coeffs(s, N, self.xN, self.g) 918 919 if self.dual: 920 # Make dual cost function (in terms of l1) 921 def pad_rk(arr): 922 return pad0(arr, min(N, len(y))) 923 924 def dgn_rk(l1): 925 return pad_rk((l1 * s) ** 2) + N1 926 927 def J(l1): 928 val = ( 929 np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2) 930 ) 931 return val 932 933 # Derivatives (not required with minimize_scalar): 934 def Jp(l1): 935 val = ( 936 -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2) 937 + -2 * eN / l1**3 938 + 2 * cL / l1 939 ) 940 return val 941 942 def Jpp(l1): 943 val = ( 944 8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3) 945 + 6 * eN / l1**4 946 + -2 * cL / l1**2 947 ) 948 return val 949 950 # Find inflation factor (optimize) 951 l1 = Newton_m(Jp, Jpp, 1.0) 952 # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0) 953 # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2), 954 # tol=1e-4).x 955 956 else: 957 # Primal form, in a fully linearized version. 958 def za(w): 959 return zeta_a(eN, cL, w) 960 961 def J(w): 962 return 0.5 * np.sum( 963 ((dy - w @ Y) @ R.sym_sqrt_inv.T) ** 2 964 ) + 0.5 * N1 * cL * np.log(eN + w @ w) 965 966 # Derivatives (not required with fmin_bfgs): 967 def Jp(w): 968 return -Y @ R.inv @ (dy - w @ Y) + w * za(w) 969 970 # Jpp = lambda w: Y@R.inv@Y.T + \ 971 # za(w)*(eye(N) - 2*np.outer(w,w)/(eN + w@w)) 972 # Approx: no radial-angular cross-deriv: 973 # Jpp = lambda w: Y@R.inv@Y.T + za(w)*eye(N) 974 975 def nvrs(w): 976 # inverse of Jpp-approx 977 return (V * (pad0(s**2, N) + za(w)) ** -1.0) @ V.T 978 979 # Find w (optimize) 980 wa = Newton_m(Jp, nvrs, zeros(N), is_inverted=True) 981 # wa = Newton_m(Jp,Jpp ,zeros(N)) 982 # wa = fmin_bfgs(J,zeros(N),Jp,disp=0) 983 l1 = sqrt(N1 / za(wa)) 984 985 # Uncomment to revert to ETKF 986 # l1 = 1.0 987 988 # Explicitly inflate prior 989 # => formulae look different from `bib.bocquet2015expanding`. 990 A *= l1 991 Y *= l1 992 993 # Compute sqrt update 994 Pw = (V * dgn_N(l1) ** (-1.0)) @ V.T 995 w = dy @ R.inv @ Y.T @ Pw 996 # For the anomalies: 997 if not self.Hess: 998 # Regular ETKF (i.e. sym sqrt) update (with inflation) 999 T = (V * dgn_N(l1) ** (-0.5)) @ V.T * sqrt(N1) 1000 # = (Y@R.inv@Y.T/N1 + eye(N))**(-0.5) 1001 else: 1002 # Also include angular-radial co-dependence. 1003 # Note: denominator not squared coz 1004 # unlike `bib.bocquet2015expanding` we have inflated Y. 1005 Hw = ( 1006 Y @ R.inv @ Y.T / N1 1007 + eye(N) 1008 - 2 * np.outer(w, w) / (eN + w @ w) 1009 ) 1010 T = funm_psd(Hw, lambda x: x**-0.5) # is there a sqrtm Woodbury? 1011 1012 E = mu + w @ A + T @ A 1013 E = post_process(E, self.infl, self.rot) 1014 1015 self.stats.infl[ko] = l1 1016 self.stats.trHK[ko] = ( 1017 ((l1 * s) ** 2 + N1) ** (-1.0) * s**2 1018 ).sum() / len(y) 1019 1020 self.stats.assess(k, ko, E=E)
Finite-size EnKF (EnKF-N).
Refs: bib.bocquet2011ensemble
, bib.bocquet2015expanding
This implementation is pedagogical, prioritizing the "dual" form.
In consequence, the efficiency of the "primal" form suffers a bit.
The primal form is included for completeness and to demonstrate equivalence.
In dapper.da_methods.variational.iEnKS
, however,
the primal form is preferred because it
already does optimization for w (as treatment for nonlinear models).
infl
should be unnecessary (assuming no model error, or that Q is correct).
Hess
: use non-approx Hessian for ensemble transform matrix?
g
is the nullity of A (state anomalies's), ie. g=max(1,N-Nx),
compensating for the redundancy in the space of w.
But we have made it an input argument instead, with default 0,
because mode-finding (of p(x) via the dual) completely ignores this redundancy,
and the mode gets (undesireably) modified by g.
xN
allows tuning the hyper-prior for the inflation.
Usually, I just try setting it to 1 (default), or 2.
Further description in hyperprior_coeffs().
881 def assimilate(self, HMM, xx, yy): 882 N, N1 = self.N, self.N - 1 883 884 # Init 885 E = HMM.X0.sample(N) 886 self.stats.assess(0, E=E) 887 888 # Cycle 889 for k, ko, t, dt in progbar(HMM.tseq.ticker): 890 # Forecast 891 E = HMM.Dyn(E, t - dt, dt) 892 E = add_noise(E, dt, HMM.Dyn.noise, self.fnoise_treatm) 893 894 # Analysis 895 if ko is not None: 896 self.stats.assess(k, ko, "f", E=E) 897 Eo = HMM.Obs(ko)(E) 898 y = yy[ko] 899 900 mu = np.mean(E, 0) 901 A = E - mu 902 903 xo = np.mean(Eo, 0) 904 Y = Eo - xo 905 dy = y - xo 906 907 R = HMM.Obs(ko).noise.C 908 V, s, UT = svd0(Y @ R.sym_sqrt_inv.T) 909 du = UT @ (dy @ R.sym_sqrt_inv.T) 910 911 def dgn_N(l1): 912 return pad0((l1 * s) ** 2, N) + N1 913 914 # Adjust hyper-prior 915 # xN_ = noise_level(self.xN, self.stats, HMM.tseq, N1, ko, A, 916 # locals().get('A_old', None)) 917 eN, cL = hyperprior_coeffs(s, N, self.xN, self.g) 918 919 if self.dual: 920 # Make dual cost function (in terms of l1) 921 def pad_rk(arr): 922 return pad0(arr, min(N, len(y))) 923 924 def dgn_rk(l1): 925 return pad_rk((l1 * s) ** 2) + N1 926 927 def J(l1): 928 val = ( 929 np.sum(du**2 / dgn_rk(l1)) + eN / l1**2 + cL * np.log(l1**2) 930 ) 931 return val 932 933 # Derivatives (not required with minimize_scalar): 934 def Jp(l1): 935 val = ( 936 -2 * l1 * np.sum(pad_rk(s**2) * du**2 / dgn_rk(l1) ** 2) 937 + -2 * eN / l1**3 938 + 2 * cL / l1 939 ) 940 return val 941 942 def Jpp(l1): 943 val = ( 944 8 * l1**2 * np.sum(pad_rk(s**4) * du**2 / dgn_rk(l1) ** 3) 945 + 6 * eN / l1**4 946 + -2 * cL / l1**2 947 ) 948 return val 949 950 # Find inflation factor (optimize) 951 l1 = Newton_m(Jp, Jpp, 1.0) 952 # l1 = fmin_bfgs(J, x0=[1], gtol=1e-4, disp=0) 953 # l1 = minimize_scalar(J, bracket=(sqrt(prior_mode), 1e2), 954 # tol=1e-4).x 955 956 else: 957 # Primal form, in a fully linearized version. 958 def za(w): 959 return zeta_a(eN, cL, w) 960 961 def J(w): 962 return 0.5 * np.sum( 963 ((dy - w @ Y) @ R.sym_sqrt_inv.T) ** 2 964 ) + 0.5 * N1 * cL * np.log(eN + w @ w) 965 966 # Derivatives (not required with fmin_bfgs): 967 def Jp(w): 968 return -Y @ R.inv @ (dy - w @ Y) + w * za(w) 969 970 # Jpp = lambda w: Y@R.inv@Y.T + \ 971 # za(w)*(eye(N) - 2*np.outer(w,w)/(eN + w@w)) 972 # Approx: no radial-angular cross-deriv: 973 # Jpp = lambda w: Y@R.inv@Y.T + za(w)*eye(N) 974 975 def nvrs(w): 976 # inverse of Jpp-approx 977 return (V * (pad0(s**2, N) + za(w)) ** -1.0) @ V.T 978 979 # Find w (optimize) 980 wa = Newton_m(Jp, nvrs, zeros(N), is_inverted=True) 981 # wa = Newton_m(Jp,Jpp ,zeros(N)) 982 # wa = fmin_bfgs(J,zeros(N),Jp,disp=0) 983 l1 = sqrt(N1 / za(wa)) 984 985 # Uncomment to revert to ETKF 986 # l1 = 1.0 987 988 # Explicitly inflate prior 989 # => formulae look different from `bib.bocquet2015expanding`. 990 A *= l1 991 Y *= l1 992 993 # Compute sqrt update 994 Pw = (V * dgn_N(l1) ** (-1.0)) @ V.T 995 w = dy @ R.inv @ Y.T @ Pw 996 # For the anomalies: 997 if not self.Hess: 998 # Regular ETKF (i.e. sym sqrt) update (with inflation) 999 T = (V * dgn_N(l1) ** (-0.5)) @ V.T * sqrt(N1) 1000 # = (Y@R.inv@Y.T/N1 + eye(N))**(-0.5) 1001 else: 1002 # Also include angular-radial co-dependence. 1003 # Note: denominator not squared coz 1004 # unlike `bib.bocquet2015expanding` we have inflated Y. 1005 Hw = ( 1006 Y @ R.inv @ Y.T / N1 1007 + eye(N) 1008 - 2 * np.outer(w, w) / (eN + w @ w) 1009 ) 1010 T = funm_psd(Hw, lambda x: x**-0.5) # is there a sqrtm Woodbury? 1011 1012 E = mu + w @ A + T @ A 1013 E = post_process(E, self.infl, self.rot) 1014 1015 self.stats.infl[ko] = l1 1016 self.stats.trHK[ko] = ( 1017 ((l1 * s) ** 2 + N1) ** (-1.0) * s**2 1018 ).sum() / len(y) 1019 1020 self.stats.assess(k, ko, E=E)