dapper.stats
Statistics for the assessment of DA methods.
Stats
is a data container for ([mostly] time series of) statistics.
It comes with a battery of methods to compute the default statistics.
Avrgs
is a data container for the same statistics,
but after they have been averaged in time (after the assimilation has finished).
Instances of these objects are created by dapper.da_methods.da_method
(i.e. "xp
") objects and written to their dapper.stats
and .avrgs
attributes.
Default statistics
List them using
>>> list(vars(xp.stats))
... list(vars(xp.avrgs))
The FAUSt
key/attribute
The time series of statistics (the attributes of dapper.stats
) may have attributes
.f
, .a
, .s
, .u
, referring to whether the statistic is for a "forecast",
"analysis", or "smoothing" estimate (as is decided when the calls to
Stats.assess
is made), or a "universal" (forecast, but at intermediate
[non-obs.-time]) estimate.
The same applies for the time-averages of .avrgs
.
Field summaries
The statistics are also averaged in space.
This is done according to the methods listed in dpr.rc.field_summaries
.
Although sometimes pretty close, rmv
(a.k.a. spread.rms
) is not (supposed
to be) an un-biased estimator of rmse
(a.k.a. err.rms
). This is because
of the square roots involved in the field summary. Instead, spread.ms
(i.e.
the mean variance) is the unbiased estimator of err.ms
.
Regional field summaries
If the HiddenMarkovModel
has the attribute .sectors
with value, e.g.,
>>> HMM.sectors = {
... "ocean": inds_of_state_of_the_ocean,
... "atmos": inds_of_state_of_the_atmosphere,
... }
then .stats.rms
and .avrgs.rms
will also have attributes
named after the keys of HMM.sectors
, e.g. stats.err.rms.ocean
.
This also goes for any other (than rms
) type of field summary method.
Declaring new, custom statistics
Only the time series created with Stats.new_series
will be in the format
operated on by Stats.average_in_time
. For example, create ndarray
of
length Ko+1
to hold the time series of estimated inflation values:
>>> self.stats.new_series('infl', 1, Ko+1)
Alternatively you can overwrite a default statistic; for example:
>>> error_time_series_a = xx - ensemble_time_series_a.mean(axis=1)
... self.stats.err.rms.a = np.sqrt(np.mean(error_time_series_a**2, axis=-1))
Of course, you could just do this
>>> self.stats.my_custom_stat = value
However, dapper.xp_launch.run_experiment
(without free=False
) will delete
the Stats
object from xp
after the assimilation, in order to save memory.
Therefore, in order to have my_custom_stat
be available among xp.avrgs
, it
must be "registered":
>>> self.stats.stat_register.append("my_custom_stat")
Alternatively, you can do both at once
>>> self.stat("my_custom_stat", value)
1"""Statistics for the assessment of DA methods. 2 3`Stats` is a data container for ([mostly] time series of) statistics. 4It comes with a battery of methods to compute the default statistics. 5 6`Avrgs` is a data container *for the same statistics*, 7but after they have been averaged in time (after the assimilation has finished). 8 9Instances of these objects are created by `dapper.da_methods.da_method` 10(i.e. "`xp`") objects and written to their `.stats` and `.avrgs` attributes. 11 12.. include:: ../docs/stats_etc.md 13""" 14 15import warnings 16 17import numpy as np 18import scipy.linalg as sla 19import struct_tools 20from matplotlib import pyplot as plt 21from patlib.std import do_once 22from tabulate import tabulate 23 24import dapper.tools.liveplotting as liveplotting 25import dapper.tools.series as series 26from dapper.dpr_config import rc 27from dapper.tools.matrices import CovMat 28from dapper.tools.progressbar import progbar 29 30 31class Stats(series.StatPrint): 32 """Contains and computes statistics of the DA methods.""" 33 34 def __init__(self, xp, HMM, xx, yy, liveplots=False, store_u=rc.store_u): 35 """Init the default statistics.""" 36 ###################################### 37 # Preamble 38 ###################################### 39 self.xp = xp 40 self.HMM = HMM 41 self.xx = xx 42 self.yy = yy 43 self.liveplots = liveplots 44 self.store_u = store_u 45 self.store_s = any(key in xp.__dict__ for key in 46 ["Lag", "DeCorr"]) # prms used by smoothers 47 48 # Shapes 49 K = xx.shape[0] - 1 50 Nx = xx.shape[1] 51 Ko = yy.shape[0] - 1 52 self.K, self.Ko, self.Nx = K, Ko, Nx 53 54 # Methods for summarizing multivariate stats ("fields") as scalars 55 # Don't use nanmean here; nan's should get propagated! 56 en_mean = lambda x: np.mean(x, axis=0) # noqa 57 self.field_summaries = dict( 58 m = lambda x: en_mean(x), # mean-field 59 ms = lambda x: en_mean(x**2), # root-mean-square 60 rms = lambda x: np.sqrt(en_mean(x**2)), # root-mean-square 61 ma = lambda x: en_mean(np.abs(x)), # mean-absolute 62 gm = lambda x: np.exp(en_mean(np.log(x))), # geometric mean 63 ) 64 # Only keep the methods listed in rc 65 self.field_summaries = struct_tools.intersect(self.field_summaries, 66 rc.field_summaries) 67 68 # Define similar methods, but restricted to sectors 69 self.sector_summaries = {} 70 def restrict(fun, inds): return (lambda x: fun(x[inds])) 71 for suffix, formula in self.field_summaries.items(): 72 for sector, inds in HMM.sectors.items(): 73 f = restrict(formula, inds) 74 self.sector_summaries['%s.%s' % (suffix, sector)] = f 75 76 ###################################### 77 # Allocate time series of various stats 78 ###################################### 79 self.new_series('mu' , Nx, field_mean='sectors') # Mean 80 self.new_series('spread', Nx, field_mean='sectors') # Std. dev. ("spread") 81 self.new_series('err' , Nx, field_mean='sectors') # Error (mu - truth) 82 self.new_series('gscore', Nx, field_mean='sectors') # Gaussian (log) score 83 84 # To save memory, we only store these field means: 85 self.new_series('mad' , 1) # Mean abs deviations 86 self.new_series('skew', 1) # Skewness 87 self.new_series('kurt', 1) # Kurtosis 88 89 if hasattr(xp, 'N'): 90 N = xp.N 91 self.new_series('w', N, field_mean=True) # Importance weights 92 self.new_series('rh', Nx, dtype=int) # Rank histogram 93 94 self._is_ens = True 95 minN = min(Nx, N) 96 self.do_spectral = np.sqrt(Nx*N) <= rc.comps["max_spectral"] 97 else: 98 self._is_ens = False 99 minN = Nx 100 self.do_spectral = Nx <= rc.comps["max_spectral"] 101 102 if self.do_spectral: 103 # Note: the mean-field and RMS time-series of 104 # (i) svals and (ii) umisf should match the corresponding series of 105 # (i) spread and (ii) err. 106 self.new_series('svals', minN) # Principal component (SVD) scores 107 self.new_series('umisf', minN) # Error in component directions 108 109 ###################################### 110 # Allocate a few series for outside use 111 ###################################### 112 self.new_series('trHK' , 1, Ko+1) 113 self.new_series('infl' , 1, Ko+1) 114 self.new_series('iters', 1, Ko+1) 115 116 # Weight-related 117 self.new_series('N_eff' , 1, Ko+1) 118 self.new_series('wroot' , 1, Ko+1) 119 self.new_series('resmpl', 1, Ko+1) 120 121 def new_series(self, name, shape, length='FAUSt', field_mean=False, **kws): 122 """Create (and register) a statistics time series, initialized with `nan`s. 123 124 If `length` is an integer, a `DataSeries` (a trivial subclass of 125 `numpy.ndarray`) is made. By default, though, a `series.FAUSt` is created. 126 127 NB: The `sliding_diagnostics` liveplotting relies on detecting `nan`'s 128 to avoid plotting stats that are not being used. 129 Thus, you cannot use `dtype=bool` or `int` for stats that get plotted. 130 """ 131 # Convert int shape to tuple 132 if not hasattr(shape, '__len__'): 133 if shape == 1: 134 shape = () 135 else: 136 shape = (shape,) 137 138 def make_series(parent, name, shape): 139 if length == 'FAUSt': 140 total_shape = self.K, self.Ko, shape 141 store_opts = self.store_u, self.store_s 142 tseries = series.FAUSt(*total_shape, *store_opts, **kws) 143 else: 144 total_shape = (length,)+shape 145 tseries = series.DataSeries(total_shape, *kws) 146 register_stat(parent, name, tseries) 147 148 # Principal series 149 make_series(self, name, shape) 150 151 # Summary (scalar) series: 152 if shape != (): 153 if field_mean: 154 for suffix in self.field_summaries: 155 make_series(getattr(self, name), suffix, ()) 156 # Make a nested level for sectors 157 if field_mean == 'sectors': 158 for ss in self.sector_summaries: 159 suffix, sector = ss.split('.') 160 make_series(struct_tools.deep_getattr( 161 self, f"{name}.{suffix}"), sector, ()) 162 163 @property 164 def data_series(self): 165 return [k for k in vars(self) 166 if isinstance(getattr(self, k), series.DataSeries)] 167 168 def assess(self, k, ko=None, faus=None, 169 E=None, w=None, mu=None, Cov=None): 170 """Common interface for both `Stats.assess_ens` and `Stats.assess_ext`. 171 172 The `_ens` assessment function gets called if `E is not None`, 173 and `_ext` if `mu is not None`. 174 175 faus: One or more of `['f',' a', 'u', 's']`, indicating 176 that the result should be stored in (respectively) 177 the forecast/analysis/universal attribute. 178 Default: `'u' if ko is None else 'au' ('a' and 'u')`. 179 """ 180 # Initial consistency checks. 181 if k == 0: 182 if ko is not None: 183 raise KeyError("DAPPER convention: no obs at t=0. Helps avoid bugs.") 184 if self._is_ens == True: 185 if E is None: 186 raise TypeError("Expected ensemble input but E is None") 187 if mu is not None: 188 raise TypeError("Expected ensemble input but mu/Cov is not None") 189 else: 190 if E is not None: 191 raise TypeError("Expected mu/Cov input but E is not None") 192 if mu is None: 193 raise TypeError("Expected mu/Cov input but mu is None") 194 195 # Default. Don't add more defaults. It just gets confusing. 196 if faus is None: 197 faus = 'u' if ko is None else 'au' 198 199 # TODO 4: for faus="au" (e.g.) we don't need to re-**compute** stats, 200 # merely re-write them? 201 for sub in faus: 202 203 # Skip assessment if ('u' and stats not stored or plotted) 204 if k != 0 and ko == None: 205 if not (self.store_u or self.LP_instance.any_figs): 206 continue 207 208 # Silence repeat warnings caused by zero variance 209 with np.errstate(divide='call', invalid='call'): 210 np.seterrcall(warn_zero_variance) 211 212 # Assess 213 stats_now = Avrgs() 214 if self._is_ens: 215 self.assess_ens(stats_now, self.xx[k], E, w) 216 else: 217 self.assess_ext(stats_now, self.xx[k], mu, Cov) 218 self.derivative_stats(stats_now) 219 self.summarize_marginals(stats_now) 220 221 self.write(stats_now, k, ko, sub) 222 223 # LivePlot -- Both init and update must come after the assessment. 224 try: 225 self.LP_instance.update((k, ko, sub), E, Cov) 226 except AttributeError: 227 self.LP_instance = liveplotting.LivePlot( 228 self, self.liveplots, (k, ko, sub), E, Cov) 229 230 def write(self, stat_dict, k, ko, sub): 231 """Write `stat_dict` to series at `(k, ko, sub)`.""" 232 for name, val in stat_dict.items(): 233 stat = struct_tools.deep_getattr(self, name) 234 isFaust = isinstance(stat, series.FAUSt) 235 stat[(k, ko, sub) if isFaust else ko] = val 236 237 def summarize_marginals(self, now): 238 """Compute Mean-field and RMS values.""" 239 formulae = {**self.field_summaries, **self.sector_summaries} 240 241 with np.errstate(divide='ignore', invalid='ignore'): 242 for stat in list(now): 243 field = now[stat] 244 for suffix, formula in formulae.items(): 245 statpath = stat+'.'+suffix 246 if struct_tools.deep_hasattr(self, statpath): 247 now[statpath] = formula(field) 248 249 def derivative_stats(self, now): 250 """Stats that derive from others, and are not specific for `_ens` or `_ext`).""" 251 try: 252 now.gscore = 2*np.log(now.spread) + (now.err/now.spread)**2 253 except AttributeError: 254 # happens in case rc.comps['error_only'] 255 pass 256 257 def assess_ens(self, now, x, E, w): 258 """Ensemble and Particle filter (weighted/importance) assessment.""" 259 N, Nx = E.shape 260 261 # weights 262 if w is None: 263 w = np.ones(N)/N # All equal. Also, rm attr from stats: 264 if hasattr(self, 'w'): 265 delattr(self, 'w') 266 # Use non-weight formula (since w=None) for mu computations. 267 # The savings are noticeable when rc.comps['error_only'] is noticeable. 268 now.mu = E.mean(0) 269 else: 270 now.w = w 271 if abs(w.sum()-1) > 1e-5: 272 raise RuntimeError("Weights did not sum to one.") 273 now.mu = w @ E 274 275 # Crash checks 276 if not np.all(np.isfinite(E)): 277 raise RuntimeError("Ensemble not finite.") 278 if not np.all(np.isreal(E)): 279 raise RuntimeError("Ensemble not Real.") 280 281 # Compute errors 282 now.err = now.mu - x 283 if rc.comps['error_only']: 284 return 285 286 A = E - now.mu 287 # While A**2 is approx as fast as A*A, 288 # A**3 is 10x slower than A**2 (or A**2.0). 289 # => Use A2 = A**2, A3 = A*A2, A4=A*A3. 290 # But, to save memory, only use A_pow. 291 A_pow = A**2 292 293 # Compute variances 294 var = w @ A_pow 295 ub = unbias_var(w, avoid_pathological=True) 296 var *= ub 297 298 # Compute standard deviation ("Spread") 299 s = np.sqrt(var) # NB: biased (even though var is unbiased) 300 now.spread = s 301 302 # For simplicity, use naive (biased) formulae, derived 303 # from "empirical measure". See doc/unbiased_skew_kurt.jpg. 304 # Normalize by var. Compute "excess" kurt, which is 0 for Gaussians. 305 A_pow *= A 306 now.skew = np.nanmean(w @ A_pow / (s*s*s)) 307 A_pow *= A 308 now.kurt = np.nanmean(w @ A_pow / var**2 - 3) 309 310 now.mad = np.nanmean(w @ abs(A)) 311 312 if self.do_spectral: 313 if N <= Nx: 314 _, s, UT = sla.svd((np.sqrt(w)*A.T).T, full_matrices=False) 315 s *= np.sqrt(ub) # Makes s^2 unbiased 316 now.svals = s 317 now.umisf = UT @ now.err 318 else: 319 P = (A.T * w) @ A 320 s2, U = sla.eigh(P) 321 s2 *= ub 322 now.svals = np.sqrt(s2.clip(0))[::-1] 323 now.umisf = U.T[::-1] @ now.err 324 325 # For each state dim [i], compute rank of truth (x) among the ensemble (E) 326 E_x = np.sort(np.vstack((E, x)), axis=0, kind='heapsort') 327 now.rh = np.asarray( 328 [np.where(E_x[:, i] == x[i])[0][0] for i in range(Nx)]) 329 330 def assess_ext(self, now, x, mu, P): 331 """Kalman filter (Gaussian) assessment.""" 332 if not np.all(np.isfinite(mu)): 333 raise RuntimeError("Estimates not finite.") 334 if not np.all(np.isreal(mu)): 335 raise RuntimeError("Estimates not Real.") 336 # Don't check the cov (might not be explicitly availble) 337 338 # Compute errors 339 now.mu = mu 340 now.err = now.mu - x 341 if rc.comps['error_only']: 342 return 343 344 # Get diag(P) 345 if P is None: 346 var = np.zeros_like(mu) 347 elif np.isscalar(P): 348 var = np.ones_like(mu) * P 349 else: 350 if isinstance(P, CovMat): 351 var = P.diag 352 P = P.full 353 else: 354 var = np.diag(P) 355 356 if self.do_spectral: 357 s2, U = sla.eigh(P) 358 now.svals = np.sqrt(np.maximum(s2, 0.0))[::-1] 359 now.umisf = (U.T @ now.err)[::-1] 360 361 # Compute stddev 362 now.spread = np.sqrt(var) 363 # Here, sqrt(2/pi) is the ratio, of MAD/Spread for Gaussians 364 now.mad = np.nanmean(now.spread) * np.sqrt(2/np.pi) 365 366 def average_in_time(self, kk=None, kko=None, free=False): 367 """Avarage all univariate (scalar) time series. 368 369 - `kk` time inds for averaging 370 - `kko` time inds for averaging obs 371 """ 372 tseq = self.HMM.tseq 373 if kk is None: 374 kk = tseq.mask 375 if kko is None: 376 kko = tseq.masko 377 378 def average1(tseries): 379 avrgs = Avrgs() 380 381 def average_multivariate(): return avrgs 382 # Plain averages of nd-series are rarely interesting. 383 # => Shortcircuit => Leave for manual computations 384 385 if isinstance(tseries, series.FAUSt): 386 # Average series for each subscript 387 if tseries.item_shape != (): 388 return average_multivariate() 389 for sub in [ch for ch in 'fas' if hasattr(tseries, ch)]: 390 avrgs[sub] = series.mean_with_conf(tseries[kko, sub]) 391 if tseries.store_u: 392 avrgs['u'] = series.mean_with_conf(tseries[kk, 'u']) 393 394 elif isinstance(tseries, series.DataSeries): 395 if tseries.array.shape[1:] != (): 396 return average_multivariate() 397 elif len(tseries.array) == self.Ko+1: 398 avrgs = series.mean_with_conf(tseries[kko]) 399 elif len(tseries.array) == self.K+1: 400 avrgs = series.mean_with_conf(tseries[kk]) 401 else: 402 raise ValueError 403 404 elif np.isscalar(tseries): 405 avrgs = tseries # Eg. just copy over "duration" from stats 406 407 else: 408 raise TypeError(f"Don't know how to average {tseries}") 409 410 return avrgs 411 412 def recurse_average(stat_parent, avrgs_parent): 413 for key in getattr(stat_parent, "stat_register", []): 414 try: 415 tseries = getattr(stat_parent, key) 416 except AttributeError: 417 continue # Eg assess_ens() deletes .weights if None 418 avrgs = average1(tseries) 419 recurse_average(tseries, avrgs) 420 avrgs_parent[key] = avrgs 421 422 avrgs = Avrgs() 423 recurse_average(self, avrgs) 424 self.xp.avrgs = avrgs 425 if free: 426 delattr(self.xp, 'stats') 427 428 def replay(self, figlist="default", speed=np.inf, t1=0, t2=None, **kwargs): 429 """Replay LivePlot with what's been stored in 'self'. 430 431 - t1, t2: time window to plot. 432 - 'figlist' and 'speed': See LivePlot's doc. 433 434 .. note:: `store_u` (whether to store non-obs-time stats) must 435 have been `True` to have smooth graphs as in the actual LivePlot. 436 437 .. note:: Ensembles are generally not stored in the stats 438 and so cannot be replayed. 439 """ 440 # Time settings 441 tseq = self.HMM.tseq 442 if t2 is None: 443 t2 = t1 + tseq.Tplot 444 445 # Ens does not get stored in stats, so we cannot replay that. 446 # If the LPs are initialized with P0!=None, then they will avoid ens plotting. 447 # TODO 4: This system for switching from Ens to stats must be replaced. 448 # It breaks down when M is very large. 449 try: 450 P0 = np.full_like(self.HMM.X0.C.full, np.nan) 451 except AttributeError: # e.g. if X0 is defined via sampling func 452 P0 = np.eye(self.HMM.Nx) 453 454 LP = liveplotting.LivePlot(self, figlist, P=P0, speed=speed, 455 Tplot=t2-t1, replay=True, **kwargs) 456 457 # Remember: must use progbar to unblock read1. 458 # Let's also make a proper description. 459 desc = self.xp.da_method + " (replay)" 460 461 # Play through assimilation cycles 462 for k, ko, t, _dt in progbar(tseq.ticker, desc): 463 if t1 <= t <= t2: 464 if ko is not None: 465 LP.update((k, ko, 'f'), None, None) 466 LP.update((k, ko, 'a'), None, None) 467 LP.update((k, ko, 'u'), None, None) 468 469 # Pause required when speed=inf. 470 # On Mac, it was also necessary to do it for each fig. 471 if LP.any_figs: 472 for _name, updater in LP.figures.items(): 473 if plt.fignum_exists(_name) and getattr(updater, 'is_active', 1): 474 plt.figure(_name) 475 plt.pause(0.01) 476 477 478def register_stat(self, name, value): 479 """Do `self.name = value` and register `name` as in self's `stat_register`. 480 481 Note: `self` is not always a `Stats` object, but could be a "child" of it. 482 """ 483 setattr(self, name, value) 484 if not hasattr(self, "stat_register"): 485 self.stat_register = [] 486 self.stat_register.append(name) 487 488 489class Avrgs(series.StatPrint, struct_tools.DotDict): 490 """A `dict` specialized for the averages of statistics. 491 492 Embellishments: 493 494 - `dapper.tools.StatPrint` 495 - `Avrgs.tabulate` 496 - `getattr` that supports abbreviations. 497 """ 498 499 def tabulate(self, statkeys=(), decimals=None): 500 columns = tabulate_avrgs([self], statkeys, decimals=decimals) 501 return tabulate(columns, headers="keys").replace('␣', ' ') 502 503 abbrevs = {'rmse': 'err.rms', 'rmss': 'spread.rms', 'rmv': 'spread.rms'} 504 505 # Use getattribute coz it gets called before getattr. 506 def __getattribute__(self, key): 507 """Support deep and abbreviated lookup.""" 508 # key = abbrevs[key] # Instead of this, also support rmse.a: 509 key = '.'.join(Avrgs.abbrevs.get(seg, seg) for seg in key.split('.')) 510 511 if "." in key: 512 return struct_tools.deep_getattr(self, key) 513 else: 514 return super().__getattribute__(key) 515 516# In case of degeneracy, variance might be 0, causing warnings 517# in computing skew/kurt/MGLS (which all normalize by variance). 518# This should and will yield nan's, but we don't want mere diagnostics 519# computations to cause repetitive warnings, so we only warn once. 520# 521# I would have expected this (more elegant solution?) to work, 522# but it just makes it worse. 523# with np.errstate(divide='warn',invalid='warn'), warnings.catch_warnings(): 524# warnings.simplefilter("once",category=RuntimeWarning) 525# ... 526 527 528@do_once 529def warn_zero_variance(err, flag): 530 msg = "\n".join(["Numerical error in stat comps.", 531 "Probably caused by a sample variance of 0."]) 532 warnings.warn(msg, stacklevel=2) 533 534 535# Why not do all columns at once using the tabulate module? Coz 536# - Want subcolumns, including fancy formatting (e.g. +/-) 537# - Want separation (using '|') of attr and stats 538# - ... 539def align_col(col, pad='␣', missingval='', just=">"): 540 r"""Align column. 541 542 Treats `int`s and fixed-point `float`/`str` especially, aligning on the point. 543 544 Example: 545 >>> xx = [1, 1., 1.234, 12.34, 123.4, "1.2e-3", None, np.nan, "inf", (1, 2)] 546 >>> print(*align_col(xx), sep="\n") 547 ␣␣1␣␣␣␣ 548 ␣␣1.0␣␣ 549 ␣␣1.234 550 ␣12.34␣ 551 123.4␣␣ 552 ␣1.2e-3 553 ␣␣␣␣␣␣␣ 554 ␣␣␣␣nan 555 ␣␣␣␣inf 556 ␣(1, 2) 557 """ 558 def split_decimal(x): 559 x = str(x) 560 try: 561 y = float(x) 562 except ValueError: 563 pass 564 else: 565 if np.isfinite(y) and ("e" not in x.lower()): 566 a, *b = x.split(".") 567 if b == []: 568 b = "int" 569 else: 570 b = b[0] 571 return a, b 572 return x, False 573 574 # Find max nInt, nDec 575 nInt = nDec = -1 576 for x in col: 577 ints, decs = split_decimal(x) 578 if decs: 579 nInt = max(nInt, len(ints)) 580 if decs != "int": 581 nDec = max(nDec, len(decs)) 582 583 # Format entries. Floats get aligned on point. 584 def frmt(x): 585 if x is None: 586 return missingval 587 ints, decs = split_decimal(x) 588 x = f"{ints.rjust(nInt, pad)}" 589 if decs == "int": 590 if nDec >= 0: 591 x += pad + pad*nDec 592 elif decs: 593 x += "." + f"{decs.ljust(nDec, pad)}" 594 else: 595 x = ints 596 return x 597 598 # Format 599 col = [frmt(x) for x in col] 600 # Find max width 601 Max = max(len(x) for x in col) 602 # Right-justify 603 shift = str.rjust if just == ">" else str.ljust 604 col = [shift(x, Max, pad) for x in col] 605 return col 606 607 608def unpack_uqs(uq_list, decimals=None): 609 """Convert list of `uq`s into dict of lists (of equal-length) of attributes. 610 611 The attributes are obtained by `vars(uq)`, 612 and may get formatted somehow (e.g. cast to strings) in the output. 613 614 If `uq` is `None`, then `None` is inserted in each list. 615 Else, `uq` must be an instance of `dapper.tools.rounding.UncertainQtty`. 616 617 Parameters 618 ---------- 619 uq_list: list 620 List of `uq`s. 621 622 decimals: int 623 Desired number of decimals. 624 Used for (only) the columns "val" and "prec". 625 Default: `None`. In this case, the formatting is left to the `uq`s. 626 """ 627 def frmt(uq): 628 if not isinstance(uq, series.UncertainQtty): 629 # Presumably uq is just a number 630 uq = series.UncertainQtty(uq) 631 632 attrs = vars(uq).copy() 633 634 # val/prec: round 635 if decimals is None: 636 v, p = str(uq).split(" ±") 637 else: 638 frmt = "%%.%df" % decimals 639 v, p = frmt % uq.val, frmt % uq.prec 640 attrs["val"], attrs["prec"] = v, p 641 642 # tuned_coord: convert to tuple 643 try: 644 attrs["tuned_coord"] = tuple(a for a in uq.tuned_coord) 645 except AttributeError: 646 pass 647 return attrs 648 649 cols = {} 650 for i, uq in enumerate(uq_list): 651 if uq is not None: 652 # Format 653 attrs = frmt(uq) 654 # Insert attrs as a "row" in the `cols`: 655 for k in attrs: 656 # Init column 657 if k not in cols: 658 cols[k] = [None]*len(uq_list) 659 # Insert element 660 cols[k][i] = attrs[k] 661 662 return cols 663 664 665def tabulate_avrgs(avrgs_list, statkeys=(), decimals=None): 666 """Tabulate avrgs (val±prec).""" 667 if not statkeys: 668 statkeys = ['rmse.a', 'rmv.a', 'rmse.f'] 669 670 columns = {} 671 for stat in statkeys: 672 column = [getattr(a, stat, None) for a in avrgs_list] 673 column = unpack_uqs(column, decimals) 674 if not column: 675 raise ValueError(f"The stat. key '{stat}' was not" 676 " found among any of the averages.") 677 vals = align_col([stat] + column["val"]) 678 precs = align_col(['1σ'] + column["prec"], just="<") 679 headr = vals[0]+' '+precs[0] 680 mattr = [f"{v} ±{c}" for v, c in zip(vals, precs)][1:] 681 columns[headr] = mattr 682 683 return columns 684 685 686def center(E, axis=0, rescale=False): 687 r"""Center ensemble. 688 689 Makes use of `np` features: keepdims and broadcasting. 690 691 Parameters 692 ---------- 693 E: ndarray 694 Ensemble which going to be inflated 695 696 axis: int, optional 697 The axis to be centered. Default: 0 698 699 rescale: bool, optional 700 If True, inflate to compensate for reduction in the expected variance. 701 The inflation factor is \(\sqrt{\frac{N}{N - 1}}\) 702 where N is the ensemble size. Default: False 703 704 Returns 705 ------- 706 X: ndarray 707 Ensemble anomaly 708 709 x: ndarray 710 Mean of the ensemble 711 """ 712 x = np.mean(E, axis=axis, keepdims=True) 713 X = E - x 714 715 if rescale: 716 N = E.shape[axis] 717 X *= np.sqrt(N/(N-1)) 718 719 x = x.squeeze(axis=axis) 720 721 return X, x 722 723 724def mean0(E, axis=0, rescale=True): 725 """Like `center`, but only return the anomalies (not the mean). 726 727 Uses `rescale=True` by default, which is beneficial 728 when used to center observation perturbations. 729 """ 730 return center(E, axis=axis, rescale=rescale)[0] 731 732 733def inflate_ens(E, factor): 734 """Inflate the ensemble (center, inflate, re-combine). 735 736 Parameters 737 ---------- 738 E : ndarray 739 Ensemble which going to be inflated 740 741 factor: `float` 742 Inflation factor 743 744 Returns 745 ------- 746 ndarray 747 Inflated ensemble 748 """ 749 if factor == 1: 750 return E 751 X, x = center(E) 752 return x + X*factor 753 754 755def weight_degeneracy(w, prec=1e-10): 756 """Check if the weights are degenerate. 757 758 If it is degenerate, the maximum weight 759 should be nearly one since sum(w) = 1 760 761 Parameters 762 ---------- 763 w: ndarray 764 Importance weights. Must sum to 1. 765 766 prec: float, optional 767 Tolerance of the distance between w and one. Default:1e-10 768 769 Returns 770 ------- 771 bool 772 If weight is degenerate True, else False 773 """ 774 return (1-w.max()) < prec 775 776 777def unbias_var(w=None, N_eff=None, avoid_pathological=False): 778 """Compute unbias-ing factor for variance estimation. 779 780 Parameters 781 ---------- 782 w: ndarray, optional 783 Importance weights. Must sum to 1. 784 Only one of `w` and `N_eff` can be `None`. Default: `None` 785 786 N_eff: float, optional 787 The "effective" size of the weighted ensemble. 788 If not provided, it is computed from the weights. 789 The unbiasing factor is $$ N_{eff} / (N_{eff} - 1) $$. 790 791 avoid_pathological: bool, optional 792 Avoid weight collapse. Default: `False` 793 794 Returns 795 ------- 796 ub: float 797 factor used to unbiasing variance 798 799 Reference 800 -------- 801 [Wikipedia](https://wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights) 802 """ 803 if N_eff is None: 804 N_eff = 1/(w@w) 805 806 if avoid_pathological and weight_degeneracy(w): 807 ub = 1 # Don't do in case of weights collapse 808 else: 809 ub = 1/(1 - 1/N_eff) # =N/(N-1) if w==ones(N)/N. 810 return ub
32class Stats(series.StatPrint): 33 """Contains and computes statistics of the DA methods.""" 34 35 def __init__(self, xp, HMM, xx, yy, liveplots=False, store_u=rc.store_u): 36 """Init the default statistics.""" 37 ###################################### 38 # Preamble 39 ###################################### 40 self.xp = xp 41 self.HMM = HMM 42 self.xx = xx 43 self.yy = yy 44 self.liveplots = liveplots 45 self.store_u = store_u 46 self.store_s = any(key in xp.__dict__ for key in 47 ["Lag", "DeCorr"]) # prms used by smoothers 48 49 # Shapes 50 K = xx.shape[0] - 1 51 Nx = xx.shape[1] 52 Ko = yy.shape[0] - 1 53 self.K, self.Ko, self.Nx = K, Ko, Nx 54 55 # Methods for summarizing multivariate stats ("fields") as scalars 56 # Don't use nanmean here; nan's should get propagated! 57 en_mean = lambda x: np.mean(x, axis=0) # noqa 58 self.field_summaries = dict( 59 m = lambda x: en_mean(x), # mean-field 60 ms = lambda x: en_mean(x**2), # root-mean-square 61 rms = lambda x: np.sqrt(en_mean(x**2)), # root-mean-square 62 ma = lambda x: en_mean(np.abs(x)), # mean-absolute 63 gm = lambda x: np.exp(en_mean(np.log(x))), # geometric mean 64 ) 65 # Only keep the methods listed in rc 66 self.field_summaries = struct_tools.intersect(self.field_summaries, 67 rc.field_summaries) 68 69 # Define similar methods, but restricted to sectors 70 self.sector_summaries = {} 71 def restrict(fun, inds): return (lambda x: fun(x[inds])) 72 for suffix, formula in self.field_summaries.items(): 73 for sector, inds in HMM.sectors.items(): 74 f = restrict(formula, inds) 75 self.sector_summaries['%s.%s' % (suffix, sector)] = f 76 77 ###################################### 78 # Allocate time series of various stats 79 ###################################### 80 self.new_series('mu' , Nx, field_mean='sectors') # Mean 81 self.new_series('spread', Nx, field_mean='sectors') # Std. dev. ("spread") 82 self.new_series('err' , Nx, field_mean='sectors') # Error (mu - truth) 83 self.new_series('gscore', Nx, field_mean='sectors') # Gaussian (log) score 84 85 # To save memory, we only store these field means: 86 self.new_series('mad' , 1) # Mean abs deviations 87 self.new_series('skew', 1) # Skewness 88 self.new_series('kurt', 1) # Kurtosis 89 90 if hasattr(xp, 'N'): 91 N = xp.N 92 self.new_series('w', N, field_mean=True) # Importance weights 93 self.new_series('rh', Nx, dtype=int) # Rank histogram 94 95 self._is_ens = True 96 minN = min(Nx, N) 97 self.do_spectral = np.sqrt(Nx*N) <= rc.comps["max_spectral"] 98 else: 99 self._is_ens = False 100 minN = Nx 101 self.do_spectral = Nx <= rc.comps["max_spectral"] 102 103 if self.do_spectral: 104 # Note: the mean-field and RMS time-series of 105 # (i) svals and (ii) umisf should match the corresponding series of 106 # (i) spread and (ii) err. 107 self.new_series('svals', minN) # Principal component (SVD) scores 108 self.new_series('umisf', minN) # Error in component directions 109 110 ###################################### 111 # Allocate a few series for outside use 112 ###################################### 113 self.new_series('trHK' , 1, Ko+1) 114 self.new_series('infl' , 1, Ko+1) 115 self.new_series('iters', 1, Ko+1) 116 117 # Weight-related 118 self.new_series('N_eff' , 1, Ko+1) 119 self.new_series('wroot' , 1, Ko+1) 120 self.new_series('resmpl', 1, Ko+1) 121 122 def new_series(self, name, shape, length='FAUSt', field_mean=False, **kws): 123 """Create (and register) a statistics time series, initialized with `nan`s. 124 125 If `length` is an integer, a `DataSeries` (a trivial subclass of 126 `numpy.ndarray`) is made. By default, though, a `series.FAUSt` is created. 127 128 NB: The `sliding_diagnostics` liveplotting relies on detecting `nan`'s 129 to avoid plotting stats that are not being used. 130 Thus, you cannot use `dtype=bool` or `int` for stats that get plotted. 131 """ 132 # Convert int shape to tuple 133 if not hasattr(shape, '__len__'): 134 if shape == 1: 135 shape = () 136 else: 137 shape = (shape,) 138 139 def make_series(parent, name, shape): 140 if length == 'FAUSt': 141 total_shape = self.K, self.Ko, shape 142 store_opts = self.store_u, self.store_s 143 tseries = series.FAUSt(*total_shape, *store_opts, **kws) 144 else: 145 total_shape = (length,)+shape 146 tseries = series.DataSeries(total_shape, *kws) 147 register_stat(parent, name, tseries) 148 149 # Principal series 150 make_series(self, name, shape) 151 152 # Summary (scalar) series: 153 if shape != (): 154 if field_mean: 155 for suffix in self.field_summaries: 156 make_series(getattr(self, name), suffix, ()) 157 # Make a nested level for sectors 158 if field_mean == 'sectors': 159 for ss in self.sector_summaries: 160 suffix, sector = ss.split('.') 161 make_series(struct_tools.deep_getattr( 162 self, f"{name}.{suffix}"), sector, ()) 163 164 @property 165 def data_series(self): 166 return [k for k in vars(self) 167 if isinstance(getattr(self, k), series.DataSeries)] 168 169 def assess(self, k, ko=None, faus=None, 170 E=None, w=None, mu=None, Cov=None): 171 """Common interface for both `Stats.assess_ens` and `Stats.assess_ext`. 172 173 The `_ens` assessment function gets called if `E is not None`, 174 and `_ext` if `mu is not None`. 175 176 faus: One or more of `['f',' a', 'u', 's']`, indicating 177 that the result should be stored in (respectively) 178 the forecast/analysis/universal attribute. 179 Default: `'u' if ko is None else 'au' ('a' and 'u')`. 180 """ 181 # Initial consistency checks. 182 if k == 0: 183 if ko is not None: 184 raise KeyError("DAPPER convention: no obs at t=0. Helps avoid bugs.") 185 if self._is_ens == True: 186 if E is None: 187 raise TypeError("Expected ensemble input but E is None") 188 if mu is not None: 189 raise TypeError("Expected ensemble input but mu/Cov is not None") 190 else: 191 if E is not None: 192 raise TypeError("Expected mu/Cov input but E is not None") 193 if mu is None: 194 raise TypeError("Expected mu/Cov input but mu is None") 195 196 # Default. Don't add more defaults. It just gets confusing. 197 if faus is None: 198 faus = 'u' if ko is None else 'au' 199 200 # TODO 4: for faus="au" (e.g.) we don't need to re-**compute** stats, 201 # merely re-write them? 202 for sub in faus: 203 204 # Skip assessment if ('u' and stats not stored or plotted) 205 if k != 0 and ko == None: 206 if not (self.store_u or self.LP_instance.any_figs): 207 continue 208 209 # Silence repeat warnings caused by zero variance 210 with np.errstate(divide='call', invalid='call'): 211 np.seterrcall(warn_zero_variance) 212 213 # Assess 214 stats_now = Avrgs() 215 if self._is_ens: 216 self.assess_ens(stats_now, self.xx[k], E, w) 217 else: 218 self.assess_ext(stats_now, self.xx[k], mu, Cov) 219 self.derivative_stats(stats_now) 220 self.summarize_marginals(stats_now) 221 222 self.write(stats_now, k, ko, sub) 223 224 # LivePlot -- Both init and update must come after the assessment. 225 try: 226 self.LP_instance.update((k, ko, sub), E, Cov) 227 except AttributeError: 228 self.LP_instance = liveplotting.LivePlot( 229 self, self.liveplots, (k, ko, sub), E, Cov) 230 231 def write(self, stat_dict, k, ko, sub): 232 """Write `stat_dict` to series at `(k, ko, sub)`.""" 233 for name, val in stat_dict.items(): 234 stat = struct_tools.deep_getattr(self, name) 235 isFaust = isinstance(stat, series.FAUSt) 236 stat[(k, ko, sub) if isFaust else ko] = val 237 238 def summarize_marginals(self, now): 239 """Compute Mean-field and RMS values.""" 240 formulae = {**self.field_summaries, **self.sector_summaries} 241 242 with np.errstate(divide='ignore', invalid='ignore'): 243 for stat in list(now): 244 field = now[stat] 245 for suffix, formula in formulae.items(): 246 statpath = stat+'.'+suffix 247 if struct_tools.deep_hasattr(self, statpath): 248 now[statpath] = formula(field) 249 250 def derivative_stats(self, now): 251 """Stats that derive from others, and are not specific for `_ens` or `_ext`).""" 252 try: 253 now.gscore = 2*np.log(now.spread) + (now.err/now.spread)**2 254 except AttributeError: 255 # happens in case rc.comps['error_only'] 256 pass 257 258 def assess_ens(self, now, x, E, w): 259 """Ensemble and Particle filter (weighted/importance) assessment.""" 260 N, Nx = E.shape 261 262 # weights 263 if w is None: 264 w = np.ones(N)/N # All equal. Also, rm attr from stats: 265 if hasattr(self, 'w'): 266 delattr(self, 'w') 267 # Use non-weight formula (since w=None) for mu computations. 268 # The savings are noticeable when rc.comps['error_only'] is noticeable. 269 now.mu = E.mean(0) 270 else: 271 now.w = w 272 if abs(w.sum()-1) > 1e-5: 273 raise RuntimeError("Weights did not sum to one.") 274 now.mu = w @ E 275 276 # Crash checks 277 if not np.all(np.isfinite(E)): 278 raise RuntimeError("Ensemble not finite.") 279 if not np.all(np.isreal(E)): 280 raise RuntimeError("Ensemble not Real.") 281 282 # Compute errors 283 now.err = now.mu - x 284 if rc.comps['error_only']: 285 return 286 287 A = E - now.mu 288 # While A**2 is approx as fast as A*A, 289 # A**3 is 10x slower than A**2 (or A**2.0). 290 # => Use A2 = A**2, A3 = A*A2, A4=A*A3. 291 # But, to save memory, only use A_pow. 292 A_pow = A**2 293 294 # Compute variances 295 var = w @ A_pow 296 ub = unbias_var(w, avoid_pathological=True) 297 var *= ub 298 299 # Compute standard deviation ("Spread") 300 s = np.sqrt(var) # NB: biased (even though var is unbiased) 301 now.spread = s 302 303 # For simplicity, use naive (biased) formulae, derived 304 # from "empirical measure". See doc/unbiased_skew_kurt.jpg. 305 # Normalize by var. Compute "excess" kurt, which is 0 for Gaussians. 306 A_pow *= A 307 now.skew = np.nanmean(w @ A_pow / (s*s*s)) 308 A_pow *= A 309 now.kurt = np.nanmean(w @ A_pow / var**2 - 3) 310 311 now.mad = np.nanmean(w @ abs(A)) 312 313 if self.do_spectral: 314 if N <= Nx: 315 _, s, UT = sla.svd((np.sqrt(w)*A.T).T, full_matrices=False) 316 s *= np.sqrt(ub) # Makes s^2 unbiased 317 now.svals = s 318 now.umisf = UT @ now.err 319 else: 320 P = (A.T * w) @ A 321 s2, U = sla.eigh(P) 322 s2 *= ub 323 now.svals = np.sqrt(s2.clip(0))[::-1] 324 now.umisf = U.T[::-1] @ now.err 325 326 # For each state dim [i], compute rank of truth (x) among the ensemble (E) 327 E_x = np.sort(np.vstack((E, x)), axis=0, kind='heapsort') 328 now.rh = np.asarray( 329 [np.where(E_x[:, i] == x[i])[0][0] for i in range(Nx)]) 330 331 def assess_ext(self, now, x, mu, P): 332 """Kalman filter (Gaussian) assessment.""" 333 if not np.all(np.isfinite(mu)): 334 raise RuntimeError("Estimates not finite.") 335 if not np.all(np.isreal(mu)): 336 raise RuntimeError("Estimates not Real.") 337 # Don't check the cov (might not be explicitly availble) 338 339 # Compute errors 340 now.mu = mu 341 now.err = now.mu - x 342 if rc.comps['error_only']: 343 return 344 345 # Get diag(P) 346 if P is None: 347 var = np.zeros_like(mu) 348 elif np.isscalar(P): 349 var = np.ones_like(mu) * P 350 else: 351 if isinstance(P, CovMat): 352 var = P.diag 353 P = P.full 354 else: 355 var = np.diag(P) 356 357 if self.do_spectral: 358 s2, U = sla.eigh(P) 359 now.svals = np.sqrt(np.maximum(s2, 0.0))[::-1] 360 now.umisf = (U.T @ now.err)[::-1] 361 362 # Compute stddev 363 now.spread = np.sqrt(var) 364 # Here, sqrt(2/pi) is the ratio, of MAD/Spread for Gaussians 365 now.mad = np.nanmean(now.spread) * np.sqrt(2/np.pi) 366 367 def average_in_time(self, kk=None, kko=None, free=False): 368 """Avarage all univariate (scalar) time series. 369 370 - `kk` time inds for averaging 371 - `kko` time inds for averaging obs 372 """ 373 tseq = self.HMM.tseq 374 if kk is None: 375 kk = tseq.mask 376 if kko is None: 377 kko = tseq.masko 378 379 def average1(tseries): 380 avrgs = Avrgs() 381 382 def average_multivariate(): return avrgs 383 # Plain averages of nd-series are rarely interesting. 384 # => Shortcircuit => Leave for manual computations 385 386 if isinstance(tseries, series.FAUSt): 387 # Average series for each subscript 388 if tseries.item_shape != (): 389 return average_multivariate() 390 for sub in [ch for ch in 'fas' if hasattr(tseries, ch)]: 391 avrgs[sub] = series.mean_with_conf(tseries[kko, sub]) 392 if tseries.store_u: 393 avrgs['u'] = series.mean_with_conf(tseries[kk, 'u']) 394 395 elif isinstance(tseries, series.DataSeries): 396 if tseries.array.shape[1:] != (): 397 return average_multivariate() 398 elif len(tseries.array) == self.Ko+1: 399 avrgs = series.mean_with_conf(tseries[kko]) 400 elif len(tseries.array) == self.K+1: 401 avrgs = series.mean_with_conf(tseries[kk]) 402 else: 403 raise ValueError 404 405 elif np.isscalar(tseries): 406 avrgs = tseries # Eg. just copy over "duration" from stats 407 408 else: 409 raise TypeError(f"Don't know how to average {tseries}") 410 411 return avrgs 412 413 def recurse_average(stat_parent, avrgs_parent): 414 for key in getattr(stat_parent, "stat_register", []): 415 try: 416 tseries = getattr(stat_parent, key) 417 except AttributeError: 418 continue # Eg assess_ens() deletes .weights if None 419 avrgs = average1(tseries) 420 recurse_average(tseries, avrgs) 421 avrgs_parent[key] = avrgs 422 423 avrgs = Avrgs() 424 recurse_average(self, avrgs) 425 self.xp.avrgs = avrgs 426 if free: 427 delattr(self.xp, 'stats') 428 429 def replay(self, figlist="default", speed=np.inf, t1=0, t2=None, **kwargs): 430 """Replay LivePlot with what's been stored in 'self'. 431 432 - t1, t2: time window to plot. 433 - 'figlist' and 'speed': See LivePlot's doc. 434 435 .. note:: `store_u` (whether to store non-obs-time stats) must 436 have been `True` to have smooth graphs as in the actual LivePlot. 437 438 .. note:: Ensembles are generally not stored in the stats 439 and so cannot be replayed. 440 """ 441 # Time settings 442 tseq = self.HMM.tseq 443 if t2 is None: 444 t2 = t1 + tseq.Tplot 445 446 # Ens does not get stored in stats, so we cannot replay that. 447 # If the LPs are initialized with P0!=None, then they will avoid ens plotting. 448 # TODO 4: This system for switching from Ens to stats must be replaced. 449 # It breaks down when M is very large. 450 try: 451 P0 = np.full_like(self.HMM.X0.C.full, np.nan) 452 except AttributeError: # e.g. if X0 is defined via sampling func 453 P0 = np.eye(self.HMM.Nx) 454 455 LP = liveplotting.LivePlot(self, figlist, P=P0, speed=speed, 456 Tplot=t2-t1, replay=True, **kwargs) 457 458 # Remember: must use progbar to unblock read1. 459 # Let's also make a proper description. 460 desc = self.xp.da_method + " (replay)" 461 462 # Play through assimilation cycles 463 for k, ko, t, _dt in progbar(tseq.ticker, desc): 464 if t1 <= t <= t2: 465 if ko is not None: 466 LP.update((k, ko, 'f'), None, None) 467 LP.update((k, ko, 'a'), None, None) 468 LP.update((k, ko, 'u'), None, None) 469 470 # Pause required when speed=inf. 471 # On Mac, it was also necessary to do it for each fig. 472 if LP.any_figs: 473 for _name, updater in LP.figures.items(): 474 if plt.fignum_exists(_name) and getattr(updater, 'is_active', 1): 475 plt.figure(_name) 476 plt.pause(0.01)
Contains and computes statistics of the DA methods.
35 def __init__(self, xp, HMM, xx, yy, liveplots=False, store_u=rc.store_u): 36 """Init the default statistics.""" 37 ###################################### 38 # Preamble 39 ###################################### 40 self.xp = xp 41 self.HMM = HMM 42 self.xx = xx 43 self.yy = yy 44 self.liveplots = liveplots 45 self.store_u = store_u 46 self.store_s = any(key in xp.__dict__ for key in 47 ["Lag", "DeCorr"]) # prms used by smoothers 48 49 # Shapes 50 K = xx.shape[0] - 1 51 Nx = xx.shape[1] 52 Ko = yy.shape[0] - 1 53 self.K, self.Ko, self.Nx = K, Ko, Nx 54 55 # Methods for summarizing multivariate stats ("fields") as scalars 56 # Don't use nanmean here; nan's should get propagated! 57 en_mean = lambda x: np.mean(x, axis=0) # noqa 58 self.field_summaries = dict( 59 m = lambda x: en_mean(x), # mean-field 60 ms = lambda x: en_mean(x**2), # root-mean-square 61 rms = lambda x: np.sqrt(en_mean(x**2)), # root-mean-square 62 ma = lambda x: en_mean(np.abs(x)), # mean-absolute 63 gm = lambda x: np.exp(en_mean(np.log(x))), # geometric mean 64 ) 65 # Only keep the methods listed in rc 66 self.field_summaries = struct_tools.intersect(self.field_summaries, 67 rc.field_summaries) 68 69 # Define similar methods, but restricted to sectors 70 self.sector_summaries = {} 71 def restrict(fun, inds): return (lambda x: fun(x[inds])) 72 for suffix, formula in self.field_summaries.items(): 73 for sector, inds in HMM.sectors.items(): 74 f = restrict(formula, inds) 75 self.sector_summaries['%s.%s' % (suffix, sector)] = f 76 77 ###################################### 78 # Allocate time series of various stats 79 ###################################### 80 self.new_series('mu' , Nx, field_mean='sectors') # Mean 81 self.new_series('spread', Nx, field_mean='sectors') # Std. dev. ("spread") 82 self.new_series('err' , Nx, field_mean='sectors') # Error (mu - truth) 83 self.new_series('gscore', Nx, field_mean='sectors') # Gaussian (log) score 84 85 # To save memory, we only store these field means: 86 self.new_series('mad' , 1) # Mean abs deviations 87 self.new_series('skew', 1) # Skewness 88 self.new_series('kurt', 1) # Kurtosis 89 90 if hasattr(xp, 'N'): 91 N = xp.N 92 self.new_series('w', N, field_mean=True) # Importance weights 93 self.new_series('rh', Nx, dtype=int) # Rank histogram 94 95 self._is_ens = True 96 minN = min(Nx, N) 97 self.do_spectral = np.sqrt(Nx*N) <= rc.comps["max_spectral"] 98 else: 99 self._is_ens = False 100 minN = Nx 101 self.do_spectral = Nx <= rc.comps["max_spectral"] 102 103 if self.do_spectral: 104 # Note: the mean-field and RMS time-series of 105 # (i) svals and (ii) umisf should match the corresponding series of 106 # (i) spread and (ii) err. 107 self.new_series('svals', minN) # Principal component (SVD) scores 108 self.new_series('umisf', minN) # Error in component directions 109 110 ###################################### 111 # Allocate a few series for outside use 112 ###################################### 113 self.new_series('trHK' , 1, Ko+1) 114 self.new_series('infl' , 1, Ko+1) 115 self.new_series('iters', 1, Ko+1) 116 117 # Weight-related 118 self.new_series('N_eff' , 1, Ko+1) 119 self.new_series('wroot' , 1, Ko+1) 120 self.new_series('resmpl', 1, Ko+1)
Init the default statistics.
122 def new_series(self, name, shape, length='FAUSt', field_mean=False, **kws): 123 """Create (and register) a statistics time series, initialized with `nan`s. 124 125 If `length` is an integer, a `DataSeries` (a trivial subclass of 126 `numpy.ndarray`) is made. By default, though, a `series.FAUSt` is created. 127 128 NB: The `sliding_diagnostics` liveplotting relies on detecting `nan`'s 129 to avoid plotting stats that are not being used. 130 Thus, you cannot use `dtype=bool` or `int` for stats that get plotted. 131 """ 132 # Convert int shape to tuple 133 if not hasattr(shape, '__len__'): 134 if shape == 1: 135 shape = () 136 else: 137 shape = (shape,) 138 139 def make_series(parent, name, shape): 140 if length == 'FAUSt': 141 total_shape = self.K, self.Ko, shape 142 store_opts = self.store_u, self.store_s 143 tseries = series.FAUSt(*total_shape, *store_opts, **kws) 144 else: 145 total_shape = (length,)+shape 146 tseries = series.DataSeries(total_shape, *kws) 147 register_stat(parent, name, tseries) 148 149 # Principal series 150 make_series(self, name, shape) 151 152 # Summary (scalar) series: 153 if shape != (): 154 if field_mean: 155 for suffix in self.field_summaries: 156 make_series(getattr(self, name), suffix, ()) 157 # Make a nested level for sectors 158 if field_mean == 'sectors': 159 for ss in self.sector_summaries: 160 suffix, sector = ss.split('.') 161 make_series(struct_tools.deep_getattr( 162 self, f"{name}.{suffix}"), sector, ())
Create (and register) a statistics time series, initialized with nan
s.
If length
is an integer, a DataSeries
(a trivial subclass of
numpy.ndarray
) is made. By default, though, a series.FAUSt
is created.
NB: The sliding_diagnostics
liveplotting relies on detecting nan
's
to avoid plotting stats that are not being used.
Thus, you cannot use dtype=bool
or int
for stats that get plotted.
169 def assess(self, k, ko=None, faus=None, 170 E=None, w=None, mu=None, Cov=None): 171 """Common interface for both `Stats.assess_ens` and `Stats.assess_ext`. 172 173 The `_ens` assessment function gets called if `E is not None`, 174 and `_ext` if `mu is not None`. 175 176 faus: One or more of `['f',' a', 'u', 's']`, indicating 177 that the result should be stored in (respectively) 178 the forecast/analysis/universal attribute. 179 Default: `'u' if ko is None else 'au' ('a' and 'u')`. 180 """ 181 # Initial consistency checks. 182 if k == 0: 183 if ko is not None: 184 raise KeyError("DAPPER convention: no obs at t=0. Helps avoid bugs.") 185 if self._is_ens == True: 186 if E is None: 187 raise TypeError("Expected ensemble input but E is None") 188 if mu is not None: 189 raise TypeError("Expected ensemble input but mu/Cov is not None") 190 else: 191 if E is not None: 192 raise TypeError("Expected mu/Cov input but E is not None") 193 if mu is None: 194 raise TypeError("Expected mu/Cov input but mu is None") 195 196 # Default. Don't add more defaults. It just gets confusing. 197 if faus is None: 198 faus = 'u' if ko is None else 'au' 199 200 # TODO 4: for faus="au" (e.g.) we don't need to re-**compute** stats, 201 # merely re-write them? 202 for sub in faus: 203 204 # Skip assessment if ('u' and stats not stored or plotted) 205 if k != 0 and ko == None: 206 if not (self.store_u or self.LP_instance.any_figs): 207 continue 208 209 # Silence repeat warnings caused by zero variance 210 with np.errstate(divide='call', invalid='call'): 211 np.seterrcall(warn_zero_variance) 212 213 # Assess 214 stats_now = Avrgs() 215 if self._is_ens: 216 self.assess_ens(stats_now, self.xx[k], E, w) 217 else: 218 self.assess_ext(stats_now, self.xx[k], mu, Cov) 219 self.derivative_stats(stats_now) 220 self.summarize_marginals(stats_now) 221 222 self.write(stats_now, k, ko, sub) 223 224 # LivePlot -- Both init and update must come after the assessment. 225 try: 226 self.LP_instance.update((k, ko, sub), E, Cov) 227 except AttributeError: 228 self.LP_instance = liveplotting.LivePlot( 229 self, self.liveplots, (k, ko, sub), E, Cov)
Common interface for both Stats.assess_ens
and Stats.assess_ext
.
The _ens
assessment function gets called if E is not None
,
and _ext
if mu is not None
.
faus: One or more of ['f',' a', 'u', 's']
, indicating
that the result should be stored in (respectively)
the forecast/analysis/universal attribute.
Default: 'u' if ko is None else 'au' ('a' and 'u')
.
231 def write(self, stat_dict, k, ko, sub): 232 """Write `stat_dict` to series at `(k, ko, sub)`.""" 233 for name, val in stat_dict.items(): 234 stat = struct_tools.deep_getattr(self, name) 235 isFaust = isinstance(stat, series.FAUSt) 236 stat[(k, ko, sub) if isFaust else ko] = val
Write stat_dict
to series at (k, ko, sub)
.
238 def summarize_marginals(self, now): 239 """Compute Mean-field and RMS values.""" 240 formulae = {**self.field_summaries, **self.sector_summaries} 241 242 with np.errstate(divide='ignore', invalid='ignore'): 243 for stat in list(now): 244 field = now[stat] 245 for suffix, formula in formulae.items(): 246 statpath = stat+'.'+suffix 247 if struct_tools.deep_hasattr(self, statpath): 248 now[statpath] = formula(field)
Compute Mean-field and RMS values.
250 def derivative_stats(self, now): 251 """Stats that derive from others, and are not specific for `_ens` or `_ext`).""" 252 try: 253 now.gscore = 2*np.log(now.spread) + (now.err/now.spread)**2 254 except AttributeError: 255 # happens in case rc.comps['error_only'] 256 pass
Stats that derive from others, and are not specific for _ens
or _ext
).
258 def assess_ens(self, now, x, E, w): 259 """Ensemble and Particle filter (weighted/importance) assessment.""" 260 N, Nx = E.shape 261 262 # weights 263 if w is None: 264 w = np.ones(N)/N # All equal. Also, rm attr from stats: 265 if hasattr(self, 'w'): 266 delattr(self, 'w') 267 # Use non-weight formula (since w=None) for mu computations. 268 # The savings are noticeable when rc.comps['error_only'] is noticeable. 269 now.mu = E.mean(0) 270 else: 271 now.w = w 272 if abs(w.sum()-1) > 1e-5: 273 raise RuntimeError("Weights did not sum to one.") 274 now.mu = w @ E 275 276 # Crash checks 277 if not np.all(np.isfinite(E)): 278 raise RuntimeError("Ensemble not finite.") 279 if not np.all(np.isreal(E)): 280 raise RuntimeError("Ensemble not Real.") 281 282 # Compute errors 283 now.err = now.mu - x 284 if rc.comps['error_only']: 285 return 286 287 A = E - now.mu 288 # While A**2 is approx as fast as A*A, 289 # A**3 is 10x slower than A**2 (or A**2.0). 290 # => Use A2 = A**2, A3 = A*A2, A4=A*A3. 291 # But, to save memory, only use A_pow. 292 A_pow = A**2 293 294 # Compute variances 295 var = w @ A_pow 296 ub = unbias_var(w, avoid_pathological=True) 297 var *= ub 298 299 # Compute standard deviation ("Spread") 300 s = np.sqrt(var) # NB: biased (even though var is unbiased) 301 now.spread = s 302 303 # For simplicity, use naive (biased) formulae, derived 304 # from "empirical measure". See doc/unbiased_skew_kurt.jpg. 305 # Normalize by var. Compute "excess" kurt, which is 0 for Gaussians. 306 A_pow *= A 307 now.skew = np.nanmean(w @ A_pow / (s*s*s)) 308 A_pow *= A 309 now.kurt = np.nanmean(w @ A_pow / var**2 - 3) 310 311 now.mad = np.nanmean(w @ abs(A)) 312 313 if self.do_spectral: 314 if N <= Nx: 315 _, s, UT = sla.svd((np.sqrt(w)*A.T).T, full_matrices=False) 316 s *= np.sqrt(ub) # Makes s^2 unbiased 317 now.svals = s 318 now.umisf = UT @ now.err 319 else: 320 P = (A.T * w) @ A 321 s2, U = sla.eigh(P) 322 s2 *= ub 323 now.svals = np.sqrt(s2.clip(0))[::-1] 324 now.umisf = U.T[::-1] @ now.err 325 326 # For each state dim [i], compute rank of truth (x) among the ensemble (E) 327 E_x = np.sort(np.vstack((E, x)), axis=0, kind='heapsort') 328 now.rh = np.asarray( 329 [np.where(E_x[:, i] == x[i])[0][0] for i in range(Nx)])
Ensemble and Particle filter (weighted/importance) assessment.
331 def assess_ext(self, now, x, mu, P): 332 """Kalman filter (Gaussian) assessment.""" 333 if not np.all(np.isfinite(mu)): 334 raise RuntimeError("Estimates not finite.") 335 if not np.all(np.isreal(mu)): 336 raise RuntimeError("Estimates not Real.") 337 # Don't check the cov (might not be explicitly availble) 338 339 # Compute errors 340 now.mu = mu 341 now.err = now.mu - x 342 if rc.comps['error_only']: 343 return 344 345 # Get diag(P) 346 if P is None: 347 var = np.zeros_like(mu) 348 elif np.isscalar(P): 349 var = np.ones_like(mu) * P 350 else: 351 if isinstance(P, CovMat): 352 var = P.diag 353 P = P.full 354 else: 355 var = np.diag(P) 356 357 if self.do_spectral: 358 s2, U = sla.eigh(P) 359 now.svals = np.sqrt(np.maximum(s2, 0.0))[::-1] 360 now.umisf = (U.T @ now.err)[::-1] 361 362 # Compute stddev 363 now.spread = np.sqrt(var) 364 # Here, sqrt(2/pi) is the ratio, of MAD/Spread for Gaussians 365 now.mad = np.nanmean(now.spread) * np.sqrt(2/np.pi)
Kalman filter (Gaussian) assessment.
367 def average_in_time(self, kk=None, kko=None, free=False): 368 """Avarage all univariate (scalar) time series. 369 370 - `kk` time inds for averaging 371 - `kko` time inds for averaging obs 372 """ 373 tseq = self.HMM.tseq 374 if kk is None: 375 kk = tseq.mask 376 if kko is None: 377 kko = tseq.masko 378 379 def average1(tseries): 380 avrgs = Avrgs() 381 382 def average_multivariate(): return avrgs 383 # Plain averages of nd-series are rarely interesting. 384 # => Shortcircuit => Leave for manual computations 385 386 if isinstance(tseries, series.FAUSt): 387 # Average series for each subscript 388 if tseries.item_shape != (): 389 return average_multivariate() 390 for sub in [ch for ch in 'fas' if hasattr(tseries, ch)]: 391 avrgs[sub] = series.mean_with_conf(tseries[kko, sub]) 392 if tseries.store_u: 393 avrgs['u'] = series.mean_with_conf(tseries[kk, 'u']) 394 395 elif isinstance(tseries, series.DataSeries): 396 if tseries.array.shape[1:] != (): 397 return average_multivariate() 398 elif len(tseries.array) == self.Ko+1: 399 avrgs = series.mean_with_conf(tseries[kko]) 400 elif len(tseries.array) == self.K+1: 401 avrgs = series.mean_with_conf(tseries[kk]) 402 else: 403 raise ValueError 404 405 elif np.isscalar(tseries): 406 avrgs = tseries # Eg. just copy over "duration" from stats 407 408 else: 409 raise TypeError(f"Don't know how to average {tseries}") 410 411 return avrgs 412 413 def recurse_average(stat_parent, avrgs_parent): 414 for key in getattr(stat_parent, "stat_register", []): 415 try: 416 tseries = getattr(stat_parent, key) 417 except AttributeError: 418 continue # Eg assess_ens() deletes .weights if None 419 avrgs = average1(tseries) 420 recurse_average(tseries, avrgs) 421 avrgs_parent[key] = avrgs 422 423 avrgs = Avrgs() 424 recurse_average(self, avrgs) 425 self.xp.avrgs = avrgs 426 if free: 427 delattr(self.xp, 'stats')
Avarage all univariate (scalar) time series.
kk
time inds for averagingkko
time inds for averaging obs
429 def replay(self, figlist="default", speed=np.inf, t1=0, t2=None, **kwargs): 430 """Replay LivePlot with what's been stored in 'self'. 431 432 - t1, t2: time window to plot. 433 - 'figlist' and 'speed': See LivePlot's doc. 434 435 .. note:: `store_u` (whether to store non-obs-time stats) must 436 have been `True` to have smooth graphs as in the actual LivePlot. 437 438 .. note:: Ensembles are generally not stored in the stats 439 and so cannot be replayed. 440 """ 441 # Time settings 442 tseq = self.HMM.tseq 443 if t2 is None: 444 t2 = t1 + tseq.Tplot 445 446 # Ens does not get stored in stats, so we cannot replay that. 447 # If the LPs are initialized with P0!=None, then they will avoid ens plotting. 448 # TODO 4: This system for switching from Ens to stats must be replaced. 449 # It breaks down when M is very large. 450 try: 451 P0 = np.full_like(self.HMM.X0.C.full, np.nan) 452 except AttributeError: # e.g. if X0 is defined via sampling func 453 P0 = np.eye(self.HMM.Nx) 454 455 LP = liveplotting.LivePlot(self, figlist, P=P0, speed=speed, 456 Tplot=t2-t1, replay=True, **kwargs) 457 458 # Remember: must use progbar to unblock read1. 459 # Let's also make a proper description. 460 desc = self.xp.da_method + " (replay)" 461 462 # Play through assimilation cycles 463 for k, ko, t, _dt in progbar(tseq.ticker, desc): 464 if t1 <= t <= t2: 465 if ko is not None: 466 LP.update((k, ko, 'f'), None, None) 467 LP.update((k, ko, 'a'), None, None) 468 LP.update((k, ko, 'u'), None, None) 469 470 # Pause required when speed=inf. 471 # On Mac, it was also necessary to do it for each fig. 472 if LP.any_figs: 473 for _name, updater in LP.figures.items(): 474 if plt.fignum_exists(_name) and getattr(updater, 'is_active', 1): 475 plt.figure(_name) 476 plt.pause(0.01)
Replay LivePlot with what's been stored in 'self'.
- t1, t2: time window to plot.
- 'figlist' and 'speed': See LivePlot's doc.
store_u
(whether to store non-obs-time stats) must
have been True
to have smooth graphs as in the actual LivePlot.
Ensembles are generally not stored in the stats
and so cannot be replayed.
Inherited Members
479def register_stat(self, name, value): 480 """Do `self.name = value` and register `name` as in self's `stat_register`. 481 482 Note: `self` is not always a `Stats` object, but could be a "child" of it. 483 """ 484 setattr(self, name, value) 485 if not hasattr(self, "stat_register"): 486 self.stat_register = [] 487 self.stat_register.append(name)
Do self.name = value
and register name
as in self's stat_register
.
Note: self
is not always a Stats
object, but could be a "child" of it.
490class Avrgs(series.StatPrint, struct_tools.DotDict): 491 """A `dict` specialized for the averages of statistics. 492 493 Embellishments: 494 495 - `dapper.tools.StatPrint` 496 - `Avrgs.tabulate` 497 - `getattr` that supports abbreviations. 498 """ 499 500 def tabulate(self, statkeys=(), decimals=None): 501 columns = tabulate_avrgs([self], statkeys, decimals=decimals) 502 return tabulate(columns, headers="keys").replace('␣', ' ') 503 504 abbrevs = {'rmse': 'err.rms', 'rmss': 'spread.rms', 'rmv': 'spread.rms'} 505 506 # Use getattribute coz it gets called before getattr. 507 def __getattribute__(self, key): 508 """Support deep and abbreviated lookup.""" 509 # key = abbrevs[key] # Instead of this, also support rmse.a: 510 key = '.'.join(Avrgs.abbrevs.get(seg, seg) for seg in key.split('.')) 511 512 if "." in key: 513 return struct_tools.deep_getattr(self, key) 514 else: 515 return super().__getattribute__(key)
A dict
specialized for the averages of statistics.
Embellishments:
dapper.tools.StatPrint
Avrgs.tabulate
getattr
that supports abbreviations.
Inherited Members
- struct_tools.DotDict
- DotDict
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- update
- fromkeys
- clear
- copy
540def align_col(col, pad='␣', missingval='', just=">"): 541 r"""Align column. 542 543 Treats `int`s and fixed-point `float`/`str` especially, aligning on the point. 544 545 Example: 546 >>> xx = [1, 1., 1.234, 12.34, 123.4, "1.2e-3", None, np.nan, "inf", (1, 2)] 547 >>> print(*align_col(xx), sep="\n") 548 ␣␣1␣␣␣␣ 549 ␣␣1.0␣␣ 550 ␣␣1.234 551 ␣12.34␣ 552 123.4␣␣ 553 ␣1.2e-3 554 ␣␣␣␣␣␣␣ 555 ␣␣␣␣nan 556 ␣␣␣␣inf 557 ␣(1, 2) 558 """ 559 def split_decimal(x): 560 x = str(x) 561 try: 562 y = float(x) 563 except ValueError: 564 pass 565 else: 566 if np.isfinite(y) and ("e" not in x.lower()): 567 a, *b = x.split(".") 568 if b == []: 569 b = "int" 570 else: 571 b = b[0] 572 return a, b 573 return x, False 574 575 # Find max nInt, nDec 576 nInt = nDec = -1 577 for x in col: 578 ints, decs = split_decimal(x) 579 if decs: 580 nInt = max(nInt, len(ints)) 581 if decs != "int": 582 nDec = max(nDec, len(decs)) 583 584 # Format entries. Floats get aligned on point. 585 def frmt(x): 586 if x is None: 587 return missingval 588 ints, decs = split_decimal(x) 589 x = f"{ints.rjust(nInt, pad)}" 590 if decs == "int": 591 if nDec >= 0: 592 x += pad + pad*nDec 593 elif decs: 594 x += "." + f"{decs.ljust(nDec, pad)}" 595 else: 596 x = ints 597 return x 598 599 # Format 600 col = [frmt(x) for x in col] 601 # Find max width 602 Max = max(len(x) for x in col) 603 # Right-justify 604 shift = str.rjust if just == ">" else str.ljust 605 col = [shift(x, Max, pad) for x in col] 606 return col
Align column.
Treats int
s and fixed-point float
/str
especially, aligning on the point.
Example:
>>> xx = [1, 1., 1.234, 12.34, 123.4, "1.2e-3", None, np.nan, "inf", (1, 2)]
>>> print(*align_col(xx), sep="\n")
␣␣1␣␣␣␣
␣␣1.0␣␣
␣␣1.234
␣12.34␣
123.4␣␣
␣1.2e-3
␣␣␣␣␣␣␣
␣␣␣␣nan
␣␣␣␣inf
␣(1, 2)
609def unpack_uqs(uq_list, decimals=None): 610 """Convert list of `uq`s into dict of lists (of equal-length) of attributes. 611 612 The attributes are obtained by `vars(uq)`, 613 and may get formatted somehow (e.g. cast to strings) in the output. 614 615 If `uq` is `None`, then `None` is inserted in each list. 616 Else, `uq` must be an instance of `dapper.tools.rounding.UncertainQtty`. 617 618 Parameters 619 ---------- 620 uq_list: list 621 List of `uq`s. 622 623 decimals: int 624 Desired number of decimals. 625 Used for (only) the columns "val" and "prec". 626 Default: `None`. In this case, the formatting is left to the `uq`s. 627 """ 628 def frmt(uq): 629 if not isinstance(uq, series.UncertainQtty): 630 # Presumably uq is just a number 631 uq = series.UncertainQtty(uq) 632 633 attrs = vars(uq).copy() 634 635 # val/prec: round 636 if decimals is None: 637 v, p = str(uq).split(" ±") 638 else: 639 frmt = "%%.%df" % decimals 640 v, p = frmt % uq.val, frmt % uq.prec 641 attrs["val"], attrs["prec"] = v, p 642 643 # tuned_coord: convert to tuple 644 try: 645 attrs["tuned_coord"] = tuple(a for a in uq.tuned_coord) 646 except AttributeError: 647 pass 648 return attrs 649 650 cols = {} 651 for i, uq in enumerate(uq_list): 652 if uq is not None: 653 # Format 654 attrs = frmt(uq) 655 # Insert attrs as a "row" in the `cols`: 656 for k in attrs: 657 # Init column 658 if k not in cols: 659 cols[k] = [None]*len(uq_list) 660 # Insert element 661 cols[k][i] = attrs[k] 662 663 return cols
Convert list of uq
s into dict of lists (of equal-length) of attributes.
The attributes are obtained by vars(uq)
,
and may get formatted somehow (e.g. cast to strings) in the output.
If uq
is None
, then None
is inserted in each list.
Else, uq
must be an instance of dapper.tools.rounding.UncertainQtty
.
Parameters
- uq_list (list):
List of
uq
s. - decimals (int):
Desired number of decimals.
Used for (only) the columns "val" and "prec".
Default:
None
. In this case, the formatting is left to theuq
s.
666def tabulate_avrgs(avrgs_list, statkeys=(), decimals=None): 667 """Tabulate avrgs (val±prec).""" 668 if not statkeys: 669 statkeys = ['rmse.a', 'rmv.a', 'rmse.f'] 670 671 columns = {} 672 for stat in statkeys: 673 column = [getattr(a, stat, None) for a in avrgs_list] 674 column = unpack_uqs(column, decimals) 675 if not column: 676 raise ValueError(f"The stat. key '{stat}' was not" 677 " found among any of the averages.") 678 vals = align_col([stat] + column["val"]) 679 precs = align_col(['1σ'] + column["prec"], just="<") 680 headr = vals[0]+' '+precs[0] 681 mattr = [f"{v} ±{c}" for v, c in zip(vals, precs)][1:] 682 columns[headr] = mattr 683 684 return columns
Tabulate avrgs (val±prec).
687def center(E, axis=0, rescale=False): 688 r"""Center ensemble. 689 690 Makes use of `np` features: keepdims and broadcasting. 691 692 Parameters 693 ---------- 694 E: ndarray 695 Ensemble which going to be inflated 696 697 axis: int, optional 698 The axis to be centered. Default: 0 699 700 rescale: bool, optional 701 If True, inflate to compensate for reduction in the expected variance. 702 The inflation factor is \(\sqrt{\frac{N}{N - 1}}\) 703 where N is the ensemble size. Default: False 704 705 Returns 706 ------- 707 X: ndarray 708 Ensemble anomaly 709 710 x: ndarray 711 Mean of the ensemble 712 """ 713 x = np.mean(E, axis=axis, keepdims=True) 714 X = E - x 715 716 if rescale: 717 N = E.shape[axis] 718 X *= np.sqrt(N/(N-1)) 719 720 x = x.squeeze(axis=axis) 721 722 return X, x
Center ensemble.
Makes use of np
features: keepdims and broadcasting.
Parameters
- E (ndarray): Ensemble which going to be inflated
- axis (int, optional): The axis to be centered. Default: 0
- rescale (bool, optional): If True, inflate to compensate for reduction in the expected variance. The inflation factor is (\sqrt{\frac{N}{N - 1}}) where N is the ensemble size. Default: False
Returns
- X (ndarray): Ensemble anomaly
- x (ndarray): Mean of the ensemble
725def mean0(E, axis=0, rescale=True): 726 """Like `center`, but only return the anomalies (not the mean). 727 728 Uses `rescale=True` by default, which is beneficial 729 when used to center observation perturbations. 730 """ 731 return center(E, axis=axis, rescale=rescale)[0]
Like center
, but only return the anomalies (not the mean).
Uses rescale=True
by default, which is beneficial
when used to center observation perturbations.
734def inflate_ens(E, factor): 735 """Inflate the ensemble (center, inflate, re-combine). 736 737 Parameters 738 ---------- 739 E : ndarray 740 Ensemble which going to be inflated 741 742 factor: `float` 743 Inflation factor 744 745 Returns 746 ------- 747 ndarray 748 Inflated ensemble 749 """ 750 if factor == 1: 751 return E 752 X, x = center(E) 753 return x + X*factor
Inflate the ensemble (center, inflate, re-combine).
Parameters
- E (ndarray): Ensemble which going to be inflated
- factor (
float
): Inflation factor
Returns
- ndarray: Inflated ensemble
756def weight_degeneracy(w, prec=1e-10): 757 """Check if the weights are degenerate. 758 759 If it is degenerate, the maximum weight 760 should be nearly one since sum(w) = 1 761 762 Parameters 763 ---------- 764 w: ndarray 765 Importance weights. Must sum to 1. 766 767 prec: float, optional 768 Tolerance of the distance between w and one. Default:1e-10 769 770 Returns 771 ------- 772 bool 773 If weight is degenerate True, else False 774 """ 775 return (1-w.max()) < prec
Check if the weights are degenerate.
If it is degenerate, the maximum weight should be nearly one since sum(w) = 1
Parameters
- w (ndarray): Importance weights. Must sum to 1.
- prec (float, optional): Tolerance of the distance between w and one. Default:1e-10
Returns
- bool: If weight is degenerate True, else False
778def unbias_var(w=None, N_eff=None, avoid_pathological=False): 779 """Compute unbias-ing factor for variance estimation. 780 781 Parameters 782 ---------- 783 w: ndarray, optional 784 Importance weights. Must sum to 1. 785 Only one of `w` and `N_eff` can be `None`. Default: `None` 786 787 N_eff: float, optional 788 The "effective" size of the weighted ensemble. 789 If not provided, it is computed from the weights. 790 The unbiasing factor is $$ N_{eff} / (N_{eff} - 1) $$. 791 792 avoid_pathological: bool, optional 793 Avoid weight collapse. Default: `False` 794 795 Returns 796 ------- 797 ub: float 798 factor used to unbiasing variance 799 800 Reference 801 -------- 802 [Wikipedia](https://wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights) 803 """ 804 if N_eff is None: 805 N_eff = 1/(w@w) 806 807 if avoid_pathological and weight_degeneracy(w): 808 ub = 1 # Don't do in case of weights collapse 809 else: 810 ub = 1/(1 - 1/N_eff) # =N/(N-1) if w==ones(N)/N. 811 return ub
Compute unbias-ing factor for variance estimation.
Parameters
- w (ndarray, optional):
Importance weights. Must sum to 1.
Only one of
w
andN_eff
can beNone
. Default:None
- N_eff (float, optional): The "effective" size of the weighted ensemble. If not provided, it is computed from the weights. The unbiasing factor is $$ N_{eff} / (N_{eff} - 1) $$.
- avoid_pathological (bool, optional):
Avoid weight collapse. Default:
False
Returns
- ub (float): factor used to unbiasing variance