dapper.stats

Statistics for the assessment of DA methods.

Stats is a data container for ([mostly] time series of) statistics. It comes with a battery of methods to compute the default statistics.

Avrgs is a data container for the same statistics, but after they have been averaged in time (after the assimilation has finished).

Instances of these objects are created by dapper.da_methods.da_method (i.e. "xp") objects and written to their dapper.stats and .avrgs attributes.

Default statistics

List them using

>>> list(vars(xp.stats))
... list(vars(xp.avrgs))

The FAUSt key/attribute

The time series of statistics (the attributes of dapper.stats) may have attributes .f, .a, .s, .u, referring to whether the statistic is for a "forecast", "analysis", or "smoothing" estimate (as is decided when the calls to Stats.assess is made), or a "universal" (forecast, but at intermediate [non-obs.-time]) estimate.

The same applies for the time-averages of .avrgs.

Field summaries

The statistics are also averaged in space. This is done according to the methods listed in dpr.rc.field_summaries.

Although sometimes pretty close, rmv (a.k.a. spread.rms) is not (supposed to be) an un-biased estimator of rmse (a.k.a. err.rms). This is because of the square roots involved in the field summary. Instead, spread.ms (i.e. the mean variance) is the unbiased estimator of err.ms.

Regional field summaries

If the HiddenMarkovModel has the attribute .sectors with value, e.g.,

>>> HMM.sectors = {
...     "ocean": inds_of_state_of_the_ocean,
...     "atmos": inds_of_state_of_the_atmosphere,
... }

then .stats.rms and .avrgs.rms will also have attributes named after the keys of HMM.sectors, e.g. stats.err.rms.ocean. This also goes for any other (than rms) type of field summary method.

Declaring new, custom statistics

Only the time series created with Stats.new_series will be in the format operated on by Stats.average_in_time. For example, create ndarray of length Ko+1 to hold the time series of estimated inflation values:

>>> self.stats.new_series('infl', 1, Ko+1)

Alternatively you can overwrite a default statistic; for example:

>>> error_time_series_a = xx - ensemble_time_series_a.mean(axis=1)
... self.stats.err.rms.a = np.sqrt(np.mean(error_time_series_a**2, axis=-1))

Of course, you could just do this

>>> self.stats.my_custom_stat = value

However, dapper.xp_launch.run_experiment (without free=False) will delete the Stats object from xp after the assimilation, in order to save memory. Therefore, in order to have my_custom_stat be available among xp.avrgs, it must be "registered":

>>> self.stats.stat_register.append("my_custom_stat")

Alternatively, you can do both at once

>>> self.stat("my_custom_stat", value)
  1"""Statistics for the assessment of DA methods.
  2
  3`Stats` is a data container for ([mostly] time series of) statistics.
  4It comes with a battery of methods to compute the default statistics.
  5
  6`Avrgs` is a data container *for the same statistics*,
  7but after they have been averaged in time (after the assimilation has finished).
  8
  9Instances of these objects are created by `dapper.da_methods.da_method`
 10(i.e. "`xp`") objects and written to their `.stats` and `.avrgs` attributes.
 11
 12.. include:: ../docs/stats_etc.md
 13"""
 14
 15import warnings
 16
 17import numpy as np
 18import scipy.linalg as sla
 19import struct_tools
 20from matplotlib import pyplot as plt
 21from patlib.std import do_once
 22from tabulate import tabulate
 23
 24import dapper.tools.liveplotting as liveplotting
 25import dapper.tools.series as series
 26from dapper.dpr_config import rc
 27from dapper.tools.matrices import CovMat
 28from dapper.tools.progressbar import progbar
 29
 30
 31class Stats(series.StatPrint):
 32    """Contains and computes statistics of the DA methods."""
 33
 34    def __init__(self, xp, HMM, xx, yy, liveplots=False, store_u=rc.store_u):
 35        """Init the default statistics."""
 36        ######################################
 37        # Preamble
 38        ######################################
 39        self.xp        = xp
 40        self.HMM       = HMM
 41        self.xx        = xx
 42        self.yy        = yy
 43        self.liveplots = liveplots
 44        self.store_u   = store_u
 45        self.store_s   = any(key in xp.__dict__ for key in
 46                             ["Lag", "DeCorr"])  # prms used by smoothers
 47
 48        # Shapes
 49        K  = xx.shape[0] - 1
 50        Nx = xx.shape[1]
 51        Ko = yy.shape[0] - 1
 52        self.K, self.Ko, self.Nx = K, Ko, Nx
 53
 54        # Methods for summarizing multivariate stats ("fields") as scalars
 55        # Don't use nanmean here; nan's should get propagated!
 56        en_mean = lambda x: np.mean(x, axis=0)  # noqa
 57        self.field_summaries = dict(
 58            m   = lambda x: en_mean(x),                  # mean-field
 59            ms  = lambda x: en_mean(x**2),               # root-mean-square
 60            rms = lambda x: np.sqrt(en_mean(x**2)),      # root-mean-square
 61            ma  = lambda x: en_mean(np.abs(x)),          # mean-absolute
 62            gm  = lambda x: np.exp(en_mean(np.log(x))),  # geometric mean
 63        )
 64        # Only keep the methods listed in rc
 65        self.field_summaries = struct_tools.intersect(self.field_summaries,
 66                                                      rc.field_summaries)
 67
 68        # Define similar methods, but restricted to sectors
 69        self.sector_summaries = {}
 70        def restrict(fun, inds): return (lambda x: fun(x[inds]))
 71        for suffix, formula in self.field_summaries.items():
 72            for sector, inds in HMM.sectors.items():
 73                f = restrict(formula, inds)
 74                self.sector_summaries['%s.%s' % (suffix, sector)] = f
 75
 76        ######################################
 77        # Allocate time series of various stats
 78        ######################################
 79        self.new_series('mu'    , Nx, field_mean='sectors')  # Mean
 80        self.new_series('spread', Nx, field_mean='sectors')  # Std. dev. ("spread")
 81        self.new_series('err'   , Nx, field_mean='sectors')  # Error (mu - truth)
 82        self.new_series('gscore', Nx, field_mean='sectors')  # Gaussian (log) score
 83
 84        # To save memory, we only store these field means:
 85        self.new_series('mad' , 1)  # Mean abs deviations
 86        self.new_series('skew', 1)  # Skewness
 87        self.new_series('kurt', 1)  # Kurtosis
 88
 89        if hasattr(xp, 'N'):
 90            N = xp.N
 91            self.new_series('w', N, field_mean=True)  # Importance weights
 92            self.new_series('rh', Nx, dtype=int)  # Rank histogram
 93
 94            self._is_ens = True
 95            minN = min(Nx, N)
 96            self.do_spectral  = np.sqrt(Nx*N) <= rc.comps["max_spectral"]
 97        else:
 98            self._is_ens = False
 99            minN = Nx
