dapper.tools.liveplotting

On-line (live) plots of the DA process for various models and methods.

Liveplotters are given by a list of tuples as property or arguments in dapper.mods.HiddenMarkovModel.

  • The first element of the tuple determines whether the liveplotter is shown if the names of liveplotters are not given by liveplots argument in assimilate.

  • The second element in the tuple gives the corresponding liveplotter function/class. See example of function LPs in dapper.mods.Lorenz63.

The liveplotters can be fine-tuned by each DA experiments via argument of liveplots when calling assimilate.

  • liveplots = True turns on liveplotters set to default in the first argument of the HMM.liveplotter and default liveplotters defined in this module (sliding_diagnostics and weight_histogram).

  • liveplots can also be a list of specified names of liveplotter, which is the name of the corresponding liveplotting classes/functions.

   1"""On-line (live) plots of the DA process for various models and methods.
   2
   3Liveplotters are given by a list of tuples as property or arguments in
   4`dapper.mods.HiddenMarkovModel`.
   5
   6- The first element of the tuple determines whether the liveplotter is shown if
   7the names of liveplotters are not given by `liveplots` argument in
   8`assimilate`.
   9
  10- The second element in the tuple gives the corresponding liveplotter
  11function/class. See example of function `LPs` in `dapper.mods.Lorenz63`.
  12
  13The liveplotters can be fine-tuned by each DA experiments via argument of
  14`liveplots` when calling `assimilate`.
  15
  16- `liveplots = True` turns on liveplotters set to default in the first
  17argument of the `HMM.liveplotter` and default liveplotters defined in this module
  18(`sliding_diagnostics` and `weight_histogram`).
  19
  20- `liveplots` can also be a list of specified names of liveplotter, which
  21is the name of the corresponding liveplotting classes/functions.
  22"""
  23
  24import numpy as np
  25import scipy.linalg as sla
  26from matplotlib import pyplot as plt
  27from matplotlib.ticker import MaxNLocator
  28from mpl_toolkits.mplot3d.art3d import juggle_axes
  29from mpl_tools import is_notebook_or_qt, place, place_ax
  30from numpy import arange, nan, ones
  31from struct_tools import DotDict, deep_getattr
  32
  33import dapper.tools.progressbar as pb
  34import dapper.tools.viz as viz
  35from dapper.dpr_config import rc
  36from dapper.mods.utils import linspace_int
  37from dapper.tools.chronos import format_time
  38from dapper.tools.matrices import CovMat
  39from dapper.tools.progressbar import read1
  40from dapper.tools.series import FAUSt, RollingArray
  41from dapper.tools.viz import not_available_text, plot_pause
  42
  43
  44class LivePlot:
  45    """Live plotting manager.
  46
  47    Deals with
  48
  49    - Pause, skip.
  50    - Which liveploters to call.
  51    - `plot_u`
  52    - Figure window (title and number).
  53    """
  54
  55    def __init__(
  56        self,
  57        stats,
  58        liveplots,
  59        key0=(0, None, "u"),
  60        E=None,
  61        P=None,
  62        speed=1.0,
  63        replay=False,
  64        **kwargs,
  65    ):
  66        """
  67        Initialize plots.
  68
  69        - liveplots: figures to plot; alternatives:
  70            - `"default"/[]/True`: All default figures for this HMM.
  71            - `"all"`            : Even more.
  72            - non-empty `list`   : Only the figures with these numbers
  73                                 (int) or names (str).
  74            - `False`            : None.
  75        - speed: speed of animation.
  76            - `>100`: instantaneous
  77            - `1`   : (default) as quick as possible allowing for
  78                      plt.draw() to work on a moderately fast computer.
  79            - `<1`  : slower.
  80        """
  81        # Disable if not rc.liveplotting
  82        self.any_figs = False
  83        if rc.liveplotting:
  84            pass
  85        elif replay and np.isinf(speed):
  86            pass
  87        else:
  88            return
  89
  90        # Determine whether all/universal/intermediate stats are plotted
  91        self.plot_u = not replay or stats.store_u
  92
  93        # Set speed/pause params
  94        self.params = {
  95            "pause_f": 0.05,
  96            "pause_a": 0.05,
  97            "pause_s": 0.05,
  98            "pause_u": 0.001,
  99        }
 100        # If speed>100: set to inf. Coz pause=1e-99 causes hangup.
 101        for pause in ["pause_" + x for x in "faus"]:
 102            speed = speed if speed < 100 else np.inf
 103            self.params[pause] /= speed
 104
 105        # Write params
 106        self.params.update(getattr(stats.xp, "LP_kwargs", {}))
 107        self.params.update(kwargs)
 108
 109        def get_name(init):
 110            """Get name of liveplotter function/class."""
 111            try:
 112                return init.__qualname__.split(".")[0]
 113            except AttributeError:
 114                return init.__class__.__name__
 115
 116        # Set up dict of liveplotters
 117        potential_LPs = {}
 118        for show, init in default_liveplotters:
 119            potential_LPs[get_name(init)] = show, init
 120        # Add HMM-specific liveplotters
 121        for show, init in getattr(stats.HMM, "liveplotters", {}):
 122            potential_LPs[get_name(init)] = show, init
 123
 124        def parse_figlist(lst):
 125            """Figures requested for this xp. Convert to list."""
 126            if isinstance(lst, str):
 127                fn = lst.lower()
 128                if "all" == fn:
 129                    lst = ["all"]  # All potential_LPs
 130                elif "default" in fn:
 131                    lst = ["default"]  # All show_by_default
 132            elif hasattr(lst, "__len__"):
 133                lst = lst  # This list (only)
 134            elif lst:
 135                lst = ["default"]  # All show_by_default
 136            else:
 137                lst = [None]  # None
 138            return lst
 139
 140        figlist = parse_figlist(liveplots)
 141
 142        # Loop over requeted figures
 143        self.figures = {}
 144        for name, (show_by_default, init) in potential_LPs.items():
 145            if (
 146                (figlist == ["all"])
 147                or (name in figlist)
 148                or (figlist == ["default"] and show_by_default)
 149            ):
 150                # Startup message
 151                if not self.any_figs:
 152                    print("Initializing liveplots...")
 153                    if is_notebook_or_qt:
 154                        pauses = [self.params["pause_" + x] for x in "faus"]
 155                        if any((p > 0) for p in pauses):
 156                            print(
 157                                "Note: liveplotting does not work very well"
 158                                " inside Jupyter notebooks. In particular,"
 159                                " there is no way to stop/skip them except"
 160                                " to interrupt the kernel (the stop button"
 161                                " in the toolbar). Consider using instead"
 162                                " only the replay functionality (with infinite"
 163                                " playback speed)."
 164                            )
 165                    elif not pb.disable_user_interaction:
 166                        print("Hit <Space> to pause/step.")
 167                        print("Hit <Enter> to resume/skip.")
 168                        print("Hit <i> to enter debug mode.")
 169                    self.paused = False
 170                    self.run_ipdb = False
 171                    self.skipping = False
 172                    self.any_figs = True
 173
 174                # Init figure
 175                post_title = "" if self.plot_u else "\n(obs times only)"
 176                updater = init(name, stats, key0, self.plot_u, E, P, **kwargs)
 177                if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
 178                    self.figures[name] = updater
 179                    fig = plt.figure(name)
 180                    win = fig.canvas
 181                    ax0 = fig.axes[0]
 182                    win.manager.set_window_title(str(name))
 183                    ax0.set_title(ax0.get_title() + post_title)
 184                    self.update(key0, E, P)  # Call initial update
 185                    if not (replay and np.isinf(speed)):
 186                        plt.pause(0.01)  # Draw
 187
 188    def update(self, key, E, P):
 189        """Update liveplots"""
 190        # Check if there are still open figures
 191        if self.any_figs:
 192            open_figns = plt.get_figlabels()
 193            live_figns = set(self.figures.keys())
 194            self.any_figs = bool(live_figns.intersection(open_figns))
 195        else:
 196            return
 197
 198        # Playback control
 199        SPACE = b" "
 200        CHAR_I = b"i"
 201        ENTERs = [b"\n", b"\r"]  # Linux + Windows
 202
 203        def pause():
 204            """Loop until user decision is made."""
 205            ch = read1()
 206            while True:
 207                # Set state (pause, skipping, ipdb)
 208                if ch in ENTERs:
 209                    self.paused = False
 210                elif ch == CHAR_I:
 211                    self.run_ipdb = True
 212                # If keypress valid, resume execution
 213                if ch in ENTERs + [SPACE, CHAR_I]:
 214                    break
 215                ch = read1()
 216                # Pause to enable zoom, pan, etc. of mpl GUI
 217                plot_pause(0.01)  # Don't use time.sleep()!
 218
 219        # Enter pause loop
 220        if self.paused:
 221            pause()
 222
 223        else:
 224            if key == (0, None, "u"):
 225                # Skip read1 for key0 (coz it blocks)
 226                pass
 227            else:
 228                ch = read1()
 229                if ch == SPACE:
 230                    # Pause
 231                    self.paused = True
 232                    self.skipping = False
 233                    pause()
 234                elif ch in ENTERs:
 235                    # Toggle skipping
 236                    self.skipping = not self.skipping
 237                elif ch == CHAR_I:
 238                    # Schedule debug
 239                    # Note: The reason we dont set_trace(frame) right here is:
 240                    # - I could not find the right frame, even doing
 241                    #   >   frame = inspect.stack()[0]
 242                    #   >   while frame.f_code.co_name != "assimilate":
 243                    #   >       frame = frame.f_back
 244                    # - It just restarts the plot.
 245                    self.run_ipdb = True
 246
 247        # Update figures
 248        if not self.skipping:
 249            faus = key[-1]
 250            if faus != "u" or self.plot_u:
 251                for name, (updater) in self.figures.items():
 252                    if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
 253                        _ = plt.figure(name)
 254                        updater(key, E, P)
 255                        plot_pause(self.params["pause_" + faus])
 256
 257        if self.run_ipdb:
 258            self.run_ipdb = False
 259            import inspect
 260
 261            import ipdb
 262
 263            print("Entering debug mode (ipdb).")
 264            print("Type '?' (and Enter) for usage help.")
 265            print("Type 'c' to continue the assimilation.")
 266            ipdb.set_trace(inspect.stack()[2].frame)
 267
 268
 269# TODO 6:
 270# - iEnKS diagnostics don't work at all when store_u=False
 271star = "${}^*$"
 272
 273
 274class sliding_diagnostics:
 275    """Plots a sliding window (like a heart rate monitor) of certain diagnostics."""
 276
 277    def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs):
 278        # STYLE TABLES - Defines which/how diagnostics get plotted
 279        styles = {}
 280
 281        def lin(a, b):
 282            return lambda x: a + b * x
 283
 284        divN = 1 / getattr(stats.xp, "N", 99)
 285        # Columns: transf, shape, plt kwargs
 286        styles["RMS"] = {
 287            "err.rms": [None, None, dict(c="k", label="Error")],
 288            "spread.rms": [None, None, dict(c="b", label="Spread", alpha=0.6)],
 289        }
 290        styles["Values"] = {
 291            "skew": [None, None, dict(c="g", label=star + r"Skew/$\sigma^3$")],
 292            "kurt": [None, None, dict(c="r", label=star + r"Kurt$/\sigma^4{-}3$")],
 293            "trHK": [None, None, dict(c="k", label=star + "HK")],
 294            "infl": [lin(-10, 10), "step", dict(c="c", label="10(infl-1)")],
 295            "N_eff": [lin(0, divN), "dirac", dict(c="y", label="N_eff/N", lw=3)],
 296            "iters": [lin(0, 0.1), "dirac", dict(c="m", label="iters/10")],
 297            "resmpl": [None, "dirac", dict(c="k", label="resampled?")],
 298        }
 299
 300        nAx = len(styles)
 301        GS = {"left": 0.125, "right": 0.76}
 302        fig, axs = place.freshfig(
 303            fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS
 304        )
 305
 306        axs[0].set_title("Diagnostics")
 307        for style, ax in zip(styles, axs):
 308            ax.set_ylabel(style)
 309        ax.set_xlabel("Time (t)")
 310        place_ax.adjust_position(ax, y0=0.03)
 311
 312        self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq)
 313
 314        def init_ax(ax, style_table):
 315            lines = {}
 316            for name in style_table:
 317                # SKIP -- if stats[name] is not in existence
 318                # Note: The nan check/deletion comes after the first ko.
 319                try:
 320                    stat = deep_getattr(stats, name)
 321                except AttributeError:
 322                    continue
 323                # try: val0 = stat[key0[0]]
 324                # except KeyError: continue
 325                # PS: recall (from series.py) that even if store_u is false, stat[k] is
 326                # still present if liveplots=True via the k_tmp functionality.
 327
 328                # Unpack style
 329                ln = {}
 330                ln["transf"] = style_table[name][0] or (lambda x: x)
 331                ln["shape"] = style_table[name][1]
 332                ln["plt"] = style_table[name][2]
 333
 334                # Create series
 335                if isinstance(stat, FAUSt):
 336                    ln["plot_u"] = plot_u
 337                    K_plot = comp_K_plot(K_lag, a_lag, ln["plot_u"])
 338                else:
 339                    ln["plot_u"] = False
 340                    K_plot = a_lag
 341                ln["data"] = RollingArray(K_plot)
 342                ln["tt"] = RollingArray(K_plot)
 343
 344                # Plot (init)
 345                (ln["handle"],) = ax.plot(ln["tt"], ln["data"], **ln["plt"])
 346
 347                # Plotting only nans yield ugly limits. Revert to defaults.
 348                ax.set_xlim(0, 1)
 349                ax.set_ylim(0, 1)
 350
 351                lines[name] = ln
 352            return lines
 353
 354        # Plot
 355        self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)]
 356
 357        # Horizontal line at y=0
 358        (self.baseline0,) = ax.plot(
 359            ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label="_nolegend_"
 360        )
 361
 362        # Store
 363        self.axs = axs
 364        self.stats = stats
 365        self.init_incomplete = True
 366
 367    # Update plot
 368    def __call__(self, key, E, P):
 369        k, ko, faus = key
 370
 371        stats = self.stats
 372        tseq = stats.HMM.tseq
 373        ax0, ax1 = self.axs
 374
 375        def update_arrays(lines):
 376            for name, ln in lines.items():
 377                stat = deep_getattr(stats, name)
 378                t = tseq.tt[k]  # == tseq.tto[ko]
 379                if isinstance(stat, FAUSt):
 380                    # ln['data'] will contain duplicates for f/a times.
 381                    if ln["plot_u"]:
 382                        val = stat[key]
 383                        ln["tt"].insert(k, t)
 384                        ln["data"].insert(k, ln["transf"](val))
 385                    elif "u" not in faus:
 386                        val = stat[key]
 387                        ln["tt"].insert(ko, t)
 388                        ln["data"].insert(ko, ln["transf"](val))
 389                else:
 390                    # ln['data'] will not contain duplicates, coz only 'a' is input.
 391                    if "a" in faus:
 392                        val = stat[ko]
 393                        ln["tt"].insert(ko, t)
 394                        ln["data"].insert(ko, ln["transf"](val))
 395                    elif "f" in faus:
 396                        pass
 397
 398        def update_plot_data(ax, lines):
 399            def bend_into(shape, xx, yy):
 400                # Get arrays. Repeat (to use for intermediate nodes).
 401                yy = yy.array.repeat(3)
 402                xx = xx.array.repeat(3)
 403                if len(xx) == 0:
 404                    pass  # shortcircuit any modifications
 405                elif shape == "step":
 406                    yy = np.hstack([yy[1:], nan])  # roll leftward
 407                elif shape == "dirac":
 408                    nonlocal nDirac
 409                    axW = np.diff(ax.get_xlim())
 410                    yy[0::3] = False  # set datapoin to 0
 411                    xx[2::3] = nan  # make datapoint disappear
 412                    xx += nDirac * axW / 100  # offset datapoint horizontally
 413                    nDirac += 1
 414                return xx, yy
 415
 416            nDirac = 1
 417            for _name, ln in lines.items():
 418                ln["handle"].set_data(*bend_into(ln["shape"], ln["tt"], ln["data"]))
 419
 420        def finalize_init(ax, lines, mm):
 421            # Rm lines that only contain NaNs
 422            for name in list(lines):
 423                ln = lines[name]
 424                stat = deep_getattr(stats, name)
 425                if not stat.were_changed:
 426                    ln["handle"].remove()  # rm from axes
 427                    del lines[name]  # rm from dict
 428            # Add legends
 429            if lines:
 430                ax.legend(loc="upper left", bbox_to_anchor=(1.01, 1), borderaxespad=0)
 431                if mm:
 432                    ax.annotate(
 433                        star + ": mean of\nmarginals",
 434                        xy=(0, -1.5 / len(lines)),
 435                        xycoords=ax.get_legend().get_frame(),
 436                        bbox=dict(alpha=0.0),
 437                        fontsize="small",
 438                    )
 439            # coz placement of annotate needs flush sometimes:
 440            plot_pause(0.01)
 441
 442        # Insert current stats
 443        for lines, ax in zip(self.d, self.axs):
 444            update_arrays(lines)
 445            update_plot_data(ax, lines)
 446
 447        # Set x-limits (time)
 448        sliding_xlim(ax0, self.d[0]["err.rms"]["tt"], self.T_lag, margin=True)
 449        self.baseline0.set_xdata(ax0.get_xlim())
 450
 451        # Set y-limits
 452        data0 = [ln["data"].array for ln in self.d[0].values()]
 453        data1 = [ln["data"].array for ln in self.d[1].values()]
 454        ax0.set_ylim(0, d_ylim(data0, ax0, cC=0.2, cE=0.9)[1])
 455        ax1.set_ylim(*d_ylim(data1, ax1, Max=4, Min=-4, cC=0.3, cE=0.9))
 456
 457        # Init legend. Rm nan lines.
 458        if self.init_incomplete and "a" == faus:
 459            self.init_incomplete = False
 460            finalize_init(ax0, self.d[0], False)
 461            finalize_init(ax1, self.d[1], True)
 462
 463
 464def sliding_xlim(ax, tt, lag, margin=False):
 465    dt = lag / 20 if margin else 0
 466    if tt.nFilled == 0:
 467        return  # Quit
 468    t1, t2 = tt.span()  # Get suggested span.
 469    s1, s2 = ax.get_xlim()  # Get previous lims.
 470    # If zero span (eg tt holds single 'f' and 'a'):
 471    if t1 == t2:
 472        t1 -= 1  # add width
 473        t2 += 1  # add width
 474    # If user has skipped (too much):
 475    elif np.isnan(t1):
 476        s2 -= dt  # Correct for dt.
 477        span = s2 - s1  # Compute previous span
 478        # If span<lag:
 479        if span < lag:
 480            span += t2 - s2  # Grow by "dt".
 481        span = min(lag, span)  # Bound
 482        t1 = t2 - span  # Set span.
 483    ax.set_xlim(t1, t2 + dt)  # Set xlim to span
 484
 485
 486class weight_histogram:
 487    """Plots histogram of weights. Refreshed each analysis."""
 488
 489    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
 490        if not hasattr(stats, "w"):
 491            self.is_active = False
 492            return
 493        fig, ax = place.freshfig(fignum, figsize=(7, 3), gridspec_kw={"bottom": 0.15})
 494
 495        ax.set_xscale("log")
 496        ax.set_xlabel("Weigth")
 497        ax.set_ylabel("Count")
 498        self.stats = stats
 499        self.ax = ax
 500        self.hist = []
 501        self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
 502
 503    def __call__(self, key, E, P):
 504        k, ko, faus = key
 505        if "a" == faus:
 506            w = self.stats.w[key]
 507            N = len(w)
 508            ax = self.ax
 509
 510            self.is_active = N < 10001
 511            if not self.is_active:
 512                not_available_text(ax, "Not computed (N > threshold)")
 513                return
 514
 515            counted = w > self.bins[0]
 516            _ = [b.remove() for b in self.hist]
 517            nn, _, self.hist = ax.hist(w[counted], bins=self.bins, color="b")
 518            ax.set_ylim(top=max(nn))
 519
 520            ax.set_title(
 521                f"N: {N:d}.   N_eff: {1/(w@w):.4g}."
 522                "   Not shown: {N-np.sum(counted):d}. "
 523            )
 524
 525
 526class spectral_errors:
 527    """Plots the (spatial-RMS) error as a functional of the SVD index."""
 528
 529    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
 530        fig, ax = place.freshfig(fignum, figsize=(6, 3))
 531        ax.set_xlabel("Sing. value index")
 532        ax.set_yscale("log")
 533        self.init_incomplete = True
 534        self.ax = ax
 535        self.plot_u = plot_u
 536
 537        try:
 538            self.msft = stats.umisf
 539            self.sprd = stats.svals
 540        except AttributeError:
 541            self.is_active = False
 542            not_available_text(ax, "Spectral stats not being computed")
 543
 544    # Update plot
 545    def __call__(self, key, E, P):
 546        k, ko, faus = key
 547        ax = self.ax
 548        if self.init_incomplete:
 549            if self.plot_u or "f" == faus:
 550                self.init_incomplete = False
 551                msft = abs(self.msft[key])
 552                sprd = self.sprd[key]
 553                if np.any(np.isinf(msft)):
 554                    not_available_text(ax, "Spectral stats not finite")
 555                    self.is_active = False
 556                else:
 557                    (self.line_msft,) = ax.plot(msft, "k", lw=2, label="Error")
 558                    (self.line_sprd,) = ax.plot(
 559                        sprd, "b", lw=2, label="Spread", alpha=0.9
 560                    )
 561                    ax.get_xaxis().set_major_locator(MaxNLocator(integer=True))
 562                    ax.legend()
 563        else:
 564            msft = abs(self.msft[key])
 565            sprd = self.sprd[key]
 566            self.line_sprd.set_ydata(sprd)
 567            self.line_msft.set_ydata(msft)
 568        # ax.set_ylim(*d_ylim(msft))
 569        # ax.set_ylim(bottom=1e-5)
 570        ax.set_ylim([1e-3, 1e1])
 571
 572
 573class correlations:
 574    """Plots the state (auto-)correlation matrix."""
 575
 576    half = True  # Whether to show half/full (symmetric) corr matrix.
 577
 578    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
 579        GS = {"height_ratios": [4, 1], "hspace": 0.09, "top": 0.95}
 580        fig, (ax, ax2) = place.freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS)
 581
 582        if E is None and np.isnan(P.diag if isinstance(P, CovMat) else P).all():
 583            not_available_text(
 584                ax, ("Not available in replays" "\ncoz full Ens/Cov not stored.")
 585            )
 586            self.is_active = False
 587            return
 588
 589        Nx = len(stats.mu[key0])
 590        if Nx <= 1003:
 591            C = np.eye(Nx)
 592            # Mask half
 593            mask = np.zeros_like(C, dtype=bool)
 594            mask[np.tril_indices_from(mask)] = True
 595            cmap = plt.get_cmap("RdBu_r")
 596            VM = 1.0  # abs(np.percentile(C,[1,99])).max()
 597            im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM)
 598            # Colorbar
 599            _ = ax.figure.colorbar(im, ax=ax, shrink=0.8)
 600            # Tune plot
 601            plt.box(False)
 602            ax.set_facecolor("w")
 603            ax.grid(False)
 604            ax.set_title("State correlation matrix:", y=1.07)
 605            ax.xaxis.tick_top()
 606
 607            # ax2 = inset_axes(ax,width="30%",height="60%",loc=3)
 608            (line_AC,) = ax2.plot(arange(Nx), ones(Nx), label="Correlation")
 609            (line_AA,) = ax2.plot(arange(Nx), ones(Nx), label="Abs. corr.")
 610            _ = ax2.hlines(0, 0, Nx - 1, "k", "dotted", lw=1)
 611            # Align ax2 with ax
 612            bb_AC = ax2.get_position()
 613            bb_C = ax.get_position()
 614            ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height])
 615            # Tune plot
 616            ax2.set_title("Auto-correlation:")
 617            ax2.set_ylabel("Mean value")
 618            ax2.set_xlabel("Distance (in state indices)")
 619            ax2.set_xticklabels([])
 620            ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]]))
 621            ax2.set_ylim(top=1)
 622            ax2.legend(
 623                frameon=True,
 624                facecolor="w",
 625                bbox_to_anchor=(1, 1),
 626                loc="upper left",
 627                borderaxespad=0.02,
 628            )
 629
 630            self.ax = ax
 631            self.ax2 = ax2
 632            self.im = im
 633            self.line_AC = line_AC
 634            self.line_AA = line_AA
 635            self.mask = mask
 636            if hasattr(stats, "w"):
 637                self.w = stats.w
 638        else:
 639            not_available_text(ax)
 640
 641    # Update plot
 642    def __call__(self, key, E, P):
 643        # Get cov matrix
 644        if E is not None:
 645            if hasattr(self, "w"):
 646                C = np.cov(E, rowvar=False, aweights=self.w[key])
 647            else:
 648                C = np.cov(E, rowvar=False)
 649        else:
 650            assert P is not None
 651            C = P.full if isinstance(P, CovMat) else P
 652            C = C.copy()
 653        # Compute corr from cov
 654        std = np.sqrt(np.diag(C))
 655        C /= std[:, None]
 656        C /= std[None, :]
 657        # Mask
 658        if self.half:
 659            C = np.ma.masked_where(self.mask, C)
 660        # Plot
 661        self.im.set_data(C)
 662        # Auto-corr function
 663        ACF = circulant_ACF(C)
 664        AAF = circulant_ACF(C, do_abs=True)
 665        self.line_AC.set_ydata(ACF)
 666        self.line_AA.set_ydata(AAF)
 667
 668
 669def circulant_ACF(C, do_abs=False):
 670    """Compute the auto-covariance-function corresponding to `C`.
 671
 672    This assumes it is the cov/corr matrix of a 1D periodic domain.
 673
 674    Vectorized or FFT implementations are
 675    [possible](https://stackoverflow.com/questions/20360675).
 676    """
 677    M = len(C)
 678    # cols = np.flipud(sla.circulant(np.arange(M)[::-1]))
 679    cols = sla.circulant(np.arange(M))
 680    ACF = np.zeros(M)
 681    for i in range(M):
 682        row = C[i, cols[i]]
 683        if do_abs:
 684            row = abs(row)
 685        ACF += row
 686        # Note: this actually also accesses masked values in C.
 687    return ACF / M
 688
 689
 690def sliding_marginals(
 691    obs_inds=(),
 692    dims=(),
 693    labels=(),
 694    Tplot=None,
 695    ens_props=dict(alpha=0.4),  # noqa
 696    zoomy=1.0,
 697):
 698    # Store parameters
 699    params_orig = DotDict(**locals())
 700
 701    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
 702        xx, yy, mu, spread, tseq = (
 703            stats.xx,
 704            stats.yy,
 705            stats.mu,
 706            stats.spread,
 707            stats.HMM.tseq,
 708        )
 709
 710        # Set parameters (kwargs takes precedence over params_orig)
 711        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
 712
 713        # Chose marginal dims to plot
 714        if not len(p.dims):
 715            p.dims = linspace_int(xx.shape[-1], min(10, xx.shape[-1]))
 716
 717        # Lag settings:
 718        T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
 719        K_plot = comp_K_plot(K_lag, a_lag, plot_u)
 720        # Extend K_plot forther for adding blanks in resampling (PartFilt):
 721        has_w = hasattr(stats, "w")
 722        if has_w:
 723            K_plot += a_lag
 724
 725        # Set up figure, axes
 726        fig, axs = place.freshfig(
 727            fignum, figsize=(5, 7), squeeze=False, nrows=len(p.dims), sharex=True
 728        )
 729        axs = axs.reshape(len(p.dims))
 730
 731        # Tune plots
 732        axs[0].set_title("Marginal time series")
 733        for ix, (m, ax) in enumerate(zip(p.dims, axs)):
 734            # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy))
 735            if not p.labels:
 736                ax.set_ylabel("$x_{%d}$" % m)
 737            else:
 738                ax.set_ylabel(p.labels[ix])
 739        axs[-1].set_xlabel("Time (t)")
 740
 741        plot_pause(0.05)
 742        plt.tight_layout()
 743
 744        # Allocate
 745        d = DotDict()  # data arrays
 746        h = DotDict()  # plot handles
 747        # Why "if True" ? Just to indent the rest of the line...
 748        if True:
 749            d.t = RollingArray((K_plot,))
 750        if True:
 751            d.x = RollingArray((K_plot, len(p.dims)))
 752            h.x = []
 753        if not_empty(p.obs_inds):
 754            d.y = RollingArray((K_plot, len(p.dims)))
 755            h.y = []
 756        if E is not None:
 757            d.E = RollingArray((K_plot, len(E), len(p.dims)))
 758            h.E = []
 759        if P is not None:
 760            d.mu = RollingArray((K_plot, len(p.dims)))
 761            h.mu = []
 762        if P is not None:
 763            d.s = RollingArray((K_plot, 2, len(p.dims)))
 764            h.s = []
 765
 766        # Plot (invisible coz everything here is nan, for the moment).
 767        for ix, ax in zip(p.dims, axs):
 768            if True:
 769                h.x += ax.plot(d.t, d.x[:, ix], "k")
 770            if not_empty(p.obs_inds):
 771                h.y += ax.plot(d.t, d.y[:, ix], "g*", ms=10)
 772            if "E" in d:
 773                h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)]
 774            if "mu" in d:
 775                h.mu += ax.plot(d.t, d.mu[:, ix], "b")
 776            if "s" in d:
 777                h.s += [ax.plot(d.t, d.s[:, :, ix], "b--", lw=1)]
 778
 779        def update(key, E, P):
 780            k, ko, faus = key
 781
 782            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)
 783
 784            # Roll data array
 785            ind = k if plot_u else ko
 786            for Ens in EE:  # If E is duplicated, so must the others be.
 787                if "E" in d:
 788                    d.E.insert(ind, Ens)
 789                if "mu" in d:
 790                    d.mu.insert(ind, mu[key][p.dims])
 791                if "s" in d:
 792                    d.s.insert(ind, mu[key][p.dims] + [[1], [-1]] * spread[key][p.dims])
 793                if True:
 794                    d.t.insert(ind, tseq.tt[k])
 795                if not_empty(p.obs_inds):
 796                    xy = nan * ones(len(p.dims))
 797                    if ko is not None:
 798                        jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
 799                        xy[jj] = yy[ko]
 800                    d.y.insert(ind, xy)
 801                if True:
 802                    d.x.insert(ind, xx[k, p.dims])
 803
 804            # Update graphs
 805            for ix, ax in zip(p.dims, axs):
 806                sliding_xlim(ax, d.t, T_lag, True)
 807                if True:
 808                    h.x[ix].set_data(d.t, d.x[:, ix])
 809                if not_empty(p.obs_inds):
 810                    h.y[ix].set_data(d.t, d.y[:, ix])
 811                if "mu" in d:
 812                    h.mu[ix].set_data(d.t, d.mu[:, ix])
 813                if "s" in d:
 814                    [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]]
 815                if "E" in d:
 816                    [h.E[ix][n].set_data(d.t, d.E[:, n, ix]) for n in range(len(E))]
 817                if "E" in d:
 818                    update_alpha(key, stats, h.E[ix])
 819
 820                # TODO 3: fixup. This might be slow?
 821                # In any case, it is very far from tested.
 822                # Also, relim'iting all of the time is distracting.
 823                # Use d_ylim?
 824                if "E" in d:
 825                    lims = d.E
 826                elif "mu" in d:
 827                    lims = d.mu
 828                lims = np.array(viz.xtrema(lims[..., ix]))
 829                if lims[0] == lims[1]:
 830                    lims += [-0.5, +0.5]
 831                ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy))
 832
 833            return
 834
 835        return update
 836
 837    return init
 838
 839
 840def phase_particles(
 841    is_3d=True,
 842    obs_inds=(),
 843    dims=(),
 844    labels=(),
 845    Tplot=None,
 846    ens_props=dict(alpha=0.4),  # noqa
 847    zoom=1.5,
 848):
 849    # Store parameters
 850    params_orig = DotDict(**locals())
 851
 852    M = 3 if is_3d else 2
 853
 854    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
 855        xx, yy, mu, _, tseq = stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq
 856
 857        # Set parameters (kwargs takes precedence over params_orig)
 858        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
 859
 860        # Lag settings:
 861        has_w = hasattr(stats, "w")
 862        if p.Tplot == 0:
 863            K_plot = 1
 864        else:
 865            T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
 866            K_plot = comp_K_plot(K_lag, a_lag, plot_u)
 867            # Extend K_plot forther for adding blanks in resampling (PartFilt):
 868            if has_w:
 869                K_plot += a_lag
 870
 871        # Dimension settings
 872        if not p.dims:
 873            p.dims = arange(M)
 874        if not p.labels:
 875            p.labels = ["$x_%d$" % d for d in p.dims]
 876        assert len(p.dims) == M
 877
 878        # Set up figure, axes
 879        fig, _ = place.freshfig(fignum, figsize=(5, 5))
 880        ax = plt.subplot(111, projection="3d" if is_3d else None)
 881        ax.set_facecolor("w")
 882        ax.set_title("Phase space trajectories")
 883        # Tune plot
 884        for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")):
 885            viz.set_ilim(ax, ind, *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom))
 886            eval(f"ax.set_{t}label('{s!s}')")
 887
 888        # Allocate
 889        d = DotDict()  # data arrays
 890        h = DotDict()  # plot handles
 891        s = DotDict()  # scatter handles
 892        if E is not None:
 893            d.E = RollingArray((K_plot, len(E), M))
 894            h.E = []
 895        if P is not None:
 896            d.mu = RollingArray((K_plot, M))
 897        if True:
 898            d.x = RollingArray((K_plot, M))
 899        if not_empty(p.obs_inds):
 900            d.y = RollingArray((K_plot, M))
 901
 902        # Plot tails (invisible coz everything here is nan, for the moment).
 903        if "E" in d:
 904            h.E += [
 905                ax.plot(*xn, **p.ens_props)[0] for xn in np.transpose(d.E, [1, 2, 0])
 906            ]
 907        if "mu" in d:
 908            h.mu = ax.plot(*d.mu.T, "b", lw=2)[0]
 909        if True:
 910            h.x = ax.plot(*d.x.T, "k", lw=3)[0]
 911        if "y" in d:
 912            h.y = ax.plot(*d.y.T, "g*", ms=14)[0]
 913
 914        # Scatter. NB: don't init with nan's coz it's buggy
 915        # (wrt. get_color() and _offsets3d) since mpl 3.1.
 916        if "E" in d:
 917            s.E = ax.scatter(*E.T[p.dims], s=3**2, c=[hn.get_color() for hn in h.E])
 918        if "mu" in d:
 919            s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()])
 920        if True:
 921            s.x = ax.scatter(
 922                *ones(M), s=14**2, c=[h.x.get_color()], marker=(5, 1), zorder=99
 923            )
 924
 925        def update(key, E, P):
 926            k, ko, faus = key
 927
 928            def update_tail(handle, newdata):
 929                handle.set_data(newdata[:, 0], newdata[:, 1])
 930                if is_3d:
 931                    handle.set_3d_properties(newdata[:, 2])
 932
 933            def update_sctr(handle, newdata):
 934                if is_3d:
 935                    handle._offsets3d = juggle_axes(*newdata.T, "z")
 936                else:
 937                    handle.set_offsets(newdata)
 938
 939            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)
 940
 941            # Roll data array
 942            ind = k if plot_u else ko
 943            for Ens in EE:  # If E is duplicated, so must the others be.
 944                if "E" in d:
 945                    d.E.insert(ind, Ens)
 946                if True:
 947                    d.x.insert(ind, xx[k, p.dims])
 948                if "y" in d:
 949                    xy = nan * ones(len(p.dims))
 950                    if ko is not None:
 951                        jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
 952                        jj = list(jj)
 953                        for i, dim in enumerate(p.dims):
 954                            try:
 955                                iobs = jj.index(dim)
 956                            except ValueError:
 957                                pass
 958                            else:
 959                                xy[i] = yy[ko][iobs]
 960                    d.y.insert(ind, xy)
 961                if "mu" in d:
 962                    d.mu.insert(ind, mu[key][p.dims])
 963
 964            # Update graph
 965            update_sctr(s.x, d.x[[-1]])
 966            update_tail(h.x, d.x)
 967            if "y" in d:
 968                update_tail(h.y, d.y)
 969            if "mu" in d:
 970                update_sctr(s.mu, d.mu[[-1]])
 971                update_tail(h.mu, d.mu)
 972            else:
 973                update_sctr(s.E, d.E[-1])
 974                for n in range(len(E)):
 975                    update_tail(h.E[n], d.E[:, n, :])
 976                update_alpha(key, stats, h.E, s.E)
 977
 978            return
 979
 980        return update
 981
 982    return init
 983
 984
 985def validate_lag(Tplot, tseq):
 986    """Return validated `T_lag` such that is is:
 987
 988    - equal to `Tplot` with fallback: `HMM.tseq.Tplot`.
 989    - no longer than `HMM.tseq.T`.
 990
 991    Also return corresponding `K_lag`, `a_lag`.
 992    """
 993    # Defaults
 994    if Tplot is None:
 995        Tplot = tseq.Tplot
 996
 997    # Rename
 998    T_lag = Tplot
 999
