
On-line (live) plots of the DA process for various models and methods.

Liveplotters are given by a list of tuples as property or arguments in dapper.mods.HiddenMarkovModel.

  • The first element of the tuple determines whether the liveplotter is shown if the names of liveplotters are not given by liveplots argument in assimilate.

  • The second element in the tuple gives the corresponding liveplotter function/class. See example of function LPs in dapper.mods.Lorenz63.

The liveplotters can be fine-tuned by each DA experiments via argument of liveplots when calling assimilate.

  • liveplots = True turns on liveplotters set to default in the first argument of the HMM.liveplotter and default liveplotters defined in this module (sliding_diagnostics and weight_histogram).

  • liveplots can also be a list of specified names of liveplotter, which is the name of the corresponding liveplotting classes/functions.

   1"""On-line (live) plots of the DA process for various models and methods.
   3Liveplotters are given by a list of tuples as property or arguments in
   6- The first element of the tuple determines whether the liveplotter is shown if
   7the names of liveplotters are not given by `liveplots` argument in
  10- The second element in the tuple gives the corresponding liveplotter
  11function/class. See example of function `LPs` in `dapper.mods.Lorenz63`.
  13The liveplotters can be fine-tuned by each DA experiments via argument of
  14`liveplots` when calling `assimilate`.
  16- `liveplots = True` turns on liveplotters set to default in the first
  17argument of the `HMM.liveplotter` and default liveplotters defined in this module
  18(`sliding_diagnostics` and `weight_histogram`).
  20- `liveplots` can also be a list of specified names of liveplotter, which
  21is the name of the corresponding liveplotting classes/functions.
  24import numpy as np
  25import scipy.linalg as sla
  26from matplotlib import pyplot as plt
  27from matplotlib.ticker import MaxNLocator
  28from mpl_toolkits.mplot3d.art3d import juggle_axes
  29from mpl_tools import is_notebook_or_qt, place, place_ax
  30from numpy import arange, nan, ones
  31from struct_tools import DotDict, deep_getattr
  33import dapper.tools.progressbar as pb
  34import dapper.tools.viz as viz
  35from dapper.dpr_config import rc
  36from dapper.mods.utils import linspace_int
  37from dapper.tools.chronos import format_time
  38from dapper.tools.matrices import CovMat
  39from dapper.tools.progressbar import read1
  40from dapper.tools.series import FAUSt, RollingArray
  41from dapper.tools.viz import not_available_text, plot_pause
  44class LivePlot:
  45    """Live plotting manager.
  47    Deals with
  49    - Pause, skip.
  50    - Which liveploters to call.
  51    - `plot_u`
  52    - Figure window (title and number).
  53    """
  55    def __init__(
  56        self,
  57        stats,
  58        liveplots,
  59        key0=(0, None, "u"),
  60        E=None,
  61        P=None,
  62        speed=1.0,
  63        replay=False,
  64        **kwargs,
  65    ):
  66        """
  67        Initialize plots.
  69        - liveplots: figures to plot; alternatives:
  70            - `"default"/[]/True`: All default figures for this HMM.
  71            - `"all"`            : Even more.
  72            - non-empty `list`   : Only the figures with these numbers
  73                                 (int) or names (str).
  74            - `False`            : None.
  75        - speed: speed of animation.
  76            - `>100`: instantaneous
  77            - `1`   : (default) as quick as possible allowing for
  78                      plt.draw() to work on a moderately fast computer.
  79            - `<1`  : slower.
  80        """
  81        # Disable if not rc.liveplotting
  82        self.any_figs = False
  83        if rc.liveplotting:
  84            pass
  85        elif replay and np.isinf(speed):
  86            pass
  87        else:
  88            return
  90        # Determine whether all/universal/intermediate stats are plotted
  91        self.plot_u = not replay or stats.store_u
  93        # Set speed/pause params
  94        self.params = {
  95            "pause_f": 0.05,
  96            "pause_a": 0.05,
  97            "pause_s": 0.05,
  98            "pause_u": 0.001,
  99        }
 100        # If speed>100: set to inf. Coz pause=1e-99 causes hangup.
 101        for pause in ["pause_" + x for x in "faus"]:
 102            speed = speed if speed < 100 else np.inf
 103            self.params[pause] /= speed
 105        # Write params
 106        self.params.update(getattr(stats.xp, "LP_kwargs", {}))
 107        self.params.update(kwargs)
 109        def get_name(init):
 110            """Get name of liveplotter function/class."""
 111            try:
 112                return init.__qualname__.split(".")[0]
 113            except AttributeError:
 114                return init.__class__.__name__
 116        # Set up dict of liveplotters
 117        potential_LPs = {}
 118        for show, init in default_liveplotters:
 119            potential_LPs[get_name(init)] = show, init
 120        # Add HMM-specific liveplotters
 121        for show, init in getattr(stats.HMM, "liveplotters", {}):
 122            potential_LPs[get_name(init)] = show, init
 124        def parse_figlist(lst):
 125            """Figures requested for this xp. Convert to list."""
 126            if isinstance(lst, str):
 127                fn = lst.lower()
 128                if "all" == fn:
 129                    lst = ["all"]  # All potential_LPs
 130                elif "default" in fn:
 131                    lst = ["default"]  # All show_by_default
 132            elif hasattr(lst, "__len__"):
 133                lst = lst  # This list (only)
 134            elif lst:
 135                lst = ["default"]  # All show_by_default
 136            else:
 137                lst = [None]  # None
 138            return lst
 140        figlist = parse_figlist(liveplots)
 142        # Loop over requeted figures
 143        self.figures = {}
 144        for name, (show_by_default, init) in potential_LPs.items():
 145            if (
 146                (figlist == ["all"])
 147                or (name in figlist)
 148                or (figlist == ["default"] and show_by_default)
 149            ):
 150                # Startup message
 151                if not self.any_figs:
 152                    print("Initializing liveplots...")
 153                    if is_notebook_or_qt:
 154                        pauses = [self.params["pause_" + x] for x in "faus"]
 155                        if any((p > 0) for p in pauses):
 156                            print(
 157                                "Note: liveplotting does not work very well"
 158                                " inside Jupyter notebooks. In particular,"
 159                                " there is no way to stop/skip them except"
 160                                " to interrupt the kernel (the stop button"
 161                                " in the toolbar). Consider using instead"
 162                                " only the replay functionality (with infinite"
 163                                " playback speed)."
 164                            )
 165                    elif not pb.disable_user_interaction:
 166                        print("Hit <Space> to pause/step.")
 167                        print("Hit <Enter> to resume/skip.")
 168                        print("Hit <i> to enter debug mode.")
 169                    self.paused = False
 170                    self.run_ipdb = False
 171                    self.skipping = False
 172                    self.any_figs = True
 174                # Init figure
 175                post_title = "" if self.plot_u else "\n(obs times only)"
 176                updater = init(name, stats, key0, self.plot_u, E, P, **kwargs)
 177                if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
 178                    self.figures[name] = updater
 179                    fig = plt.figure(name)
 180                    win = fig.canvas
 181                    ax0 = fig.axes[0]
 182                    win.manager.set_window_title(str(name))
 183                    ax0.set_title(ax0.get_title() + post_title)
 184                    self.update(key0, E, P)  # Call initial update
 185                    if not (replay and np.isinf(speed)):
 186                        plt.pause(0.01)  # Draw
 188    def update(self, key, E, P):
 189        """Update liveplots"""
 190        # Check if there are still open figures
 191        if self.any_figs:
 192            open_figns = plt.get_figlabels()
 193            live_figns = set(self.figures.keys())
 194            self.any_figs = bool(live_figns.intersection(open_figns))
 195        else:
 196            return
 198        # Playback control
 199        SPACE = b" "
 200        CHAR_I = b"i"
 201        ENTERs = [b"\n", b"\r"]  # Linux + Windows
 203        def pause():
 204            """Loop until user decision is made."""
 205            ch = read1()
 206            while True:
 207                # Set state (pause, skipping, ipdb)
 208                if ch in ENTERs:
 209                    self.paused = False
 210                elif ch == CHAR_I:
 211                    self.run_ipdb = True
 212                # If keypress valid, resume execution
 213                if ch in ENTERs + [SPACE, CHAR_I]:
 214                    break
 215                ch = read1()
 216                # Pause to enable zoom, pan, etc. of mpl GUI
 217                plot_pause(0.01)  # Don't use time.sleep()!
 219        # Enter pause loop
 220        if self.paused:
 221            pause()
 223        else:
 224            if key == (0, None, "u"):
 225                # Skip read1 for key0 (coz it blocks)
 226                pass
 227            else:
 228                ch = read1()
 229                if ch == SPACE:
 230                    # Pause
 231                    self.paused = True
 232                    self.skipping = False
 233                    pause()
 234                elif ch in ENTERs:
 235                    # Toggle skipping
 236                    self.skipping = not self.skipping
 237                elif ch == CHAR_I:
 238                    # Schedule debug
 239                    # Note: The reason we dont set_trace(frame) right here is:
 240                    # - I could not find the right frame, even doing
 241                    #   >   frame = inspect.stack()[0]
 242                    #   >   while frame.f_code.co_name != "assimilate":
 243                    #   >       frame = frame.f_back
 244                    # - It just restarts the plot.
 245                    self.run_ipdb = True
 247        # Update figures
 248        if not self.skipping:
 249            faus = key[-1]
 250            if faus != "u" or self.plot_u:
 251                for name, (updater) in self.figures.items():
 252                    if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
 253                        _ = plt.figure(name)
 254                        updater(key, E, P)
 255                        plot_pause(self.params["pause_" + faus])
 257        if self.run_ipdb:
 258            self.run_ipdb = False
 259            import inspect
 261            import ipdb
 263            print("Entering debug mode (ipdb).")
 264            print("Type '?' (and Enter) for usage help.")
 265            print("Type 'c' to continue the assimilation.")
 266            ipdb.set_trace(inspect.stack()[2].frame)
 269# TODO 6:
 270# - iEnKS diagnostics don't work at all when store_u=False
 271star = "${}^*$"
 274class sliding_diagnostics:
 275    """Plots a sliding window (like a heart rate monitor) of certain diagnostics."""
 277    def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs):
 278        # STYLE TABLES - Defines which/how diagnostics get plotted
 279        styles = {}
 281        def lin(a, b):
 282            return lambda x: a + b * x
 284        divN = 1 / getattr(stats.xp, "N", 99)
 285        # Columns: transf, shape, plt kwargs
 286        styles["RMS"] = {
 287            "err.rms": [None, None, dict(c="k", label="Error")],
 288            "spread.rms": [None, None, dict(c="b", label="Spread", alpha=0.6)],
 289        }
 290        styles["Values"] = {
 291            "skew": [None, None, dict(c="g", label=star + r"Skew/$\sigma^3$")],
 292            "kurt": [None, None, dict(c="r", label=star + r"Kurt$/\sigma^4{-}3$")],
 293            "trHK": [None, None, dict(c="k", label=star + "HK")],
 294            "infl": [lin(-10, 10), "step", dict(c="c", label="10(infl-1)")],
 295            "N_eff": [lin(0, divN), "dirac", dict(c="y", label="N_eff/N", lw=3)],
 296            "iters": [lin(0, 0.1), "dirac", dict(c="m", label="iters/10")],
 297            "resmpl": [None, "dirac", dict(c="k", label="resampled?")],
 298        }
 300        nAx = len(styles)
 301        GS = {"left": 0.125, "right": 0.76}
 302        fig, axs = place.freshfig(
 303            fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS
 304        )
 306        axs[0].set_title("Diagnostics")
 307        for style, ax in zip(styles, axs):
 308            ax.set_ylabel(style)
 309        ax.set_xlabel("Time (t)")
 310        place_ax.adjust_position(ax, y0=0.03)
 312        self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq)
 314        def init_ax(ax, style_table):
 315            lines = {}
 316            for name in style_table:
 317                # SKIP -- if stats[name] is not in existence
 318                # Note: The nan check/deletion comes after the first ko.
 319                try:
 320                    stat = deep_getattr(stats, name)
 321                except AttributeError:
 322                    continue
 323                # try: val0 = stat[key0[0]]
 324                # except KeyError: continue
 325                # PS: recall (from series.py) that even if store_u is false, stat[k] is
 326                # still present if liveplots=True via the k_tmp functionality.
 328                # Unpack style
 329                ln = {}
 330                ln["transf"] = style_table[name][0] or (lambda x: x)
 331                ln["shape"] = style_table[name][1]
 332                ln["plt"] = style_table[name][2]
 334                # Create series
 335                if isinstance(stat, FAUSt):
 336                    ln["plot_u"] = plot_u
 337                    K_plot = comp_K_plot(K_lag, a_lag, ln["plot_u"])
 338                else:
 339                    ln["plot_u"] = False
 340                    K_plot = a_lag
 341                ln["data"] = RollingArray(K_plot)
 342                ln["tt"] = RollingArray(K_plot)
 344                # Plot (init)
 345                (ln["handle"],) = ax.plot(ln["tt"], ln["data"], **ln["plt"])
 347                # Plotting only nans yield ugly limits. Revert to defaults.
 348                ax.set_xlim(0, 1)
 349                ax.set_ylim(0, 1)
 351                lines[name] = ln
 352            return lines
 354        # Plot
 355        self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)]
 357        # Horizontal line at y=0
 358        (self.baseline0,) = ax.plot(
 359            ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label="_nolegend_"
 360        )
 362        # Store
 363        self.axs = axs
 364        self.stats = stats
 365        self.init_incomplete = True
 367    # Update plot
 368    def __call__(self, key, E, P):
 369        k, ko, faus = key
 371        stats = self.stats
 372        tseq = stats.HMM.tseq
 373        ax0, ax1 = self.axs
 375        def update_arrays(lines):
 376            for name, ln in lines.items():
 377                stat = deep_getattr(stats, name)
 378                t = tseq.tt[k]  # == tseq.tto[ko]
 379                if isinstance(stat, FAUSt):
 380                    # ln['data'] will contain duplicates for f/a times.
 381                    if ln["plot_u"]:
 382                        val = stat[key]
 383                        ln["tt"].insert(k, t)
 384                        ln["data"].insert(k, ln["transf"](val))
 385                    elif "u" not in faus:
 386                        val = stat[key]
 387                        ln["tt"].insert(ko, t)
 388                        ln["data"].insert(ko, ln["transf"](val))
 389                else:
 390                    # ln['data'] will not contain duplicates, coz only 'a' is input.
 391                    if "a" in faus:
 392                        val = stat[ko]
 393                        ln["tt"].insert(ko, t)
 394                        ln["data"].insert(ko, ln["transf"](val))
 395                    elif "f" in faus:
 396                        pass
 398        def update_plot_data(ax, lines):
 399            def bend_into(shape, xx, yy):
 400                # Get arrays. Repeat (to use for intermediate nodes).
 401                yy = yy.array.repeat(3)
 402                xx = xx.array.repeat(3)
 403                if len(xx) == 0:
 404                    pass  # shortcircuit any modifications
 405                elif shape == "step":
 406                    yy = np.hstack([yy[1:], nan])  # roll leftward
 407                elif shape == "dirac":
 408                    nonlocal nDirac
 409                    axW = np.diff(ax.get_xlim())
 410                    yy[0::3] = False  # set datapoin to 0
 411                    xx[2::3] = nan  # make datapoint disappear
 412                    xx += nDirac * axW / 100  # offset datapoint horizontally
 413                    nDirac += 1
 414                return xx, yy
 416            nDirac = 1
 417            for _name, ln in lines.items():
 418                ln["handle"].set_data(*bend_into(ln["shape"], ln["tt"], ln["data"]))
 420        def finalize_init(ax, lines, mm):
 421            # Rm lines that only contain NaNs
 422            for name in list(lines):
 423                ln = lines[name]
 424                stat = deep_getattr(stats, name)
 425                if not stat.were_changed:
 426                    ln["handle"].remove()  # rm from axes
 427                    del lines[name]  # rm from dict
 428            # Add legends
 429            if lines:
 430                ax.legend(loc="upper left", bbox_to_anchor=(1.01, 1), borderaxespad=0)
 431                if mm:
 432                    ax.annotate(
 433                        star + ": mean of\nmarginals",
 434                        xy=(0, -1.5 / len(lines)),
 435                        xycoords=ax.get_legend().get_frame(),
 436                        bbox=dict(alpha=0.0),
 437                        fontsize="small",
 438                    )
 439            # coz placement of annotate needs flush sometimes:
 440            plot_pause(0.01)
 442        # Insert current stats
 443        for lines, ax in zip(self.d, self.axs):
 444            update_arrays(lines)
 445            update_plot_data(ax, lines)
 447        # Set x-limits (time)
 448        sliding_xlim(ax0, self.d[0]["err.rms"]["tt"], self.T_lag, margin=True)
 449        self.baseline0.set_xdata(ax0.get_xlim())
 451        # Set y-limits
 452        data0 = [ln["data"].array for ln in self.d[0].values()]
 453        data1 = [ln["data"].array for ln in self.d[1].values()]
 454        ax0.set_ylim(0, d_ylim(data0, ax0, cC=0.2, cE=0.9)[1])
 455        ax1.set_ylim(*d_ylim(data1, ax1, Max=4, Min=-4, cC=0.3, cE=0.9))
 457        # Init legend. Rm nan lines.
 458        if self.init_incomplete and "a" == faus:
 459            self.init_incomplete = False
 460            finalize_init(ax0, self.d[0], False)
 461            finalize_init(ax1, self.d[1], True)
 464def sliding_xlim(ax, tt, lag, margin=False):
 465    dt = lag / 20 if margin else 0
 466    if tt.nFilled == 0:
 467        return  # Quit
 468    t1, t2 = tt.span()  # Get suggested span.
 469    s1, s2 = ax.get_xlim()  # Get previous lims.
 470    # If zero span (eg tt holds single 'f' and 'a'):
 471    if t1 == t2:
 472        t1 -= 1  # add width
 473        t2 += 1  # add width
 474    # If user has skipped (too much):
 475    elif np.isnan(t1):
 476        s2 -= dt  # Correct for dt.
 477        span = s2 - s1  # Compute previous span
 478        # If span<lag:
 479        if span < lag:
 480            span += t2 - s2  # Grow by "dt".
 481        span = min(lag, span)  # Bound
 482        t1 = t2 - span  # Set span.
 483    ax.set_xlim(t1, t2 + dt)  # Set xlim to span
 486class weight_histogram:
 487    """Plots histogram of weights. Refreshed each analysis."""
 489    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
 490        if not hasattr(stats, "w"):
 491            self.is_active = False
 492            return
 493        fig, ax = place.freshfig(fignum, figsize=(7, 3), gridspec_kw={"bottom": 0.15})
 495        ax.set_xscale("log")
 496        ax.set_xlabel("Weigth")
 497        ax.set_ylabel("Count")
 498        self.stats = stats
 499        self.ax = ax
 500        self.hist = []
 501        self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
 503    def __call__(self, key, E, P):
 504        k, ko, faus = key
 505        if "a" == faus:
 506            w = self.stats.w[key]
 507            N = len(w)
 508            ax = self.ax
 510            self.is_active = N < 10001
 511            if not self.is_active:
 512                not_available_text(ax, "Not computed (N > threshold)")
 513                return
 515            counted = w > self.bins[0]
 516            _ = [b.remove() for b in self.hist]
 517            nn, _, self.hist = ax.hist(w[counted], bins=self.bins, color="b")
 518            ax.set_ylim(top=max(nn))
 520            ax.set_title(
 521                f"N: {N:d}.   N_eff: {1/(w@w):.4g}."
 522                "   Not shown: {N-np.sum(counted):d}. "
 523            )
 526class spectral_errors:
 527    """Plots the (spatial-RMS) error as a functional of the SVD index."""
 529    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
 530        fig, ax = place.freshfig(fignum, figsize=(6, 3))
 531        ax.set_xlabel("Sing. value index")
 532        ax.set_yscale("log")
 533        self.init_incomplete = True
 534        self.ax = ax
 535        self.plot_u = plot_u
 537        try:
 538            self.msft = stats.umisf
 539            self.sprd = stats.svals
 540        except AttributeError:
 541            self.is_active = False
 542            not_available_text(ax, "Spectral stats not being computed")
 544    # Update plot
 545    def __call__(self, key, E, P):
 546        k, ko, faus = key
 547        ax = self.ax
 548        if self.init_incomplete:
 549            if self.plot_u or "f" == faus:
 550                self.init_incomplete = False
 551                msft = abs(self.msft[key])
 552                sprd = self.sprd[key]
 553                if np.any(np.isinf(msft)):
 554                    not_available_text(ax, "Spectral stats not finite")
 555                    self.is_active = False
 556                else:
 557                    (self.line_msft,) = ax.plot(msft, "k", lw=2, label="Error")
 558                    (self.line_sprd,) = ax.plot(
 559                        sprd, "b", lw=2, label="Spread", alpha=0.9
 560                    )
 561                    ax.get_xaxis().set_major_locator(MaxNLocator(integer=True))
 562                    ax.legend()
 563        else:
 564            msft = abs(self.msft[key])
 565            sprd = self.sprd[key]
 566            self.line_sprd.set_ydata(sprd)
 567            self.line_msft.set_ydata(msft)
 568        # ax.set_ylim(*d_ylim(msft))
 569        # ax.set_ylim(bottom=1e-5)
 570        ax.set_ylim([1e-3, 1e1])
 573class correlations:
 574    """Plots the state (auto-)correlation matrix."""
 576    half = True  # Whether to show half/full (symmetric) corr matrix.
 578    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
 579        GS = {"height_ratios": [4, 1], "hspace": 0.09, "top": 0.95}
 580        fig, (ax, ax2) = place.freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS)
 582        if E is None and np.isnan(P.diag if isinstance(P, CovMat) else P).all():
 583            not_available_text(
 584                ax, ("Not available in replays" "\ncoz full Ens/Cov not stored.")
 585            )
 586            self.is_active = False
 587            return
 589        Nx = len(stats.mu[key0])
 590        if Nx <= 1003:
 591            C = np.eye(Nx)
 592            # Mask half
 593            mask = np.zeros_like(C, dtype=bool)
 594            mask[np.tril_indices_from(mask)] = True
 595            cmap = plt.get_cmap("RdBu_r")
 596            VM = 1.0  # abs(np.percentile(C,[1,99])).max()
 597            im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM)
 598            # Colorbar
 599            _ = ax.figure.colorbar(im, ax=ax, shrink=0.8)
 600            # Tune plot
 601            plt.box(False)
 602            ax.set_facecolor("w")
 603            ax.grid(False)
 604            ax.set_title("State correlation matrix:", y=1.07)
 605            ax.xaxis.tick_top()
 607            # ax2 = inset_axes(ax,width="30%",height="60%",loc=3)
 608            (line_AC,) = ax2.plot(arange(Nx), ones(Nx), label="Correlation")
 609            (line_AA,) = ax2.plot(arange(Nx), ones(Nx), label="Abs. corr.")
 610            _ = ax2.hlines(0, 0, Nx - 1, "k", "dotted", lw=1)
 611            # Align ax2 with ax
 612            bb_AC = ax2.get_position()
 613            bb_C = ax.get_position()
 614            ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height])
 615            # Tune plot
 616            ax2.set_title("Auto-correlation:")
 617            ax2.set_ylabel("Mean value")
 618            ax2.set_xlabel("Distance (in state indices)")
 619            ax2.set_xticklabels([])
 620            ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]]))
 621            ax2.set_ylim(top=1)
 622            ax2.legend(
 623                frameon=True,
 624                facecolor="w",
 625                bbox_to_anchor=(1, 1),
 626                loc="upper left",
 627                borderaxespad=0.02,
 628            )
 630            self.ax = ax
 631            self.ax2 = ax2
 632            self.im = im
 633            self.line_AC = line_AC
 634            self.line_AA = line_AA
 635            self.mask = mask
 636            if hasattr(stats, "w"):
 637                self.w = stats.w
 638        else:
 639            not_available_text(ax)
 641    # Update plot
 642    def __call__(self, key, E, P):
 643        # Get cov matrix
 644        if E is not None:
 645            if hasattr(self, "w"):
 646                C = np.cov(E, rowvar=False, aweights=self.w[key])
 647            else:
 648                C = np.cov(E, rowvar=False)
 649        else:
 650            assert P is not None
 651            C = P.full if isinstance(P, CovMat) else P
 652            C = C.copy()
 653        # Compute corr from cov
 654        std = np.sqrt(np.diag(C))
 655        C /= std[:, None]
 656        C /= std[None, :]
 657        # Mask
 658        if self.half:
 659            C = np.ma.masked_where(self.mask, C)
 660        # Plot
 661        self.im.set_data(C)
 662        # Auto-corr function
 663        ACF = circulant_ACF(C)
 664        AAF = circulant_ACF(C, do_abs=True)
 665        self.line_AC.set_ydata(ACF)
 666        self.line_AA.set_ydata(AAF)
 669def circulant_ACF(C, do_abs=False):
 670    """Compute the auto-covariance-function corresponding to `C`.
 672    This assumes it is the cov/corr matrix of a 1D periodic domain.
 674    Vectorized or FFT implementations are
 675    [possible](https://stackoverflow.com/questions/20360675).
 676    """
 677    M = len(C)
 678    # cols = np.flipud(sla.circulant(np.arange(M)[::-1]))
 679    cols = sla.circulant(np.arange(M))
 680    ACF = np.zeros(M)
 681    for i in range(M):
 682        row = C[i, cols[i]]
 683        if do_abs:
 684            row = abs(row)
 685        ACF += row
 686        # Note: this actually also accesses masked values in C.
 687    return ACF / M
 690def sliding_marginals(
 691    obs_inds=(),
 692    dims=(),
 693    labels=(),
 694    Tplot=None,
 695    ens_props=dict(alpha=0.4),  # noqa
 696    zoomy=1.0,
 698    # Store parameters
 699    params_orig = DotDict(**locals())
 701    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
 702        xx, yy, mu, spread, tseq = (
 703            stats.xx,
 704            stats.yy,
 705            stats.mu,
 706            stats.spread,
 707            stats.HMM.tseq,
 708        )
 710        # Set parameters (kwargs takes precedence over params_orig)
 711        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
 713        # Chose marginal dims to plot
 714        if not len(p.dims):
 715            p.dims = linspace_int(xx.shape[-1], min(10, xx.shape[-1]))
 717        # Lag settings:
 718        T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
 719        K_plot = comp_K_plot(K_lag, a_lag, plot_u)
 720        # Extend K_plot forther for adding blanks in resampling (PartFilt):
 721        has_w = hasattr(stats, "w")
 722        if has_w:
 723            K_plot += a_lag
 725        # Set up figure, axes
 726        fig, axs = place.freshfig(
 727            fignum, figsize=(5, 7), squeeze=False, nrows=len(p.dims), sharex=True
 728        )
 729        axs = axs.reshape(len(p.dims))
 731        # Tune plots
 732        axs[0].set_title("Marginal time series")
 733        for ix, (m, ax) in enumerate(zip(p.dims, axs)):
 734            # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy))
 735            if not p.labels:
 736                ax.set_ylabel("$x_{%d}$" % m)
 737            else:
 738                ax.set_ylabel(p.labels[ix])
 739        axs[-1].set_xlabel("Time (t)")
 741        plot_pause(0.05)
 742        plt.tight_layout()
 744        # Allocate
 745        d = DotDict()  # data arrays
 746        h = DotDict()  # plot handles
 747        # Why "if True" ? Just to indent the rest of the line...
 748        if True:
 749            d.t = RollingArray((K_plot,))
 750        if True:
 751            d.x = RollingArray((K_plot, len(p.dims)))
 752            h.x = []
 753        if not_empty(p.obs_inds):
 754            d.y = RollingArray((K_plot, len(p.dims)))
 755            h.y = []
 756        if E is not None:
 757            d.E = RollingArray((K_plot, len(E), len(p.dims)))
 758            h.E = []
 759        if P is not None:
 760            d.mu = RollingArray((K_plot, len(p.dims)))
 761            h.mu = []
 762        if P is not None:
 763            d.s = RollingArray((K_plot, 2, len(p.dims)))
 764            h.s = []
 766        # Plot (invisible coz everything here is nan, for the moment).
 767        for ix, ax in zip(p.dims, axs):
 768            if True:
 769                h.x += ax.plot(d.t, d.x[:, ix], "k")
 770            if not_empty(p.obs_inds):
 771                h.y += ax.plot(d.t, d.y[:, ix], "g*", ms=10)
 772            if "E" in d:
 773                h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)]
 774            if "mu" in d:
 775                h.mu += ax.plot(d.t, d.mu[:, ix], "b")
 776            if "s" in d:
 777                h.s += [ax.plot(d.t, d.s[:, :, ix], "b--", lw=1)]
 779        def update(key, E, P):
 780            k, ko, faus = key
 782            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)
 784            # Roll data array
 785            ind = k if plot_u else ko
 786            for Ens in EE:  # If E is duplicated, so must the others be.
 787                if "E" in d:
 788                    d.E.insert(ind, Ens)
 789                if "mu" in d:
 790                    d.mu.insert(ind, mu[key][p.dims])
 791                if "s" in d:
 792                    d.s.insert(ind, mu[key][p.dims] + [[1], [-1]] * spread[key][p.dims])
 793                if True:
 794                    d.t.insert(ind, tseq.tt[k])
 795                if not_empty(p.obs_inds):
 796                    xy = nan * ones(len(p.dims))
 797                    if ko is not None:
 798                        jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
 799                        xy[jj] = yy[ko]
 800                    d.y.insert(ind, xy)
 801                if True:
 802                    d.x.insert(ind, xx[k, p.dims])
 804            # Update graphs
 805            for ix, ax in zip(p.dims, axs):
 806                sliding_xlim(ax, d.t, T_lag, True)
 807                if True:
 808                    h.x[ix].set_data(d.t, d.x[:, ix])
 809                if not_empty(p.obs_inds):
 810                    h.y[ix].set_data(d.t, d.y[:, ix])
 811                if "mu" in d:
 812                    h.mu[ix].set_data(d.t, d.mu[:, ix])
 813                if "s" in d:
 814                    [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]]
 815                if "E" in d:
 816                    [h.E[ix][n].set_data(d.t, d.E[:, n, ix]) for n in range(len(E))]
 817                if "E" in d:
 818                    update_alpha(key, stats, h.E[ix])
 820                # TODO 3: fixup. This might be slow?
 821                # In any case, it is very far from tested.
 822                # Also, relim'iting all of the time is distracting.
 823                # Use d_ylim?
 824                if "E" in d:
 825                    lims = d.E
 826                elif "mu" in d:
 827                    lims = d.mu
 828                lims = np.array(viz.xtrema(lims[..., ix]))
 829                if lims[0] == lims[1]:
 830                    lims += [-0.5, +0.5]
 831                ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy))
 833            return
 835        return update
 837    return init
 840def phase_particles(
 841    is_3d=True,
 842    obs_inds=(),
 843    dims=(),
 844    labels=(),
 845    Tplot=None,
 846    ens_props=dict(alpha=0.4),  # noqa
 847    zoom=1.5,
 849    # Store parameters
 850    params_orig = DotDict(**locals())
 852    M = 3 if is_3d else 2
 854    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
 855        xx, yy, mu, _, tseq = stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq
 857        # Set parameters (kwargs takes precedence over params_orig)
 858        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
 860        # Lag settings:
 861        has_w = hasattr(stats, "w")
 862        if p.Tplot == 0:
 863            K_plot = 1
 864        else:
 865            T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
 866            K_plot = comp_K_plot(K_lag, a_lag, plot_u)
 867            # Extend K_plot forther for adding blanks in resampling (PartFilt):
 868            if has_w:
 869                K_plot += a_lag
 871        # Dimension settings
 872        if not p.dims:
 873            p.dims = arange(M)
 874        if not p.labels:
 875            p.labels = ["$x_%d$" % d for d in p.dims]
 876        assert len(p.dims) == M
 878        # Set up figure, axes
 879        fig, _ = place.freshfig(fignum, figsize=(5, 5))
 880        ax = plt.subplot(111, projection="3d" if is_3d else None)
 881        ax.set_facecolor("w")
 882        ax.set_title("Phase space trajectories")
 883        # Tune plot
 884        for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")):
 885            viz.set_ilim(ax, ind, *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom))
 886            eval(f"ax.set_{t}label('{s!s}')")
 888        # Allocate
 889        d = DotDict()  # data arrays
 890        h = DotDict()  # plot handles
 891        s = DotDict()  # scatter handles
 892        if E is not None:
 893            d.E = RollingArray((K_plot, len(E), M))
 894            h.E = []
 895        if P is not None:
 896            d.mu = RollingArray((K_plot, M))
 897        if True:
 898            d.x = RollingArray((K_plot, M))
 899        if not_empty(p.obs_inds):
 900            d.y = RollingArray((K_plot, M))
 902        # Plot tails (invisible coz everything here is nan, for the moment).
 903        if "E" in d:
 904            h.E += [
 905                ax.plot(*xn, **p.ens_props)[0] for xn in np.transpose(d.E, [1, 2, 0])
 906            ]
 907        if "mu" in d:
 908            h.mu = ax.plot(*d.mu.T, "b", lw=2)[0]
 909        if True:
 910            h.x = ax.plot(*d.x.T, "k", lw=3)[0]
 911        if "y" in d:
 912            h.y = ax.plot(*d.y.T, "g*", ms=14)[0]
 914        # Scatter. NB: don't init with nan's coz it's buggy
 915        # (wrt. get_color() and _offsets3d) since mpl 3.1.
 916        if "E" in d:
 917            s.E = ax.scatter(*E.T[p.dims], s=3**2, c=[hn.get_color() for hn in h.E])
 918        if "mu" in d:
 919            s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()])
 920        if True:
 921            s.x = ax.scatter(
 922                *ones(M), s=14**2, c=[h.x.get_color()], marker=(5, 1), zorder=99
 923            )
 925        def update(key, E, P):
 926            k, ko, faus = key
 928            def update_tail(handle, newdata):
 929                handle.set_data(newdata[:, 0], newdata[:, 1])
 930                if is_3d:
 931                    handle.set_3d_properties(newdata[:, 2])
 933            def update_sctr(handle, newdata):
 934                if is_3d:
 935                    handle._offsets3d = juggle_axes(*newdata.T, "z")
 936                else:
 937                    handle.set_offsets(newdata)
 939            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)
 941            # Roll data array
 942            ind = k if plot_u else ko
 943            for Ens in EE:  # If E is duplicated, so must the others be.
 944                if "E" in d:
 945                    d.E.insert(ind, Ens)
 946                if True:
 947                    d.x.insert(ind, xx[k, p.dims])
 948                if "y" in d:
 949                    xy = nan * ones(len(p.dims))
 950                    if ko is not None:
 951                        jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
 952                        jj = list(jj)
 953                        for i, dim in enumerate(p.dims):
 954                            try:
 955                                iobs = jj.index(dim)
 956                            except ValueError:
 957                                pass
 958                            else:
 959                                xy[i] = yy[ko][iobs]
 960                    d.y.insert(ind, xy)
 961                if "mu" in d:
 962                    d.mu.insert(ind, mu[key][p.dims])
 964            # Update graph
 965            update_sctr(s.x, d.x[[-1]])
 966            update_tail(h.x, d.x)
 967            if "y" in d:
 968                update_tail(h.y, d.y)
 969            if "mu" in d:
 970                update_sctr(s.mu, d.mu[[-1]])
 971                update_tail(h.mu, d.mu)
 972            else:
 973                update_sctr(s.E, d.E[-1])
 974                for n in range(len(E)):
 975                    update_tail(h.E[n], d.E[:, n, :])
 976                update_alpha(key, stats, h.E, s.E)
 978            return
 980        return update
 982    return init
 985def validate_lag(Tplot, tseq):
 986    """Return validated `T_lag` such that is is:
 988    - equal to `Tplot` with fallback: `HMM.tseq.Tplot`.
 989    - no longer than `HMM.tseq.T`.
 991    Also return corresponding `K_lag`, `a_lag`.
 992    """
 993    # Defaults
 994    if Tplot is None:
 995        Tplot = tseq.Tplot
 997    # Rename
 998    T_lag = Tplot
1000    assert T_lag >= 0
1002    # Validate T_lag
1003    t2 = tseq.tt[-1]
1004    t1 = max(tseq.tt[0], t2 - T_lag)
1005    T_lag = t2 - t1
1007    K_lag = int(T_lag / tseq.dt) + 1  # Lag in indices
1008    a_lag = K_lag // tseq.dko + 1  # Lag in obs indices
1010    return T_lag, K_lag, a_lag
1013def comp_K_plot(K_lag, a_lag, plot_u):
1014    K_plot = 2 * a_lag  # Sum of lags of {f,a} series.
1015    if plot_u:
1016        K_plot += K_lag  # Add lag of u series.
1017    return K_plot
1020def update_alpha(key, stats, lines, scatters=None):
1021    """Adjust color alpha (for particle filters)."""
1022    k, ko, faus = key
1023    if ko is None:
1024        return
1025    if faus == "f":
1026        return
1027    if not hasattr(stats, "w"):
1028        return
1030    # Compute alpha values
1031    w = stats.w[key]
1032    alpha = (w / w.max()).clip(0.1, 0.4)
1034    # Set line alpha
1035    for line, a in zip(lines, alpha):
1036        line.set_alpha(a)
1038    # Scatter plot does not have alpha. => Fake it.
1039    if scatters is not None:
1040        colors = scatters.get_facecolor()[:, :3]
1041        if len(colors) == 1:
1042            colors = colors.repeat(len(w), axis=0)
1043        scatters.set_color(np.hstack([colors, alpha[:, None]]))
1046def not_empty(xx):
1047    """Works for non-iterable and iterables (including ndarrays)."""
1048    try:
1049        return len(xx) > 0
1050    except TypeError:
1051        return bool(xx)
1054def duplicate_with_blanks_for_resampled(E, dims, key, has_w):
1055    """Particle filter: insert breaks for resampled particles."""
1056    if E is None:
1057        return [E]
1058    EE = []
1059    E = E[:, dims]
1060    if has_w:
1061        k, ko, faus = key
1062        if faus == "f":
1063            pass
1064        elif faus == "a":
1065            _Ea[0] = E[:, 0]  # Store (1st dim of) ens.
1066        elif faus == "u" and ko is not None:
1067            # Find resampled particles. Insert duplicate ensemble. Write nans (breaks).
1068            resampled = _Ea[0] != E[:, 0]  # Mark as resampled if ens changed.
1069            # Insert current ensemble (copy to avoid overwriting).
1070            EE.append(E.copy())
1071            EE[0][resampled] = nan  # Write breaks
1072    # Always: append current ensemble
1073    EE.append(E)
1074    return EE
1077_Ea = [None]  # persistent storage for ens
1080def d_ylim(data, ax=None, cC=0, cE=1, pp=(1, 99), Min=-1e20, Max=+1e20):
1081    """Provide new ylim's intelligently, from percentiles of the data.
1083    - `data`: iterable of arrays for computing percentiles.
1084    - `pp`: percentiles
1086    - `ax`: If present, then the delta_zoom in/out is also considered.
1088      - `cE`: exansion (widenting) rate ∈ [0,1].
1089        Default: 1, which immediately expands to percentile.
1090      - `cC`: compression (narrowing) rate ∈ [0,1].
1091        Default: 0, which does not allow compression.
1093    - `Min`/`Max`: bounds
1095    Despite being a little involved,
1096    the cost of this subroutine is typically not substantial
1097    because there's usually not that much data to sort through.
1098    """
1099    # Find "reasonable" limits (by percentiles), looping over data
1100    maxv = minv = -np.inf  # init
1101    for d in data:
1102        d = d[np.isfinite(d)]
1103        if len(d):
1104            perc = np.array([-1, 1]) * np.percentile(d, pp)
1105            minv, maxv = np.maximum([minv, maxv], perc)
1106    minv *= -1
1108    # Pry apart equal values
1109    if np.isclose(minv, maxv):
1110        maxv += 0.5
1111        minv -= 0.5
1113    # Make the zooming transition smooth
1114    if ax is not None:
1115        current = ax.get_ylim()
1116        # Set rate factor as compress or expand factor.
1117        c0 = cC if minv > current[0] else cE
1118        c1 = cC if maxv < current[1] else cE
1119        # Adjust
1120        minv = np.interp(c0, (0, 1), (current[0], minv))
1121        maxv = np.interp(c1, (0, 1), (current[1], maxv))
1123    # Bounds
1124    maxv = min(Max, maxv)
1125    minv = max(Min, minv)
1127    # Set (if anything's changed)
1128    def worth_updating(a, b, curr):
1129        # Note: should depend on cC and cE
1130        d = abs(curr[1] - curr[0])
1131        lower = abs(a - curr[0]) > 0.002 * d
1132        upper = abs(b - curr[1]) > 0.002 * d
1133        return lower and upper
1135    # if worth_updating(minv,maxv,current):
1136    # ax.set_ylim(minv,maxv)
1138    # Some mpl versions don't handle inf limits.
1139    if not np.isfinite(minv):
1140        minv = None
1141    if not np.isfinite(maxv):
1142        maxv = None
1144    return minv, maxv
1147def spatial1d(
1148    obs_inds=(),
1149    periodicity=None,
1150    dims=(),
1151    ens_props={"color": "b", "alpha": 0.1},  # noqa
1152    conf_mult=None,
1154    # Store parameters
1155    params_orig = DotDict(**locals())
1157    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
1158        xx, yy, mu = stats.xx, stats.yy, stats.mu
1160        # Set parameters (kwargs takes precedence over params_orig)
1161        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
1163        if not p.dims:
1164            M = xx.shape[-1]
1165            p.dims = arange(M)
1166        else:
1167            M = len(p.dims)
1169        # Make periodic wrapper
1170        ii, wrap = viz.setup_wrapping(M, p.periodicity)
1172        # Set up figure, axes
1173        fig, ax = place.freshfig(fignum, figsize=(8, 5))
1174        fig.suptitle("1d amplitude plot")
1176        # Nans
1177        nan1 = wrap(nan * ones(M))
1179        if E is None and p.conf_mult is None:
1180            p.conf_mult = 2
1182        # Init plots
1183        if p.conf_mult:
1184            lines_s = ax.plot(
1185                ii, nan1, "b-", lw=1, label=(str(p.conf_mult) + r"$\sigma$ conf")
1186            )
1187            lines_s += ax.plot(ii, nan1, "b-", lw=1)
1188            (line_mu,) = ax.plot(ii, nan1, "b-", lw=2, label="DA mean")
1189        else:
1190            nanE = nan * ones((stats.xp.N, M))
1191            lines_E = ax.plot(ii, wrap(nanE[0]), **p.ens_props, lw=1, label="Ensemble")
1192            lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1)
1193        # Truth, Obs
1194        (line_x,) = ax.plot(ii, nan1, "k-", lw=3, label="Truth")
1195        if not_empty(p.obs_inds):
1196            (line_y,) = ax.plot(ii, nan1, "g*", ms=5, label="Obs")
1198        # Tune plot
1199        ax.set_ylim(*viz.xtrema(xx))
1200        ax.set_xlim(viz.stretch(ii[0], ii[-1], 1))
1201        # Xticks
1202        xt = ax.get_xticks()
1203        xt = xt[abs(xt % 1) < 0.01].astype(int)  # Keep only the integer ticks
1204        xt = xt[xt >= 0]
1205        xt = xt[xt < len(p.dims)]
1206        ax.set_xticks(xt)
1207        ax.set_xticklabels(p.dims[xt])
1209        ax.set_xlabel("State index")
1210        ax.set_ylabel("Value")
1211        ax.legend(loc="upper right")
1213        text_t = ax.text(
1214            0.01,
1215            0.01,
1216            format_time(None, None, None),
1217            transform=ax.transAxes,
1218            family="monospace",
1219            ha="left",
1220        )
1222        def update(key, E, P):
1223            k, ko, faus = key
1225            if p.conf_mult:
1226                sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]]
1227                lines_s[0].set_ydata(wrap(sigma[0, p.dims]))
1228                lines_s[1].set_ydata(wrap(sigma[1, p.dims]))
1229                line_mu.set_ydata(wrap(mu[key][p.dims]))
1230            else:
1231                for n, line in enumerate(lines_E):
1232                    line.set_ydata(wrap(E[n, p.dims]))
1233                update_alpha(key, stats, lines_E)
1235            line_x.set_ydata(wrap(xx[k, p.dims]))
1237            text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k]))
1239            if "f" in faus:
1240                if not_empty(p.obs_inds):
1241                    xy = nan * ones(len(xx[0]))
1242                    jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
1243                    xy[jj] = yy[ko]
1244                    line_y.set_ydata(wrap(xy[p.dims]))
1245                    line_y.set_zorder(5)
1246                    line_y.set_visible(True)
1248            if "u" in faus:
1249                if not_empty(p.obs_inds):
1250                    line_y.set_visible(False)
1252            return
1254        return update
1256    return init
1259def spatial2d(
1260    square,
1261    ind2sub,
1262    obs_inds=(),
1263    cm=plt.cm.jet,
1264    clims=((-40, 40), (-40, 40), (-10, 10), (-10, 10)),
1266    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
1267        GS = {"left": 0.125 - 0.04, "right": 0.9 - 0.04}
1268        fig, axs = place.freshfig(
1269            fignum,
1270            figsize=(6, 6),
1271            nrows=2,
1272            ncols=2,
1273            sharex=True,
1274            sharey=True,
1275            gridspec_kw=GS,
1276        )
1278        for ax in axs.flatten():
1279            ax.set_aspect("equal", "box")
1281        ((ax_11, ax_12), (ax_21, ax_22)) = axs
1283        ax_11.grid(color="w", linewidth=0.2)
1284        ax_12.grid(color="w", linewidth=0.2)
1285        ax_21.grid(color="k", linewidth=0.1)
1286        ax_22.grid(color="k", linewidth=0.1)
1288        # Upper colorbar -- position relative to ax_12
1289        bb = ax_12.get_position()
1290        dy = 0.1 * bb.height
1291        ax_13 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
1292        # Lower colorbar -- position relative to ax_22
1293        bb = ax_22.get_position()
1294        dy = 0.1 * bb.height
1295        ax_23 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
1297        # Extract data arrays
1298        xx, _, mu, spread, err = stats.xx, stats.yy, stats.mu, stats.spread, stats.err
1299        k = key0[0]
1300        tt = stats.HMM.tseq.tt
1302        # Plot
1303        # - origin='lower' might get overturned by set_ylim() below.
1304        im_11 = ax_11.imshow(square(mu[key0]), cmap=cm)
1305        im_12 = ax_12.imshow(square(xx[k]), cmap=cm)
1306        # hot is better, but needs +1 colorbar
1307        im_21 = ax_21.imshow(square(spread[key0]), cmap=plt.cm.bwr)
1308        im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr)
1309        ims = (im_11, im_12, im_21, im_22)
1310        # Obs init -- a list where item 0 is the handle of something invisible.
1311        lh = list(ax_12.plot(0, 0)[0:1])
1313        sx = "$\\psi$"
1314        ax_11.set_title("mean " + sx)
1315        ax_12.set_title("true " + sx)
1316        ax_21.set_title("spread. " + sx)
1317        ax_22.set_title("err. " + sx)
1319        # TODO 7
1320        # for ax in axs.flatten():
1321        # Crop boundries (which should be 0, i.e. yield harsh q gradients):
1322        # lims = (1, nx-2)
1323        # step = (nx - 1)/8
1324        # ticks = arange(step,nx-1,step)
1325        # ax.set_xlim  (lims)
1326        # ax.set_ylim  (lims[::-1])
1327        # ax.set_xticks(ticks)
1328        # ax.set_yticks(ticks)
1330        for im, clim in zip(ims, clims):
1331            im.set_clim(clim)
1333        fig.colorbar(im_12, cax=ax_13)
1334        fig.colorbar(im_22, cax=ax_23)
1335        for ax in [ax_13, ax_23]:
1336            ax.yaxis.set_tick_params(
1337                "major", length=2, width=0.5, direction="in", left=True, right=True
1338            )
1339            ax.set_axisbelow("line")  # make ticks appear over colorbar patch
1341        # Title
1342        title = "Streamfunction (" + sx + ")"
1343        fig.suptitle(title)
1344        # Time info
1345        text_t = ax_12.text(
1346            1,
1347            1.1,
1348            format_time(None, None, None),
1349            transform=ax_12.transAxes,
1350            family="monospace",
1351            ha="left",
1352        )
1354        def update(key, E, P):
1355            k, ko, faus = key
1356            t = tt[k]
1358            im_11.set_data(square(mu[key]))
1359            im_12.set_data(square(xx[k]))
1360            im_21.set_data(square(spread[key]))
1361            im_22.set_data(square(err[key]))
1363            # Remove previous obs
1364            try:
1365                lh[0].remove()
1366            except ValueError:
1367                pass
1368            # Plot current obs.
1369            #  - plot() automatically adjusts to direction of y-axis in use.
1370            #  - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse.
1372            if ko is not None and not_empty(obs_inds):
1373                lh[0] = ax_12.plot(*ind2sub(obs_inds(ko))[::-1], "k.", ms=1, zorder=5)[
1374                    0
1375                ]
1377            text_t.set_text(format_time(k, ko, t))
1379            return
1381        return update
1383    return init
1386# List of liveplotters available for all HMMs.
1387# Columns:
1388# - fignum
1389# - show_by_default
1390# - function/class
1391default_liveplotters = [
1392    (1, sliding_diagnostics),
1393    (1, weight_histogram),
class LivePlot:
 45class LivePlot:
 46    """Live plotting manager.
 48    Deals with
 50    - Pause, skip.
 51    - Which liveploters to call.
 52    - `plot_u`
 53    - Figure window (title and number).
 54    """
 56    def __init__(
 57        self,
 58        stats,
 59        liveplots,
 60        key0=(0, None, "u"),
 61        E=None,
 62        P=None,
 63        speed=1.0,
 64        replay=False,
 65        **kwargs,
 66    ):
 67        """
 68        Initialize plots.
 70        - liveplots: figures to plot; alternatives:
 71            - `"default"/[]/True`: All default figures for this HMM.
 72            - `"all"`            : Even more.
 73            - non-empty `list`   : Only the figures with these numbers
 74                                 (int) or names (str).
 75            - `False`            : None.
 76        - speed: speed of animation.
 77            - `>100`: instantaneous
 78            - `1`   : (default) as quick as possible allowing for
 79                      plt.draw() to work on a moderately fast computer.
 80            - `<1`  : slower.
 81        """
 82        # Disable if not rc.liveplotting
 83        self.any_figs = False
 84        if rc.liveplotting:
 85            pass
 86        elif replay and np.isinf(speed):
 87            pass
 88        else:
 89            return
 91        # Determine whether all/universal/intermediate stats are plotted
 92        self.plot_u = not replay or stats.store_u
 94        # Set speed/pause params
 95        self.params = {
 96            "pause_f": 0.05,
 97            "pause_a": 0.05,
 98            "pause_s": 0.05,
 99            "pause_u": 0.001,
100        }
101        # If speed>100: set to inf. Coz pause=1e-99 causes hangup.
102        for pause in ["pause_" + x for x in "faus"]:
103            speed = speed if speed < 100 else np.inf
104            self.params[pause] /= speed
106        # Write params
107        self.params.update(getattr(stats.xp, "LP_kwargs", {}))
108        self.params.update(kwargs)
110        def get_name(init):
111            """Get name of liveplotter function/class."""
112            try:
113                return init.__qualname__.split(".")[0]
114            except AttributeError:
115                return init.__class__.__name__
117        # Set up dict of liveplotters
118        potential_LPs = {}
119        for show, init in default_liveplotters:
120            potential_LPs[get_name(init)] = show, init
121        # Add HMM-specific liveplotters
122        for show, init in getattr(stats.HMM, "liveplotters", {}):
123            potential_LPs[get_name(init)] = show, init
125        def parse_figlist(lst):
126            """Figures requested for this xp. Convert to list."""
127            if isinstance(lst, str):
128                fn = lst.lower()
129                if "all" == fn:
130                    lst = ["all"]  # All potential_LPs
131                elif "default" in fn:
132                    lst = ["default"]  # All show_by_default
133            elif hasattr(lst, "__len__"):
134                lst = lst  # This list (only)
135            elif lst:
136                lst = ["default"]  # All show_by_default
137            else:
138                lst = [None]  # None
139            return lst
141        figlist = parse_figlist(liveplots)
143        # Loop over requeted figures
144        self.figures = {}
145        for name, (show_by_default, init) in potential_LPs.items():
146            if (
147                (figlist == ["all"])
148                or (name in figlist)
149                or (figlist == ["default"] and show_by_default)
150            ):
151                # Startup message
152                if not self.any_figs:
153                    print("Initializing liveplots...")
154                    if is_notebook_or_qt:
155                        pauses = [self.params["pause_" + x] for x in "faus"]
156                        if any((p > 0) for p in pauses):
157                            print(
158                                "Note: liveplotting does not work very well"
159                                " inside Jupyter notebooks. In particular,"
160                                " there is no way to stop/skip them except"
161                                " to interrupt the kernel (the stop button"
162                                " in the toolbar). Consider using instead"
163                                " only the replay functionality (with infinite"
164                                " playback speed)."
165                            )
166                    elif not pb.disable_user_interaction:
167                        print("Hit <Space> to pause/step.")
168                        print("Hit <Enter> to resume/skip.")
169                        print("Hit <i> to enter debug mode.")
170                    self.paused = False
171                    self.run_ipdb = False
172                    self.skipping = False
173                    self.any_figs = True
175                # Init figure
176                post_title = "" if self.plot_u else "\n(obs times only)"
177                updater = init(name, stats, key0, self.plot_u, E, P, **kwargs)
178                if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
179                    self.figures[name] = updater
180                    fig = plt.figure(name)
181                    win = fig.canvas
182                    ax0 = fig.axes[0]
183                    win.manager.set_window_title(str(name))
184                    ax0.set_title(ax0.get_title() + post_title)
185                    self.update(key0, E, P)  # Call initial update
186                    if not (replay and np.isinf(speed)):
187                        plt.pause(0.01)  # Draw
189    def update(self, key, E, P):
190        """Update liveplots"""
191        # Check if there are still open figures
192        if self.any_figs:
193            open_figns = plt.get_figlabels()
194            live_figns = set(self.figures.keys())
195            self.any_figs = bool(live_figns.intersection(open_figns))
196        else:
197            return
199        # Playback control
200        SPACE = b" "
201        CHAR_I = b"i"
202        ENTERs = [b"\n", b"\r"]  # Linux + Windows
204        def pause():
205            """Loop until user decision is made."""
206            ch = read1()
207            while True:
208                # Set state (pause, skipping, ipdb)
209                if ch in ENTERs:
210                    self.paused = False
211                elif ch == CHAR_I:
212                    self.run_ipdb = True
213                # If keypress valid, resume execution
214                if ch in ENTERs + [SPACE, CHAR_I]:
215                    break
216                ch = read1()
217                # Pause to enable zoom, pan, etc. of mpl GUI
218                plot_pause(0.01)  # Don't use time.sleep()!
220        # Enter pause loop
221        if self.paused:
222            pause()
224        else:
225            if key == (0, None, "u"):
226                # Skip read1 for key0 (coz it blocks)
227                pass
228            else:
229                ch = read1()
230                if ch == SPACE:
231                    # Pause
232                    self.paused = True
233                    self.skipping = False
234                    pause()
235                elif ch in ENTERs:
236                    # Toggle skipping
237                    self.skipping = not self.skipping
238                elif ch == CHAR_I:
239                    # Schedule debug
240                    # Note: The reason we dont set_trace(frame) right here is:
241                    # - I could not find the right frame, even doing
242                    #   >   frame = inspect.stack()[0]
243                    #   >   while frame.f_code.co_name != "assimilate":
244                    #   >       frame = frame.f_back
245                    # - It just restarts the plot.
246                    self.run_ipdb = True
248        # Update figures
249        if not self.skipping:
250            faus = key[-1]
251            if faus != "u" or self.plot_u:
252                for name, (updater) in self.figures.items():
253                    if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
254                        _ = plt.figure(name)
255                        updater(key, E, P)
256                        plot_pause(self.params["pause_" + faus])
258        if self.run_ipdb:
259            self.run_ipdb = False
260            import inspect
262            import ipdb
264            print("Entering debug mode (ipdb).")
265            print("Type '?' (and Enter) for usage help.")
266            print("Type 'c' to continue the assimilation.")
267            ipdb.set_trace(inspect.stack()[2].frame)

Live plotting manager.

Deals with

  • Pause, skip.
  • Which liveploters to call.
  • plot_u
  • Figure window (title and number).
LivePlot( stats, liveplots, key0=(0, None, 'u'), E=None, P=None, speed=1.0, replay=False, **kwargs)
 56    def __init__(
 57        self,
 58        stats,
 59        liveplots,
 60        key0=(0, None, "u"),
 61        E=None,
 62        P=None,
 63        speed=1.0,
 64        replay=False,
 65        **kwargs,
 66    ):
 67        """
 68        Initialize plots.
 70        - liveplots: figures to plot; alternatives:
 71            - `"default"/[]/True`: All default figures for this HMM.
 72            - `"all"`            : Even more.
 73            - non-empty `list`   : Only the figures with these numbers
 74                                 (int) or names (str).
 75            - `False`            : None.
 76        - speed: speed of animation.
 77            - `>100`: instantaneous
 78            - `1`   : (default) as quick as possible allowing for
 79                      plt.draw() to work on a moderately fast computer.
 80            - `<1`  : slower.
 81        """
 82        # Disable if not rc.liveplotting
 83        self.any_figs = False
 84        if rc.liveplotting:
 85            pass
 86        elif replay and np.isinf(speed):
 87            pass
 88        else:
 89            return
 91        # Determine whether all/universal/intermediate stats are plotted
 92        self.plot_u = not replay or stats.store_u
 94        # Set speed/pause params
 95        self.params = {
 96            "pause_f": 0.05,
 97            "pause_a": 0.05,
 98            "pause_s": 0.05,
 99            "pause_u": 0.001,
100        }
101        # If speed>100: set to inf. Coz pause=1e-99 causes hangup.
102        for pause in ["pause_" + x for x in "faus"]:
103            speed = speed if speed < 100 else np.inf
104            self.params[pause] /= speed
106        # Write params
107        self.params.update(getattr(stats.xp, "LP_kwargs", {}))
108        self.params.update(kwargs)
110        def get_name(init):
111            """Get name of liveplotter function/class."""
112            try:
113                return init.__qualname__.split(".")[0]
114            except AttributeError:
115                return init.__class__.__name__
117        # Set up dict of liveplotters
118        potential_LPs = {}
119        for show, init in default_liveplotters:
120            potential_LPs[get_name(init)] = show, init
121        # Add HMM-specific liveplotters
122        for show, init in getattr(stats.HMM, "liveplotters", {}):
123            potential_LPs[get_name(init)] = show, init
125        def parse_figlist(lst):
126            """Figures requested for this xp. Convert to list."""
127            if isinstance(lst, str):
128                fn = lst.lower()
129                if "all" == fn:
130                    lst = ["all"]  # All potential_LPs
131                elif "default" in fn:
132                    lst = ["default"]  # All show_by_default
133            elif hasattr(lst, "__len__"):
134                lst = lst  # This list (only)
135            elif lst:
136                lst = ["default"]  # All show_by_default
137            else:
138                lst = [None]  # None
139            return lst
141        figlist = parse_figlist(liveplots)
143        # Loop over requeted figures
144        self.figures = {}
145        for name, (show_by_default, init) in potential_LPs.items():
146            if (
147                (figlist == ["all"])
148                or (name in figlist)
149                or (figlist == ["default"] and show_by_default)
150            ):
151                # Startup message
152                if not self.any_figs:
153                    print("Initializing liveplots...")
154                    if is_notebook_or_qt:
155                        pauses = [self.params["pause_" + x] for x in "faus"]
156                        if any((p > 0) for p in pauses):
157                            print(
158                                "Note: liveplotting does not work very well"
159                                " inside Jupyter notebooks. In particular,"
160                                " there is no way to stop/skip them except"
161                                " to interrupt the kernel (the stop button"
162                                " in the toolbar). Consider using instead"
163                                " only the replay functionality (with infinite"
164                                " playback speed)."
165                            )
166                    elif not pb.disable_user_interaction:
167                        print("Hit <Space> to pause/step.")
168                        print("Hit <Enter> to resume/skip.")
169                        print("Hit <i> to enter debug mode.")
170                    self.paused = False
171                    self.run_ipdb = False
172                    self.skipping = False
173                    self.any_figs = True
175                # Init figure
176                post_title = "" if self.plot_u else "\n(obs times only)"
177                updater = init(name, stats, key0, self.plot_u, E, P, **kwargs)
178                if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
179                    self.figures[name] = updater
180                    fig = plt.figure(name)
181                    win = fig.canvas
182                    ax0 = fig.axes[0]
183                    win.manager.set_window_title(str(name))
184                    ax0.set_title(ax0.get_title() + post_title)
185                    self.update(key0, E, P)  # Call initial update
186                    if not (replay and np.isinf(speed)):
187                        plt.pause(0.01)  # Draw

Initialize plots.

  • liveplots: figures to plot; alternatives:
    • "default"/[]/True: All default figures for this HMM.
    • "all" : Even more.
    • non-empty list : Only the figures with these numbers (int) or names (str).
    • False : None.
  • speed: speed of animation.
    • >100: instantaneous
    • 1 : (default) as quick as possible allowing for plt.draw() to work on a moderately fast computer.
    • <1 : slower.
def update(self, key, E, P):
189    def update(self, key, E, P):
190        """Update liveplots"""
191        # Check if there are still open figures
192        if self.any_figs:
193            open_figns = plt.get_figlabels()
194            live_figns = set(self.figures.keys())
195            self.any_figs = bool(live_figns.intersection(open_figns))
196        else:
197            return
199        # Playback control
200        SPACE = b" "
201        CHAR_I = b"i"
202        ENTERs = [b"\n", b"\r"]  # Linux + Windows
204        def pause():
205            """Loop until user decision is made."""
206            ch = read1()
207            while True:
208                # Set state (pause, skipping, ipdb)
209                if ch in ENTERs:
210                    self.paused = False
211                elif ch == CHAR_I:
212                    self.run_ipdb = True
213                # If keypress valid, resume execution
214                if ch in ENTERs + [SPACE, CHAR_I]:
215                    break
216                ch = read1()
217                # Pause to enable zoom, pan, etc. of mpl GUI
218                plot_pause(0.01)  # Don't use time.sleep()!
220        # Enter pause loop
221        if self.paused:
222            pause()
224        else:
225            if key == (0, None, "u"):
226                # Skip read1 for key0 (coz it blocks)
227                pass
228            else:
229                ch = read1()
230                if ch == SPACE:
231                    # Pause
232                    self.paused = True
233                    self.skipping = False
234                    pause()
235                elif ch in ENTERs:
236                    # Toggle skipping
237                    self.skipping = not self.skipping
238                elif ch == CHAR_I:
239                    # Schedule debug
240                    # Note: The reason we dont set_trace(frame) right here is:
241                    # - I could not find the right frame, even doing
242                    #   >   frame = inspect.stack()[0]
243                    #   >   while frame.f_code.co_name != "assimilate":
244                    #   >       frame = frame.f_back
245                    # - It just restarts the plot.
246                    self.run_ipdb = True
248        # Update figures
249        if not self.skipping:
250            faus = key[-1]
251            if faus != "u" or self.plot_u:
252                for name, (updater) in self.figures.items():
253                    if plt.fignum_exists(name) and getattr(updater, "is_active", 1):
254                        _ = plt.figure(name)
255                        updater(key, E, P)
256                        plot_pause(self.params["pause_" + faus])
258        if self.run_ipdb:
259            self.run_ipdb = False
260            import inspect
262            import ipdb
264            print("Entering debug mode (ipdb).")
265            print("Type '?' (and Enter) for usage help.")
266            print("Type 'c' to continue the assimilation.")
267            ipdb.set_trace(inspect.stack()[2].frame)

Update liveplots

star = '${}^*$'
class sliding_diagnostics:
275class sliding_diagnostics:
276    """Plots a sliding window (like a heart rate monitor) of certain diagnostics."""
278    def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs):
279        # STYLE TABLES - Defines which/how diagnostics get plotted
280        styles = {}
282        def lin(a, b):
283            return lambda x: a + b * x
285        divN = 1 / getattr(stats.xp, "N", 99)
286        # Columns: transf, shape, plt kwargs
287        styles["RMS"] = {
288            "err.rms": [None, None, dict(c="k", label="Error")],
289            "spread.rms": [None, None, dict(c="b", label="Spread", alpha=0.6)],
290        }
291        styles["Values"] = {
292            "skew": [None, None, dict(c="g", label=star + r"Skew/$\sigma^3$")],
293            "kurt": [None, None, dict(c="r", label=star + r"Kurt$/\sigma^4{-}3$")],
294            "trHK": [None, None, dict(c="k", label=star + "HK")],
295            "infl": [lin(-10, 10), "step", dict(c="c", label="10(infl-1)")],
296            "N_eff": [lin(0, divN), "dirac", dict(c="y", label="N_eff/N", lw=3)],
297            "iters": [lin(0, 0.1), "dirac", dict(c="m", label="iters/10")],
298            "resmpl": [None, "dirac", dict(c="k", label="resampled?")],
299        }
301        nAx = len(styles)
302        GS = {"left": 0.125, "right": 0.76}
303        fig, axs = place.freshfig(
304            fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS
305        )
307        axs[0].set_title("Diagnostics")
308        for style, ax in zip(styles, axs):
309            ax.set_ylabel(style)
310        ax.set_xlabel("Time (t)")
311        place_ax.adjust_position(ax, y0=0.03)
313        self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq)
315        def init_ax(ax, style_table):
316            lines = {}
317            for name in style_table:
318                # SKIP -- if stats[name] is not in existence
319                # Note: The nan check/deletion comes after the first ko.
320                try:
321                    stat = deep_getattr(stats, name)
322                except AttributeError:
323                    continue
324                # try: val0 = stat[key0[0]]
325                # except KeyError: continue
326                # PS: recall (from series.py) that even if store_u is false, stat[k] is
327                # still present if liveplots=True via the k_tmp functionality.
329                # Unpack style
330                ln = {}
331                ln["transf"] = style_table[name][0] or (lambda x: x)
332                ln["shape"] = style_table[name][1]
333                ln["plt"] = style_table[name][2]
335                # Create series
336                if isinstance(stat, FAUSt):
337                    ln["plot_u"] = plot_u
338                    K_plot = comp_K_plot(K_lag, a_lag, ln["plot_u"])
339                else:
340                    ln["plot_u"] = False
341                    K_plot = a_lag
342                ln["data"] = RollingArray(K_plot)
343                ln["tt"] = RollingArray(K_plot)
345                # Plot (init)
346                (ln["handle"],) = ax.plot(ln["tt"], ln["data"], **ln["plt"])
348                # Plotting only nans yield ugly limits. Revert to defaults.
349                ax.set_xlim(0, 1)
350                ax.set_ylim(0, 1)
352                lines[name] = ln
353            return lines
355        # Plot
356        self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)]
358        # Horizontal line at y=0
359        (self.baseline0,) = ax.plot(
360            ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label="_nolegend_"
361        )
363        # Store
364        self.axs = axs
365        self.stats = stats
366        self.init_incomplete = True
368    # Update plot
369    def __call__(self, key, E, P):
370        k, ko, faus = key
372        stats = self.stats
373        tseq = stats.HMM.tseq
374        ax0, ax1 = self.axs
376        def update_arrays(lines):
377            for name, ln in lines.items():
378                stat = deep_getattr(stats, name)
379                t = tseq.tt[k]  # == tseq.tto[ko]
380                if isinstance(stat, FAUSt):
381                    # ln['data'] will contain duplicates for f/a times.
382                    if ln["plot_u"]:
383                        val = stat[key]
384                        ln["tt"].insert(k, t)
385                        ln["data"].insert(k, ln["transf"](val))
386                    elif "u" not in faus:
387                        val = stat[key]
388                        ln["tt"].insert(ko, t)
389                        ln["data"].insert(ko, ln["transf"](val))
390                else:
391                    # ln['data'] will not contain duplicates, coz only 'a' is input.
392                    if "a" in faus:
393                        val = stat[ko]
394                        ln["tt"].insert(ko, t)
395                        ln["data"].insert(ko, ln["transf"](val))
396                    elif "f" in faus:
397                        pass
399        def update_plot_data(ax, lines):
400            def bend_into(shape, xx, yy):
401                # Get arrays. Repeat (to use for intermediate nodes).
402                yy = yy.array.repeat(3)
403                xx = xx.array.repeat(3)
404                if len(xx) == 0:
405                    pass  # shortcircuit any modifications
406                elif shape == "step":
407                    yy = np.hstack([yy[1:], nan])  # roll leftward
408                elif shape == "dirac":
409                    nonlocal nDirac
410                    axW = np.diff(ax.get_xlim())
411                    yy[0::3] = False  # set datapoin to 0
412                    xx[2::3] = nan  # make datapoint disappear
413                    xx += nDirac * axW / 100  # offset datapoint horizontally
414                    nDirac += 1
415                return xx, yy
417            nDirac = 1
418            for _name, ln in lines.items():
419                ln["handle"].set_data(*bend_into(ln["shape"], ln["tt"], ln["data"]))
421        def finalize_init(ax, lines, mm):
422            # Rm lines that only contain NaNs
423            for name in list(lines):
424                ln = lines[name]
425                stat = deep_getattr(stats, name)
426                if not stat.were_changed:
427                    ln["handle"].remove()  # rm from axes
428                    del lines[name]  # rm from dict
429            # Add legends
430            if lines:
431                ax.legend(loc="upper left", bbox_to_anchor=(1.01, 1), borderaxespad=0)
432                if mm:
433                    ax.annotate(
434                        star + ": mean of\nmarginals",
435                        xy=(0, -1.5 / len(lines)),
436                        xycoords=ax.get_legend().get_frame(),
437                        bbox=dict(alpha=0.0),
438                        fontsize="small",
439                    )
440            # coz placement of annotate needs flush sometimes:
441            plot_pause(0.01)
443        # Insert current stats
444        for lines, ax in zip(self.d, self.axs):
445            update_arrays(lines)
446            update_plot_data(ax, lines)
448        # Set x-limits (time)
449        sliding_xlim(ax0, self.d[0]["err.rms"]["tt"], self.T_lag, margin=True)
450        self.baseline0.set_xdata(ax0.get_xlim())
452        # Set y-limits
453        data0 = [ln["data"].array for ln in self.d[0].values()]
454        data1 = [ln["data"].array for ln in self.d[1].values()]
455        ax0.set_ylim(0, d_ylim(data0, ax0, cC=0.2, cE=0.9)[1])
456        ax1.set_ylim(*d_ylim(data1, ax1, Max=4, Min=-4, cC=0.3, cE=0.9))
458        # Init legend. Rm nan lines.
459        if self.init_incomplete and "a" == faus:
460            self.init_incomplete = False
461            finalize_init(ax0, self.d[0], False)
462            finalize_init(ax1, self.d[1], True)

Plots a sliding window (like a heart rate monitor) of certain diagnostics.

sliding_diagnostics(fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs)
278    def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs):
279        # STYLE TABLES - Defines which/how diagnostics get plotted
280        styles = {}
282        def lin(a, b):
283            return lambda x: a + b * x
285        divN = 1 / getattr(stats.xp, "N", 99)
286        # Columns: transf, shape, plt kwargs
287        styles["RMS"] = {
288            "err.rms": [None, None, dict(c="k", label="Error")],
289            "spread.rms": [None, None, dict(c="b", label="Spread", alpha=0.6)],
290        }
291        styles["Values"] = {
292            "skew": [None, None, dict(c="g", label=star + r"Skew/$\sigma^3$")],
293            "kurt": [None, None, dict(c="r", label=star + r"Kurt$/\sigma^4{-}3$")],
294            "trHK": [None, None, dict(c="k", label=star + "HK")],
295            "infl": [lin(-10, 10), "step", dict(c="c", label="10(infl-1)")],
296            "N_eff": [lin(0, divN), "dirac", dict(c="y", label="N_eff/N", lw=3)],
297            "iters": [lin(0, 0.1), "dirac", dict(c="m", label="iters/10")],
298            "resmpl": [None, "dirac", dict(c="k", label="resampled?")],
299        }
301        nAx = len(styles)
302        GS = {"left": 0.125, "right": 0.76}
303        fig, axs = place.freshfig(
304            fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS
305        )
307        axs[0].set_title("Diagnostics")
308        for style, ax in zip(styles, axs):
309            ax.set_ylabel(style)
310        ax.set_xlabel("Time (t)")
311        place_ax.adjust_position(ax, y0=0.03)
313        self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq)
315        def init_ax(ax, style_table):
316            lines = {}
317            for name in style_table:
318                # SKIP -- if stats[name] is not in existence
319                # Note: The nan check/deletion comes after the first ko.
320                try:
321                    stat = deep_getattr(stats, name)
322                except AttributeError:
323                    continue
324                # try: val0 = stat[key0[0]]
325                # except KeyError: continue
326                # PS: recall (from series.py) that even if store_u is false, stat[k] is
327                # still present if liveplots=True via the k_tmp functionality.
329                # Unpack style
330                ln = {}
331                ln["transf"] = style_table[name][0] or (lambda x: x)
332                ln["shape"] = style_table[name][1]
333                ln["plt"] = style_table[name][2]
335                # Create series
336                if isinstance(stat, FAUSt):
337                    ln["plot_u"] = plot_u
338                    K_plot = comp_K_plot(K_lag, a_lag, ln["plot_u"])
339                else:
340                    ln["plot_u"] = False
341                    K_plot = a_lag
342                ln["data"] = RollingArray(K_plot)
343                ln["tt"] = RollingArray(K_plot)
345                # Plot (init)
346                (ln["handle"],) = ax.plot(ln["tt"], ln["data"], **ln["plt"])
348                # Plotting only nans yield ugly limits. Revert to defaults.
349                ax.set_xlim(0, 1)
350                ax.set_ylim(0, 1)
352                lines[name] = ln
353            return lines
355        # Plot
356        self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)]
358        # Horizontal line at y=0
359        (self.baseline0,) = ax.plot(
360            ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label="_nolegend_"
361        )
363        # Store
364        self.axs = axs
365        self.stats = stats
366        self.init_incomplete = True
def sliding_xlim(ax, tt, lag, margin=False):
465def sliding_xlim(ax, tt, lag, margin=False):
466    dt = lag / 20 if margin else 0
467    if tt.nFilled == 0:
468        return  # Quit
469    t1, t2 = tt.span()  # Get suggested span.
470    s1, s2 = ax.get_xlim()  # Get previous lims.
471    # If zero span (eg tt holds single 'f' and 'a'):
472    if t1 == t2:
473        t1 -= 1  # add width
474        t2 += 1  # add width
475    # If user has skipped (too much):
476    elif np.isnan(t1):
477        s2 -= dt  # Correct for dt.
478        span = s2 - s1  # Compute previous span
479        # If span<lag:
480        if span < lag:
481            span += t2 - s2  # Grow by "dt".
482        span = min(lag, span)  # Bound
483        t1 = t2 - span  # Set span.
484    ax.set_xlim(t1, t2 + dt)  # Set xlim to span
class weight_histogram:
487class weight_histogram:
488    """Plots histogram of weights. Refreshed each analysis."""
490    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
491        if not hasattr(stats, "w"):
492            self.is_active = False
493            return
494        fig, ax = place.freshfig(fignum, figsize=(7, 3), gridspec_kw={"bottom": 0.15})
496        ax.set_xscale("log")
497        ax.set_xlabel("Weigth")
498        ax.set_ylabel("Count")
499        self.stats = stats
500        self.ax = ax
501        self.hist = []
502        self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
504    def __call__(self, key, E, P):
505        k, ko, faus = key
506        if "a" == faus:
507            w = self.stats.w[key]
508            N = len(w)
509            ax = self.ax
511            self.is_active = N < 10001
512            if not self.is_active:
513                not_available_text(ax, "Not computed (N > threshold)")
514                return
516            counted = w > self.bins[0]
517            _ = [b.remove() for b in self.hist]
518            nn, _, self.hist = ax.hist(w[counted], bins=self.bins, color="b")
519            ax.set_ylim(top=max(nn))
521            ax.set_title(
522                f"N: {N:d}.   N_eff: {1/(w@w):.4g}."
523                "   Not shown: {N-np.sum(counted):d}. "
524            )

Plots histogram of weights. Refreshed each analysis.

weight_histogram(fignum, stats, key0, plot_u, E, P, **kwargs)
490    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
491        if not hasattr(stats, "w"):
492            self.is_active = False
493            return
494        fig, ax = place.freshfig(fignum, figsize=(7, 3), gridspec_kw={"bottom": 0.15})
496        ax.set_xscale("log")
497        ax.set_xlabel("Weigth")
498        ax.set_ylabel("Count")
499        self.stats = stats
500        self.ax = ax
501        self.hist = []
502        self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
class spectral_errors:
527class spectral_errors:
528    """Plots the (spatial-RMS) error as a functional of the SVD index."""
530    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
531        fig, ax = place.freshfig(fignum, figsize=(6, 3))
532        ax.set_xlabel("Sing. value index")
533        ax.set_yscale("log")
534        self.init_incomplete = True
535        self.ax = ax
536        self.plot_u = plot_u
538        try:
539            self.msft = stats.umisf
540            self.sprd = stats.svals
541        except AttributeError:
542            self.is_active = False
543            not_available_text(ax, "Spectral stats not being computed")
545    # Update plot
546    def __call__(self, key, E, P):
547        k, ko, faus = key
548        ax = self.ax
549        if self.init_incomplete:
550            if self.plot_u or "f" == faus:
551                self.init_incomplete = False
552                msft = abs(self.msft[key])
553                sprd = self.sprd[key]
554                if np.any(np.isinf(msft)):
555                    not_available_text(ax, "Spectral stats not finite")
556                    self.is_active = False
557                else:
558                    (self.line_msft,) = ax.plot(msft, "k", lw=2, label="Error")
559                    (self.line_sprd,) = ax.plot(
560                        sprd, "b", lw=2, label="Spread", alpha=0.9
561                    )
562                    ax.get_xaxis().set_major_locator(MaxNLocator(integer=True))
563                    ax.legend()
564        else:
565            msft = abs(self.msft[key])
566            sprd = self.sprd[key]
567            self.line_sprd.set_ydata(sprd)
568            self.line_msft.set_ydata(msft)
569        # ax.set_ylim(*d_ylim(msft))
570        # ax.set_ylim(bottom=1e-5)
571        ax.set_ylim([1e-3, 1e1])

Plots the (spatial-RMS) error as a functional of the SVD index.

spectral_errors(fignum, stats, key0, plot_u, E, P, **kwargs)
530    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
531        fig, ax = place.freshfig(fignum, figsize=(6, 3))
532        ax.set_xlabel("Sing. value index")
533        ax.set_yscale("log")
534        self.init_incomplete = True
535        self.ax = ax
536        self.plot_u = plot_u
538        try:
539            self.msft = stats.umisf
540            self.sprd = stats.svals
541        except AttributeError:
542            self.is_active = False
543            not_available_text(ax, "Spectral stats not being computed")
class correlations:
574class correlations:
575    """Plots the state (auto-)correlation matrix."""
577    half = True  # Whether to show half/full (symmetric) corr matrix.
579    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
580        GS = {"height_ratios": [4, 1], "hspace": 0.09, "top": 0.95}
581        fig, (ax, ax2) = place.freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS)
583        if E is None and np.isnan(P.diag if isinstance(P, CovMat) else P).all():
584            not_available_text(
585                ax, ("Not available in replays" "\ncoz full Ens/Cov not stored.")
586            )
587            self.is_active = False
588            return
590        Nx = len(stats.mu[key0])
591        if Nx <= 1003:
592            C = np.eye(Nx)
593            # Mask half
594            mask = np.zeros_like(C, dtype=bool)
595            mask[np.tril_indices_from(mask)] = True
596            cmap = plt.get_cmap("RdBu_r")
597            VM = 1.0  # abs(np.percentile(C,[1,99])).max()
598            im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM)
599            # Colorbar
600            _ = ax.figure.colorbar(im, ax=ax, shrink=0.8)
601            # Tune plot
602            plt.box(False)
603            ax.set_facecolor("w")
604            ax.grid(False)
605            ax.set_title("State correlation matrix:", y=1.07)
606            ax.xaxis.tick_top()
608            # ax2 = inset_axes(ax,width="30%",height="60%",loc=3)
609            (line_AC,) = ax2.plot(arange(Nx), ones(Nx), label="Correlation")
610            (line_AA,) = ax2.plot(arange(Nx), ones(Nx), label="Abs. corr.")
611            _ = ax2.hlines(0, 0, Nx - 1, "k", "dotted", lw=1)
612            # Align ax2 with ax
613            bb_AC = ax2.get_position()
614            bb_C = ax.get_position()
615            ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height])
616            # Tune plot
617            ax2.set_title("Auto-correlation:")
618            ax2.set_ylabel("Mean value")
619            ax2.set_xlabel("Distance (in state indices)")
620            ax2.set_xticklabels([])
621            ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]]))
622            ax2.set_ylim(top=1)
623            ax2.legend(
624                frameon=True,
625                facecolor="w",
626                bbox_to_anchor=(1, 1),
627                loc="upper left",
628                borderaxespad=0.02,
629            )
631            self.ax = ax
632            self.ax2 = ax2
633            self.im = im
634            self.line_AC = line_AC
635            self.line_AA = line_AA
636            self.mask = mask
637            if hasattr(stats, "w"):
638                self.w = stats.w
639        else:
640            not_available_text(ax)
642    # Update plot
643    def __call__(self, key, E, P):
644        # Get cov matrix
645        if E is not None:
646            if hasattr(self, "w"):
647                C = np.cov(E, rowvar=False, aweights=self.w[key])
648            else:
649                C = np.cov(E, rowvar=False)
650        else:
651            assert P is not None
652            C = P.full if isinstance(P, CovMat) else P
653            C = C.copy()
654        # Compute corr from cov
655        std = np.sqrt(np.diag(C))
656        C /= std[:, None]
657        C /= std[None, :]
658        # Mask
659        if self.half:
660            C = np.ma.masked_where(self.mask, C)
661        # Plot
662        self.im.set_data(C)
663        # Auto-corr function
664        ACF = circulant_ACF(C)
665        AAF = circulant_ACF(C, do_abs=True)
666        self.line_AC.set_ydata(ACF)
667        self.line_AA.set_ydata(AAF)

Plots the state (auto-)correlation matrix.

correlations(fignum, stats, key0, plot_u, E, P, **kwargs)
579    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
580        GS = {"height_ratios": [4, 1], "hspace": 0.09, "top": 0.95}
581        fig, (ax, ax2) = place.freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS)
583        if E is None and np.isnan(P.diag if isinstance(P, CovMat) else P).all():
584            not_available_text(
585                ax, ("Not available in replays" "\ncoz full Ens/Cov not stored.")
586            )
587            self.is_active = False
588            return
590        Nx = len(stats.mu[key0])
591        if Nx <= 1003:
592            C = np.eye(Nx)
593            # Mask half
594            mask = np.zeros_like(C, dtype=bool)
595            mask[np.tril_indices_from(mask)] = True
596            cmap = plt.get_cmap("RdBu_r")
597            VM = 1.0  # abs(np.percentile(C,[1,99])).max()
598            im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM)
599            # Colorbar
600            _ = ax.figure.colorbar(im, ax=ax, shrink=0.8)
601            # Tune plot
602            plt.box(False)
603            ax.set_facecolor("w")
604            ax.grid(False)
605            ax.set_title("State correlation matrix:", y=1.07)
606            ax.xaxis.tick_top()
608            # ax2 = inset_axes(ax,width="30%",height="60%",loc=3)
609            (line_AC,) = ax2.plot(arange(Nx), ones(Nx), label="Correlation")
610            (line_AA,) = ax2.plot(arange(Nx), ones(Nx), label="Abs. corr.")
611            _ = ax2.hlines(0, 0, Nx - 1, "k", "dotted", lw=1)
612            # Align ax2 with ax
613            bb_AC = ax2.get_position()
614            bb_C = ax.get_position()
615            ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height])
616            # Tune plot
617            ax2.set_title("Auto-correlation:")
618            ax2.set_ylabel("Mean value")
619            ax2.set_xlabel("Distance (in state indices)")
620            ax2.set_xticklabels([])
621            ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]]))
622            ax2.set_ylim(top=1)
623            ax2.legend(
624                frameon=True,
625                facecolor="w",
626                bbox_to_anchor=(1, 1),
627                loc="upper left",
628                borderaxespad=0.02,
629            )
631            self.ax = ax
632            self.ax2 = ax2
633            self.im = im
634            self.line_AC = line_AC
635            self.line_AA = line_AA
636            self.mask = mask
637            if hasattr(stats, "w"):
638                self.w = stats.w
639        else:
640            not_available_text(ax)
half = True
def circulant_ACF(C, do_abs=False):
670def circulant_ACF(C, do_abs=False):
671    """Compute the auto-covariance-function corresponding to `C`.
673    This assumes it is the cov/corr matrix of a 1D periodic domain.
675    Vectorized or FFT implementations are
676    [possible](https://stackoverflow.com/questions/20360675).
677    """
678    M = len(C)
679    # cols = np.flipud(sla.circulant(np.arange(M)[::-1]))
680    cols = sla.circulant(np.arange(M))
681    ACF = np.zeros(M)
682    for i in range(M):
683        row = C[i, cols[i]]
684        if do_abs:
685            row = abs(row)
686        ACF += row
687        # Note: this actually also accesses masked values in C.
688    return ACF / M

Compute the auto-covariance-function corresponding to C.

This assumes it is the cov/corr matrix of a 1D periodic domain.

Vectorized or FFT implementations are possible.

def sliding_marginals( obs_inds=(), dims=(), labels=(), Tplot=None, ens_props={'alpha': 0.4}, zoomy=1.0):
691def sliding_marginals(
692    obs_inds=(),
693    dims=(),
694    labels=(),
695    Tplot=None,
696    ens_props=dict(alpha=0.4),  # noqa
697    zoomy=1.0,
699    # Store parameters
700    params_orig = DotDict(**locals())
702    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
703        xx, yy, mu, spread, tseq = (
704            stats.xx,
705            stats.yy,
706            stats.mu,
707            stats.spread,
708            stats.HMM.tseq,
709        )
711        # Set parameters (kwargs takes precedence over params_orig)
712        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
714        # Chose marginal dims to plot
715        if not len(p.dims):
716            p.dims = linspace_int(xx.shape[-1], min(10, xx.shape[-1]))
718        # Lag settings:
719        T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
720        K_plot = comp_K_plot(K_lag, a_lag, plot_u)
721        # Extend K_plot forther for adding blanks in resampling (PartFilt):
722        has_w = hasattr(stats, "w")
723        if has_w:
724            K_plot += a_lag
726        # Set up figure, axes
727        fig, axs = place.freshfig(
728            fignum, figsize=(5, 7), squeeze=False, nrows=len(p.dims), sharex=True
729        )
730        axs = axs.reshape(len(p.dims))
732        # Tune plots
733        axs[0].set_title("Marginal time series")
734        for ix, (m, ax) in enumerate(zip(p.dims, axs)):
735            # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy))
736            if not p.labels:
737                ax.set_ylabel("$x_{%d}$" % m)
738            else:
739                ax.set_ylabel(p.labels[ix])
740        axs[-1].set_xlabel("Time (t)")
742        plot_pause(0.05)
743        plt.tight_layout()
745        # Allocate
746        d = DotDict()  # data arrays
747        h = DotDict()  # plot handles
748        # Why "if True" ? Just to indent the rest of the line...
749        if True:
750            d.t = RollingArray((K_plot,))
751        if True:
752            d.x = RollingArray((K_plot, len(p.dims)))
753            h.x = []
754        if not_empty(p.obs_inds):
755            d.y = RollingArray((K_plot, len(p.dims)))
756            h.y = []
757        if E is not None:
758            d.E = RollingArray((K_plot, len(E), len(p.dims)))
759            h.E = []
760        if P is not None:
761            d.mu = RollingArray((K_plot, len(p.dims)))
762            h.mu = []
763        if P is not None:
764            d.s = RollingArray((K_plot, 2, len(p.dims)))
765            h.s = []
767        # Plot (invisible coz everything here is nan, for the moment).
768        for ix, ax in zip(p.dims, axs):
769            if True:
770                h.x += ax.plot(d.t, d.x[:, ix], "k")
771            if not_empty(p.obs_inds):
772                h.y += ax.plot(d.t, d.y[:, ix], "g*", ms=10)
773            if "E" in d:
774                h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)]
775            if "mu" in d:
776                h.mu += ax.plot(d.t, d.mu[:, ix], "b")
777            if "s" in d:
778                h.s += [ax.plot(d.t, d.s[:, :, ix], "b--", lw=1)]
780        def update(key, E, P):
781            k, ko, faus = key
783            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)
785            # Roll data array
786            ind = k if plot_u else ko
787            for Ens in EE:  # If E is duplicated, so must the others be.
788                if "E" in d:
789                    d.E.insert(ind, Ens)
790                if "mu" in d:
791                    d.mu.insert(ind, mu[key][p.dims])
792                if "s" in d:
793                    d.s.insert(ind, mu[key][p.dims] + [[1], [-1]] * spread[key][p.dims])
794                if True:
795                    d.t.insert(ind, tseq.tt[k])
796                if not_empty(p.obs_inds):
797                    xy = nan * ones(len(p.dims))
798                    if ko is not None:
799                        jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
800                        xy[jj] = yy[ko]
801                    d.y.insert(ind, xy)
802                if True:
803                    d.x.insert(ind, xx[k, p.dims])
805            # Update graphs
806            for ix, ax in zip(p.dims, axs):
807                sliding_xlim(ax, d.t, T_lag, True)
808                if True:
809                    h.x[ix].set_data(d.t, d.x[:, ix])
810                if not_empty(p.obs_inds):
811                    h.y[ix].set_data(d.t, d.y[:, ix])
812                if "mu" in d:
813                    h.mu[ix].set_data(d.t, d.mu[:, ix])
814                if "s" in d:
815                    [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]]
816                if "E" in d:
817                    [h.E[ix][n].set_data(d.t, d.E[:, n, ix]) for n in range(len(E))]
818                if "E" in d:
819                    update_alpha(key, stats, h.E[ix])
821                # TODO 3: fixup. This might be slow?
822                # In any case, it is very far from tested.
823                # Also, relim'iting all of the time is distracting.
824                # Use d_ylim?
825                if "E" in d:
826                    lims = d.E
827                elif "mu" in d:
828                    lims = d.mu
829                lims = np.array(viz.xtrema(lims[..., ix]))
830                if lims[0] == lims[1]:
831                    lims += [-0.5, +0.5]
832                ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy))
834            return
836        return update
838    return init
def phase_particles( is_3d=True, obs_inds=(), dims=(), labels=(), Tplot=None, ens_props={'alpha': 0.4}, zoom=1.5):
841def phase_particles(
842    is_3d=True,
843    obs_inds=(),
844    dims=(),
845    labels=(),
846    Tplot=None,
847    ens_props=dict(alpha=0.4),  # noqa
848    zoom=1.5,
850    # Store parameters
851    params_orig = DotDict(**locals())
853    M = 3 if is_3d else 2
855    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
856        xx, yy, mu, _, tseq = stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq
858        # Set parameters (kwargs takes precedence over params_orig)
859        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
861        # Lag settings:
862        has_w = hasattr(stats, "w")
863        if p.Tplot == 0:
864            K_plot = 1
865        else:
866            T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
867            K_plot = comp_K_plot(K_lag, a_lag, plot_u)
868            # Extend K_plot forther for adding blanks in resampling (PartFilt):
869            if has_w:
870                K_plot += a_lag
872        # Dimension settings
873        if not p.dims:
874            p.dims = arange(M)
875        if not p.labels:
876            p.labels = ["$x_%d$" % d for d in p.dims]
877        assert len(p.dims) == M
879        # Set up figure, axes
880        fig, _ = place.freshfig(fignum, figsize=(5, 5))
881        ax = plt.subplot(111, projection="3d" if is_3d else None)
882        ax.set_facecolor("w")
883        ax.set_title("Phase space trajectories")
884        # Tune plot
885        for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")):
886            viz.set_ilim(ax, ind, *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom))
887            eval(f"ax.set_{t}label('{s!s}')")
889        # Allocate
890        d = DotDict()  # data arrays
891        h = DotDict()  # plot handles
892        s = DotDict()  # scatter handles
893        if E is not None:
894            d.E = RollingArray((K_plot, len(E), M))
895            h.E = []
896        if P is not None:
897            d.mu = RollingArray((K_plot, M))
898        if True:
899            d.x = RollingArray((K_plot, M))
900        if not_empty(p.obs_inds):
901            d.y = RollingArray((K_plot, M))
903        # Plot tails (invisible coz everything here is nan, for the moment).
904        if "E" in d:
905            h.E += [
906                ax.plot(*xn, **p.ens_props)[0] for xn in np.transpose(d.E, [1, 2, 0])
907            ]
908        if "mu" in d:
909            h.mu = ax.plot(*d.mu.T, "b", lw=2)[0]
910        if True:
911            h.x = ax.plot(*d.x.T, "k", lw=3)[0]
912        if "y" in d:
913            h.y = ax.plot(*d.y.T, "g*", ms=14)[0]
915        # Scatter. NB: don't init with nan's coz it's buggy
916        # (wrt. get_color() and _offsets3d) since mpl 3.1.
917        if "E" in d:
918            s.E = ax.scatter(*E.T[p.dims], s=3**2, c=[hn.get_color() for hn in h.E])
919        if "mu" in d:
920            s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()])
921        if True:
922            s.x = ax.scatter(
923                *ones(M), s=14**2, c=[h.x.get_color()], marker=(5, 1), zorder=99
924            )
926        def update(key, E, P):
927            k, ko, faus = key
929            def update_tail(handle, newdata):
930                handle.set_data(newdata[:, 0], newdata[:, 1])
931                if is_3d:
932                    handle.set_3d_properties(newdata[:, 2])
934            def update_sctr(handle, newdata):
935                if is_3d:
936                    handle._offsets3d = juggle_axes(*newdata.T, "z")
937                else:
938                    handle.set_offsets(newdata)
940            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)
942            # Roll data array
943            ind = k if plot_u else ko
944            for Ens in EE:  # If E is duplicated, so must the others be.
945                if "E" in d:
946                    d.E.insert(ind, Ens)
947                if True:
948                    d.x.insert(ind, xx[k, p.dims])
949                if "y" in d:
950                    xy = nan * ones(len(p.dims))
951                    if ko is not None:
952                        jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
953                        jj = list(jj)
954                        for i, dim in enumerate(p.dims):
955                            try:
956                                iobs = jj.index(dim)
957                            except ValueError:
958                                pass
959                            else:
960                                xy[i] = yy[ko][iobs]
961                    d.y.insert(ind, xy)
962                if "mu" in d:
963                    d.mu.insert(ind, mu[key][p.dims])
965            # Update graph
966            update_sctr(s.x, d.x[[-1]])
967            update_tail(h.x, d.x)
968            if "y" in d:
969                update_tail(h.y, d.y)
970            if "mu" in d:
971                update_sctr(s.mu, d.mu[[-1]])
972                update_tail(h.mu, d.mu)
973            else:
974                update_sctr(s.E, d.E[-1])
975                for n in range(len(E)):
976                    update_tail(h.E[n], d.E[:, n, :])
977                update_alpha(key, stats, h.E, s.E)
979            return
981        return update
983    return init
def validate_lag(Tplot, tseq):
 986def validate_lag(Tplot, tseq):
 987    """Return validated `T_lag` such that is is:
 989    - equal to `Tplot` with fallback: `HMM.tseq.Tplot`.
 990    - no longer than `HMM.tseq.T`.
 992    Also return corresponding `K_lag`, `a_lag`.
 993    """
 994    # Defaults
 995    if Tplot is None:
 996        Tplot = tseq.Tplot
 998    # Rename
 999    T_lag = Tplot
1001    assert T_lag >= 0
1003    # Validate T_lag
1004    t2 = tseq.tt[-1]
1005    t1 = max(tseq.tt[0], t2 - T_lag)
1006    T_lag = t2 - t1
1008    K_lag = int(T_lag / tseq.dt) + 1  # Lag in indices
1009    a_lag = K_lag // tseq.dko + 1  # Lag in obs indices
1011    return T_lag, K_lag, a_lag

Return validated T_lag such that is is:

  • equal to Tplot with fallback: HMM.tseq.Tplot.
  • no longer than HMM.tseq.T.

Also return corresponding K_lag, a_lag.

def comp_K_plot(K_lag, a_lag, plot_u):
1014def comp_K_plot(K_lag, a_lag, plot_u):
1015    K_plot = 2 * a_lag  # Sum of lags of {f,a} series.
1016    if plot_u:
1017        K_plot += K_lag  # Add lag of u series.
1018    return K_plot
def update_alpha(key, stats, lines, scatters=None):
1021def update_alpha(key, stats, lines, scatters=None):
1022    """Adjust color alpha (for particle filters)."""
1023    k, ko, faus = key
1024    if ko is None:
1025        return
1026    if faus == "f":
1027        return
1028    if not hasattr(stats, "w"):
1029        return
1031    # Compute alpha values
1032    w = stats.w[key]
1033    alpha = (w / w.max()).clip(0.1, 0.4)
1035    # Set line alpha
1036    for line, a in zip(lines, alpha):
1037        line.set_alpha(a)
1039    # Scatter plot does not have alpha. => Fake it.
1040    if scatters is not None:
1041        colors = scatters.get_facecolor()[:, :3]
1042        if len(colors) == 1:
1043            colors = colors.repeat(len(w), axis=0)
1044        scatters.set_color(np.hstack([colors, alpha[:, None]]))

Adjust color alpha (for particle filters).

def not_empty(xx):
1047def not_empty(xx):
1048    """Works for non-iterable and iterables (including ndarrays)."""
1049    try:
1050        return len(xx) > 0
1051    except TypeError:
1052        return bool(xx)

Works for non-iterable and iterables (including ndarrays).

def duplicate_with_blanks_for_resampled(E, dims, key, has_w):
1055def duplicate_with_blanks_for_resampled(E, dims, key, has_w):
1056    """Particle filter: insert breaks for resampled particles."""
1057    if E is None:
1058        return [E]
1059    EE = []
1060    E = E[:, dims]
1061    if has_w:
1062        k, ko, faus = key
1063        if faus == "f":
1064            pass
1065        elif faus == "a":
1066            _Ea[0] = E[:, 0]  # Store (1st dim of) ens.
1067        elif faus == "u" and ko is not None:
1068            # Find resampled particles. Insert duplicate ensemble. Write nans (breaks).
1069            resampled = _Ea[0] != E[:, 0]  # Mark as resampled if ens changed.
1070            # Insert current ensemble (copy to avoid overwriting).
1071            EE.append(E.copy())
1072            EE[0][resampled] = nan  # Write breaks
1073    # Always: append current ensemble
1074    EE.append(E)
1075    return EE

Particle filter: insert breaks for resampled particles.

def d_ylim(data, ax=None, cC=0, cE=1, pp=(1, 99), Min=-1e+20, Max=1e+20):
1081def d_ylim(data, ax=None, cC=0, cE=1, pp=(1, 99), Min=-1e20, Max=+1e20):
1082    """Provide new ylim's intelligently, from percentiles of the data.
1084    - `data`: iterable of arrays for computing percentiles.
1085    - `pp`: percentiles
1087    - `ax`: If present, then the delta_zoom in/out is also considered.
1089      - `cE`: exansion (widenting) rate ∈ [0,1].
1090        Default: 1, which immediately expands to percentile.
1091      - `cC`: compression (narrowing) rate ∈ [0,1].
1092        Default: 0, which does not allow compression.
1094    - `Min`/`Max`: bounds
1096    Despite being a little involved,
1097    the cost of this subroutine is typically not substantial
1098    because there's usually not that much data to sort through.
1099    """
1100    # Find "reasonable" limits (by percentiles), looping over data
1101    maxv = minv = -np.inf  # init
1102    for d in data:
1103        d = d[np.isfinite(d)]
1104        if len(d):
1105            perc = np.array([-1, 1]) * np.percentile(d, pp)
1106            minv, maxv = np.maximum([minv, maxv], perc)
1107    minv *= -1
1109    # Pry apart equal values
1110    if np.isclose(minv, maxv):
1111        maxv += 0.5
1112        minv -= 0.5
1114    # Make the zooming transition smooth
1115    if ax is not None:
1116        current = ax.get_ylim()
1117        # Set rate factor as compress or expand factor.
1118        c0 = cC if minv > current[0] else cE
1119        c1 = cC if maxv < current[1] else cE
1120        # Adjust
1121        minv = np.interp(c0, (0, 1), (current[0], minv))
1122        maxv = np.interp(c1, (0, 1), (current[1], maxv))
1124    # Bounds
1125    maxv = min(Max, maxv)
1126    minv = max(Min, minv)
1128    # Set (if anything's changed)
1129    def worth_updating(a, b, curr):
1130        # Note: should depend on cC and cE
1131        d = abs(curr[1] - curr[0])
1132        lower = abs(a - curr[0]) > 0.002 * d
1133        upper = abs(b - curr[1]) > 0.002 * d
1134        return lower and upper
1136    # if worth_updating(minv,maxv,current):
1137    # ax.set_ylim(minv,maxv)
1139    # Some mpl versions don't handle inf limits.
1140    if not np.isfinite(minv):
1141        minv = None
1142    if not np.isfinite(maxv):
1143        maxv = None
1145    return minv, maxv

Provide new ylim's intelligently, from percentiles of the data.

  • data: iterable of arrays for computing percentiles.
  • pp: percentiles

  • ax: If present, then the delta_zoom in/out is also considered.

    • cE: exansion (widenting) rate ∈ [0,1]. Default: 1, which immediately expands to percentile.
    • cC: compression (narrowing) rate ∈ [0,1]. Default: 0, which does not allow compression.
  • Min/Max: bounds

Despite being a little involved, the cost of this subroutine is typically not substantial because there's usually not that much data to sort through.

def spatial1d( obs_inds=(), periodicity=None, dims=(), ens_props={'color': 'b', 'alpha': 0.1}, conf_mult=None):
1148def spatial1d(
1149    obs_inds=(),
1150    periodicity=None,
1151    dims=(),
1152    ens_props={"color": "b", "alpha": 0.1},  # noqa
1153    conf_mult=None,
1155    # Store parameters
1156    params_orig = DotDict(**locals())
1158    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
1159        xx, yy, mu = stats.xx, stats.yy, stats.mu
1161        # Set parameters (kwargs takes precedence over params_orig)
1162        p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()})
1164        if not p.dims:
1165            M = xx.shape[-1]
1166            p.dims = arange(M)
1167        else:
1168            M = len(p.dims)
1170        # Make periodic wrapper
1171        ii, wrap = viz.setup_wrapping(M, p.periodicity)
1173        # Set up figure, axes
1174        fig, ax = place.freshfig(fignum, figsize=(8, 5))
1175        fig.suptitle("1d amplitude plot")
1177        # Nans
1178        nan1 = wrap(nan * ones(M))
1180        if E is None and p.conf_mult is None:
1181            p.conf_mult = 2
1183        # Init plots
1184        if p.conf_mult:
1185            lines_s = ax.plot(
1186                ii, nan1, "b-", lw=1, label=(str(p.conf_mult) + r"$\sigma$ conf")
1187            )
1188            lines_s += ax.plot(ii, nan1, "b-", lw=1)
1189            (line_mu,) = ax.plot(ii, nan1, "b-", lw=2, label="DA mean")
1190        else:
1191            nanE = nan * ones((stats.xp.N, M))
1192            lines_E = ax.plot(ii, wrap(nanE[0]), **p.ens_props, lw=1, label="Ensemble")
1193            lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1)
1194        # Truth, Obs
1195        (line_x,) = ax.plot(ii, nan1, "k-", lw=3, label="Truth")
1196        if not_empty(p.obs_inds):
1197            (line_y,) = ax.plot(ii, nan1, "g*", ms=5, label="Obs")
1199        # Tune plot
1200        ax.set_ylim(*viz.xtrema(xx))
1201        ax.set_xlim(viz.stretch(ii[0], ii[-1], 1))
1202        # Xticks
1203        xt = ax.get_xticks()
1204        xt = xt[abs(xt % 1) < 0.01].astype(int)  # Keep only the integer ticks
1205        xt = xt[xt >= 0]
1206        xt = xt[xt < len(p.dims)]
1207        ax.set_xticks(xt)
1208        ax.set_xticklabels(p.dims[xt])
1210        ax.set_xlabel("State index")
1211        ax.set_ylabel("Value")
1212        ax.legend(loc="upper right")
1214        text_t = ax.text(
1215            0.01,
1216            0.01,
1217            format_time(None, None, None),
1218            transform=ax.transAxes,
1219            family="monospace",
1220            ha="left",
1221        )
1223        def update(key, E, P):
1224            k, ko, faus = key
1226            if p.conf_mult:
1227                sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]]
1228                lines_s[0].set_ydata(wrap(sigma[0, p.dims]))
1229                lines_s[1].set_ydata(wrap(sigma[1, p.dims]))
1230                line_mu.set_ydata(wrap(mu[key][p.dims]))
1231            else:
1232                for n, line in enumerate(lines_E):
1233                    line.set_ydata(wrap(E[n, p.dims]))
1234                update_alpha(key, stats, lines_E)
1236            line_x.set_ydata(wrap(xx[k, p.dims]))
1238            text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k]))
1240            if "f" in faus:
1241                if not_empty(p.obs_inds):
1242                    xy = nan * ones(len(xx[0]))
1243                    jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds
1244                    xy[jj] = yy[ko]
1245                    line_y.set_ydata(wrap(xy[p.dims]))
1246                    line_y.set_zorder(5)
1247                    line_y.set_visible(True)
1249            if "u" in faus:
1250                if not_empty(p.obs_inds):
1251                    line_y.set_visible(False)
1253            return
1255        return update
1257    return init
def spatial2d( square, ind2sub, obs_inds=(), cm=<matplotlib.colors.LinearSegmentedColormap object>, clims=((-40, 40), (-40, 40), (-10, 10), (-10, 10))):
1260def spatial2d(
1261    square,
1262    ind2sub,
1263    obs_inds=(),
1264    cm=plt.cm.jet,
1265    clims=((-40, 40), (-40, 40), (-10, 10), (-10, 10)),
1267    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
1268        GS = {"left": 0.125 - 0.04, "right": 0.9 - 0.04}
1269        fig, axs = place.freshfig(
1270            fignum,
1271            figsize=(6, 6),
1272            nrows=2,
1273            ncols=2,
1274            sharex=True,
1275            sharey=True,
1276            gridspec_kw=GS,
1277        )
1279        for ax in axs.flatten():
1280            ax.set_aspect("equal", "box")
1282        ((ax_11, ax_12), (ax_21, ax_22)) = axs
1284        ax_11.grid(color="w", linewidth=0.2)
1285        ax_12.grid(color="w", linewidth=0.2)
1286        ax_21.grid(color="k", linewidth=0.1)
1287        ax_22.grid(color="k", linewidth=0.1)
1289        # Upper colorbar -- position relative to ax_12
1290        bb = ax_12.get_position()
1291        dy = 0.1 * bb.height
1292        ax_13 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
1293        # Lower colorbar -- position relative to ax_22
1294        bb = ax_22.get_position()
1295        dy = 0.1 * bb.height
1296        ax_23 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
1298        # Extract data arrays
1299        xx, _, mu, spread, err = stats.xx, stats.yy, stats.mu, stats.spread, stats.err
1300        k = key0[0]
1301        tt = stats.HMM.tseq.tt
1303        # Plot
1304        # - origin='lower' might get overturned by set_ylim() below.
1305        im_11 = ax_11.imshow(square(mu[key0]), cmap=cm)
1306        im_12 = ax_12.imshow(square(xx[k]), cmap=cm)
1307        # hot is better, but needs +1 colorbar
1308        im_21 = ax_21.imshow(square(spread[key0]), cmap=plt.cm.bwr)
1309        im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr)
1310        ims = (im_11, im_12, im_21, im_22)
1311        # Obs init -- a list where item 0 is the handle of something invisible.
1312        lh = list(ax_12.plot(0, 0)[0:1])
1314        sx = "$\\psi$"
1315        ax_11.set_title("mean " + sx)
1316        ax_12.set_title("true " + sx)
1317        ax_21.set_title("spread. " + sx)
1318        ax_22.set_title("err. " + sx)
1320        # TODO 7
1321        # for ax in axs.flatten():
1322        # Crop boundries (which should be 0, i.e. yield harsh q gradients):
1323        # lims = (1, nx-2)
1324        # step = (nx - 1)/8
1325        # ticks = arange(step,nx-1,step)
1326        # ax.set_xlim  (lims)
1327        # ax.set_ylim  (lims[::-1])
1328        # ax.set_xticks(ticks)
1329        # ax.set_yticks(ticks)
1331        for im, clim in zip(ims, clims):
1332            im.set_clim(clim)
1334        fig.colorbar(im_12, cax=ax_13)
1335        fig.colorbar(im_22, cax=ax_23)
1336        for ax in [ax_13, ax_23]:
1337            ax.yaxis.set_tick_params(
1338                "major", length=2, width=0.5, direction="in", left=True, right=True
1339            )
1340            ax.set_axisbelow("line")  # make ticks appear over colorbar patch
1342        # Title
1343        title = "Streamfunction (" + sx + ")"
1344        fig.suptitle(title)
1345        # Time info
1346        text_t = ax_12.text(
1347            1,
1348            1.1,
1349            format_time(None, None, None),
1350            transform=ax_12.transAxes,
1351            family="monospace",
1352            ha="left",
1353        )
1355        def update(key, E, P):
1356            k, ko, faus = key
1357            t = tt[k]
1359            im_11.set_data(square(mu[key]))
1360            im_12.set_data(square(xx[k]))
1361            im_21.set_data(square(spread[key]))
1362            im_22.set_data(square(err[key]))
1364            # Remove previous obs
1365            try:
1366                lh[0].remove()
1367            except ValueError:
1368                pass
1369            # Plot current obs.
1370            #  - plot() automatically adjusts to direction of y-axis in use.
1371            #  - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse.
1373            if ko is not None and not_empty(obs_inds):
1374                lh[0] = ax_12.plot(*ind2sub(obs_inds(ko))[::-1], "k.", ms=1, zorder=5)[
1375                    0
1376                ]
1378            text_t.set_text(format_time(k, ko, t))
1380            return
1382        return update
1384    return init
default_liveplotters = [(1, <class 'sliding_diagnostics'>), (1, <class 'weight_histogram'>)]