100            self.do_spectral = Nx <= rc.comps["max_spectral"]
101
102        if self.do_spectral:
103            # Note: the mean-field and RMS time-series of
104            # (i) svals and (ii) umisf should match the corresponding series of
105            # (i) spread and (ii) err.
106            self.new_series('svals', minN)  # Principal component (SVD) scores
107            self.new_series('umisf', minN)  # Error in component directions
108
109        ######################################
110        # Allocate a few series for outside use
111        ######################################
112        self.new_series('trHK' , 1, Ko+1)
113        self.new_series('infl' , 1, Ko+1)
114        self.new_series('iters', 1, Ko+1)
115
116        # Weight-related
117        self.new_series('N_eff' , 1, Ko+1)
118        self.new_series('wroot' , 1, Ko+1)
119        self.new_series('resmpl', 1, Ko+1)
120
121    def new_series(self, name, shape, length='FAUSt', field_mean=False, **kws):
122        """Create (and register) a statistics time series, initialized with `nan`s.
123
124        If `length` is an integer, a `DataSeries` (a trivial subclass of
125        `numpy.ndarray`) is made. By default, though, a `series.FAUSt` is created.
126
127        NB: The `sliding_diagnostics` liveplotting relies on detecting `nan`'s
128            to avoid plotting stats that are not being used.
129            Thus, you cannot use `dtype=bool` or `int` for stats that get plotted.
130        """
131        # Convert int shape to tuple
132        if not hasattr(shape, '__len__'):
133            if shape == 1:
134                shape = ()
135            else:
136                shape = (shape,)
137
138        def make_series(parent, name, shape):
139            if length == 'FAUSt':
140                total_shape = self.K, self.Ko, shape
141                store_opts = self.store_u, self.store_s
142                tseries = series.FAUSt(*total_shape, *store_opts, **kws)
143            else:
144                total_shape = (length,)+shape
145                tseries = series.DataSeries(total_shape, *kws)
146            register_stat(parent, name, tseries)
147
148        # Principal series
149        make_series(self, name, shape)
150
151        # Summary (scalar) series:
152        if shape != ():
153            if field_mean:
154                for suffix in self.field_summaries:
155                    make_series(getattr(self, name), suffix, ())
156            # Make a nested level for sectors
157            if field_mean == 'sectors':
158                for ss in self.sector_summaries:
159                    suffix, sector = ss.split('.')
160                    make_series(struct_tools.deep_getattr(
161                        self, f"{name}.{suffix}"), sector, ())
162
163    @property
164    def data_series(self):
165        return [k for k in vars(self)
166                if isinstance(getattr(self, k), series.DataSeries)]
167
168    def assess(self, k, ko=None, faus=None,
169               E=None, w=None, mu=None, Cov=None):
170        """Common interface for both `Stats.assess_ens` and `Stats.assess_ext`.
171
172        The `_ens` assessment function gets called if `E is not None`,
173        and `_ext` if `mu is not None`.
174
175        faus: One or more of `['f',' a', 'u', 's']`, indicating
176              that the result should be stored in (respectively)
177              the forecast/analysis/universal attribute.
178              Default: `'u' if ko is None else 'au' ('a' and 'u')`.
179        """
180        # Initial consistency checks.
181        if k == 0:
182            if ko is not None:
183                raise KeyError("DAPPER convention: no obs at t=0. Helps avoid bugs.")
184            if self._is_ens == True:
185                if E is None:
186                    raise TypeError("Expected ensemble input but E is None")
187                if mu is not None:
188                    raise TypeError("Expected ensemble input but mu/Cov is not None")
189            else:
190                if E is not None:
191                    raise TypeError("Expected mu/Cov input but E is not None")
192                if mu is None:
193                    raise TypeError("Expected mu/Cov input but mu is None")
194
195        # Default. Don't add more defaults. It just gets confusing.
196        if faus is None:
197            faus = 'u' if ko is None else 'au'
198
199        # TODO 4: for faus="au" (e.g.) we don't need to re-**compute** stats,
200        #         merely re-write them?
201        for sub in faus:
202
203            # Skip assessment if ('u' and stats not stored or plotted)
204            if k != 0 and ko == None:
205                if not (self.store_u or self.LP_instance.any_figs):
206                    continue
207
208            # Silence repeat warnings caused by zero variance
209            with np.errstate(divide='call', invalid='call'):
210                np.seterrcall(warn_zero_variance)
211
212                # Assess
213                stats_now = Avrgs()
214                if self._is_ens:
215                    self.assess_ens(stats_now, self.xx[k], E, w)
216                else:
217                    self.assess_ext(stats_now, self.xx[k], mu, Cov)
218                self.derivative_stats(stats_now)
219                self.summarize_marginals(stats_now)
220
221            self.write(stats_now, k, ko, sub)
222
223            # LivePlot -- Both init and update must come after the assessment.
224            try:
225                self.LP_instance.update((k, ko, sub), E, Cov)
226            except AttributeError:
227                self.LP_instance = liveplotting.LivePlot(
228                    self, self.liveplots, (k, ko, sub), E, Cov)
229
230    def write(self, stat_dict, k, ko, sub):
231        """Write `stat_dict` to series at `(k, ko, sub)`."""
232        for name, val in stat_dict.items():
233            stat = struct_tools.deep_getattr(self, name)
234            isFaust = isinstance(stat, series.FAUSt)
235            stat[(k, ko, sub) if isFaust else ko] = val
236
237    def summarize_marginals(self, now):
238        """Compute Mean-field and RMS values."""
239        formulae = {**self.field_summaries, **self.sector_summaries}
240
241        with np.errstate(divide='ignore', invalid='ignore'):
242            for stat in list(now):
243                field = now[stat]
244                for suffix, formula in formulae.items():
245                    statpath = stat+'.'+suffix
246                    if struct_tools.deep_hasattr(self, statpath):
247                        now[statpath] = formula(field)
248
249    def derivative_stats(self, now):
250        """Stats that derive from others, and are not specific for `_ens` or `_ext`)."""
251        try:
252            now.gscore = 2*np.log(now.spread) + (now.err/now.spread)**2
253        except AttributeError:
254            # happens in case rc.comps['error_only']
255            pass
256
257    def assess_ens(self, now, x, E, w):
258        """Ensemble and Particle filter (weighted/importance) assessment."""
259        N, Nx = E.shape
260
261        # weights
262        if w is None:
263            w = np.ones(N)/N  # All equal. Also, rm attr from stats:
264            if hasattr(self, 'w'):
265                delattr(self, 'w')
266            # Use non-weight formula (since w=None) for mu computations.
267            # The savings are noticeable when rc.comps['error_only'] is noticeable.
268            now.mu = E.mean(0)
269        else:
270            now.w = w
271            if abs(w.sum()-1) > 1e-5:
272                raise RuntimeError("Weights did not sum to one.")
273            now.mu = w @ E
274
275        # Crash checks
276        if not np.all(np.isfinite(E)):
277            raise RuntimeError("Ensemble not finite.")
278        if not np.all(np.isreal(E)):
279            raise RuntimeError("Ensemble not Real.")
280
281        # Compute errors
282        now.err = now.mu - x
283        if rc.comps['error_only']:
284            return
285
286        A = E - now.mu
287        # While A**2 is approx as fast as A*A,
288        # A**3 is 10x slower than A**2 (or A**2.0).
289        # => Use A2 = A**2, A3 = A*A2, A4=A*A3.
290        # But, to save memory, only use A_pow.
291        A_pow = A**2
292
293        # Compute variances
294        var  = w @ A_pow
295        ub   = unbias_var(w, avoid_pathological=True)
296        var *= ub
297
298        # Compute standard deviation ("Spread")
299        s = np.sqrt(var)  # NB: biased (even though var is unbiased)
300        now.spread = s
301
302        # For simplicity, use naive (biased) formulae, derived
303        # from "empirical measure". See doc/unbiased_skew_kurt.jpg.
304        # Normalize by var. Compute "excess" kurt, which is 0 for Gaussians.
305        A_pow *= A
306        now.skew = np.nanmean(w @ A_pow / (s*s*s))
307        A_pow *= A
308        now.kurt = np.nanmean(w @ A_pow / var**2 - 3)
309
310        now.mad  = np.nanmean(w @ abs(A))
311
312        if self.do_spectral:
313            if N <= Nx:
314                _, s, UT  = sla.svd((np.sqrt(w)*A.T).T, full_matrices=False)
315                s        *= np.sqrt(ub)  # Makes s^2 unbiased
316                now.svals = s
317                now.umisf = UT @ now.err
318            else:
319                P         = (A.T * w) @ A
320                s2, U     = sla.eigh(P)
321                s2       *= ub
322                now.svals = np.sqrt(s2.clip(0))[::-1]
323                now.umisf = U.T[::-1] @ now.err
324
325            # For each state dim [i], compute rank of truth (x) among the ensemble (E)
326            E_x = np.sort(np.vstack((E, x)), axis=0, kind='heapsort')
327            now.rh = np.asarray(
328                [np.where(E_x[:, i] == x[i])[0][0] for i in range(Nx)])
329
330    def assess_ext(self, now, x, mu, P):
331        """Kalman filter (Gaussian) assessment."""
332        if not np.all(np.isfinite(mu)):
333            raise RuntimeError("Estimates not finite.")
334        if not np.all(np.isreal(mu)):
335            raise RuntimeError("Estimates not Real.")
336        # Don't check the cov (might not be explicitly availble)
337
338        # Compute errors
339        now.mu  = mu
340        now.err = now.mu - x
341        if rc.comps['error_only']:
342            return
343
344        # Get diag(P)
345        if P is None:
346            var = np.zeros_like(mu)
347        elif np.isscalar(P):
348            var = np.ones_like(mu) * P
349        else:
350            if isinstance(P, CovMat):
351                var = P.diag
352                P   = P.full
353            else:
354                var = np.diag(P)
355
356            if self.do_spectral:
357                s2, U     = sla.eigh(P)
358                now.svals = np.sqrt(np.maximum(s2, 0.0))[::-1]
359                now.umisf = (U.T @ now.err)[::-1]
360
361        # Compute stddev
362        now.spread = np.sqrt(var)
363        # Here, sqrt(2/pi) is the ratio, of MAD/Spread for Gaussians
364        now.mad = np.nanmean(now.spread) * np.sqrt(2/np.pi)
365
366    def average_in_time(self, kk=None, kko=None, free=False):
367        """Avarage all univariate (scalar) time series.
368
369        - `kk`    time inds for averaging
370        - `kko` time inds for averaging obs
371        """
372        tseq = self.HMM.tseq
373        if kk is None:
374            kk     = tseq.mask
375        if kko is None:
376            kko  = tseq.masko
377
378        def average1(tseries):
379            avrgs = Avrgs()
380
381            def average_multivariate(): return avrgs
382            # Plain averages of nd-series are rarely interesting.
383            # => Shortcircuit => Leave for manual computations
384
385            if isinstance(tseries, series.FAUSt):
386                # Average series for each subscript
387                if tseries.item_shape != ():
388                    return average_multivariate()
389                for sub in [ch for ch in 'fas' if hasattr(tseries, ch)]:
390                    avrgs[sub] = series.mean_with_conf(tseries[kko, sub])
391                if tseries.store_u:
392                    avrgs['u'] = series.mean_with_conf(tseries[kk, 'u'])
393
394            elif isinstance(tseries, series.DataSeries):
395                if tseries.array.shape[1:] != ():
396                    return average_multivariate()
397                elif len(tseries.array) == self.Ko+1:
398                    avrgs = series.mean_with_conf(tseries[kko])
399                elif len(tseries.array) == self.K+1:
400                    avrgs = series.mean_with_conf(tseries[kk])
401                else:
402                    raise ValueError
403
404            elif np.isscalar(tseries):
405                avrgs = tseries  # Eg. just copy over "duration" from stats
406
407            else:
408                raise TypeError(f"Don't know how to average {tseries}")
409
410            return avrgs
411
412        def recurse_average(stat_parent, avrgs_parent):
413            for key in getattr(stat_parent, "stat_register", []):
414                try:
415                    tseries = getattr(stat_parent, key)
416                except AttributeError:
417                    continue  # Eg assess_ens() deletes .weights if None
418                avrgs = average1(tseries)
419                recurse_average(tseries, avrgs)
420                avrgs_parent[key] = avrgs
421
422        avrgs = Avrgs()
423        recurse_average(self, avrgs)
424        self.xp.avrgs = avrgs
425        if free:
426            delattr(self.xp, 'stats')
427
428    def replay(self, figlist="default", speed=np.inf, t1=0, t2=None, **kwargs):
429        """Replay LivePlot with what's been stored in 'self'.
430
431        - t1, t2: time window to plot.
432        - 'figlist' and 'speed': See LivePlot's doc.
433
434        .. note:: `store_u` (whether to store non-obs-time stats) must
435        have been `True` to have smooth graphs as in the actual LivePlot.
436
437        .. note:: Ensembles are generally not stored in the stats
438        and so cannot be replayed.
439        """
440        # Time settings
441        tseq = self.HMM.tseq
442        if t2 is None:
443            t2 = t1 + tseq.Tplot
444
445        # Ens does not get stored in stats, so we cannot replay that.
446        # If the LPs are initialized with P0!=None, then they will avoid ens plotting.
447        # TODO 4: This system for switching from Ens to stats must be replaced.
448        #       It breaks down when M is very large.
449        try:
450            P0 = np.full_like(self.HMM.X0.C.full, np.nan)
451        except AttributeError:  # e.g. if X0 is defined via sampling func
452            P0 = np.eye(self.HMM.Nx)
453
454        LP = liveplotting.LivePlot(self, figlist, P=P0, speed=speed,
455                                   Tplot=t2-t1, replay=True, **kwargs)
456
457        # Remember: must use progbar to unblock read1.
458        # Let's also make a proper description.
459        desc = self.xp.da_method + " (replay)"
460
461        # Play through assimilation cycles
462        for k, ko, t, _dt in progbar(tseq.ticker, desc):
463            if t1 <= t <= t2:
464                if ko is not None:
465                    LP.update((k, ko, 'f'), None, None)
466                    LP.update((k, ko, 'a'), None, None)
467                LP.update((k, ko, 'u'), None, None)
468
469        # Pause required when speed=inf.
470        # On Mac, it was also necessary to do it for each fig.
471        if LP.any_figs:
472            for _name, updater in LP.figures.items():
473                if plt.fignum_exists(_name) and getattr(updater, 'is_active', 1):
474                    plt.figure(_name)
475                    plt.pause(0.01)
476
477
478def register_stat(self, name, value):
479    """Do `self.name = value` and register `name` as in self's `stat_register`.
480
481    Note: `self` is not always a `Stats` object, but could be a "child" of it.
482    """
483    setattr(self, name, value)
484    if not hasattr(self, "stat_register"):
485        self.stat_register = []
486    self.stat_register.append(name)
487
488
489class Avrgs(series.StatPrint, struct_tools.DotDict):
490    """A `dict` specialized for the averages of statistics.
491
492    Embellishments:
493
494    - `dapper.tools.StatPrint`
495    - `Avrgs.tabulate`
496    - `getattr` that supports abbreviations.
497    """
498
499    def tabulate(self, statkeys=(), decimals=None):
500        columns = tabulate_avrgs([self], statkeys, decimals=decimals)
501        return tabulate(columns, headers="keys").replace('␣', ' ')
502
503    abbrevs = {'rmse': 'err.rms', 'rmss': 'spread.rms', 'rmv': 'spread.rms'}
504
505    # Use getattribute coz it gets called before getattr.
506    def __getattribute__(self, key):
507        """Support deep and abbreviated lookup."""
508        # key = abbrevs[key] # Instead of this, also support rmse.a:
509        key = '.'.join(Avrgs.abbrevs.get(seg, seg) for seg in key.split('.'))
510
511        if "." in key:
512            return struct_tools.deep_getattr(self, key)
513        else:
514            return super().__getattribute__(key)
515
516# In case of degeneracy, variance might be 0, causing warnings
517# in computing skew/kurt/MGLS (which all normalize by variance).
518# This should and will yield nan's, but we don't want mere diagnostics
519# computations to cause repetitive warnings, so we only warn once.
520#
521# I would have expected this (more elegant solution?) to work,
522# but it just makes it worse.
523# with np.errstate(divide='warn',invalid='warn'), warnings.catch_warnings():
524# warnings.simplefilter("once",category=RuntimeWarning)
525# ...
526
527
528@do_once
529def warn_zero_variance(err, flag):
530    msg = "\n".join(["Numerical error in stat comps.",
531                     "Probably caused by a sample variance of 0."])
532    warnings.warn(msg, stacklevel=2)
533
534
535# Why not do all columns at once using the tabulate module? Coz
536#  - Want subcolumns, including fancy formatting (e.g. +/-)
537#  - Want separation (using '|') of attr and stats
538#  - ...
539def align_col(col, pad='␣', missingval='', just=">"):
540    r"""Align column.
541
542    Treats `int`s and fixed-point `float`/`str` especially, aligning on the point.
543
544    Example:
545    >>> xx = [1, 1., 1.234, 12.34, 123.4, "1.2e-3", None, np.nan, "inf", (1, 2)]
546    >>> print(*align_col(xx), sep="\n")
547    ␣␣1␣␣␣␣
548    ␣␣1.0␣␣
549    ␣␣1.234
550    ␣12.34␣
551    123.4␣␣
552    ␣1.2e-3
553    ␣␣␣␣␣␣␣
554    ␣␣␣␣nan
555    ␣␣␣␣inf
556    ␣(1, 2)
557    """
558    def split_decimal(x):
559        x = str(x)
560        try:
561            y = float(x)
562        except ValueError:
563            pass
564        else:
565            if np.isfinite(y) and ("e" not in x.lower()):
566                a, *b = x.split(".")
567                if b == []:
568                    b = "int"
569                else:
570                    b = b[0]
571                return a, b
572        return x, False
573
574    # Find max nInt, nDec
575    nInt = nDec = -1
576    for x in col:
577        ints, decs = split_decimal(x)
578        if decs:
579            nInt = max(nInt, len(ints))
580            if decs != "int":
581                nDec = max(nDec, len(decs))
582
583    # Format entries. Floats get aligned on point.
584    def frmt(x):
585        if x is None:
586            return missingval
587        ints, decs = split_decimal(x)
588        x = f"{ints.rjust(nInt, pad)}"
589        if decs == "int":
590            if nDec >= 0:
591                x += pad + pad*nDec
592        elif decs:
593            x += "." + f"{decs.ljust(nDec, pad)}"
594        else:
595            x = ints
596        return x
597
598    # Format
599    col = [frmt(x) for x in col]
600    # Find max width
601    Max = max(len(x) for x in col)
602    # Right-justify
603    shift = str.rjust if just == ">" else str.ljust
604    col = [shift(x, Max, pad) for x in col]
605    return col
606
607
608def unpack_uqs(uq_list, decimals=None):
609    """Convert list of `uq`s into dict of lists (of equal-length) of attributes.
610
611    The attributes are obtained by `vars(uq)`,
612    and may get formatted somehow (e.g. cast to strings) in the output.
613
614    If `uq` is `None`, then `None` is inserted in each list.
615    Else, `uq` must be an instance of `dapper.tools.rounding.UncertainQtty`.
616
617    Parameters
618    ----------
619    uq_list: list
620        List of `uq`s.
621
622    decimals: int
623        Desired number of decimals.
624        Used for (only) the columns "val" and "prec".
625        Default: `None`. In this case, the formatting is left to the `uq`s.
626    """
627    def frmt(uq):
628        if not isinstance(uq, series.UncertainQtty):
629            # Presumably uq is just a number
630            uq = series.UncertainQtty(uq)
631
632        attrs = vars(uq).copy()
633
634        # val/prec: round
635        if decimals is None:
636            v, p = str(uq).split(" ±")
637        else:
638            frmt = "%%.%df" % decimals
639            v, p = frmt % uq.val, frmt % uq.prec
640        attrs["val"], attrs["prec"] = v, p
641
642        # tuned_coord: convert to tuple
643        try:
644            attrs["tuned_coord"] = tuple(a for a in uq.tuned_coord)
645        except AttributeError:
646            pass
647        return attrs
648
649    cols = {}
650    for i, uq in enumerate(uq_list):
651        if uq is not None:
652            # Format
653            attrs = frmt(uq)
654            # Insert attrs as a "row" in the `cols`:
655            for k in attrs:
656                # Init column
657                if k not in cols:
658                    cols[k] = [None]*len(uq_list)
659                # Insert element
660                cols[k][i] = attrs[k]
661
662    return cols
663
664
665def tabulate_avrgs(avrgs_list, statkeys=(), decimals=None):
666    """Tabulate avrgs (val±prec)."""
667    if not statkeys:
668        statkeys = ['rmse.a', 'rmv.a', 'rmse.f']
669
670    columns = {}
671    for stat in statkeys:
672        column = [getattr(a, stat, None) for a in avrgs_list]
673        column = unpack_uqs(column, decimals)
674        if not column:
675            raise ValueError(f"The stat. key '{stat}' was not"
676                             " found among any of the averages.")
677        vals  = align_col([stat] + column["val"])
678        precs = align_col(['1σ'] + column["prec"], just="<")
679        headr = vals[0]+'  '+precs[0]
680        mattr = [f"{v} ±{c}" for v, c in zip(vals, precs)][1:]
681        columns[headr] = mattr
682
683    return columns
684
685
686def center(E, axis=0, rescale=False):
687    r"""Center ensemble.
688
689    Makes use of `np` features: keepdims and broadcasting.
690
691    Parameters
692    ----------
693    E: ndarray
694        Ensemble which going to be inflated
695
696    axis: int, optional
697        The axis to be centered. Default: 0
698
699    rescale: bool, optional
700        If True, inflate to compensate for reduction in the expected variance.
701        The inflation factor is \(\sqrt{\frac{N}{N - 1}}\)
702        where N is the ensemble size. Default: False
703
704    Returns
705    -------
706    X: ndarray
707        Ensemble anomaly
708
709    x: ndarray
710        Mean of the ensemble
711    """
712    x = np.mean(E, axis=axis, keepdims=True)
713    X = E - x
714
715    if rescale:
716        N = E.shape[axis]
717        X *= np.sqrt(N/(N-1))
718
719    x = x.squeeze(axis=axis)
720
721    return X, x
722
723
724def mean0(E, axis=0, rescale=True):
725    """Like `center`, but only return the anomalies (not the mean).
726
727    Uses `rescale=True` by default, which is beneficial
728    when used to center observation perturbations.
729    """
730    return center(E, axis=axis, rescale=rescale)[0]
731
732
733def inflate_ens(E, factor):
734    """Inflate the ensemble (center, inflate, re-combine).
735
736    Parameters
737    ----------
738    E : ndarray
739        Ensemble which going to be inflated
740
741    factor: `float`
742        Inflation factor
743
744    Returns
745    -------
746    ndarray
747        Inflated ensemble
748    """
749    if factor == 1:
750        return E
751    X, x = center(E)
752    return x + X*factor
753
754
755def weight_degeneracy(w, prec=1e-10):
756    """Check if the weights are degenerate.
757
758    If it is degenerate, the maximum weight
759    should be nearly one since sum(w) = 1
760
761    Parameters
762    ----------
763    w: ndarray
764        Importance weights. Must sum to 1.
765
766    prec: float, optional
767        Tolerance of the distance between w and one. Default:1e-10
768
769    Returns
770    -------
771    bool
772        If weight is degenerate True, else False
773    """
774    return (1-w.max()) < prec
775
776
777def unbias_var(w=None, N_eff=None, avoid_pathological=False):
778    """Compute unbias-ing factor for variance estimation.
779
780    Parameters
781    ----------
782    w: ndarray, optional
783        Importance weights. Must sum to 1.
784        Only one of `w` and `N_eff` can be `None`. Default: `None`
785
786    N_eff: float, optional
787        The "effective" size of the weighted ensemble.
788        If not provided, it is computed from the weights.
789        The unbiasing factor is $$ N_{eff} / (N_{eff} - 1) $$.
790
791    avoid_pathological: bool, optional
792        Avoid weight collapse. Default: `False`
793
794    Returns
795    -------
796    ub: float
797        factor used to unbiasing variance
798
799    Reference
800    --------
801    [Wikipedia](https://wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights)
802    """
803    if N_eff is None:
804        N_eff = 1/(w@w)
805
806    if avoid_pathological and weight_degeneracy(w):
807        ub = 1  # Don't do in case of weights collapse
808    else:
809        ub = 1/(1 - 1/N_eff)  # =N/(N-1) if w==ones(N)/N.
810    return ub
class Stats(dapper.tools.series.StatPrint):
 32class Stats(series.StatPrint):
 33    """Contains and computes statistics of the DA methods."""
 34
 35    def __init__(self, xp, HMM, xx, yy, liveplots=False, store_u=rc.store_u):
 36        """Init the default statistics."""
 37        ######################################
 38        # Preamble
 39        ######################################
 40        self.xp        = xp
 41        self.HMM       = HMM
 42        self.xx        = xx
 43        self.yy        = yy
 44        self.liveplots = liveplots
 45        self.store_u   = store_u
 46        self.store_s   = any(key in xp.__dict__ for key in
 47                             ["Lag", "DeCorr"])  # prms used by smoothers
 48
 49        # Shapes
 50        K  = xx.shape[0] - 1
 51        Nx = xx.shape[1]
 52        Ko = yy.shape[0] - 1
 53        self.K, self.Ko, self.Nx = K, Ko, Nx
 54
 55        # Methods for summarizing multivariate stats ("fields") as scalars
 56        # Don't use nanmean here; nan's should get propagated!
 57        en_mean = lambda x: np.mean(x, axis=0)  # noqa
 58        self.field_summaries = dict(
 59            m   = lambda x: en_mean(x),                  # mean-field
 60            ms  = lambda x: en_mean(x**2),               # root-mean-square
 61            rms = lambda x: np.sqrt(en_mean(x**2)),      # root-mean-square
 62            ma  = lambda x: en_mean(np.abs(x)),          # mean-absolute
 63            gm  = lambda x: np.exp(en_mean(np.log(x))),  # geometric mean
 64        )
 65        # Only keep the methods listed in rc
 66        self.field_summaries = struct_tools.intersect(self.field_summaries,
 67                                                      rc.field_summaries)
 68
 69        # Define similar methods, but restricted to sectors
 70        self.sector_summaries = {}
 71        def restrict(fun, inds): return (lambda x: fun(x[inds]))
 72        for suffix, formula in self.field_summaries.items():
 73            for sector, inds in HMM.sectors.items():
 74                f = restrict(formula, inds)
 75                self.sector_summaries['%s.%s' % (suffix, sector)] = f
 76
 77        ######################################
 78        # Allocate time series of various stats
 79        ######################################
 80        self.new_series('mu'    , Nx, field_mean='sectors')  # Mean
 81        self.new_series('spread', Nx, field_mean='sectors')  # Std. dev. ("spread")
 82        self.new_series('err'   , Nx, field_mean='sectors')  # Error (mu - truth)
 83        self.new_series('gscore', Nx, field_mean='sectors')  # Gaussian (log) score
 84
 85        # To save memory, we only store these field means:
 86        self.new_series('mad' , 1)  # Mean abs deviations
 87        self.new_series('skew', 1)  # Skewness
 88        self.new_series('kurt', 1)  # Kurtosis
 89
 90        if hasattr(xp, 'N'):
 91            N = xp.N
 92            self.new_series('w', N, field_mean=True)  # Importance weights
 93            self.new_series('rh', Nx, dtype=int)  # Rank histogram
 94
 95            self._is_ens = True
 96            minN = min(Nx, N)
 97            self.do_spectral  = np.sqrt(Nx*N) <= rc.comps["max_spectral"]
 98        else:
 99            self._is_ens = False