1000    assert T_lag >= 0
1001
1002    # Validate T_lag
1003    t2 = tseq.tt[-1]
1004    t1 = max(tseq.tt[0], t2 - T_lag)
1005    T_lag = t2 - t1
1006
1007    K_lag = int(T_lag / tseq.dt) + 1  # Lag in indices
1008    a_lag = K_lag // tseq.dko + 1  # Lag in obs indices
1009
1010    return T_lag, K_lag, a_lag
1011
1012
1013def comp_K_plot(K_lag, a_lag, plot_u):
1014    K_plot = 2 * a_lag  # Sum of lags of {f,a} series.
1015    if plot_u:
1016        K_plot += K_lag  # Add lag of u series.
1017    return K_plot
1018
1019
1020def update_alpha(key, stats, lines, scatters=None):
1021    """Adjust color alpha (for particle filters)."""
1022    k, ko, faus = key
1023    if ko is None:
1024        return
1025    if faus == "f":
1026        return
1027    if not hasattr(stats, "w"):
1028        return
1029
1030    # Compute alpha values
1031    w = stats.w[key]
1032    alpha = (w / w.max()).clip(0.1, 0.4)
1033
1034    # Set line alpha
1035    for line, a in zip(lines, alpha):
1036        line.set_alpha(a)
1037
1038    # Scatter plot does not have alpha. => Fake it.
1039    if scatters is not None:
1040        colors = scatters.get_facecolor()[:, :3]
1041        if len(colors) == 1:
1042            colors = colors.repeat(len(w), axis=0)
1043        scatters.set_color(np.hstack([colors, alpha[:, None]]))
1044
1045
1046def not_empty(xx):
1047    """Works for non-iterable and iterables (including ndarrays)."""
1048    try:
1049        return len(xx) > 0
1050    except TypeError:
1051        return bool(xx)
1052
1053
1054def duplicate_with_blanks_for_resampled(E, dims, key, has_w):
1055    """Particle filter: insert breaks for resampled particles."""
1056    if E is None:
1057        return [E]
1058    EE = []
1059    E = E[:, dims]
1060    if has_w:
1061        k, ko, faus = key
1062        if faus == "f":
1063            pass
1064        elif faus == "a":
1065            _Ea[0] = E[:, 0]  # Store (1st dim of) ens.
1066        elif faus == "u" and ko is not None:
1067            # Find resampled particles. Insert duplicate ensemble. Write nans (breaks).
1068            resampled = _Ea[0] != E[:, 0]  # Mark as resampled if ens changed.
1069            # Insert current ensemble (copy to avoid overwriting).
1070            EE.append(E.copy())
1071            EE[0][resampled] = nan  # Write breaks
1072    # Always: append current ensemble
1073    EE.append(E)
1074    return EE
1075
1076
1077_Ea = [None]  # persistent storage for ens
1078
1079
1080def d_ylim(data, ax=None, cC=0, cE=1, pp=(1, 99), Min=-1e20, Max=+1e20):
1081    """Provide new ylim's intelligently, from percentiles of the data.
1082
1083    - `data`: iterable of arrays for computing percentiles.
1084    - `pp`: percentiles
1085
1086    - `ax`: If present, then the delta_zoom in/out is also considered.
1087
1088      - `cE`: exansion (widenting) rate ∈ [0,1].
1089        Default: 1, which immediately expands to percentile.
1090      - `cC`: compression (narrowing) rate ∈ [0,1].
1091        Default: 0, which does not allow compression.
1092
1093    - `Min`/`Max`: bounds
1094
1095    Despite being a little involved,
1096    the cost of this subroutine is typically not substantial
1097    because there's usually not that much data to sort through.
1098    """
1099    # Find "reasonable" limits (by percentiles), looping over data
1100    maxv = minv = -np.inf  # init
1101    for d in data:
1102        d = d[np.isfinite(d)]
1103        if len(d):
1104            perc = np.array([-1, 1]) * np.percentile(d, pp)
1105            minv, maxv = np.maximum([minv, maxv], perc)
1106    minv *= -1
1107
1108    # Pry apart equal values
1109    if np.isclose(minv, maxv):
1110        maxv += 0.5
1111        minv -= 0.5
1112
1113    # Make the zooming transition smooth
1114    if ax is not None:
1115        current = ax.get_ylim()
1116        # Set rate factor as compress or expand factor.
1117        c0 = cC if minv > current[0] else cE
1118        c1 = cC if maxv < current[1] else cE
1119        # Adjust
1120        minv = np.interp(c0, (0, 1), (current[0], minv))
1121        maxv = np.interp(c1, (0, 1), (current[1], maxv))
1122
1123    # Bounds
1124    maxv = min(Max, maxv)
1125    minv = max(Min, minv)
1126
1127    # Set (if anything's changed)
1128    def worth_updating(a, b, curr):
1129        # Note: should depend on cC and cE
1130        d = abs(curr[1] - curr[0])
1131        lower = abs(a - curr[0]) > 0.002 * d
1132        upper = abs(b - curr[1]) > 0.002 * d
1133        return lower and upper
1134
1135    # if worth_updating(minv,maxv,current):
1136    # ax.set_ylim(minv,maxv)
1137
1138    # Some mpl versions don't handle inf limits.
1139    if not np.isfinite(minv):
1140        minv = None
1141    if not np.isfinite(maxv):
1142        maxv = None
1143
1144    return minv, maxv
1145
1146
1147def spatial1d(
1148    obs_inds=(),
1149    periodicity=None,
1150    dims=(),
1151    ens_props={"color": "b", "alpha": 0.1},  # noqa
1152    conf_mult=None,
1153):
1154    # Store parameters
1155    params_orig = DotDict(**locals())
1156
1157    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
1158        xx, yy, mu = stats.xx, stats.yy, stats.mu
1159
1160        # Set parameters (kwargs takes precedence over params_orig)
1161        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
1162
1163        if not p.dims:
1164            M = xx.shape[-1]
1165            p.dims = arange(M)
1166        else:
1167            M = len(p.dims)
1168
1169        # Make periodic wrapper
1170        ii, wrap = viz.setup_wrapping(M, p.periodicity)
1171
1172        # Set up figure, axes
1173        fig, ax = place.freshfig(fignum, figsize=(8, 5))
1174        fig.suptitle("1d amplitude plot")
1175
1176        # Nans
1177        nan1 = wrap(nan * ones(M))
1178
1179        if E is None and p.conf_mult is None:
1180            p.conf_mult = 2
1181
1182        # Init plots
1183        if p.conf_mult:
1184            lines_s = ax.plot(
1185                ii, nan1, "b-", lw=1, label=(str(p.conf_mult) + r"$\sigma$ conf")
1186            )
1187            lines_s += ax.plot(ii, nan1, "b-", lw=1)
1188            (line_mu,) = ax.plot(ii, nan1, "b-", lw=2, label="DA mean")
1189        else:
1190            nanE = nan * ones((stats.xp.N, M))
1191            lines_E = ax.plot(ii, wrap(nanE[0]), **p.ens_props, lw=1, label="Ensemble")
1192            lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1)
1193        # Truth, Obs
1194        (line_x,) = ax.plot(ii, nan1, "k-", lw=3, label="Truth")
1195        if not_empty(p.obs_inds):
1196            (line_y,) = ax.plot(ii, nan1, "g*", ms=5, label="Obs")
1197
1198        # Tune plot
1199        ax.set_ylim(*viz.xtrema(xx))
1200        ax.set_xlim(viz.stretch(ii[0], ii[-1], 1))
1201        # Xticks
1202        xt = ax.get_xticks()
1203        xt = xt[abs(xt % 1) < 0.01].astype(int)  # Keep only the integer ticks
1204        xt = xt[xt >= 0]
1205        xt = xt[xt < len(p.dims)]
1206        ax.set_xticks(xt)
1207        ax.set_xticklabels(p.dims[xt])
1208
1209        ax.set_xlabel("State index")
1210        ax.set_ylabel("Value")
1211        ax.legend(loc="upper right")
1212
1213        text_t = ax.text(
1214            0.01,
1215            0.01,
1216            format_time(None, None, None),
1217            transform=ax.transAxes,
1218            family="monospace",
1219            ha="left",
1220        )
1221
1222        def update(key, E, P):
1223            k, ko, faus = key
1224
1225            if p.conf_mult:
1226                sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]]
1227                lines_s[0].set_ydata(wrap(sigma[0, p.dims]))
1228                lines_s[1].set_ydata(wrap(sigma[1, p.dims]))
1229                line_mu.set_ydata(wrap(mu[key][p.dims]))
1230            else:
1231                for n, line in enumerate(lines_E):
1232                    line.set_ydata(wrap(E[n, p.dims]))
1233                update_alpha(key, stats, lines_E)
1234
1235            line_x.set_ydata(wrap(xx[k, p.dims]))
1236
1237            text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k]))
1238
1239            if "f" in faus:
1240                if not_empty(p.obs_inds):
1241                    xy = nan * ones(len(xx[0]))
1242                    jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
1243                    xy[jj] = yy[ko]
1244                    line_y.set_ydata(wrap(xy[p.dims]))
1245                    line_y.set_zorder(5)
1246                    line_y.set_visible(True)
1247
1248            if "u" in faus:
1249                if not_empty(p.obs_inds):
1250                    line_y.set_visible(False)
1251
1252            return
1253
1254        return update
1255
1256    return init
1257
1258
1259def spatial2d(
1260    square,
1261    ind2sub,
1262    obs_inds=(),
1263    cm=plt.cm.jet,
1264    clims=((-40, 40), (-40, 40), (-10, 10), (-10, 10)),
1265):
1266    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
1267        GS = {"left": 0.125 - 0.04, "right": 0.9 - 0.04}
1268        fig, axs = place.freshfig(
1269            fignum,
1270            figsize=(6, 6),
1271            nrows=2,
1272            ncols=2,
1273            sharex=True,
1274            sharey=True,
1275            gridspec_kw=GS,
1276        )
1277
1278        for ax in axs.flatten():
1279            ax.set_aspect("equal", "box")
1280
1281        ((ax_11, ax_12), (ax_21, ax_22)) = axs
1282
1283        ax_11.grid(color="w", linewidth=0.2)
1284        ax_12.grid(color="w", linewidth=0.2)
1285        ax_21.grid(color="k", linewidth=0.1)
1286        ax_22.grid(color="k", linewidth=0.1)
1287
1288        # Upper colorbar -- position relative to ax_12
1289        bb = ax_12.get_position()
1290        dy = 0.1 * bb.height
1291        ax_13 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
1292        # Lower colorbar -- position relative to ax_22
1293        bb = ax_22.get_position()
1294        dy = 0.1 * bb.height
1295        ax_23 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
1296
1297        # Extract data arrays
1298        xx, _, mu, spread, err = stats.xx, stats.yy, stats.mu, stats.spread, stats.err
1299        k = key0[0]
1300        tt = stats.HMM.tseq.tt
1301
1302        # Plot
1303        # - origin='lower' might get overturned by set_ylim() below.
1304        im_11 = ax_11.imshow(square(mu[key0]), cmap=cm)
1305        im_12 = ax_12.imshow(square(xx[k]), cmap=cm)
1306        # hot is better, but needs +1 colorbar
1307        im_21 = ax_21.imshow(square(spread[key0]), cmap=plt.cm.bwr)
1308        im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr)
1309        ims = (im_11, im_12, im_21, im_22)
1310        # Obs init -- a list where item 0 is the handle of something invisible.
1311        lh = list(ax_12.plot(0, 0)[0:1])
1312
1313        sx = "$\\psi$"
1314        ax_11.set_title("mean " + sx)
1315        ax_12.set_title("true " + sx)
1316        ax_21.set_title("spread. " + sx)
1317        ax_22.set_title("err. " + sx)
1318
1319        # TODO 7
1320        # for ax in axs.flatten():
1321        # Crop boundries (which should be 0, i.e. yield harsh q gradients):
1322        # lims = (1, nx-2)
1323        # step = (nx - 1)/8
1324        # ticks = arange(step,nx-1,step)
1325        # ax.set_xlim  (lims)
1326        # ax.set_ylim  (lims[::-1])
1327        # ax.set_xticks(ticks)
1328        # ax.set_yticks(ticks)
1329
1330        for im, clim in zip(ims, clims):
1331            im.set_clim(clim)
1332
1333        fig.colorbar(im_12, cax=ax_13)
1334        fig.colorbar(im_22, cax=ax_23)
1335        for ax in [ax_13, ax_23]:
1336            ax.yaxis.set_tick_params(
1337                "major", length=2, width=0.5, direction="in", left=True, right=True
1338            )
1339            ax.set_axisbelow("line")  # make ticks appear over colorbar patch
1340
1341        # Title
1342        title = "Streamfunction (" + sx + ")"
1343        fig.suptitle(title)
1344        # Time info
1345        text_t = ax_12.text(
1346            1,
1347            1.1,
1348            format_time(None, None, None),
1349            transform=ax_12.transAxes,
1350            family="monospace",
1351            ha="left",
1352        )
1353
1354        def update(key, E, P):
1355            k, ko, faus = key
1356            t = tt[k]
1357
1358            im_11.set_data(square(mu[key]))
1359            im_12.set_data(square(xx[k]))
1360            im_21.set_data(square(spread[key]))
1361            im_22.set_data(square(err[key]))
1362
1363            # Remove previous obs
1364            try:
1365                lh[0].remove()
1366            except ValueError:
1367                pass
1368            # Plot current obs.
1369            #  - plot() automatically adjusts to direction of y-axis in use.
1370            #  - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse.
1371
1372            if ko is not None and not_empty(obs_inds):
1373                lh[0] = ax_12.plot(*ind2sub(obs_inds(ko))[::-1], "k.", ms=1, zorder=5)[
1374                    0
1375                ]
1376
1377            text_t.set_text(format_time(k, ko, t))
1378
1379            return
1380
1381        return update
1382
1383    return init
1384
1385
1386# List of liveplotters available for all HMMs.
1387# Columns:
1388# - fignum
1389# - show_by_default
1390# - function/class
1391default_liveplotters = [
1392    (1, sliding_diagnostics),
1393    (1, weight_histogram),
1394]
class LivePlot:
 45class LivePlot:
 46    """Live plotting manager.
 47
 48    Deals with
 49
 50    - Pause, skip.
 51    - Which liveploters to call.
 52    - `plot_u`
 53    - Figure window (title and number).
 54    """
 55
 56    def __init__(
 57        self,
 58        stats,
 59        liveplots,
 60        key0=(0, None, "u"),
 61        E=None,
 62        P=None,
 63        speed=1.0,
 64        replay=False,
 65        **kwargs,
 66    ):
 67        """
 68        Initialize plots.
 69
 70        - liveplots: figures to plot; alternatives:
 71            - `"default"/[]/True`: All default figures for this HMM.
 72            - `"all"`            : Even more.
 73            - non-empty `list`   : Only the figures with these numbers
 74                                 (int) or names (str).
 75            - `False`            : None.
 76        - speed: speed of animation.
 77            - `>100`: instantaneous
 78            - `1`   : (default) as quick as possible allowing for
 79                      plt.draw() to work on a moderately fast computer.
 80            - `<1`  : slower.
 81        """
 82        # Disable if not rc.liveplotting
 83        self.any_figs = False
 84        if rc.liveplotting:
 85            pass
 86        elif replay and np.isinf(speed):
 87            pass
 88        else:
 89            return
 90
 91        # Determine whether all/universal/intermediate stats are plotted
 92        self.plot_u = not replay or stats.store_u
 93
 94        # Set speed/pause params
 95        self.params = {
 96            "pause_f": 0.05,
 97            "pause_a": 0.05,
 98            "pause_s": 0.05,
 99            "pause_u": 0.001,
100        }
101        # If speed>100: set to inf. Coz pause=1e-99 causes hangup.
102        for pause in ["pause_" + x for x in "faus"]:
103            speed = speed if speed < 100 else np.inf
104            self.params[pause] /= speed
105
106        # Write params
107        self.params.update(getattr(stats.xp, "LP_kwargs", {}))
108        self.params.update(kwargs)
109
110        def get_name(init):
111            """Get name of liveplotter function/class."""
112            try:
113                return init.__qualname__.split(".")[0]
114            except AttributeError:
115                return init.__class__.__name__
116
117        # Set up dict of liveplotters
118        potential_LPs = {}
119        for show, init in default_liveplotters:
120            potential_LPs[get_name(init)] = show, init
121        # Add HMM-specific liveplotters
122        for show, init in getattr(stats.HMM, "liveplotters", {}):
123            potential_LPs[get_name(init)] = show, init
124
125        def parse_figlist(lst):
126            """Figures requested for this xp. Convert to list."""
127            if isinstance(lst, str):
128                fn = lst.lower()
129                if "all" == fn:
130                    lst = ["all"]  # All potential_LPs
131                elif "default" in fn:
132                    lst = ["default"]  # All show_by_default
133            elif hasattr(lst, "__len__"):
134                lst = lst  # This list (only)
135            elif lst:
136                lst = ["default"]  # All show_by_default
137            else:
138                lst = [None]  # None
139            return lst
140
141        figlist = parse_figlist(liveplots)
142
143        # Loop over requeted figures
144        self.figures = {}
145        for name, (show_by_default, init) in potential_LPs.items():
146            if (
147                (figlist == ["all"])
148                or (name in figlist)
149                or (figlist == ["default"] and show_by_default)
150            ):
151                # Startup message
152                if not self.any_figs:
153                    print("Initializing liveplots...")
154                    if is_notebook_or_qt:
155                        pauses = [self.params["pause_" + x] for x in "faus"]
156                        if any((p > 0) for p in pauses):
157                            print(
158                                "Note: liveplotting does not work very well"
159                                " inside Jupyter notebooks. In particular,"
160                                " there is no way to stop/skip them except"
161                                " to interrupt the kernel (the stop button"
162                                " in the toolbar). Consider using instead"
163                                " only the replay functionality (with infinite"
164                                " playback speed)."
165                            )
166                    elif not pb.disable_user_interaction:
167                        print("Hit <Space> to pause/step.")
168                        print("Hit <Enter> to resume/skip.")
169                        print("Hit <i> to enter debug mode.")
170                    self.paused = False
171                    self.run_ipdb = False
172                    self.skipping = False
173                    self.any_figs = True
174
175                # Init figure
176                post_title = "" if self.plot_u else "\n(obs times only)"
177                updater = init(name, stats, key0, self.plot_u, E, P, **kwargs)
178                if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
179                    self.figures[name] = updater
180                    fig = plt.figure(name)
181                    win = fig.canvas
182                    ax0 = fig.axes[0]
183                    win.manager.set_window_title(str(name))
184                    ax0.set_title(ax0.get_title() + post_title)
185                    self.update(key0, E, P)  # Call initial update
186                    if not (replay and np.isinf(speed)):
187                        plt.pause(0.01)  # Draw
188
189    def update(self, key, E, P):
190        """Update liveplots"""
191        # Check if there are still open figures
192        if self.any_figs:
193            open_figns = plt.get_figlabels()
194            live_figns = set(self.figures.keys())
195            self.any_figs = bool(live_figns.intersection(open_figns))
196        else:
197            return
198
199        # Playback control
200        SPACE = b" "
201        CHAR_I = b"i"
202        ENTERs = [b"\n", b"\r"]  # Linux + Windows
203
204        def pause():
205            """Loop until user decision is made."""
206            ch = read1()
207            while True:
208                # Set state (pause, skipping, ipdb)
209                if ch in ENTERs:
210                    self.paused = False
211                elif ch == CHAR_I:
212                    self.run_ipdb = True
213                # If keypress valid, resume execution
214                if ch in ENTERs + [SPACE, CHAR_I]:
215                    break
216                ch = read1()
217                # Pause to enable zoom, pan, etc. of mpl GUI
218                plot_pause(0.01)  # Don't use time.sleep()!
219
220        # Enter pause loop
221        if self.paused:
222            pause()
223
224        else:
225            if key == (0, None, "u"):
226                # Skip read1 for key0 (coz it blocks)
227                pass
228            else:
229                ch = read1()
230                if ch == SPACE:
231                    # Pause
232                    self.paused = True
233                    self.skipping = False
234                    pause()
235                elif ch in ENTERs:
236                    # Toggle skipping
237                    self.skipping = not self.skipping
238                elif ch == CHAR_I:
239                    # Schedule debug
240                    # Note: The reason we dont set_trace(frame) right here is:
241                    # - I could not find the right frame, even doing
242                    #   >   frame = inspect.stack()[0]
243                    #   >   while frame.f_code.co_name != "assimilate":
244                    #   >       frame = frame.f_back
245                    # - It just restarts the plot.
246                    self.run_ipdb = True
247
248        # Update figures
249        if not self.skipping:
250            faus = key[-1]
251            if faus != "u" or self.plot_u:
252                for name, (updater) in self.figures.items():
253                    if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
254                        _ = plt.figure(name)
255                        updater(key, E, P)
256                        plot_pause(self.params["pause_" + faus])
257
258        if self.run_ipdb:
259            self.run_ipdb = False
260            import inspect
261
262            import ipdb
263
264            print("Entering debug mode (ipdb).")
265            print("Type '?' (and Enter) for usage help.")
266            print("Type 'c' to continue the assimilation.")
267            ipdb.set_trace(inspect.stack()[2].frame)