100            minN = Nx
101            self.do_spectral = Nx <= rc.comps["max_spectral"]
102
103        if self.do_spectral:
104            # Note: the mean-field and RMS time-series of
105            # (i) svals and (ii) umisf should match the corresponding series of
106            # (i) spread and (ii) err.
107            self.new_series('svals', minN)  # Principal component (SVD) scores
108            self.new_series('umisf', minN)  # Error in component directions
109
110        ######################################
111        # Allocate a few series for outside use
112        ######################################
113        self.new_series('trHK' , 1, Ko+1)
114        self.new_series('infl' , 1, Ko+1)
115        self.new_series('iters', 1, Ko+1)
116
117        # Weight-related
118        self.new_series('N_eff' , 1, Ko+1)
119        self.new_series('wroot' , 1, Ko+1)
120        self.new_series('resmpl', 1, Ko+1)
121
122    def new_series(self, name, shape, length='FAUSt', field_mean=False, **kws):
123        """Create (and register) a statistics time series, initialized with `nan`s.
124
125        If `length` is an integer, a `DataSeries` (a trivial subclass of
126        `numpy.ndarray`) is made. By default, though, a `series.FAUSt` is created.
127
128        NB: The `sliding_diagnostics` liveplotting relies on detecting `nan`'s
129            to avoid plotting stats that are not being used.
130            Thus, you cannot use `dtype=bool` or `int` for stats that get plotted.
131        """
132        # Convert int shape to tuple
133        if not hasattr(shape, '__len__'):
134            if shape == 1:
135                shape = ()
136            else:
137                shape = (shape,)
138
139        def make_series(parent, name, shape):
140            if length == 'FAUSt':
141                total_shape = self.K, self.Ko, shape
142                store_opts = self.store_u, self.store_s
143                tseries = series.FAUSt(*total_shape, *store_opts, **kws)
144            else:
145                total_shape = (length,)+shape
146                tseries = series.DataSeries(total_shape, *kws)
147            register_stat(parent, name, tseries)
148
149        # Principal series
150        make_series(self, name, shape)
151
152        # Summary (scalar) series:
153        if shape != ():
154            if field_mean:
155                for suffix in self.field_summaries:
156                    make_series(getattr(self, name), suffix, ())
157            # Make a nested level for sectors
158            if field_mean == 'sectors':
159                for ss in self.sector_summaries:
160                    suffix, sector = ss.split('.')
161                    make_series(struct_tools.deep_getattr(
162                        self, f"{name}.{suffix}"), sector, ())
163
164    @property
165    def data_series(self):
166        return [k for k in vars(self)
167                if isinstance(getattr(self, k), series.DataSeries)]
168
169    def assess(self, k, ko=None, faus=None,
170               E=None, w=None, mu=None, Cov=None):
171        """Common interface for both `Stats.assess_ens` and `Stats.assess_ext`.
172
173        The `_ens` assessment function gets called if `E is not None`,
174        and `_ext` if `mu is not None`.
175
176        faus: One or more of `['f',' a', 'u', 's']`, indicating
177              that the result should be stored in (respectively)
178              the forecast/analysis/universal attribute.
179              Default: `'u' if ko is None else 'au' ('a' and 'u')`.
180        """
181        # Initial consistency checks.
182        if k == 0:
183            if ko is not None:
184                raise KeyError("DAPPER convention: no obs at t=0. Helps avoid bugs.")
185            if self._is_ens == True:
186                if E is None:
187                    raise TypeError("Expected ensemble input but E is None")
188                if mu is not None:
189                    raise TypeError("Expected ensemble input but mu/Cov is not None")
190            else:
191                if E is not None:
192                    raise TypeError("Expected mu/Cov input but E is not None")
193                if mu is None:
194                    raise TypeError("Expected mu/Cov input but mu is None")
195
196        # Default. Don't add more defaults. It just gets confusing.
197        if faus is None:
198            faus = 'u' if ko is None else 'au'
199
200        # TODO 4: for faus="au" (e.g.) we don't need to re-**compute** stats,
201        #         merely re-write them?
202        for sub in faus:
203
204            # Skip assessment if ('u' and stats not stored or plotted)
205            if k != 0 and ko == None:
206                if not (self.store_u or self.LP_instance.any_figs):
207                    continue
208
209            # Silence repeat warnings caused by zero variance
210            with np.errstate(divide='call', invalid='call'):
211                np.seterrcall(warn_zero_variance)
212
213                # Assess
214                stats_now = Avrgs()
215                if self._is_ens:
216                    self.assess_ens(stats_now, self.xx[k], E, w)
217                else:
218                    self.assess_ext(stats_now, self.xx[k], mu, Cov)
219                self.derivative_stats(stats_now)
220                self.summarize_marginals(stats_now)
221
222            self.write(stats_now, k, ko, sub)
223
224            # LivePlot -- Both init and update must come after the assessment.
225            try:
226                self.LP_instance.update((k, ko, sub), E, Cov)
227            except AttributeError:
228                self.LP_instance = liveplotting.LivePlot(
229                    self, self.liveplots, (k, ko, sub), E, Cov)
230
231    def write(self, stat_dict, k, ko, sub):
232        """Write `stat_dict` to series at `(k, ko, sub)`."""
233        for name, val in stat_dict.items():
234            stat = struct_tools.deep_getattr(self, name)
235            isFaust = isinstance(stat, series.FAUSt)
236            stat[(k, ko, sub) if isFaust else ko] = val
237
238    def summarize_marginals(self, now):
239        """Compute Mean-field and RMS values."""
240        formulae = {**self.field_summaries, **self.sector_summaries}
241
242        with np.errstate(divide='ignore', invalid='ignore'):
243            for stat in list(now):
244                field = now[stat]
245                for suffix, formula in formulae.items():
246                    statpath = stat+'.'+suffix
247                    if struct_tools.deep_hasattr(self, statpath):
248                        now[statpath] = formula(field)
249
250    def derivative_stats(self, now):
251        """Stats that derive from others, and are not specific for `_ens` or `_ext`)."""
252        try:
253            now.gscore = 2*np.log(now.spread) + (now.err/now.spread)**2
254        except AttributeError:
255            # happens in case rc.comps['error_only']
256            pass
257
258    def assess_ens(self, now, x, E, w):
259        """Ensemble and Particle filter (weighted/importance) assessment."""
260        N, Nx = E.shape
261
262        # weights
263        if w is None:
264            w = np.ones(N)/N  # All equal. Also, rm attr from stats:
265            if hasattr(self, 'w'):
266                delattr(self, 'w')
267            # Use non-weight formula (since w=None) for mu computations.
268            # The savings are noticeable when rc.comps['error_only'] is noticeable.
269            now.mu = E.mean(0)
270        else:
271            now.w = w
272            if abs(w.sum()-1) > 1e-5:
273                raise RuntimeError("Weights did not sum to one.")
274            now.mu = w @ E
275
276        # Crash checks
277        if not np.all(np.isfinite(E)):
278            raise RuntimeError("Ensemble not finite.")
279        if not np.all(np.isreal(E)):
280            raise RuntimeError("Ensemble not Real.")
281
282        # Compute errors
283        now.err = now.mu - x
284        if rc.comps['error_only']:
285            return
286
287        A = E - now.mu
288        # While A**2 is approx as fast as A*A,
289        # A**3 is 10x slower than A**2 (or A**2.0).
290        # => Use A2 = A**2, A3 = A*A2, A4=A*A3.
291        # But, to save memory, only use A_pow.
292        A_pow = A**2
293
294        # Compute variances
295        var  = w @ A_pow
296        ub   = unbias_var(w, avoid_pathological=True)
297        var *= ub
298
299        # Compute standard deviation ("Spread")
300        s = np.sqrt(var)  # NB: biased (even though var is unbiased)
301        now.spread = s
302
303        # For simplicity, use naive (biased) formulae, derived
304        # from "empirical measure". See doc/unbiased_skew_kurt.jpg.
305        # Normalize by var. Compute "excess" kurt, which is 0 for Gaussians.
306        A_pow *= A
307        now.skew = np.nanmean(w @ A_pow / (s*s*s))
308        A_pow *= A
309        now.kurt = np.nanmean(w @ A_pow / var**2 - 3)
310
311        now.mad  = np.nanmean(w @ abs(A))
312
313        if self.do_spectral:
314            if N <= Nx:
315                _, s, UT  = sla.svd((np.sqrt(w)*A.T).T, full_matrices=False)
316                s        *= np.sqrt(ub)  # Makes s^2 unbiased
317                now.svals = s
318                now.umisf = UT @ now.err
319            else:
320                P         = (A.T * w) @ A
321                s2, U     = sla.eigh(P)
322                s2       *= ub
323                now.svals = np.sqrt(s2.clip(0))[::-1]
324                now.umisf = U.T[::-1] @ now.err
325
326            # For each state dim [i], compute rank of truth (x) among the ensemble (E)
327            E_x = np.sort(np.vstack((E, x)), axis=0, kind='heapsort')
328            now.rh = np.asarray(
329                [np.where(E_x[:, i] == x[i])[0][0] for i in range(Nx)])
330
331    def assess_ext(self, now, x, mu, P):
332        """Kalman filter (Gaussian) assessment."""
333        if not np.all(np.isfinite(mu)):
334            raise RuntimeError("Estimates not finite.")
335        if not np.all(np.isreal(mu)):
336            raise RuntimeError("Estimates not Real.")
337        # Don't check the cov (might not be explicitly availble)
338
339        # Compute errors
340        now.mu  = mu
341        now.err = now.mu - x
342        if rc.comps['error_only']:
343            return
344
345        # Get diag(P)
346        if P is None:
347            var = np.zeros_like(mu)
348        elif np.isscalar(P):
349            var = np.ones_like(mu) * P
350        else:
351            if isinstance(P, CovMat):
352                var = P.diag
353                P   = P.full
354            else:
355                var = np.diag(P)
356
357            if self.do_spectral:
358                s2, U     = sla.eigh(P)
359                now.svals = np.sqrt(np.maximum(s2, 0.0))[::-1]
360                now.umisf = (U.T @ now.err)[::-1]
361
362        # Compute stddev
363        now.spread = np.sqrt(var)
364        # Here, sqrt(2/pi) is the ratio, of MAD/Spread for Gaussians
365        now.mad = np.nanmean(now.spread) * np.sqrt(2/np.pi)
366
367    def average_in_time(self, kk=None, kko=None, free=False):
368        """Avarage all univariate (scalar) time series.
369
370        - `kk`    time inds for averaging
371        - `kko` time inds for averaging obs
372        """
373        tseq = self.HMM.tseq
374        if kk is None:
375            kk     = tseq.mask
376        if kko is None:
377            kko  = tseq.masko
378
379        def average1(tseries):
380            avrgs = Avrgs()
381
382            def average_multivariate(): return avrgs
383            # Plain averages of nd-series are rarely interesting.
384            # => Shortcircuit => Leave for manual computations
385
386            if isinstance(tseries, series.FAUSt):
387                # Average series for each subscript
388                if tseries.item_shape != ():
389                    return average_multivariate()
390                for sub in [ch for ch in 'fas' if hasattr(tseries, ch)]:
391                    avrgs[sub] = series.mean_with_conf(tseries[kko, sub])
392                if tseries.store_u:
393                    avrgs['u'] = series.mean_with_conf(tseries[kk, 'u'])
394
395            elif isinstance(tseries, series.DataSeries):
396                if tseries.array.shape[1:] != ():
397                    return average_multivariate()
398                elif len(tseries.array) == self.Ko+1:
399                    avrgs = series.mean_with_conf(tseries[kko])
400                elif len(tseries.array) == self.K+1:
401                    avrgs = series.mean_with_conf(tseries[kk])
402                else:
403                    raise ValueError
404
405            elif np.isscalar(tseries):
406                avrgs = tseries  # Eg. just copy over "duration" from stats
407
408            else:
409                raise TypeError(f"Don't know how to average {tseries}")
410
411            return avrgs
412
413        def recurse_average(stat_parent, avrgs_parent):
414            for key in getattr(stat_parent, "stat_register", []):
415                try:
416                    tseries = getattr(stat_parent, key)
417                except AttributeError:
418                    continue  # Eg assess_ens() deletes .weights if None
419                avrgs = average1(tseries)
420                recurse_average(tseries, avrgs)
421                avrgs_parent[key] = avrgs
422
423        avrgs = Avrgs()
424        recurse_average(self, avrgs)
425        self.xp.avrgs = avrgs
426        if free:
427            delattr(self.xp, 'stats')
428
429    def replay(self, figlist="default", speed=np.inf, t1=0, t2=None, **kwargs):
430        """Replay LivePlot with what's been stored in 'self'.
431
432        - t1, t2: time window to plot.
433        - 'figlist' and 'speed': See LivePlot's doc.
434
435        .. note:: `store_u` (whether to store non-obs-time stats) must
436        have been `True` to have smooth graphs as in the actual LivePlot.
437
438        .. note:: Ensembles are generally not stored in the stats
439        and so cannot be replayed.
440        """
441        # Time settings
442        tseq = self.HMM.tseq
443        if t2 is None:
444            t2 = t1 + tseq.Tplot
445
446        # Ens does not get stored in stats, so we cannot replay that.
447        # If the LPs are initialized with P0!=None, then they will avoid ens plotting.
448        # TODO 4: This system for switching from Ens to stats must be replaced.
449        #       It breaks down when M is very large.
450        try:
451            P0 = np.full_like(self.HMM.X0.C.full, np.nan)
452        except AttributeError:  # e.g. if X0 is defined via sampling func
453            P0 = np.eye(self.HMM.Nx)
454
455        LP = liveplotting.LivePlot(self, figlist, P=P0, speed=speed,
456                                   Tplot=t2-t1, replay=True, **kwargs)
457
458        # Remember: must use progbar to unblock read1.
459        # Let's also make a proper description.
460        desc = self.xp.da_method + " (replay)"
461
462        # Play through assimilation cycles
463        for k, ko, t, _dt in progbar(tseq.ticker, desc):
464            if t1 <= t <= t2:
465                if ko is not None:
466                    LP.update((k, ko, 'f'), None, None)
467                    LP.update((k, ko, 'a'), None, None)
468                LP.update((k, ko, 'u'), None, None)
469
470        # Pause required when speed=inf.
471        # On Mac, it was also necessary to do it for each fig.
472        if LP.any_figs:
473            for _name, updater in LP.figures.items():
474                if plt.fignum_exists(_name) and getattr(updater, 'is_active', 1):
475                    plt.figure(_name)
476                    plt.pause(0.01)

Contains and computes statistics of the DA methods.

Stats(xp, HMM, xx, yy, liveplots=False, store_u=False)
 35    def __init__(self, xp, HMM, xx, yy, liveplots=False, store_u=rc.store_u):
 36        """Init the default statistics."""
 37        ######################################
 38        # Preamble
 39        ######################################
 40        self.xp        = xp
 41        self.HMM       = HMM
 42        self.xx        = xx
 43        self.yy        = yy
 44        self.liveplots = liveplots
 45        self.store_u   = store_u
 46        self.store_s   = any(key in xp.__dict__ for key in
 47                             ["Lag", "DeCorr"])  # prms used by smoothers
 48
 49        # Shapes
 50        K  = xx.shape[0] - 1
 51        Nx = xx.shape[1]
 52        Ko = yy.shape[0] - 1
 53        self.K, self.Ko, self.Nx = K, Ko, Nx
 54
 55        # Methods for summarizing multivariate stats ("fields") as scalars
 56        # Don't use nanmean here; nan's should get propagated!
 57        en_mean = lambda x: np.mean(x, axis=0)  # noqa
 58        self.field_summaries = dict(
 59            m   = lambda x: en_mean(x),                  # mean-field
 60            ms  = lambda x: en_mean(x**2),               # root-mean-square
 61            rms = lambda x: np.sqrt(en_mean(x**2)),      # root-mean-square
 62            ma  = lambda x: en_mean(np.abs(x)),          # mean-absolute
 63            gm  = lambda x: np.exp(en_mean(np.log(x))),  # geometric mean
 64        )
 65        # Only keep the methods listed in rc
 66        self.field_summaries = struct_tools.intersect(self.field_summaries,
 67                                                      rc.field_summaries)
 68
 69        # Define similar methods, but restricted to sectors
 70        self.sector_summaries = {}
 71        def restrict(fun, inds): return (lambda x: fun(x[inds]))
 72        for suffix, formula in self.field_summaries.items():
 73            for sector, inds in HMM.sectors.items():
 74                f = restrict(formula, inds)
 75                self.sector_summaries['%s.%s' % (suffix, sector)] = f
 76
 77        ######################################
 78        # Allocate time series of various stats
 79        ######################################
 80        self.new_series('mu'    , Nx, field_mean='sectors')  # Mean
 81        self.new_series('spread', Nx, field_mean='sectors')  # Std. dev. ("spread")
 82        self.new_series('err'   , Nx, field_mean='sectors')  # Error (mu - truth)
 83        self.new_series('gscore', Nx, field_mean='sectors')  # Gaussian (log) score
 84
 85        # To save memory, we only store these field means:
 86        self.new_series('mad' , 1)  # Mean abs deviations
 87        self.new_series('skew', 1)  # Skewness
 88        self.new_series('kurt', 1)  # Kurtosis
 89
 90        if hasattr(xp, 'N'):
 91            N = xp.N
 92            self.new_series('w', N, field_mean=True)  # Importance weights
 93            self.new_series('rh', Nx, dtype=int)  # Rank histogram
 94
 95            self._is_ens = True
 96            minN = min(Nx, N)
 97            self.do_spectral  = np.sqrt(Nx*N) <= rc.comps["max_spectral"]
 98        else:
 99            self._is_ens = False