Live plotting manager.

Deals with

  • Pause, skip.
  • Which liveploters to call.
  • plot_u
  • Figure window (title and number).
LivePlot( stats, liveplots, key0=(0, None, 'u'), E=None, P=None, speed=1.0, replay=False, **kwargs)
 56    def __init__(
 57        self,
 58        stats,
 59        liveplots,
 60        key0=(0, None, "u"),
 61        E=None,
 62        P=None,
 63        speed=1.0,
 64        replay=False,
 65        **kwargs,
 66    ):
 67        """
 68        Initialize plots.
 69
 70        - liveplots: figures to plot; alternatives:
 71            - `"default"/[]/True`: All default figures for this HMM.
 72            - `"all"`            : Even more.
 73            - non-empty `list`   : Only the figures with these numbers
 74                                 (int) or names (str).
 75            - `False`            : None.
 76        - speed: speed of animation.
 77            - `>100`: instantaneous
 78            - `1`   : (default) as quick as possible allowing for
 79                      plt.draw() to work on a moderately fast computer.
 80            - `<1`  : slower.
 81        """
 82        # Disable if not rc.liveplotting
 83        self.any_figs = False
 84        if rc.liveplotting:
 85            pass
 86        elif replay and np.isinf(speed):
 87            pass
 88        else:
 89            return
 90
 91        # Determine whether all/universal/intermediate stats are plotted
 92        self.plot_u = not replay or stats.store_u
 93
 94        # Set speed/pause params
 95        self.params = {
 96            "pause_f": 0.05,
 97            "pause_a": 0.05,
 98            "pause_s": 0.05,
 99            "pause_u": 0.001,
100        }
101        # If speed>100: set to inf. Coz pause=1e-99 causes hangup.
102        for pause in ["pause_" + x for x in "faus"]:
103            speed = speed if speed < 100 else np.inf
104            self.params[pause] /= speed
105
106        # Write params
107        self.params.update(getattr(stats.xp, "LP_kwargs", {}))
108        self.params.update(kwargs)
109
110        def get_name(init):
111            """Get name of liveplotter function/class."""
112            try:
113                return init.__qualname__.split(".")[0]
114            except AttributeError:
115                return init.__class__.__name__
116
117        # Set up dict of liveplotters
118        potential_LPs = {}
119        for show, init in default_liveplotters:
120            potential_LPs[get_name(init)] = show, init
121        # Add HMM-specific liveplotters
122        for show, init in getattr(stats.HMM, "liveplotters", {}):
123            potential_LPs[get_name(init)] = show, init
124
125        def parse_figlist(lst):
126            """Figures requested for this xp. Convert to list."""
127            if isinstance(lst, str):
128                fn = lst.lower()
129                if "all" == fn:
130                    lst = ["all"]  # All potential_LPs
131                elif "default" in fn:
132                    lst = ["default"]  # All show_by_default
133            elif hasattr(lst, "__len__"):
134                lst = lst  # This list (only)
135            elif lst:
136                lst = ["default"]  # All show_by_default
137            else:
138                lst = [None]  # None
139            return lst
140
141        figlist = parse_figlist(liveplots)
142
143        # Loop over requeted figures
144        self.figures = {}
145        for name, (show_by_default, init) in potential_LPs.items():
146            if (
147                (figlist == ["all"])
148                or (name in figlist)
149                or (figlist == ["default"] and show_by_default)
150            ):
151                # Startup message
152                if not self.any_figs:
153                    print("Initializing liveplots...")
154                    if is_notebook_or_qt:
155                        pauses = [self.params["pause_" + x] for x in "faus"]
156                        if any((p > 0) for p in pauses):
157                            print(
158                                "Note: liveplotting does not work very well"
159                                " inside Jupyter notebooks. In particular,"
160                                " there is no way to stop/skip them except"
161                                " to interrupt the kernel (the stop button"
162                                " in the toolbar). Consider using instead"
163                                " only the replay functionality (with infinite"
164                                " playback speed)."
165                            )
166                    elif not pb.disable_user_interaction:
167                        print("Hit <Space> to pause/step.")
168                        print("Hit <Enter> to resume/skip.")
169                        print("Hit <i> to enter debug mode.")
170                    self.paused = False
171                    self.run_ipdb = False
172                    self.skipping = False
173                    self.any_figs = True
174
175                # Init figure
176                post_title = "" if self.plot_u else "\n(obs times only)"
177                updater = init(name, stats, key0, self.plot_u, E, P, **kwargs)
178                if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
179                    self.figures[name] = updater
180                    fig = plt.figure(name)
181                    win = fig.canvas
182                    ax0 = fig.axes[0]
183                    win.manager.set_window_title(str(name))
184                    ax0.set_title(ax0.get_title() + post_title)
185                    self.update(key0, E, P)  # Call initial update
186                    if not (replay and np.isinf(speed)):
187                        plt.pause(0.01)  # Draw