100            minN = Nx
101            self.do_spectral = Nx <= rc.comps["max_spectral"]
102
103        if self.do_spectral:
104            # Note: the mean-field and RMS time-series of
105            # (i) svals and (ii) umisf should match the corresponding series of
106            # (i) spread and (ii) err.
107            self.new_series('svals', minN)  # Principal component (SVD) scores
108            self.new_series('umisf', minN)  # Error in component directions
109
110        ######################################
111        # Allocate a few series for outside use
112        ######################################
113        self.new_series('trHK' , 1, Ko+1)
114        self.new_series('infl' , 1, Ko+1)
115        self.new_series('iters', 1, Ko+1)
116
117        # Weight-related
118        self.new_series('N_eff' , 1, Ko+1)
119        self.new_series('wroot' , 1, Ko+1)
120        self.new_series('resmpl', 1, Ko+1)

Init the default statistics.

xp
HMM
xx
yy
liveplots
store_u
store_s
field_summaries
sector_summaries
def new_series(self, name, shape, length='FAUSt', field_mean=False, **kws):
122    def new_series(self, name, shape, length='FAUSt', field_mean=False, **kws):
123        """Create (and register) a statistics time series, initialized with `nan`s.
124
125        If `length` is an integer, a `DataSeries` (a trivial subclass of
126        `numpy.ndarray`) is made. By default, though, a `series.FAUSt` is created.
127
128        NB: The `sliding_diagnostics` liveplotting relies on detecting `nan`'s
129            to avoid plotting stats that are not being used.
130            Thus, you cannot use `dtype=bool` or `int` for stats that get plotted.
131        """
132        # Convert int shape to tuple
133        if not hasattr(shape, '__len__'):
134            if shape == 1:
135                shape = ()
136            else:
137                shape = (shape,)
138
139        def make_series(parent, name, shape):
140            if length == 'FAUSt':
141                total_shape = self.K, self.Ko, shape
142                store_opts = self.store_u, self.store_s
143                tseries = series.FAUSt(*total_shape, *store_opts, **kws)
144            else:
145                total_shape = (length,)+shape
146                tseries = series.DataSeries(total_shape, *kws)
147            register_stat(parent, name, tseries)
148
149        # Principal series
150        make_series(self, name, shape)
151
152        # Summary (scalar) series:
153        if shape != ():
154            if field_mean:
155                for suffix in self.field_summaries:
156                    make_series(getattr(self, name), suffix, ())
157            # Make a nested level for sectors
158            if field_mean == 'sectors':
159                for ss in self.sector_summaries:
160                    suffix, sector = ss.split('.')
161                    make_series(struct_tools.deep_getattr(
162                        self, f"{name}.{suffix}"), sector, ())