Initialize plots.

  • liveplots: figures to plot; alternatives:
    • "default"/[]/True: All default figures for this HMM.
    • "all" : Even more.
    • non-empty list : Only the figures with these numbers (int) or names (str).
    • False : None.
  • speed: speed of animation.
    • >100: instantaneous
    • 1 : (default) as quick as possible allowing for plt.draw() to work on a moderately fast computer.
    • <1 : slower.
any_figs
plot_u
params
figures
def update(self, key, E, P):
189    def update(self, key, E, P):
190        """Update liveplots"""
191        # Check if there are still open figures
192        if self.any_figs:
193            open_figns = plt.get_figlabels()
194            live_figns = set(self.figures.keys())
195            self.any_figs = bool(live_figns.intersection(open_figns))
196        else:
197            return
198
199        # Playback control
200        SPACE = b" "
201        CHAR_I = b"i"
202        ENTERs = [b"\n", b"\r"]  # Linux + Windows
203
204        def pause():
205            """Loop until user decision is made."""
206            ch = read1()
207            while True:
208                # Set state (pause, skipping, ipdb)
209                if ch in ENTERs:
210                    self.paused = False
211                elif ch == CHAR_I:
212                    self.run_ipdb = True
213                # If keypress valid, resume execution
214                if ch in ENTERs + [SPACE, CHAR_I]:
215                    break
216                ch = read1()
217                # Pause to enable zoom, pan, etc. of mpl GUI
218                plot_pause(0.01)  # Don't use time.sleep()!
219
220        # Enter pause loop
221        if self.paused:
222            pause()
223
224        else:
225            if key == (0, None, "u"):
226                # Skip read1 for key0 (coz it blocks)
227                pass
228            else:
229                ch = read1()
230                if ch == SPACE:
231                    # Pause
232                    self.paused = True
233                    self.skipping = False
234                    pause()
235                elif ch in ENTERs:
236                    # Toggle skipping
237                    self.skipping = not self.skipping
238                elif ch == CHAR_I:
239                    # Schedule debug
240                    # Note: The reason we dont set_trace(frame) right here is:
241                    # - I could not find the right frame, even doing
242                    #   >   frame = inspect.stack()[0]
243                    #   >   while frame.f_code.co_name != "assimilate":
244                    #   >       frame = frame.f_back
245                    # - It just restarts the plot.
246                    self.run_ipdb = True
247
248        # Update figures
249        if not self.skipping:
250            faus = key[-1]
251            if faus != "u" or self.plot_u:
252                for name, (updater) in self.figures.items():
253                    if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
254                        _ = plt.figure(name)
255                        updater(key, E, P)
256                        plot_pause(self.params["pause_" + faus])
257
258        if self.run_ipdb:
259            self.run_ipdb = False
260            import inspect
261
262            import ipdb
263
264            print("Entering debug mode (ipdb).")
265            print("Type '?' (and Enter) for usage help.")
266            print("Type 'c' to continue the assimilation.")
267            ipdb.set_trace(inspect.stack()[2].frame)