Create (and register) a statistics time series, initialized with nans.

If length is an integer, a DataSeries (a trivial subclass of numpy.ndarray) is made. By default, though, a series.FAUSt is created.

NB: The sliding_diagnostics liveplotting relies on detecting nan's to avoid plotting stats that are not being used. Thus, you cannot use dtype=bool or int for stats that get plotted.

data_series
164    @property
165    def data_series(self):
166        return [k for k in vars(self)
167                if isinstance(getattr(self, k), series.DataSeries)]
def assess(self, k, ko=None, faus=None, E=None, w=None, mu=None, Cov=None):
169    def assess(self, k, ko=None, faus=None,
170               E=None, w=None, mu=None, Cov=None):
171        """Common interface for both `Stats.assess_ens` and `Stats.assess_ext`.
172
173        The `_ens` assessment function gets called if `E is not None`,
174        and `_ext` if `mu is not None`.
175
176        faus: One or more of `['f',' a', 'u', 's']`, indicating
177              that the result should be stored in (respectively)
178              the forecast/analysis/universal attribute.
179              Default: `'u' if ko is None else 'au' ('a' and 'u')`.
180        """
181        # Initial consistency checks.
182        if k == 0:
183            if ko is not None:
184                raise KeyError("DAPPER convention: no obs at t=0. Helps avoid bugs.")
185            if self._is_ens == True:
186                if E is None:
187                    raise TypeError("Expected ensemble input but E is None")
188                if mu is not None:
189                    raise TypeError("Expected ensemble input but mu/Cov is not None")
190            else:
191                if E is not None:
192                    raise TypeError("Expected mu/Cov input but E is not None")
193                if mu is None:
194                    raise TypeError("Expected mu/Cov input but mu is None")
195
196        # Default. Don't add more defaults. It just gets confusing.
197        if faus is None:
198            faus = 'u' if ko is None else 'au'
199
200        # TODO 4: for faus="au" (e.g.) we don't need to re-**compute** stats,
201        #         merely re-write them?
202        for sub in faus:
203
204            # Skip assessment if ('u' and stats not stored or plotted)
205            if k != 0 and ko == None:
206                if not (self.store_u or self.LP_instance.any_figs):
207                    continue
208
209            # Silence repeat warnings caused by zero variance
210            with np.errstate(divide='call', invalid='call'):
211                np.seterrcall(warn_zero_variance)
212
213                # Assess
214                stats_now = Avrgs()
215                if self._is_ens:
216                    self.assess_ens(stats_now, self.xx[k], E, w)
217                else:
218                    self.assess_ext(stats_now, self.xx[k], mu, Cov)
219                self.derivative_stats(stats_now)
220                self.summarize_marginals(stats_now)
221
222            self.write(stats_now, k, ko, sub)
223
224            # LivePlot -- Both init and update must come after the assessment.
225            try:
226                self.LP_instance.update((k, ko, sub), E, Cov)
227            except AttributeError:
228                self.LP_instance = liveplotting.LivePlot(
229                    self, self.liveplots, (k, ko, sub), E, Cov)

Common interface for both Stats.assess_ens and Stats.assess_ext.

The _ens assessment function gets called if E is not None, and _ext if mu is not None.

faus: One or more of ['f',' a', 'u', 's'], indicating that the result should be stored in (respectively) the forecast/analysis/universal attribute. Default: 'u' if ko is None else 'au' ('a' and 'u').

def write(self, stat_dict, k, ko, sub):
231    def write(self, stat_dict, k, ko, sub):
232        """Write `stat_dict` to series at `(k, ko, sub)`."""
233        for name, val in stat_dict.items():
234            stat = struct_tools.deep_getattr(self, name)
235            isFaust = isinstance(stat, series.FAUSt)
236            stat[(k, ko, sub) if isFaust else ko] = val

Write stat_dict to series at (k, ko, sub).

def summarize_marginals(self, now):
238    def summarize_marginals(self, now):
239        """Compute Mean-field and RMS values."""
240        formulae = {**self.field_summaries, **self.sector_summaries}
241
242        with np.errstate(divide='ignore', invalid='ignore'):
243            for stat in list(now):
244                field = now[stat]
245                for suffix, formula in formulae.items():
246                    statpath = stat+'.'+suffix
247                    if struct_tools.deep_hasattr(self, statpath):
248                        now[statpath] = formula(field)

Compute Mean-field and RMS values.

def derivative_stats(self, now):
250    def derivative_stats(self, now):
251        """Stats that derive from others, and are not specific for `_ens` or `_ext`)."""
252        try:
253            now.gscore = 2*np.log(now.spread) + (now.err/now.spread)**2
254        except AttributeError:
255            # happens in case rc.comps['error_only']
256            pass

Stats that derive from others, and are not specific for _ens or _ext).

def assess_ens(self, now, x, E, w):
258    def assess_ens(self, now, x, E, w):
259        """Ensemble and Particle filter (weighted/importance) assessment."""
260        N, Nx = E.shape
261
262        # weights
263        if w is None:
264            w = np.ones(N)/N  # All equal. Also, rm attr from stats:
265            if hasattr(self, 'w'):
266                delattr(self, 'w')
267            # Use non-weight formula (since w=None) for mu computations.
268            # The savings are noticeable when rc.comps['error_only'] is noticeable.
269            now.mu = E.mean(0)
270        else:
271            now.w = w
272            if abs(w.sum()-1) > 1e-5:
273                raise RuntimeError("Weights did not sum to one.")
274            now.mu = w @ E
275
276        # Crash checks
277        if not np.all(np.isfinite(E)):
278            raise RuntimeError("Ensemble not finite.")
279        if not np.all(np.isreal(E)):
280            raise RuntimeError("Ensemble not Real.")
281
282        # Compute errors
283        now.err = now.mu - x
284        if rc.comps['error_only']:
285            return
286
287        A = E - now.mu
288        # While A**2 is approx as fast as A*A,
289        # A**3 is 10x slower than A**2 (or A**2.0).
290        # => Use A2 = A**2, A3 = A*A2, A4=A*A3.
291        # But, to save memory, only use A_pow.
292        A_pow = A**2
293
294        # Compute variances
295        var  = w @ A_pow
296        ub   = unbias_var(w, avoid_pathological=True)
297        var *= ub
298
299        # Compute standard deviation ("Spread")
300        s = np.sqrt(var)  # NB: biased (even though var is unbiased)
301        now.spread = s
302
303        # For simplicity, use naive (biased) formulae, derived
304        # from "empirical measure". See doc/unbiased_skew_kurt.jpg.
305        # Normalize by var. Compute "excess" kurt, which is 0 for Gaussians.
306        A_pow *= A
307        now.skew = np.nanmean(w @ A_pow / (s*s*s))
308        A_pow *= A
309        now.kurt = np.nanmean(w @ A_pow / var**2 - 3)
310
311        now.mad  = np.nanmean(w @ abs(A))
312
313        if self.do_spectral:
314            if N <= Nx:
315                _, s, UT  = sla.svd((np.sqrt(w)*A.T).T, full_matrices=False)
316                s        *= np.sqrt(ub)  # Makes s^2 unbiased
317                now.svals = s
318                now.umisf = UT @ now.err
319            else:
320                P         = (A.T * w) @ A
321                s2, U     = sla.eigh(P)
322                s2       *= ub
323                now.svals = np.sqrt(s2.clip(0))[::-1]
324                now.umisf = U.T[::-1] @ now.err
325
326            # For each state dim [i], compute rank of truth (x) among the ensemble (E)
327            E_x = np.sort(np.vstack((E, x)), axis=0, kind='heapsort')
328            now.rh = np.asarray(
329                [np.where(E_x[:, i] == x[i])[0][0] for i in range(Nx)])