Update liveplots

star = '${}^*$'
class sliding_diagnostics:
275class sliding_diagnostics:
276    """Plots a sliding window (like a heart rate monitor) of certain diagnostics."""
277
278    def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs):
279        # STYLE TABLES - Defines which/how diagnostics get plotted
280        styles = {}
281
282        def lin(a, b):
283            return lambda x: a + b * x
284
285        divN = 1 / getattr(stats.xp, "N", 99)
286        # Columns: transf, shape, plt kwargs
287        styles["RMS"] = {
288            "err.rms": [None, None, dict(c="k", label="Error")],
289            "spread.rms": [None, None, dict(c="b", label="Spread", alpha=0.6)],
290        }
291        styles["Values"] = {
292            "skew": [None, None, dict(c="g", label=star + r"Skew/$\sigma^3$")],
293            "kurt": [None, None, dict(c="r", label=star + r"Kurt$/\sigma^4{-}3$")],
294            "trHK": [None, None, dict(c="k", label=star + "HK")],
295            "infl": [lin(-10, 10), "step", dict(c="c", label="10(infl-1)")],
296            "N_eff": [lin(0, divN), "dirac", dict(c="y", label="N_eff/N", lw=3)],
297            "iters": [lin(0, 0.1), "dirac", dict(c="m", label="iters/10")],
298            "resmpl": [None, "dirac", dict(c="k", label="resampled?")],
299        }
300
301        nAx = len(styles)
302        GS = {"left": 0.125, "right": 0.76}
303        fig, axs = place.freshfig(
304            fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS
305        )
306
307        axs[0].set_title("Diagnostics")
308        for style, ax in zip(styles, axs):
309            ax.set_ylabel(style)
310        ax.set_xlabel("Time (t)")
311        place_ax.adjust_position(ax, y0=0.03)
312
313        self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq)
314
315        def init_ax(ax, style_table):
316            lines = {}
317            for name in style_table:
318                # SKIP -- if stats[name] is not in existence
319                # Note: The nan check/deletion comes after the first ko.
320                try:
321                    stat = deep_getattr(stats, name)
322                except AttributeError:
323                    continue
324                # try: val0 = stat[key0[0]]
325                # except KeyError: continue
326                # PS: recall (from series.py) that even if store_u is false, stat[k] is
327                # still present if liveplots=True via the k_tmp functionality.
328
329                # Unpack style
330                ln = {}
331                ln["transf"] = style_table[name][0] or (lambda x: x)
332                ln["shape"] = style_table[name][1]
333                ln["plt"] = style_table[name][2]
334
335                # Create series
336                if isinstance(stat, FAUSt):
337                    ln["plot_u"] = plot_u
338                    K_plot = comp_K_plot(K_lag, a_lag, ln["plot_u"])
339                else:
340                    ln["plot_u"] = False
341                    K_plot = a_lag
342                ln["data"] = RollingArray(K_plot)
343                ln["tt"] = RollingArray(K_plot)
344
345                # Plot (init)
346                (ln["handle"],) = ax.plot(ln["tt"], ln["data"], **ln["plt"])
347
348                # Plotting only nans yield ugly limits. Revert to defaults.
349                ax.set_xlim(0, 1)
350                ax.set_ylim(0, 1)
351
352                lines[name] = ln
353            return lines
354
355        # Plot
356        self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)]
357
358        # Horizontal line at y=0
359        (self.baseline0,) = ax.plot(
360            ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label="_nolegend_"
361        )
362
363        # Store
364        self.axs = axs
365        self.stats = stats
366        self.init_incomplete = True
367
368    # Update plot
369    def __call__(self, key, E, P):
370        k, ko, faus = key
371
372        stats = self.stats
373        tseq = stats.HMM.tseq
374        ax0, ax1 = self.axs
375
376        def update_arrays(lines):
377            for name, ln in lines.items():
378                stat = deep_getattr(stats, name)
379                t = tseq.tt[k]  # == tseq.tto[ko]
380                if isinstance(stat, FAUSt):
381                    # ln['data'] will contain duplicates for f/a times.
382                    if ln["plot_u"]:
383                        val = stat[key]
384                        ln["tt"].insert(k, t)
385                        ln["data"].insert(k, ln["transf"](val))
386                    elif "u" not in faus:
387                        val = stat[key]
388                        ln["tt"].insert(ko, t)
389                        ln["data"].insert(ko, ln["transf"](val))
390                else:
391                    # ln['data'] will not contain duplicates, coz only 'a' is input.
392                    if "a" in faus:
393                        val = stat[ko]
394                        ln["tt"].insert(ko, t)
395                        ln["data"].insert(ko, ln["transf"](val))
396                    elif "f" in faus:
397                        pass
398
399        def update_plot_data(ax, lines):
400            def bend_into(shape, xx, yy):
401                # Get arrays. Repeat (to use for intermediate nodes).
402                yy = yy.array.repeat(3)
403                xx = xx.array.repeat(3)
404                if len(xx) == 0:
405                    pass  # shortcircuit any modifications
406                elif shape == "step":
407                    yy = np.hstack([yy[1:], nan])  # roll leftward
408                elif shape == "dirac":
409                    nonlocal nDirac
410                    axW = np.diff(ax.get_xlim())
411                    yy[0::3] = False  # set datapoin to 0
412                    xx[2::3] = nan  # make datapoint disappear
413                    xx += nDirac * axW / 100  # offset datapoint horizontally
414                    nDirac += 1
415                return xx, yy
416
417            nDirac = 1
418            for _name, ln in lines.items():
419                ln["handle"].set_data(*bend_into(ln["shape"], ln["tt"], ln["data"]))
420
421        def finalize_init(ax, lines, mm):
422            # Rm lines that only contain NaNs
423            for name in list(lines):
424                ln = lines[name]
425                stat = deep_getattr(stats, name)
426                if not stat.were_changed:
427                    ln["handle"].remove()  # rm from axes
428                    del lines[name]  # rm from dict
429            # Add legends
430            if lines:
431                ax.legend(loc="upper left", bbox_to_anchor=(1.01, 1), borderaxespad=0)
432                if mm:
433                    ax.annotate(
434                        star + ": mean of\nmarginals",
435                        xy=(0, -1.5 / len(lines)),
436                        xycoords=ax.get_legend().get_frame(),
437                        bbox=dict(alpha=0.0),
438                        fontsize="small",
439                    )
440            # coz placement of annotate needs flush sometimes:
441            plot_pause(0.01)
442
443        # Insert current stats
444        for lines, ax in zip(self.d, self.axs):
445            update_arrays(lines)
446            update_plot_data(ax, lines)
447
448        # Set x-limits (time)
449        sliding_xlim(ax0, self.d[0]["err.rms"]["tt"], self.T_lag, margin=True)
450        self.baseline0.set_xdata(ax0.get_xlim())
451
452        # Set y-limits
453        data0 = [ln["data"].array for ln in self.d[0].values()]
454        data1 = [ln["data"].array for ln in self.d[1].values()]
455        ax0.set_ylim(0, d_ylim(data0, ax0, cC=0.2, cE=0.9)[1])
456        ax1.set_ylim(*d_ylim(data1, ax1, Max=4, Min=-4, cC=0.3, cE=0.9))
457
458        # Init legend. Rm nan lines.
459        if self.init_incomplete and "a" == faus:
460            self.init_incomplete = False
461            finalize_init(ax0, self.d[0], False)
462            finalize_init(ax1, self.d[1], True)

Plots a sliding window (like a heart rate monitor) of certain diagnostics.

sliding_diagnostics(fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs)
278    def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs):
279        # STYLE TABLES - Defines which/how diagnostics get plotted
280        styles = {}
281
282        def lin(a, b):
283            return lambda x: a + b * x
284
285        divN = 1 / getattr(stats.xp, "N", 99)
286        # Columns: transf, shape, plt kwargs
287        styles["RMS"] = {
288            "err.rms": [None, None, dict(c="k", label="Error")],
289            "spread.rms": [None, None, dict(c="b", label="Spread", alpha=0.6)],
290        }
291        styles["Values"] = {
292            "skew": [None, None, dict(c="g", label=star + r"Skew/$\sigma^3$")],
293            "kurt": [None, None, dict(c="r", label=star + r"Kurt$/\sigma^4{-}3$")],
294            "trHK": [None, None, dict(c="k", label=star + "HK")],
295            "infl": [lin(-10, 10), "step", dict(c="c", label="10(infl-1)")],
296            "N_eff": [lin(0, divN), "dirac", dict(c="y", label="N_eff/N", lw=3)],
297            "iters": [lin(0, 0.1), "dirac", dict(c="m", label="iters/10")],
298            "resmpl": [None, "dirac", dict(c="k", label="resampled?")],
299        }
300
301        nAx = len(styles)
302        GS = {"left": 0.125, "right": 0.76}
303        fig, axs = place.freshfig(
304            fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS
305        )
306
307        axs[0].set_title("Diagnostics")
308        for style, ax in zip(styles, axs):
309            ax.set_ylabel(style)
310        ax.set_xlabel("Time (t)")
311        place_ax.adjust_position(ax, y0=0.03)
312
313        self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq)
314
315        def init_ax(ax, style_table):
316            lines = {}
317            for name in style_table:
318                # SKIP -- if stats[name] is not in existence
319                # Note: The nan check/deletion comes after the first ko.
320                try:
321                    stat = deep_getattr(stats, name)
322                except AttributeError:
323                    continue
324                # try: val0 = stat[key0[0]]
325                # except KeyError: continue
326                # PS: recall (from series.py) that even if store_u is false, stat[k] is
327                # still present if liveplots=True via the k_tmp functionality.
328
329                # Unpack style
330                ln = {}
331                ln["transf"] = style_table[name][0] or (lambda x: x)
332                ln["shape"] = style_table[name][1]
333                ln["plt"] = style_table[name][2]
334
335                # Create series
336                if isinstance(stat, FAUSt):
337                    ln["plot_u"] = plot_u
338                    K_plot = comp_K_plot(K_lag, a_lag, ln["plot_u"])
339                else:
340                    ln["plot_u"] = False
341                    K_plot = a_lag
342                ln["data"] = RollingArray(K_plot)
343                ln["tt"] = RollingArray(K_plot)
344
345                # Plot (init)
346                (ln["handle"],) = ax.plot(ln["tt"], ln["data"], **ln["plt"])
347
348                # Plotting only nans yield ugly limits. Revert to defaults.
349                ax.set_xlim(0, 1)
350                ax.set_ylim(0, 1)
351
352                lines[name] = ln
353            return lines
354
355        # Plot
356        self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)]
357
358        # Horizontal line at y=0
359        (self.baseline0,) = ax.plot(
360            ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label="_nolegend_"
361        )
362
363        # Store
364        self.axs = axs
365        self.stats = stats
366        self.init_incomplete = True
d
axs
stats
init_incomplete
def sliding_xlim(ax, tt, lag, margin=False):
465def sliding_xlim(ax, tt, lag, margin=False):
466    dt = lag / 20 if margin else 0
467    if tt.nFilled == 0:
468        return  # Quit
469    t1, t2 = tt.span()  # Get suggested span.
470    s1, s2 = ax.get_xlim()  # Get previous lims.
471    # If zero span (eg tt holds single 'f' and 'a'):
472    if t1 == t2:
473        t1 -= 1  # add width
474        t2 += 1  # add width
475    # If user has skipped (too much):
476    elif np.isnan(t1):
477        s2 -= dt  # Correct for dt.
478        span = s2 - s1  # Compute previous span
479        # If span<lag:
480        if span < lag:
481            span += t2 - s2  # Grow by "dt".
482        span = min(lag, span)  # Bound
483        t1 = t2 - span  # Set span.
484    ax.set_xlim(t1, t2 + dt)  # Set xlim to span
class weight_histogram:
487class weight_histogram:
488    """Plots histogram of weights. Refreshed each analysis."""
489
490    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
491        if not hasattr(stats, "w"):
492            self.is_active = False
493            return
494        fig, ax = place.freshfig(fignum, figsize=(7, 3), gridspec_kw={"bottom": 0.15})
495
496        ax.set_xscale("log")
497        ax.set_xlabel("Weigth")
498        ax.set_ylabel("Count")
499        self.stats = stats
500        self.ax = ax
501        self.hist = []
502        self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
503
504    def __call__(self, key, E, P):
505        k, ko, faus = key
506        if "a" == faus:
507            w = self.stats.w[key]
508            N = len(w)
509            ax = self.ax
510
511            self.is_active = N < 10001
512            if not self.is_active:
513                not_available_text(ax, "Not computed (N > threshold)")
514                return
515
516            counted = w > self.bins[0]
517            _ = [b.remove() for b in self.hist]
518            nn, _, self.hist = ax.hist(w[counted], bins=self.bins, color="b")
519            ax.set_ylim(top=max(nn))
520
521            ax.set_title(
522                f"N: {N:d}.   N_eff: {1/(w@w):.4g}."
523                "   Not shown: {N-np.sum(counted):d}. "
524            )

Plots histogram of weights. Refreshed each analysis.

weight_histogram(fignum, stats, key0, plot_u, E, P, **kwargs)
490    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
491        if not hasattr(stats, "w"):
492            self.is_active = False
493            return
494        fig, ax = place.freshfig(fignum, figsize=(7, 3), gridspec_kw={"bottom": 0.15})
495
496        ax.set_xscale("log")
497        ax.set_xlabel("Weigth")
498        ax.set_ylabel("Count")
499        self.stats = stats
500        self.ax = ax
501        self.hist = []
502        self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
stats
ax
hist
bins
class spectral_errors:
527class spectral_errors:
528    """Plots the (spatial-RMS) error as a functional of the SVD index."""
529
530    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
531        fig, ax = place.freshfig(fignum, figsize=(6, 3))
532        ax.set_xlabel("Sing. value index")
533        ax.set_yscale("log")
534        self.init_incomplete = True
535        self.ax = ax
536        self.plot_u = plot_u
537
538        try:
539            self.msft = stats.umisf
540            self.sprd = stats.svals
541        except AttributeError:
542            self.is_active = False
543            not_available_text(ax, "Spectral stats not being computed")
544
545    # Update plot
546    def __call__(self, key, E, P):
547        k, ko, faus = key
548        ax = self.ax
549        if self.init_incomplete:
550            if self.plot_u or "f" == faus:
551                self.init_incomplete = False
552                msft = abs(self.msft[key])
553                sprd = self.sprd[key]
554                if np.any(np.isinf(msft)):
555                    not_available_text(ax, "Spectral stats not finite")
556                    self.is_active = False
557                else:
558                    (self.line_msft,) = ax.plot(msft, "k", lw=2, label="Error")
559                    (self.line_sprd,) = ax.plot(
560                        sprd, "b", lw=2, label="Spread", alpha=0.9
561                    )
562                    ax.get_xaxis().set_major_locator(MaxNLocator(integer=True))
563                    ax.legend()
564        else:
565            msft = abs(self.msft[key])
566            sprd = self.sprd[key]
567            self.line_sprd.set_ydata(sprd)
568            self.line_msft.set_ydata(msft)
569        # ax.set_ylim(*d_ylim(msft))
570        # ax.set_ylim(bottom=1e-5)
571        ax.set_ylim([1e-3, 1e1])

Plots the (spatial-RMS) error as a functional of the SVD index.

spectral_errors(fignum, stats, key0, plot_u, E, P, **kwargs)
530    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
531        fig, ax = place.freshfig(fignum, figsize=(6, 3))
532        ax.set_xlabel("Sing. value index")
533        ax.set_yscale("log")
534        self.init_incomplete = True
535        self.ax = ax
536        self.plot_u = plot_u
537
538        try:
539            self.msft = stats.umisf
540            self.sprd = stats.svals
541        except AttributeError:
542            self.is_active = False
543            not_available_text(ax, "Spectral stats not being computed")
init_incomplete
ax
plot_u
class correlations:
574class correlations:
575    """Plots the state (auto-)correlation matrix."""
576
577    half = True  # Whether to show half/full (symmetric) corr matrix.
578
579    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
580        GS = {"height_ratios": [4, 1], "hspace": 0.09, "top": 0.95}
581        fig, (ax, ax2) = place.freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS)
582
583        if E is None and np.isnan(P.diag if isinstance(P, CovMat) else P).all():
584            not_available_text(
585                ax, ("Not available in replays" "\ncoz full Ens/Cov not stored.")
586            )
587            self.is_active = False
588            return
589
590        Nx = len(stats.mu[key0])
591        if Nx <= 1003:
592            C = np.eye(Nx)
593            # Mask half
594            mask = np.zeros_like(C, dtype=bool)
595            mask[np.tril_indices_from(mask)] = True
596            cmap = plt.get_cmap("RdBu_r")
597            VM = 1.0  # abs(np.percentile(C,[1,99])).max()
598            im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM)
599            # Colorbar
600            _ = ax.figure.colorbar(im, ax=ax, shrink=0.8)
601            # Tune plot
602            plt.box(False)
603            ax.set_facecolor("w")
604            ax.grid(False)
605            ax.set_title("State correlation matrix:", y=1.07)
606            ax.xaxis.tick_top()
607
608            # ax2 = inset_axes(ax,width="30%",height="60%",loc=3)
609            (line_AC,) = ax2.plot(arange(Nx), ones(Nx), label="Correlation")
610            (line_AA,) = ax2.plot(arange(Nx), ones(Nx), label="Abs. corr.")
611            _ = ax2.hlines(0, 0, Nx - 1, "k", "dotted", lw=1)
612            # Align ax2 with ax
613            bb_AC = ax2.get_position()
614            bb_C = ax.get_position()
615            ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height])
616            # Tune plot
617            ax2.set_title("Auto-correlation:")
618            ax2.set_ylabel("Mean value")
619            ax2.set_xlabel("Distance (in state indices)")
620            ax2.set_xticklabels([])
621            ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]]))
622            ax2.set_ylim(top=1)
623            ax2.legend(
624                frameon=True,
625                facecolor="w",
626                bbox_to_anchor=(1, 1),
627                loc="upper left",
628                borderaxespad=0.02,
629            )
630
631            self.ax = ax
632            self.ax2 = ax2
633            self.im = im
634            self.line_AC = line_AC
635            self.line_AA = line_AA
636            self.mask = mask
637            if hasattr(stats, "w"):
638                self.w = stats.w
639        else:
640            not_available_text(ax)
641
642    # Update plot
643    def __call__(self, key, E, P):
644        # Get cov matrix
645        if E is not None:
646            if hasattr(self, "w"):
647                C = np.cov(E, rowvar=False, aweights=self.w[key])
648            else:
649                C = np.cov(E, rowvar=False)
650        else:
651            assert P is not None
652            C = P.full if isinstance(P, CovMat) else P
653            C = C.copy()
654        # Compute corr from cov
655        std = np.sqrt(np.diag(C))
656        C /= std[:, None]
657        C /= std[None, :]
658        # Mask
659        if self.half:
660            C = np.ma.masked_where(self.mask, C)
661        # Plot
662        self.im.set_data(C)
663        # Auto-corr function
664        ACF = circulant_ACF(C)
665        AAF = circulant_ACF(C, do_abs=True)
666        self.line_AC.set_ydata(ACF)
667        self.line_AA.set_ydata(AAF)

Plots the state (auto-)correlation matrix.