Ensemble and Particle filter (weighted/importance) assessment.

def assess_ext(self, now, x, mu, P):
331    def assess_ext(self, now, x, mu, P):
332        """Kalman filter (Gaussian) assessment."""
333        if not np.all(np.isfinite(mu)):
334            raise RuntimeError("Estimates not finite.")
335        if not np.all(np.isreal(mu)):
336            raise RuntimeError("Estimates not Real.")
337        # Don't check the cov (might not be explicitly availble)
338
339        # Compute errors
340        now.mu  = mu
341        now.err = now.mu - x
342        if rc.comps['error_only']:
343            return
344
345        # Get diag(P)
346        if P is None:
347            var = np.zeros_like(mu)
348        elif np.isscalar(P):
349            var = np.ones_like(mu) * P
350        else:
351            if isinstance(P, CovMat):
352                var = P.diag
353                P   = P.full
354            else:
355                var = np.diag(P)
356
357            if self.do_spectral:
358                s2, U     = sla.eigh(P)
359                now.svals = np.sqrt(np.maximum(s2, 0.0))[::-1]
360                now.umisf = (U.T @ now.err)[::-1]
361
362        # Compute stddev
363        now.spread = np.sqrt(var)
364        # Here, sqrt(2/pi) is the ratio, of MAD/Spread for Gaussians
365        now.mad = np.nanmean(now.spread) * np.sqrt(2/np.pi)

Kalman filter (Gaussian) assessment.

def average_in_time(self, kk=None, kko=None, free=False):
367    def average_in_time(self, kk=None, kko=None, free=False):
368        """Avarage all univariate (scalar) time series.
369
370        - `kk`    time inds for averaging
371        - `kko` time inds for averaging obs
372        """
373        tseq = self.HMM.tseq
374        if kk is None:
375            kk     = tseq.mask
376        if kko is None:
377            kko  = tseq.masko
378
379        def average1(tseries):
380            avrgs = Avrgs()
381
382            def average_multivariate(): return avrgs
383            # Plain averages of nd-series are rarely interesting.
384            # => Shortcircuit => Leave for manual computations
385
386            if isinstance(tseries, series.FAUSt):
387                # Average series for each subscript
388                if tseries.item_shape != ():
389                    return average_multivariate()
390                for sub in [ch for ch in 'fas' if hasattr(tseries, ch)]:
391                    avrgs[sub] = series.mean_with_conf(tseries[kko, sub])
392                if tseries.store_u:
393                    avrgs['u'] = series.mean_with_conf(tseries[kk, 'u'])
394
395            elif isinstance(tseries, series.DataSeries):
396                if tseries.array.shape[1:] != ():
397                    return average_multivariate()
398                elif len(tseries.array) == self.Ko+1:
399                    avrgs = series.mean_with_conf(tseries[kko])
400                elif len(tseries.array) == self.K+1:
401                    avrgs = series.mean_with_conf(tseries[kk])
402                else:
403                    raise ValueError
404
405            elif np.isscalar(tseries):
406                avrgs = tseries  # Eg. just copy over "duration" from stats
407
408            else:
409                raise TypeError(f"Don't know how to average {tseries}")
410
411            return avrgs
412
413        def recurse_average(stat_parent, avrgs_parent):
414            for key in getattr(stat_parent, "stat_register", []):
415                try:
416                    tseries = getattr(stat_parent, key)
417                except AttributeError:
418                    continue  # Eg assess_ens() deletes .weights if None
419                avrgs = average1(tseries)
420                recurse_average(tseries, avrgs)
421                avrgs_parent[key] = avrgs
422
423        avrgs = Avrgs()
424        recurse_average(self, avrgs)
425        self.xp.avrgs = avrgs
426        if free:
427            delattr(self.xp, 'stats')

Avarage all univariate (scalar) time series.

  • kk time inds for averaging
  • kko time inds for averaging obs
def replay(self, figlist='default', speed=inf, t1=0, t2=None, **kwargs):
429    def replay(self, figlist="default", speed=np.inf, t1=0, t2=None, **kwargs):
430        """Replay LivePlot with what's been stored in 'self'.
431
432        - t1, t2: time window to plot.
433        - 'figlist' and 'speed': See LivePlot's doc.
434
435        .. note:: `store_u` (whether to store non-obs-time stats) must
436        have been `True` to have smooth graphs as in the actual LivePlot.
437
438        .. note:: Ensembles are generally not stored in the stats
439        and so cannot be replayed.
440        """
441        # Time settings
442        tseq = self.HMM.tseq
443        if t2 is None:
444            t2 = t1 + tseq.Tplot
445
446        # Ens does not get stored in stats, so we cannot replay that.
447        # If the LPs are initialized with P0!=None, then they will avoid ens plotting.
448        # TODO 4: This system for switching from Ens to stats must be replaced.
449        #       It breaks down when M is very large.
450        try:
451            P0 = np.full_like(self.HMM.X0.C.full, np.nan)
452        except AttributeError:  # e.g. if X0 is defined via sampling func
453            P0 = np.eye(self.HMM.Nx)
454
455        LP = liveplotting.LivePlot(self, figlist, P=P0, speed=speed,
456                                   Tplot=t2-t1, replay=True, **kwargs)
457
458        # Remember: must use progbar to unblock read1.
459        # Let's also make a proper description.
460        desc = self.xp.da_method + " (replay)"
461
462        # Play through assimilation cycles
463        for k, ko, t, _dt in progbar(tseq.ticker, desc):
464            if t1 <= t <= t2:
465                if ko is not None:
466                    LP.update((k, ko, 'f'), None, None)
467                    LP.update((k, ko, 'a'), None, None)
468                LP.update((k, ko, 'u'), None, None)
469
470        # Pause required when speed=inf.
471        # On Mac, it was also necessary to do it for each fig.
472        if LP.any_figs:
473            for _name, updater in LP.figures.items():
474                if plt.fignum_exists(_name) and getattr(updater, 'is_active', 1):
475                    plt.figure(_name)
476                    plt.pause(0.01)

Replay LivePlot with what's been stored in 'self'.

  • t1, t2: time window to plot.
  • 'figlist' and 'speed': See LivePlot's doc.
store_u (whether to store non-obs-time stats) must

have been True to have smooth graphs as in the actual LivePlot.

Ensembles are generally not stored in the stats

and so cannot be replayed.

def register_stat(self, name, value):
479def register_stat(self, name, value):
480    """Do `self.name = value` and register `name` as in self's `stat_register`.
481
482    Note: `self` is not always a `Stats` object, but could be a "child" of it.
483    """
484    setattr(self, name, value)
485    if not hasattr(self, "stat_register"):
486        self.stat_register = []
487    self.stat_register.append(name)

Do self.name = value and register name as in self's stat_register.

Note: self is not always a Stats object, but could be a "child" of it.

class Avrgs(dapper.tools.series.StatPrint, struct_tools.DotDict):
490class Avrgs(series.StatPrint, struct_tools.DotDict):
491    """A `dict` specialized for the averages of statistics.
492
493    Embellishments:
494
495    - `dapper.tools.StatPrint`
496    - `Avrgs.tabulate`
497    - `getattr` that supports abbreviations.
498    """
499
500    def tabulate(self, statkeys=(), decimals=None):
501        columns = tabulate_avrgs([self], statkeys, decimals=decimals)
502        return tabulate(columns, headers="keys").replace('␣', ' ')
503
504    abbrevs = {'rmse': 'err.rms', 'rmss': 'spread.rms', 'rmv': 'spread.rms'}
505
506    # Use getattribute coz it gets called before getattr.
507    def __getattribute__(self, key):
508        """Support deep and abbreviated lookup."""
509        # key = abbrevs[key] # Instead of this, also support rmse.a:
510        key = '.'.join(Avrgs.abbrevs.get(seg, seg) for seg in key.split('.'))
511
512        if "." in key:
513            return struct_tools.deep_getattr(self, key)
514        else:
515            return super().__getattribute__(key)

A dict specialized for the averages of statistics.

Embellishments:

  • dapper.tools.StatPrint
  • Avrgs.tabulate
  • getattr that supports abbreviations.