correlations(fignum, stats, key0, plot_u, E, P, **kwargs)
579    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
580        GS = {"height_ratios": [4, 1], "hspace": 0.09, "top": 0.95}
581        fig, (ax, ax2) = place.freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS)
582
583        if E is None and np.isnan(P.diag if isinstance(P, CovMat) else P).all():
584            not_available_text(
585                ax, ("Not available in replays" "\ncoz full Ens/Cov not stored.")
586            )
587            self.is_active = False
588            return
589
590        Nx = len(stats.mu[key0])
591        if Nx <= 1003:
592            C = np.eye(Nx)
593            # Mask half
594            mask = np.zeros_like(C, dtype=bool)
595            mask[np.tril_indices_from(mask)] = True
596            cmap = plt.get_cmap("RdBu_r")
597            VM = 1.0  # abs(np.percentile(C,[1,99])).max()
598            im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM)
599            # Colorbar
600            _ = ax.figure.colorbar(im, ax=ax, shrink=0.8)
601            # Tune plot
602            plt.box(False)
603            ax.set_facecolor("w")
604            ax.grid(False)
605            ax.set_title("State correlation matrix:", y=1.07)
606            ax.xaxis.tick_top()
607
608            # ax2 = inset_axes(ax,width="30%",height="60%",loc=3)
609            (line_AC,) = ax2.plot(arange(Nx), ones(Nx), label="Correlation")
610            (line_AA,) = ax2.plot(arange(Nx), ones(Nx), label="Abs. corr.")
611            _ = ax2.hlines(0, 0, Nx - 1, "k", "dotted", lw=1)
612            # Align ax2 with ax
613            bb_AC = ax2.get_position()
614            bb_C = ax.get_position()
615            ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height])
616            # Tune plot
617            ax2.set_title("Auto-correlation:")
618            ax2.set_ylabel("Mean value")
619            ax2.set_xlabel("Distance (in state indices)")
620            ax2.set_xticklabels([])
621            ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]]))
622            ax2.set_ylim(top=1)
623            ax2.legend(
624                frameon=True,
625                facecolor="w",
626                bbox_to_anchor=(1, 1),
627                loc="upper left",
628                borderaxespad=0.02,
629            )
630
631            self.ax = ax
632            self.ax2 = ax2
633            self.im = im
634            self.line_AC = line_AC
635            self.line_AA = line_AA
636            self.mask = mask
637            if hasattr(stats, "w"):
638                self.w = stats.w
639        else:
640            not_available_text(ax)
half = True
def circulant_ACF(C, do_abs=False):
670def circulant_ACF(C, do_abs=False):
671    """Compute the auto-covariance-function corresponding to `C`.
672
673    This assumes it is the cov/corr matrix of a 1D periodic domain.
674
675    Vectorized or FFT implementations are
676    [possible](https://stackoverflow.com/questions/20360675).
677    """
678    M = len(C)
679    # cols = np.flipud(sla.circulant(np.arange(M)[::-1]))
680    cols = sla.circulant(np.arange(M))
681    ACF = np.zeros(M)
682    for i in range(M):
683        row = C[i, cols[i]]
684        if do_abs:
685            row = abs(row)
686        ACF += row
687        # Note: this actually also accesses masked values in C.
688    return ACF / M

Compute the auto-covariance-function corresponding to C.

This assumes it is the cov/corr matrix of a 1D periodic domain.

Vectorized or FFT implementations are possible.

def sliding_marginals( obs_inds=(), dims=(), labels=(), Tplot=None, ens_props={'alpha': 0.4}, zoomy=1.0):
691def sliding_marginals(
692    obs_inds=(),
693    dims=(),
694    labels=(),
695    Tplot=None,
696    ens_props=dict(alpha=0.4),  # noqa
697    zoomy=1.0,
698):
699    # Store parameters
700    params_orig = DotDict(**locals())
701
702    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
703        xx, yy, mu, spread, tseq = (
704            stats.xx,
705            stats.yy,
706            stats.mu,
707            stats.spread,
708            stats.HMM.tseq,
709        )
710
711        # Set parameters (kwargs takes precedence over params_orig)
712        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
713
714        # Chose marginal dims to plot
715        if not len(p.dims):
716            p.dims = linspace_int(xx.shape[-1], min(10, xx.shape[-1]))
717
718        # Lag settings:
719        T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
720        K_plot = comp_K_plot(K_lag, a_lag, plot_u)
721        # Extend K_plot forther for adding blanks in resampling (PartFilt):
722        has_w = hasattr(stats, "w")
723        if has_w:
724            K_plot += a_lag
725
726        # Set up figure, axes
727        fig, axs = place.freshfig(
728            fignum, figsize=(5, 7), squeeze=False, nrows=len(p.dims), sharex=True
729        )
730        axs = axs.reshape(len(p.dims))
731
732        # Tune plots
733        axs[0].set_title("Marginal time series")
734        for ix, (m, ax) in enumerate(zip(p.dims, axs)):
735            # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy))
736            if not p.labels:
737                ax.set_ylabel("$x_{%d}$" % m)
738            else:
739                ax.set_ylabel(p.labels[ix])
740        axs[-1].set_xlabel("Time (t)")
741
742        plot_pause(0.05)
743        plt.tight_layout()
744
745        # Allocate
746        d = DotDict()  # data arrays
747        h = DotDict()  # plot handles
748        # Why "if True" ? Just to indent the rest of the line...
749        if True:
750            d.t = RollingArray((K_plot,))
751        if True:
752            d.x = RollingArray((K_plot, len(p.dims)))
753            h.x = []
754        if not_empty(p.obs_inds):
755            d.y = RollingArray((K_plot, len(p.dims)))
756            h.y = []
757        if E is not None:
758            d.E = RollingArray((K_plot, len(E), len(p.dims)))
759            h.E = []
760        if P is not None:
761            d.mu = RollingArray((K_plot, len(p.dims)))
762            h.mu = []
763        if P is not None:
764            d.s = RollingArray((K_plot, 2, len(p.dims)))
765            h.s = []
766
767        # Plot (invisible coz everything here is nan, for the moment).
768        for ix, ax in zip(p.dims, axs):
769            if True:
770                h.x += ax.plot(d.t, d.x[:, ix], "k")
771            if not_empty(p.obs_inds):
772                h.y += ax.plot(d.t, d.y[:, ix], "g*", ms=10)
773            if "E" in d:
774                h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)]
775            if "mu" in d:
776                h.mu += ax.plot(d.t, d.mu[:, ix], "b")
777            if "s" in d:
778                h.s += [ax.plot(d.t, d.s[:, :, ix], "b--", lw=1)]
779
780        def update(key, E, P):
781            k, ko, faus = key
782
783            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)
784
785            # Roll data array
786            ind = k if plot_u else ko
787            for Ens in EE:  # If E is duplicated, so must the others be.
788                if "E" in d:
789                    d.E.insert(ind, Ens)
790                if "mu" in d:
791                    d.mu.insert(ind, mu[key][p.dims])
792                if "s" in d:
793                    d.s.insert(ind, mu[key][p.dims] + [[1], [-1]] * spread[key][p.dims])
794                if True:
795                    d.t.insert(ind, tseq.tt[k])
796                if not_empty(p.obs_inds):
797                    xy = nan * ones(len(p.dims))
798                    if ko is not None:
799                        jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
800                        xy[jj] = yy[ko]
801                    d.y.insert(ind, xy)
802                if True:
803                    d.x.insert(ind, xx[k, p.dims])
804
805            # Update graphs
806            for ix, ax in zip(p.dims, axs):
807                sliding_xlim(ax, d.t, T_lag, True)
808                if True:
809                    h.x[ix].set_data(d.t, d.x[:, ix])
810                if not_empty(p.obs_inds):
811                    h.y[ix].set_data(d.t, d.y[:, ix])
812                if "mu" in d:
813                    h.mu[ix].set_data(d.t, d.mu[:, ix])
814                if "s" in d:
815                    [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]]
816                if "E" in d:
817                    [h.E[ix][n].set_data(d.t, d.E[:, n, ix]) for n in range(len(E))]
818                if "E" in d:
819                    update_alpha(key, stats, h.E[ix])
820
821                # TODO 3: fixup. This might be slow?
822                # In any case, it is very far from tested.
823                # Also, relim'iting all of the time is distracting.
824                # Use d_ylim?
825                if "E" in d:
826                    lims = d.E
827                elif "mu" in d:
828                    lims = d.mu
829                lims = np.array(viz.xtrema(lims[..., ix]))
830                if lims[0] == lims[1]:
831                    lims += [-0.5, +0.5]
832                ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy))
833
834            return
835
836        return update
837
838    return init
def phase_particles( is_3d=True, obs_inds=(), dims=(), labels=(), Tplot=None, ens_props={'alpha': 0.4}, zoom=1.5):
841def phase_particles(
842    is_3d=True,
843    obs_inds=(),
844    dims=(),
845    labels=(),
846    Tplot=None,
847    ens_props=dict(alpha=0.4),  # noqa
848    zoom=1.5,
849):
850    # Store parameters
851    params_orig = DotDict(**locals())
852
853    M = 3 if is_3d else 2
854
855    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
856        xx, yy, mu, _, tseq = stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq
857
858        # Set parameters (kwargs takes precedence over params_orig)
859        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
860
861        # Lag settings:
862        has_w = hasattr(stats, "w")
863        if p.Tplot == 0:
864            K_plot = 1
865        else:
866            T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
867            K_plot = comp_K_plot(K_lag, a_lag, plot_u)
868            # Extend K_plot forther for adding blanks in resampling (PartFilt):
869            if has_w:
870                K_plot += a_lag
871
872        # Dimension settings
873        if not p.dims:
874            p.dims = arange(M)
875        if not p.labels:
876            p.labels = ["$x_%d$" % d for d in p.dims]
877        assert len(p.dims) == M
878
879        # Set up figure, axes
880        fig, _ = place.freshfig(fignum, figsize=(5, 5))
881        ax = plt.subplot(111, projection="3d" if is_3d else None)
882        ax.set_facecolor("w")
883        ax.set_title("Phase space trajectories")
884        # Tune plot
885        for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")):
886            viz.set_ilim(ax, ind, *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom))
887            eval(f"ax.set_{t}label('{s!s}')")
888
889        # Allocate
890        d = DotDict()  # data arrays
891        h = DotDict()  # plot handles
892        s = DotDict()  # scatter handles
893        if E is not None:
894            d.E = RollingArray((K_plot, len(E), M))
895            h.E = []
896        if P is not None:
897            d.mu = RollingArray((K_plot, M))
898        if True:
899            d.x = RollingArray((K_plot, M))
900        if not_empty(p.obs_inds):
901            d.y = RollingArray((K_plot, M))
902
903        # Plot tails (invisible coz everything here is nan, for the moment).
904        if "E" in d:
905            h.E += [
906                ax.plot(*xn, **p.ens_props)[0] for xn in np.transpose(d.E, [1, 2, 0])
907            ]
908        if "mu" in d:
909            h.mu = ax.plot(*d.mu.T, "b", lw=2)[0]
910        if True:
911            h.x = ax.plot(*d.x.T, "k", lw=3)[0]
912        if "y" in d:
913            h.y = ax.plot(*d.y.T, "g*", ms=14)[0]
914
915        # Scatter. NB: don't init with nan's coz it's buggy
916        # (wrt. get_color() and _offsets3d) since mpl 3.1.
917        if "E" in d:
918            s.E = ax.scatter(*E.T[p.dims], s=3**2, c=[hn.get_color() for hn in h.E])
919        if "mu" in d:
920            s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()])
921        if True:
922            s.x = ax.scatter(
923                *ones(M), s=14**2, c=[h.x.get_color()], marker=(5, 1), zorder=99
924            )
925
926        def update(key, E, P):
927            k, ko, faus = key
928
929            def update_tail(handle, newdata):
930                handle.set_data(newdata[:, 0], newdata[:, 1])
931                if is_3d:
932                    handle.set_3d_properties(newdata[:, 2])
933
934            def update_sctr(handle, newdata):
935                if is_3d:
936                    handle._offsets3d = juggle_axes(*newdata.T, "z")
937                else:
938                    handle.set_offsets(newdata)
939
940            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)
941
942            # Roll data array
943            ind = k if plot_u else ko
944            for Ens in EE:  # If E is duplicated, so must the others be.
945                if "E" in d:
946                    d.E.insert(ind, Ens)
947                if True:
948                    d.x.insert(ind, xx[k, p.dims])
949                if "y" in d:
950                    xy = nan * ones(len(p.dims))
951                    if ko is not None:
952                        jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
953                        jj = list(jj)
954                        for i, dim in enumerate(p.dims):
955                            try:
956                                iobs = jj.index(dim)
957                            except ValueError:
958                                pass
959                            else:
960                                xy[i] = yy[ko][iobs]
961                    d.y.insert(ind, xy)
962                if "mu" in d:
963                    d.mu.insert(ind, mu[key][p.dims])
964
965            # Update graph
966            update_sctr(s.x, d.x[[-1]])
967            update_tail(h.x, d.x)
968            if "y" in d:
969                update_tail(h.y, d.y)
970            if "mu" in d:
971                update_sctr(s.mu, d.mu[[-1]])
972                update_tail(h.mu, d.mu)
973            else:
974                update_sctr(s.E, d.E[-1])
975                for n in range(len(E)):
976                    update_tail(h.E[n], d.E[:, n, :])
977                update_alpha(key, stats, h.E, s.E)
978
979            return
980
981        return update
982
983    return init
def validate_lag(Tplot, tseq):
 986def validate_lag(Tplot, tseq):
 987    """Return validated `T_lag` such that is is:
 988
 989    - equal to `Tplot` with fallback: `HMM.tseq.Tplot`.
 990    - no longer than `HMM.tseq.T`.
 991
 992    Also return corresponding `K_lag`, `a_lag`.
 993    """
 994    # Defaults
 995    if Tplot is None:
 996        Tplot = tseq.Tplot
 997
 998    # Rename
 999    T_lag = Tplot
1000
1001    assert T_lag >= 0
1002
1003    # Validate T_lag
1004    t2 = tseq.tt[-1]
1005    t1 = max(tseq.tt[0], t2 - T_lag)
1006    T_lag = t2 - t1
1007
1008    K_lag = int(T_lag / tseq.dt) + 1  # Lag in indices
1009    a_lag = K_lag // tseq.dko + 1  # Lag in obs indices
1010
1011    return T_lag, K_lag, a_lag

Return validated T_lag such that is is:

  • equal to Tplot with fallback: HMM.tseq.Tplot.
  • no longer than HMM.tseq.T.

Also return corresponding K_lag, a_lag.