def tabulate(self, statkeys=(), decimals=None):
500    def tabulate(self, statkeys=(), decimals=None):
501        columns = tabulate_avrgs([self], statkeys, decimals=decimals)
502        return tabulate(columns, headers="keys").replace('␣', ' ')
abbrevs = {'rmse': 'err.rms', 'rmss': 'spread.rms', 'rmv': 'spread.rms'}
Inherited Members
struct_tools.DotDict
DotDict
dapper.tools.series.StatPrint
printopts
builtins.dict
get
setdefault
pop
popitem
keys
items
values
update
fromkeys
clear
copy
def align_col(col, pad='␣', missingval='', just='>'):
540def align_col(col, pad='␣', missingval='', just=">"):
541    r"""Align column.
542
543    Treats `int`s and fixed-point `float`/`str` especially, aligning on the point.
544
545    Example:
546    >>> xx = [1, 1., 1.234, 12.34, 123.4, "1.2e-3", None, np.nan, "inf", (1, 2)]
547    >>> print(*align_col(xx), sep="\n")
548    ␣␣1␣␣␣␣
549    ␣␣1.0␣␣
550    ␣␣1.234
551    ␣12.34␣
552    123.4␣␣
553    ␣1.2e-3
554    ␣␣␣␣␣␣␣
555    ␣␣␣␣nan
556    ␣␣␣␣inf
557    ␣(1, 2)
558    """
559    def split_decimal(x):
560        x = str(x)
561        try:
562            y = float(x)
563        except ValueError:
564            pass
565        else:
566            if np.isfinite(y) and ("e" not in x.lower()):
567                a, *b = x.split(".")
568                if b == []:
569                    b = "int"
570                else:
571                    b = b[0]
572                return a, b
573        return x, False
574
575    # Find max nInt, nDec
576    nInt = nDec = -1
577    for x in col:
578        ints, decs = split_decimal(x)
579        if decs:
580            nInt = max(nInt, len(ints))
581            if decs != "int":
582                nDec = max(nDec, len(decs))
583
584    # Format entries. Floats get aligned on point.
585    def frmt(x):
586        if x is None:
587            return missingval
588        ints, decs = split_decimal(x)
589        x = f"{ints.rjust(nInt, pad)}"
590        if decs == "int":
591            if nDec >= 0:
592                x += pad + pad*nDec
593        elif decs:
594            x += "." + f"{decs.ljust(nDec, pad)}"
595        else:
596            x = ints
597        return x
598
599    # Format
600    col = [frmt(x) for x in col]
601    # Find max width
602    Max = max(len(x) for x in col)
603    # Right-justify
604    shift = str.rjust if just == ">" else str.ljust
605    col = [shift(x, Max, pad) for x in col]
606    return col

Align column.

Treats ints and fixed-point float/str especially, aligning on the point.

Example:

>>> xx = [1, 1., 1.234, 12.34, 123.4, "1.2e-3", None, np.nan, "inf", (1, 2)]
>>> print(*align_col(xx), sep="\n")
␣␣1␣␣␣␣
␣␣1.0␣␣
␣␣1.234
␣12.34␣
123.4␣␣
␣1.2e-3
␣␣␣␣␣␣␣
␣␣␣␣nan
␣␣␣␣inf
␣(1, 2)
def unpack_uqs(uq_list, decimals=None):
609def unpack_uqs(uq_list, decimals=None):
610    """Convert list of `uq`s into dict of lists (of equal-length) of attributes.
611
612    The attributes are obtained by `vars(uq)`,
613    and may get formatted somehow (e.g. cast to strings) in the output.
614
615    If `uq` is `None`, then `None` is inserted in each list.
616    Else, `uq` must be an instance of `dapper.tools.rounding.UncertainQtty`.
617
618    Parameters
619    ----------
620    uq_list: list
621        List of `uq`s.
622
623    decimals: int
624        Desired number of decimals.
625        Used for (only) the columns "val" and "prec".
626        Default: `None`. In this case, the formatting is left to the `uq`s.
627    """
628    def frmt(uq):
629        if not isinstance(uq, series.UncertainQtty):
630            # Presumably uq is just a number
631            uq = series.UncertainQtty(uq)
632
633        attrs = vars(uq).copy()
634
635        # val/prec: round
636        if decimals is None:
637            v, p = str(uq).split(" ±")
638        else:
639            frmt = "%%.%df" % decimals
640            v, p = frmt % uq.val, frmt % uq.prec
641        attrs["val"], attrs["prec"] = v, p
642
643        # tuned_coord: convert to tuple
644        try:
645            attrs["tuned_coord"] = tuple(a for a in uq.tuned_coord)
646        except AttributeError:
647            pass
648        return attrs
649
650    cols = {}
651    for i, uq in enumerate(uq_list):
652        if uq is not None:
653            # Format
654            attrs = frmt(uq)
655            # Insert attrs as a "row" in the `cols`:
656            for k in attrs:
657                # Init column
658                if k not in cols:
659                    cols[k] = [None]*len(uq_list)
660                # Insert element
661                cols[k][i] = attrs[k]
662
663    return cols

Convert list of uqs into dict of lists (of equal-length) of attributes.

The attributes are obtained by vars(uq), and may get formatted somehow (e.g. cast to strings) in the output.

If uq is None, then None is inserted in each list. Else, uq must be an instance of dapper.tools.rounding.UncertainQtty.

Parameters
  • uq_list (list): List of uqs.
  • decimals (int): Desired number of decimals. Used for (only) the columns "val" and "prec". Default: None. In this case, the formatting is left to the uqs.
def tabulate_avrgs(avrgs_list, statkeys=(), decimals=None):
666def tabulate_avrgs(avrgs_list, statkeys=(), decimals=None):
667    """Tabulate avrgs (val±prec)."""
668    if not statkeys:
669        statkeys = ['rmse.a', 'rmv.a', 'rmse.f']
670
671    columns = {}
672    for stat in statkeys:
673        column = [getattr(a, stat, None) for a in avrgs_list]
674        column = unpack_uqs(column, decimals)
675        if not column:
676            raise ValueError(f"The stat. key '{stat}' was not"
677                             " found among any of the averages.")
678        vals  = align_col([stat] + column["val"])
679        precs = align_col(['1σ'] + column["prec"], just="<")
680        headr = vals[0]+'  '+precs[0]
681        mattr = [f"{v} ±{c}" for v, c in zip(vals, precs)][1:]
682        columns[headr] = mattr
683
684    return columns

Tabulate avrgs (val±prec).

def center(E, axis=0, rescale=False):
687def center(E, axis=0, rescale=False):
688    r"""Center ensemble.
689
690    Makes use of `np` features: keepdims and broadcasting.
691
692    Parameters
693    ----------
694    E: ndarray
695        Ensemble which going to be inflated
696
697    axis: int, optional
698        The axis to be centered. Default: 0
699
700    rescale: bool, optional
701        If True, inflate to compensate for reduction in the expected variance.
702        The inflation factor is \(\sqrt{\frac{N}{N - 1}}\)
703        where N is the ensemble size. Default: False
704
705    Returns
706    -------
707    X: ndarray
708        Ensemble anomaly
709
710    x: ndarray
711        Mean of the ensemble
712    """
713    x = np.mean(E, axis=axis, keepdims=True)
714    X = E - x
715
716    if rescale:
717        N = E.shape[axis]
718        X *= np.sqrt(N/(N-1))
719
720    x = x.squeeze(axis=axis)
721
722    return X, x

Center ensemble.

Makes use of np features: keepdims and broadcasting.

Parameters
  • E (ndarray): Ensemble which going to be inflated
  • axis (int, optional): The axis to be centered. Default: 0
  • rescale (bool, optional): If True, inflate to compensate for reduction in the expected variance. The inflation factor is (\sqrt{\frac{N}{N - 1}}) where N is the ensemble size. Default: False
Returns
  • X (ndarray): Ensemble anomaly
  • x (ndarray): Mean of the ensemble
def mean0(E, axis=0, rescale=True):
725def mean0(E, axis=0, rescale=True):
726    """Like `center`, but only return the anomalies (not the mean).
727
728    Uses `rescale=True` by default, which is beneficial
729    when used to center observation perturbations.
730    """
731    return center(E, axis=axis, rescale=rescale)[0]

Like center, but only return the anomalies (not the mean).

Uses rescale=True by default, which is beneficial when used to center observation perturbations.

def inflate_ens(E, factor):
734def inflate_ens(E, factor):
735    """Inflate the ensemble (center, inflate, re-combine).
736
737    Parameters
738    ----------
739    E : ndarray
740        Ensemble which going to be inflated
741
742    factor: `float`
743        Inflation factor
744
745    Returns
746    -------
747    ndarray
748        Inflated ensemble
749    """
750    if factor == 1:
751        return E
752    X, x = center(E)
753    return x + X*factor

Inflate the ensemble (center, inflate, re-combine).

Parameters
  • E (ndarray): Ensemble which going to be inflated
  • factor (float): Inflation factor
Returns
  • ndarray: Inflated ensemble
def weight_degeneracy(w, prec=1e-10):
756def weight_degeneracy(w, prec=1e-10):
757    """Check if the weights are degenerate.
758
759    If it is degenerate, the maximum weight
760    should be nearly one since sum(w) = 1
761
762    Parameters
763    ----------
764    w: ndarray
765        Importance weights. Must sum to 1.
766
767    prec: float, optional
768        Tolerance of the distance between w and one. Default:1e-10
769
770    Returns
771    -------
772    bool
773        If weight is degenerate True, else False
774    """
775    return (1-w.max()) < prec

Check if the weights are degenerate.

If it is degenerate, the maximum weight should be nearly one since sum(w) = 1

Parameters
  • w (ndarray): Importance weights. Must sum to 1.
  • prec (float, optional): Tolerance of the distance between w and one. Default:1e-10
Returns
  • bool: If weight is degenerate True, else False
def unbias_var(w=None, N_eff=None, avoid_pathological=False):
778def unbias_var(w=None, N_eff=None, avoid_pathological=False):
779    """Compute unbias-ing factor for variance estimation.
780
781    Parameters
782    ----------
783    w: ndarray, optional
784        Importance weights. Must sum to 1.
785        Only one of `w` and `N_eff` can be `None`. Default: `None`
786
787    N_eff: float, optional
788        The "effective" size of the weighted ensemble.
789        If not provided, it is computed from the weights.
790        The unbiasing factor is $$ N_{eff} / (N_{eff} - 1) $$.
791
792    avoid_pathological: bool, optional
793        Avoid weight collapse. Default: `False`
794
795    Returns
796    -------
797    ub: float
798        factor used to unbiasing variance
799
800    Reference
801    --------
802    [Wikipedia](https://wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights)
803    """
804    if N_eff is None:
805        N_eff = 1/(w@w)
806
807    if avoid_pathological and weight_degeneracy(w):
808        ub = 1  # Don't do in case of weights collapse
809    else:
810        ub = 1/(1 - 1/N_eff)  # =N/(N-1) if w==ones(N)/N.
811    return ub

Compute unbias-ing factor for variance estimation.

Parameters
  • w (ndarray, optional): Importance weights. Must sum to 1. Only one of w and N_eff can be None. Default: None
  • N_eff (float, optional): The "effective" size of the weighted ensemble. If not provided, it is computed from the weights. The unbiasing factor is $$ N_{eff} / (N_{eff} - 1) $$.
  • avoid_pathological (bool, optional): Avoid weight collapse. Default: False
Returns
  • ub (float): factor used to unbiasing variance
Reference

Wikipedia