def comp_K_plot(K_lag, a_lag, plot_u):
1014def comp_K_plot(K_lag, a_lag, plot_u):
1015    K_plot = 2 * a_lag  # Sum of lags of {f,a} series.
1016    if plot_u:
1017        K_plot += K_lag  # Add lag of u series.
1018    return K_plot
def update_alpha(key, stats, lines, scatters=None):
1021def update_alpha(key, stats, lines, scatters=None):
1022    """Adjust color alpha (for particle filters)."""
1023    k, ko, faus = key
1024    if ko is None:
1025        return
1026    if faus == "f":
1027        return
1028    if not hasattr(stats, "w"):
1029        return
1030
1031    # Compute alpha values
1032    w = stats.w[key]
1033    alpha = (w / w.max()).clip(0.1, 0.4)
1034
1035    # Set line alpha
1036    for line, a in zip(lines, alpha):
1037        line.set_alpha(a)
1038
1039    # Scatter plot does not have alpha. => Fake it.
1040    if scatters is not None:
1041        colors = scatters.get_facecolor()[:, :3]
1042        if len(colors) == 1:
1043            colors = colors.repeat(len(w), axis=0)
1044        scatters.set_color(np.hstack([colors, alpha[:, None]]))

Adjust color alpha (for particle filters).

def not_empty(xx):
1047def not_empty(xx):
1048    """Works for non-iterable and iterables (including ndarrays)."""
1049    try:
1050        return len(xx) > 0
1051    except TypeError:
1052        return bool(xx)

Works for non-iterable and iterables (including ndarrays).

def duplicate_with_blanks_for_resampled(E, dims, key, has_w):
1055def duplicate_with_blanks_for_resampled(E, dims, key, has_w):
1056    """Particle filter: insert breaks for resampled particles."""
1057    if E is None:
1058        return [E]
1059    EE = []
1060    E = E[:, dims]
1061    if has_w:
1062        k, ko, faus = key
1063        if faus == "f":
1064            pass
1065        elif faus == "a":
1066            _Ea[0] = E[:, 0]  # Store (1st dim of) ens.
1067        elif faus == "u" and ko is not None:
1068            # Find resampled particles. Insert duplicate ensemble. Write nans (breaks).
1069            resampled = _Ea[0] != E[:, 0]  # Mark as resampled if ens changed.
1070            # Insert current ensemble (copy to avoid overwriting).
1071            EE.append(E.copy())
1072            EE[0][resampled] = nan  # Write breaks
1073    # Always: append current ensemble
1074    EE.append(E)
1075    return EE

Particle filter: insert breaks for resampled particles.

def d_ylim(data, ax=None, cC=0, cE=1, pp=(1, 99), Min=-1e+20, Max=1e+20):
1081def d_ylim(data, ax=None, cC=0, cE=1, pp=(1, 99), Min=-1e20, Max=+1e20):
1082    """Provide new ylim's intelligently, from percentiles of the data.
1083
1084    - `data`: iterable of arrays for computing percentiles.
1085    - `pp`: percentiles
1086
1087    - `ax`: If present, then the delta_zoom in/out is also considered.
1088
1089      - `cE`: exansion (widenting) rate ∈ [0,1].
1090        Default: 1, which immediately expands to percentile.
1091      - `cC`: compression (narrowing) rate ∈ [0,1].
1092        Default: 0, which does not allow compression.
1093
1094    - `Min`/`Max`: bounds
1095
1096    Despite being a little involved,
1097    the cost of this subroutine is typically not substantial
1098    because there's usually not that much data to sort through.
1099    """
1100    # Find "reasonable" limits (by percentiles), looping over data
1101    maxv = minv = -np.inf  # init
1102    for d in data:
1103        d = d[np.isfinite(d)]
1104        if len(d):
1105            perc = np.array([-1, 1]) * np.percentile(d, pp)
1106            minv, maxv = np.maximum([minv, maxv], perc)
1107    minv *= -1
1108
1109    # Pry apart equal values
1110    if np.isclose(minv, maxv):
1111        maxv += 0.5
1112        minv -= 0.5
1113
1114    # Make the zooming transition smooth
1115    if ax is not None:
1116        current = ax.get_ylim()
1117        # Set rate factor as compress or expand factor.
1118        c0 = cC if minv > current[0] else cE
1119        c1 = cC if maxv < current[1] else cE
1120        # Adjust
1121        minv = np.interp(c0, (0, 1), (current[0], minv))
1122        maxv = np.interp(c1, (0, 1), (current[1], maxv))
1123
1124    # Bounds
1125    maxv = min(Max, maxv)
1126    minv = max(Min, minv)
1127
1128    # Set (if anything's changed)
1129    def worth_updating(a, b, curr):
1130        # Note: should depend on cC and cE
1131        d = abs(curr[1] - curr[0])
1132        lower = abs(a - curr[0]) > 0.002 * d
1133        upper = abs(b - curr[1]) > 0.002 * d
1134        return lower and upper
1135
1136    # if worth_updating(minv,maxv,current):
1137    # ax.set_ylim(minv,maxv)
1138
1139    # Some mpl versions don't handle inf limits.
1140    if not np.isfinite(minv):
1141        minv = None
1142    if not np.isfinite(maxv):
1143        maxv = None
1144
1145    return minv, maxv

Provide new ylim's intelligently, from percentiles of the data.

  • data: iterable of arrays for computing percentiles.
  • pp: percentiles

  • ax: If present, then the delta_zoom in/out is also considered.

    • cE: exansion (widenting) rate ∈ [0,1]. Default: 1, which immediately expands to percentile.
    • cC: compression (narrowing) rate ∈ [0,1]. Default: 0, which does not allow compression.
  • Min/Max: bounds

Despite being a little involved, the cost of this subroutine is typically not substantial because there's usually not that much data to sort through.

def spatial1d( obs_inds=(), periodicity=None, dims=(), ens_props={'color': 'b', 'alpha': 0.1}, conf_mult=None):
1148def spatial1d(
1149    obs_inds=(),
1150    periodicity=None,
1151    dims=(),
1152    ens_props={"color": "b", "alpha": 0.1},  # noqa
1153    conf_mult=None,
1154):
1155    # Store parameters
1156    params_orig = DotDict(**locals())
1157
1158    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
1159        xx, yy, mu = stats.xx, stats.yy, stats.mu
1160
1161        # Set parameters (kwargs takes precedence over params_orig)
1162        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
1163
1164        if not p.dims:
1165            M = xx.shape[-1]
1166            p.dims = arange(M)
1167        else:
1168            M = len(p.dims)
1169
1170        # Make periodic wrapper
1171        ii, wrap = viz.setup_wrapping(M, p.periodicity)
1172
1173        # Set up figure, axes
1174        fig, ax = place.freshfig(fignum, figsize=(8, 5))
1175        fig.suptitle("1d amplitude plot")
1176
1177        # Nans
1178        nan1 = wrap(nan * ones(M))
1179
1180        if E is None and p.conf_mult is None:
1181            p.conf_mult = 2
1182
1183        # Init plots
1184        if p.conf_mult:
1185            lines_s = ax.plot(
1186                ii, nan1, "b-", lw=1, label=(str(p.conf_mult) + r"$\sigma$ conf")
1187            )
1188            lines_s += ax.plot(ii, nan1, "b-", lw=1)
1189            (line_mu,) = ax.plot(ii, nan1, "b-", lw=2, label="DA mean")
1190        else:
1191            nanE = nan * ones((stats.xp.N, M))
1192            lines_E = ax.plot(ii, wrap(nanE[0]), **p.ens_props, lw=1, label="Ensemble")
1193            lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1)
1194        # Truth, Obs
1195        (line_x,) = ax.plot(ii, nan1, "k-", lw=3, label="Truth")
1196        if not_empty(p.obs_inds):
1197            (line_y,) = ax.plot(ii, nan1, "g*", ms=5, label="Obs")
1198
1199        # Tune plot
1200        ax.set_ylim(*viz.xtrema(xx))
1201        ax.set_xlim(viz.stretch(ii[0], ii[-1], 1))
1202        # Xticks
1203        xt = ax.get_xticks()
1204        xt = xt[abs(xt % 1) < 0.01].astype(int)  # Keep only the integer ticks
1205        xt = xt[xt >= 0]
1206        xt = xt[xt < len(p.dims)]
1207        ax.set_xticks(xt)
1208        ax.set_xticklabels(p.dims[xt])
1209
1210        ax.set_xlabel("State index")
1211        ax.set_ylabel("Value")
1212        ax.legend(loc="upper right")
1213
1214        text_t = ax.text(
1215            0.01,
1216            0.01,
1217            format_time(None, None, None),
1218            transform=ax.transAxes,
1219            family="monospace",
1220            ha="left",
1221        )
1222
1223        def update(key, E, P):
1224            k, ko, faus = key
1225
1226            if p.conf_mult:
1227                sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]]
1228                lines_s[0].set_ydata(wrap(sigma[0, p.dims]))
1229                lines_s[1].set_ydata(wrap(sigma[1, p.dims]))
1230                line_mu.set_ydata(wrap(mu[key][p.dims]))
1231            else:
1232                for n, line in enumerate(lines_E):
1233                    line.set_ydata(wrap(E[n, p.dims]))
1234                update_alpha(key, stats, lines_E)
1235
1236            line_x.set_ydata(wrap(xx[k, p.dims]))
1237
1238            text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k]))
1239
1240            if "f" in faus:
1241                if not_empty(p.obs_inds):
1242                    xy = nan * ones(len(xx[0]))
1243                    jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
1244                    xy[jj] = yy[ko]
1245                    line_y.set_ydata(wrap(xy[p.dims]))
1246                    line_y.set_zorder(5)
1247                    line_y.set_visible(True)
1248
1249            if "u" in faus:
1250                if not_empty(p.obs_inds):
1251                    line_y.set_visible(False)
1252
1253            return
1254
1255        return update
1256
1257    return init
def spatial2d( square, ind2sub, obs_inds=(), cm=<matplotlib.colors.LinearSegmentedColormap object>, clims=((-40, 40), (-40, 40), (-10, 10), (-10, 10))):
1260def spatial2d(
1261    square,
1262    ind2sub,
1263    obs_inds=(),
1264    cm=plt.cm.jet,
1265    clims=((-40, 40), (-40, 40), (-10, 10), (-10, 10)),
1266):
1267    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
1268        GS = {"left": 0.125 - 0.04, "right": 0.9 - 0.04}
1269        fig, axs = place.freshfig(
1270            fignum,
1271            figsize=(6, 6),
1272            nrows=2,
1273            ncols=2,
1274            sharex=True,
1275            sharey=True,
1276            gridspec_kw=GS,
1277        )
1278
1279        for ax in axs.flatten():
1280            ax.set_aspect("equal", "box")
1281
1282        ((ax_11, ax_12), (ax_21, ax_22)) = axs
1283
1284        ax_11.grid(color="w", linewidth=0.2)
1285        ax_12.grid(color="w", linewidth=0.2)
1286        ax_21.grid(color="k", linewidth=0.1)
1287        ax_22.grid(color="k", linewidth=0.1)
1288
1289        # Upper colorbar -- position relative to ax_12
1290        bb = ax_12.get_position()
1291        dy = 0.1 * bb.height
1292        ax_13 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
1293        # Lower colorbar -- position relative to ax_22
1294        bb = ax_22.get_position()
1295        dy = 0.1 * bb.height
1296        ax_23 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
1297
1298        # Extract data arrays
1299        xx, _, mu, spread, err = stats.xx, stats.yy, stats.mu, stats.spread, stats.err
1300        k = key0[0]
1301        tt = stats.HMM.tseq.tt
1302
1303        # Plot
1304        # - origin='lower' might get overturned by set_ylim() below.
1305        im_11 = ax_11.imshow(square(mu[key0]), cmap=cm)
1306        im_12 = ax_12.imshow(square(xx[k]), cmap=cm)
1307        # hot is better, but needs +1 colorbar
1308        im_21 = ax_21.imshow(square(spread[key0]), cmap=plt.cm.bwr)
1309        im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr)
1310        ims = (im_11, im_12, im_21, im_22)
1311        # Obs init -- a list where item 0 is the handle of something invisible.
1312        lh = list(ax_12.plot(0, 0)[0:1])
1313
1314        sx = "$\\psi$"
1315        ax_11.set_title("mean " + sx)
1316        ax_12.set_title("true " + sx)
1317        ax_21.set_title("spread. " + sx)
1318        ax_22.set_title("err. " + sx)
1319
1320        # TODO 7
1321        # for ax in axs.flatten():
1322        # Crop boundries (which should be 0, i.e. yield harsh q gradients):
1323        # lims = (1, nx-2)
1324        # step = (nx - 1)/8
1325        # ticks = arange(step,nx-1,step)
1326        # ax.set_xlim  (lims)
1327        # ax.set_ylim  (lims[::-1])
1328        # ax.set_xticks(ticks)
1329        # ax.set_yticks(ticks)
1330
1331        for im, clim in zip(ims, clims):
1332            im.set_clim(clim)
1333
1334        fig.colorbar(im_12, cax=ax_13)
1335        fig.colorbar(im_22, cax=ax_23)
1336        for ax in [ax_13, ax_23]:
1337            ax.yaxis.set_tick_params(
1338                "major", length=2, width=0.5, direction="in", left=True, right=True
1339            )
1340            ax.set_axisbelow("line")  # make ticks appear over colorbar patch
1341
1342        # Title
1343        title = "Streamfunction (" + sx + ")"
1344        fig.suptitle(title)
1345        # Time info
1346        text_t = ax_12.text(
1347            1,
1348            1.1,
1349            format_time(None, None, None),
1350            transform=ax_12.transAxes,
1351            family="monospace",
1352            ha="left",
1353        )
1354
1355        def update(key, E, P):
1356            k, ko, faus = key
1357            t = tt[k]
1358
1359            im_11.set_data(square(mu[key]))
1360            im_12.set_data(square(xx[k]))
1361            im_21.set_data(square(spread[key]))
1362            im_22.set_data(square(err[key]))
1363
1364            # Remove previous obs
1365            try:
1366                lh[0].remove()
1367            except ValueError:
1368                pass
1369            # Plot current obs.
1370            #  - plot() automatically adjusts to direction of y-axis in use.
1371            #  - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse.
1372
1373            if ko is not None and not_empty(obs_inds):
1374                lh[0] = ax_12.plot(*ind2sub(obs_inds(ko))[::-1], "k.", ms=1, zorder=5)[
1375                    0
1376                ]
1377
1378            text_t.set_text(format_time(k, ko, t))
1379
1380            return
1381
1382        return update
1383
1384    return init
default_liveplotters = [(1, <class 'sliding_diagnostics'>), (1, <class 'weight_histogram'>)]