dapper.tools.liveplotting
On-line (live) plots of the DA process for various models and methods.
Liveplotters are given by a list of tuples as property or arguments in
dapper.mods.HiddenMarkovModel
.
The first element of the tuple determines whether the liveplotter is shown if the names of liveplotters are not given by
liveplots
argument inassimilate
.The second element in the tuple gives the corresponding liveplotter function/class. See example of function
LPs
indapper.mods.Lorenz63
.
The liveplotters can be fine-tuned by each DA experiments via argument of
liveplots
when calling assimilate
.
liveplots = True
turns on liveplotters set to default in the first argument of theHMM.liveplotter
and default liveplotters defined in this module (sliding_diagnostics
andweight_histogram
).liveplots
can also be a list of specified names of liveplotter, which is the name of the corresponding liveplotting classes/functions.
1"""On-line (live) plots of the DA process for various models and methods. 2 3Liveplotters are given by a list of tuples as property or arguments in 4`dapper.mods.HiddenMarkovModel`. 5 6- The first element of the tuple determines whether the liveplotter is shown if 7the names of liveplotters are not given by `liveplots` argument in 8`assimilate`. 9 10- The second element in the tuple gives the corresponding liveplotter 11function/class. See example of function `LPs` in `dapper.mods.Lorenz63`. 12 13The liveplotters can be fine-tuned by each DA experiments via argument of 14`liveplots` when calling `assimilate`. 15 16- `liveplots = True` turns on liveplotters set to default in the first 17argument of the `HMM.liveplotter` and default liveplotters defined in this module 18(`sliding_diagnostics` and `weight_histogram`). 19 20- `liveplots` can also be a list of specified names of liveplotter, which 21is the name of the corresponding liveplotting classes/functions. 22""" 23 24import numpy as np 25import scipy.linalg as sla 26from matplotlib import pyplot as plt 27from matplotlib.ticker import MaxNLocator 28from mpl_toolkits.mplot3d.art3d import juggle_axes 29from mpl_tools import is_notebook_or_qt, place, place_ax 30from numpy import arange, nan, ones 31from struct_tools import DotDict, deep_getattr 32 33import dapper.tools.progressbar as pb 34import dapper.tools.viz as viz 35from dapper.dpr_config import rc 36from dapper.mods.utils import linspace_int 37from dapper.tools.chronos import format_time 38from dapper.tools.matrices import CovMat 39from dapper.tools.progressbar import read1 40from dapper.tools.series import FAUSt, RollingArray 41from dapper.tools.viz import not_available_text, plot_pause 42 43 44class LivePlot: 45 """Live plotting manager. 46 47 Deals with 48 49 - Pause, skip. 50 - Which liveploters to call. 51 - `plot_u` 52 - Figure window (title and number). 53 """ 54 55 def __init__( 56 self, 57 stats, 58 liveplots, 59 key0=(0, None, "u"), 60 E=None, 61 P=None, 62 speed=1.0, 63 replay=False, 64 **kwargs, 65 ): 66 """ 67 Initialize plots. 68 69 - liveplots: figures to plot; alternatives: 70 - `"default"/[]/True`: All default figures for this HMM. 71 - `"all"` : Even more. 72 - non-empty `list` : Only the figures with these numbers 73 (int) or names (str). 74 - `False` : None. 75 - speed: speed of animation. 76 - `>100`: instantaneous 77 - `1` : (default) as quick as possible allowing for 78 plt.draw() to work on a moderately fast computer. 79 - `<1` : slower. 80 """ 81 # Disable if not rc.liveplotting 82 self.any_figs = False 83 if rc.liveplotting: 84 pass 85 elif replay and np.isinf(speed): 86 pass 87 else: 88 return 89 90 # Determine whether all/universal/intermediate stats are plotted 91 self.plot_u = not replay or stats.store_u 92 93 # Set speed/pause params 94 self.params = { 95 "pause_f": 0.05, 96 "pause_a": 0.05, 97 "pause_s": 0.05, 98 "pause_u": 0.001, 99 } 100 # If speed>100: set to inf. Coz pause=1e-99 causes hangup. 101 for pause in ["pause_" + x for x in "faus"]: 102 speed = speed if speed < 100 else np.inf 103 self.params[pause] /= speed 104 105 # Write params 106 self.params.update(getattr(stats.xp, "LP_kwargs", {})) 107 self.params.update(kwargs) 108 109 def get_name(init): 110 """Get name of liveplotter function/class.""" 111 try: 112 return init.__qualname__.split(".")[0] 113 except AttributeError: 114 return init.__class__.__name__ 115 116 # Set up dict of liveplotters 117 potential_LPs = {} 118 for show, init in default_liveplotters: 119 potential_LPs[get_name(init)] = show, init 120 # Add HMM-specific liveplotters 121 for show, init in getattr(stats.HMM, "liveplotters", {}): 122 potential_LPs[get_name(init)] = show, init 123 124 def parse_figlist(lst): 125 """Figures requested for this xp. Convert to list.""" 126 if isinstance(lst, str): 127 fn = lst.lower() 128 if "all" == fn: 129 lst = ["all"] # All potential_LPs 130 elif "default" in fn: 131 lst = ["default"] # All show_by_default 132 elif hasattr(lst, "__len__"): 133 lst = lst # This list (only) 134 elif lst: 135 lst = ["default"] # All show_by_default 136 else: 137 lst = [None] # None 138 return lst 139 140 figlist = parse_figlist(liveplots) 141 142 # Loop over requeted figures 143 self.figures = {} 144 for name, (show_by_default, init) in potential_LPs.items(): 145 if ( 146 (figlist == ["all"]) 147 or (name in figlist) 148 or (figlist == ["default"] and show_by_default) 149 ): 150 # Startup message 151 if not self.any_figs: 152 print("Initializing liveplots...") 153 if is_notebook_or_qt: 154 pauses = [self.params["pause_" + x] for x in "faus"] 155 if any((p > 0) for p in pauses): 156 print( 157 "Note: liveplotting does not work very well" 158 " inside Jupyter notebooks. In particular," 159 " there is no way to stop/skip them except" 160 " to interrupt the kernel (the stop button" 161 " in the toolbar). Consider using instead" 162 " only the replay functionality (with infinite" 163 " playback speed)." 164 ) 165 elif not pb.disable_user_interaction: 166 print("Hit <Space> to pause/step.") 167 print("Hit <Enter> to resume/skip.") 168 print("Hit <i> to enter debug mode.") 169 self.paused = False 170 self.run_ipdb = False 171 self.skipping = False 172 self.any_figs = True 173 174 # Init figure 175 post_title = "" if self.plot_u else "\n(obs times only)" 176 updater = init(name, stats, key0, self.plot_u, E, P, **kwargs) 177 if plt.fignum_exists(name) and getattr(updater, "is_active", 1): 178 self.figures[name] = updater 179 fig = plt.figure(name) 180 win = fig.canvas 181 ax0 = fig.axes[0] 182 win.manager.set_window_title(str(name)) 183 ax0.set_title(ax0.get_title() + post_title) 184 self.update(key0, E, P) # Call initial update 185 if not (replay and np.isinf(speed)): 186 plt.pause(0.01) # Draw 187 188 def update(self, key, E, P): 189 """Update liveplots""" 190 # Check if there are still open figures 191 if self.any_figs: 192 open_figns = plt.get_figlabels() 193 live_figns = set(self.figures.keys()) 194 self.any_figs = bool(live_figns.intersection(open_figns)) 195 else: 196 return 197 198 # Playback control 199 SPACE = b" " 200 CHAR_I = b"i" 201 ENTERs = [b"\n", b"\r"] # Linux + Windows 202 203 def pause(): 204 """Loop until user decision is made.""" 205 ch = read1() 206 while True: 207 # Set state (pause, skipping, ipdb) 208 if ch in ENTERs: 209 self.paused = False 210 elif ch == CHAR_I: 211 self.run_ipdb = True 212 # If keypress valid, resume execution 213 if ch in ENTERs + [SPACE, CHAR_I]: 214 break 215 ch = read1() 216 # Pause to enable zoom, pan, etc. of mpl GUI 217 plot_pause(0.01) # Don't use time.sleep()! 218 219 # Enter pause loop 220 if self.paused: 221 pause() 222 223 else: 224 if key == (0, None, "u"): 225 # Skip read1 for key0 (coz it blocks) 226 pass 227 else: 228 ch = read1() 229 if ch == SPACE: 230 # Pause 231 self.paused = True 232 self.skipping = False 233 pause() 234 elif ch in ENTERs: 235 # Toggle skipping 236 self.skipping = not self.skipping 237 elif ch == CHAR_I: 238 # Schedule debug 239 # Note: The reason we dont set_trace(frame) right here is: 240 # - I could not find the right frame, even doing 241 # > frame = inspect.stack()[0] 242 # > while frame.f_code.co_name != "assimilate": 243 # > frame = frame.f_back 244 # - It just restarts the plot. 245 self.run_ipdb = True 246 247 # Update figures 248 if not self.skipping: 249 faus = key[-1] 250 if faus != "u" or self.plot_u: 251 for name, (updater) in self.figures.items(): 252 if plt.fignum_exists(name) and getattr(updater, "is_active", 1): 253 _ = plt.figure(name) 254 updater(key, E, P) 255 plot_pause(self.params["pause_" + faus]) 256 257 if self.run_ipdb: 258 self.run_ipdb = False 259 import inspect 260 261 import ipdb 262 263 print("Entering debug mode (ipdb).") 264 print("Type '?' (and Enter) for usage help.") 265 print("Type 'c' to continue the assimilation.") 266 ipdb.set_trace(inspect.stack()[2].frame) 267 268 269# TODO 6: 270# - iEnKS diagnostics don't work at all when store_u=False 271star = "${}^*$" 272 273 274class sliding_diagnostics: 275 """Plots a sliding window (like a heart rate monitor) of certain diagnostics.""" 276 277 def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs): 278 # STYLE TABLES - Defines which/how diagnostics get plotted 279 styles = {} 280 281 def lin(a, b): 282 return lambda x: a + b * x 283 284 divN = 1 / getattr(stats.xp, "N", 99) 285 # Columns: transf, shape, plt kwargs 286 styles["RMS"] = { 287 "err.rms": [None, None, dict(c="k", label="Error")], 288 "spread.rms": [None, None, dict(c="b", label="Spread", alpha=0.6)], 289 } 290 styles["Values"] = { 291 "skew": [None, None, dict(c="g", label=star + r"Skew/$\sigma^3$")], 292 "kurt": [None, None, dict(c="r", label=star + r"Kurt$/\sigma^4{-}3$")], 293 "trHK": [None, None, dict(c="k", label=star + "HK")], 294 "infl": [lin(-10, 10), "step", dict(c="c", label="10(infl-1)")], 295 "N_eff": [lin(0, divN), "dirac", dict(c="y", label="N_eff/N", lw=3)], 296 "iters": [lin(0, 0.1), "dirac", dict(c="m", label="iters/10")], 297 "resmpl": [None, "dirac", dict(c="k", label="resampled?")], 298 } 299 300 nAx = len(styles) 301 GS = {"left": 0.125, "right": 0.76} 302 fig, axs = place.freshfig( 303 fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS 304 ) 305 306 axs[0].set_title("Diagnostics") 307 for style, ax in zip(styles, axs): 308 ax.set_ylabel(style) 309 ax.set_xlabel("Time (t)") 310 place_ax.adjust_position(ax, y0=0.03) 311 312 self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq) 313 314 def init_ax(ax, style_table): 315 lines = {} 316 for name in style_table: 317 # SKIP -- if stats[name] is not in existence 318 # Note: The nan check/deletion comes after the first ko. 319 try: 320 stat = deep_getattr(stats, name) 321 except AttributeError: 322 continue 323 # try: val0 = stat[key0[0]] 324 # except KeyError: continue 325 # PS: recall (from series.py) that even if store_u is false, stat[k] is 326 # still present if liveplots=True via the k_tmp functionality. 327 328 # Unpack style 329 ln = {} 330 ln["transf"] = style_table[name][0] or (lambda x: x) 331 ln["shape"] = style_table[name][1] 332 ln["plt"] = style_table[name][2] 333 334 # Create series 335 if isinstance(stat, FAUSt): 336 ln["plot_u"] = plot_u 337 K_plot = comp_K_plot(K_lag, a_lag, ln["plot_u"]) 338 else: 339 ln["plot_u"] = False 340 K_plot = a_lag 341 ln["data"] = RollingArray(K_plot) 342 ln["tt"] = RollingArray(K_plot) 343 344 # Plot (init) 345 (ln["handle"],) = ax.plot(ln["tt"], ln["data"], **ln["plt"]) 346 347 # Plotting only nans yield ugly limits. Revert to defaults. 348 ax.set_xlim(0, 1) 349 ax.set_ylim(0, 1) 350 351 lines[name] = ln 352 return lines 353 354 # Plot 355 self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)] 356 357 # Horizontal line at y=0 358 (self.baseline0,) = ax.plot( 359 ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label="_nolegend_" 360 ) 361 362 # Store 363 self.axs = axs 364 self.stats = stats 365 self.init_incomplete = True 366 367 # Update plot 368 def __call__(self, key, E, P): 369 k, ko, faus = key 370 371 stats = self.stats 372 tseq = stats.HMM.tseq 373 ax0, ax1 = self.axs 374 375 def update_arrays(lines): 376 for name, ln in lines.items(): 377 stat = deep_getattr(stats, name) 378 t = tseq.tt[k] # == tseq.tto[ko] 379 if isinstance(stat, FAUSt): 380 # ln['data'] will contain duplicates for f/a times. 381 if ln["plot_u"]: 382 val = stat[key] 383 ln["tt"].insert(k, t) 384 ln["data"].insert(k, ln["transf"](val)) 385 elif "u" not in faus: 386 val = stat[key] 387 ln["tt"].insert(ko, t) 388 ln["data"].insert(ko, ln["transf"](val)) 389 else: 390 # ln['data'] will not contain duplicates, coz only 'a' is input. 391 if "a" in faus: 392 val = stat[ko] 393 ln["tt"].insert(ko, t) 394 ln["data"].insert(ko, ln["transf"](val)) 395 elif "f" in faus: 396 pass 397 398 def update_plot_data(ax, lines): 399 def bend_into(shape, xx, yy): 400 # Get arrays. Repeat (to use for intermediate nodes). 401 yy = yy.array.repeat(3) 402 xx = xx.array.repeat(3) 403 if len(xx) == 0: 404 pass # shortcircuit any modifications 405 elif shape == "step": 406 yy = np.hstack([yy[1:], nan]) # roll leftward 407 elif shape == "dirac": 408 nonlocal nDirac 409 axW = np.diff(ax.get_xlim()) 410 yy[0::3] = False # set datapoin to 0 411 xx[2::3] = nan # make datapoint disappear 412 xx += nDirac * axW / 100 # offset datapoint horizontally 413 nDirac += 1 414 return xx, yy 415 416 nDirac = 1 417 for _name, ln in lines.items(): 418 ln["handle"].set_data(*bend_into(ln["shape"], ln["tt"], ln["data"])) 419 420 def finalize_init(ax, lines, mm): 421 # Rm lines that only contain NaNs 422 for name in list(lines): 423 ln = lines[name] 424 stat = deep_getattr(stats, name) 425 if not stat.were_changed: 426 ln["handle"].remove() # rm from axes 427 del lines[name] # rm from dict 428 # Add legends 429 if lines: 430 ax.legend(loc="upper left", bbox_to_anchor=(1.01, 1), borderaxespad=0) 431 if mm: 432 ax.annotate( 433 star + ": mean of\nmarginals", 434 xy=(0, -1.5 / len(lines)), 435 xycoords=ax.get_legend().get_frame(), 436 bbox=dict(alpha=0.0), 437 fontsize="small", 438 ) 439 # coz placement of annotate needs flush sometimes: 440 plot_pause(0.01) 441 442 # Insert current stats 443 for lines, ax in zip(self.d, self.axs): 444 update_arrays(lines) 445 update_plot_data(ax, lines) 446 447 # Set x-limits (time) 448 sliding_xlim(ax0, self.d[0]["err.rms"]["tt"], self.T_lag, margin=True) 449 self.baseline0.set_xdata(ax0.get_xlim()) 450 451 # Set y-limits 452 data0 = [ln["data"].array for ln in self.d[0].values()] 453 data1 = [ln["data"].array for ln in self.d[1].values()] 454 ax0.set_ylim(0, d_ylim(data0, ax0, cC=0.2, cE=0.9)[1]) 455 ax1.set_ylim(*d_ylim(data1, ax1, Max=4, Min=-4, cC=0.3, cE=0.9)) 456 457 # Init legend. Rm nan lines. 458 if self.init_incomplete and "a" == faus: 459 self.init_incomplete = False 460 finalize_init(ax0, self.d[0], False) 461 finalize_init(ax1, self.d[1], True) 462 463 464def sliding_xlim(ax, tt, lag, margin=False): 465 dt = lag / 20 if margin else 0 466 if tt.nFilled == 0: 467 return # Quit 468 t1, t2 = tt.span() # Get suggested span. 469 s1, s2 = ax.get_xlim() # Get previous lims. 470 # If zero span (eg tt holds single 'f' and 'a'): 471 if t1 == t2: 472 t1 -= 1 # add width 473 t2 += 1 # add width 474 # If user has skipped (too much): 475 elif np.isnan(t1): 476 s2 -= dt # Correct for dt. 477 span = s2 - s1 # Compute previous span 478 # If span<lag: 479 if span < lag: 480 span += t2 - s2 # Grow by "dt". 481 span = min(lag, span) # Bound 482 t1 = t2 - span # Set span. 483 ax.set_xlim(t1, t2 + dt) # Set xlim to span 484 485 486class weight_histogram: 487 """Plots histogram of weights. Refreshed each analysis.""" 488 489 def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): 490 if not hasattr(stats, "w"): 491 self.is_active = False 492 return 493 fig, ax = place.freshfig(fignum, figsize=(7, 3), gridspec_kw={"bottom": 0.15}) 494 495 ax.set_xscale("log") 496 ax.set_xlabel("Weigth") 497 ax.set_ylabel("Count") 498 self.stats = stats 499 self.ax = ax 500 self.hist = [] 501 self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31)) 502 503 def __call__(self, key, E, P): 504 k, ko, faus = key 505 if "a" == faus: 506 w = self.stats.w[key] 507 N = len(w) 508 ax = self.ax 509 510 self.is_active = N < 10001 511 if not self.is_active: 512 not_available_text(ax, "Not computed (N > threshold)") 513 return 514 515 counted = w > self.bins[0] 516 _ = [b.remove() for b in self.hist] 517 nn, _, self.hist = ax.hist(w[counted], bins=self.bins, color="b") 518 ax.set_ylim(top=max(nn)) 519 520 ax.set_title( 521 f"N: {N:d}. N_eff: {1/(w@w):.4g}." 522 " Not shown: {N-np.sum(counted):d}. " 523 ) 524 525 526class spectral_errors: 527 """Plots the (spatial-RMS) error as a functional of the SVD index.""" 528 529 def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): 530 fig, ax = place.freshfig(fignum, figsize=(6, 3)) 531 ax.set_xlabel("Sing. value index") 532 ax.set_yscale("log") 533 self.init_incomplete = True 534 self.ax = ax 535 self.plot_u = plot_u 536 537 try: 538 self.msft = stats.umisf 539 self.sprd = stats.svals 540 except AttributeError: 541 self.is_active = False 542 not_available_text(ax, "Spectral stats not being computed") 543 544 # Update plot 545 def __call__(self, key, E, P): 546 k, ko, faus = key 547 ax = self.ax 548 if self.init_incomplete: 549 if self.plot_u or "f" == faus: 550 self.init_incomplete = False 551 msft = abs(self.msft[key]) 552 sprd = self.sprd[key] 553 if np.any(np.isinf(msft)): 554 not_available_text(ax, "Spectral stats not finite") 555 self.is_active = False 556 else: 557 (self.line_msft,) = ax.plot(msft, "k", lw=2, label="Error") 558 (self.line_sprd,) = ax.plot( 559 sprd, "b", lw=2, label="Spread", alpha=0.9 560 ) 561 ax.get_xaxis().set_major_locator(MaxNLocator(integer=True)) 562 ax.legend() 563 else: 564 msft = abs(self.msft[key]) 565 sprd = self.sprd[key] 566 self.line_sprd.set_ydata(sprd) 567 self.line_msft.set_ydata(msft) 568 # ax.set_ylim(*d_ylim(msft)) 569 # ax.set_ylim(bottom=1e-5) 570 ax.set_ylim([1e-3, 1e1]) 571 572 573class correlations: 574 """Plots the state (auto-)correlation matrix.""" 575 576 half = True # Whether to show half/full (symmetric) corr matrix. 577 578 def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): 579 GS = {"height_ratios": [4, 1], "hspace": 0.09, "top": 0.95} 580 fig, (ax, ax2) = place.freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS) 581 582 if E is None and np.isnan(P.diag if isinstance(P, CovMat) else P).all(): 583 not_available_text( 584 ax, ("Not available in replays" "\ncoz full Ens/Cov not stored.") 585 ) 586 self.is_active = False 587 return 588 589 Nx = len(stats.mu[key0]) 590 if Nx <= 1003: 591 C = np.eye(Nx) 592 # Mask half 593 mask = np.zeros_like(C, dtype=bool) 594 mask[np.tril_indices_from(mask)] = True 595 cmap = plt.get_cmap("RdBu_r") 596 VM = 1.0 # abs(np.percentile(C,[1,99])).max() 597 im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM) 598 # Colorbar 599 _ = ax.figure.colorbar(im, ax=ax, shrink=0.8) 600 # Tune plot 601 plt.box(False) 602 ax.set_facecolor("w") 603 ax.grid(False) 604 ax.set_title("State correlation matrix:", y=1.07) 605 ax.xaxis.tick_top() 606 607 # ax2 = inset_axes(ax,width="30%",height="60%",loc=3) 608 (line_AC,) = ax2.plot(arange(Nx), ones(Nx), label="Correlation") 609 (line_AA,) = ax2.plot(arange(Nx), ones(Nx), label="Abs. corr.") 610 _ = ax2.hlines(0, 0, Nx - 1, "k", "dotted", lw=1) 611 # Align ax2 with ax 612 bb_AC = ax2.get_position() 613 bb_C = ax.get_position() 614 ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height]) 615 # Tune plot 616 ax2.set_title("Auto-correlation:") 617 ax2.set_ylabel("Mean value") 618 ax2.set_xlabel("Distance (in state indices)") 619 ax2.set_xticklabels([]) 620 ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]])) 621 ax2.set_ylim(top=1) 622 ax2.legend( 623 frameon=True, 624 facecolor="w", 625 bbox_to_anchor=(1, 1), 626 loc="upper left", 627 borderaxespad=0.02, 628 ) 629 630 self.ax = ax 631 self.ax2 = ax2 632 self.im = im 633 self.line_AC = line_AC 634 self.line_AA = line_AA 635 self.mask = mask 636 if hasattr(stats, "w"): 637 self.w = stats.w 638 else: 639 not_available_text(ax) 640 641 # Update plot 642 def __call__(self, key, E, P): 643 # Get cov matrix 644 if E is not None: 645 if hasattr(self, "w"): 646 C = np.cov(E, rowvar=False, aweights=self.w[key]) 647 else: 648 C = np.cov(E, rowvar=False) 649 else: 650 assert P is not None 651 C = P.full if isinstance(P, CovMat) else P 652 C = C.copy() 653 # Compute corr from cov 654 std = np.sqrt(np.diag(C)) 655 C /= std[:, None] 656 C /= std[None, :] 657 # Mask 658 if self.half: 659 C = np.ma.masked_where(self.mask, C) 660 # Plot 661 self.im.set_data(C) 662 # Auto-corr function 663 ACF = circulant_ACF(C) 664 AAF = circulant_ACF(C, do_abs=True) 665 self.line_AC.set_ydata(ACF) 666 self.line_AA.set_ydata(AAF) 667 668 669def circulant_ACF(C, do_abs=False): 670 """Compute the auto-covariance-function corresponding to `C`. 671 672 This assumes it is the cov/corr matrix of a 1D periodic domain. 673 674 Vectorized or FFT implementations are 675 [possible](https://stackoverflow.com/questions/20360675). 676 """ 677 M = len(C) 678 # cols = np.flipud(sla.circulant(np.arange(M)[::-1])) 679 cols = sla.circulant(np.arange(M)) 680 ACF = np.zeros(M) 681 for i in range(M): 682 row = C[i, cols[i]] 683 if do_abs: 684 row = abs(row) 685 ACF += row 686 # Note: this actually also accesses masked values in C. 687 return ACF / M 688 689 690def sliding_marginals( 691 obs_inds=(), 692 dims=(), 693 labels=(), 694 Tplot=None, 695 ens_props=dict(alpha=0.4), # noqa 696 zoomy=1.0, 697): 698 # Store parameters 699 params_orig = DotDict(**locals()) 700 701 def init(fignum, stats, key0, plot_u, E, P, **kwargs): 702 xx, yy, mu, spread, tseq = ( 703 stats.xx, 704 stats.yy, 705 stats.mu, 706 stats.spread, 707 stats.HMM.tseq, 708 ) 709 710 # Set parameters (kwargs takes precedence over params_orig) 711 p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) 712 713 # Chose marginal dims to plot 714 if not len(p.dims): 715 p.dims = linspace_int(xx.shape[-1], min(10, xx.shape[-1])) 716 717 # Lag settings: 718 T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq) 719 K_plot = comp_K_plot(K_lag, a_lag, plot_u) 720 # Extend K_plot forther for adding blanks in resampling (PartFilt): 721 has_w = hasattr(stats, "w") 722 if has_w: 723 K_plot += a_lag 724 725 # Set up figure, axes 726 fig, axs = place.freshfig( 727 fignum, figsize=(5, 7), squeeze=False, nrows=len(p.dims), sharex=True 728 ) 729 axs = axs.reshape(len(p.dims)) 730 731 # Tune plots 732 axs[0].set_title("Marginal time series") 733 for ix, (m, ax) in enumerate(zip(p.dims, axs)): 734 # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy)) 735 if not p.labels: 736 ax.set_ylabel("$x_{%d}$" % m) 737 else: 738 ax.set_ylabel(p.labels[ix]) 739 axs[-1].set_xlabel("Time (t)") 740 741 plot_pause(0.05) 742 plt.tight_layout() 743 744 # Allocate 745 d = DotDict() # data arrays 746 h = DotDict() # plot handles 747 # Why "if True" ? Just to indent the rest of the line... 748 if True: 749 d.t = RollingArray((K_plot,)) 750 if True: 751 d.x = RollingArray((K_plot, len(p.dims))) 752 h.x = [] 753 if not_empty(p.obs_inds): 754 d.y = RollingArray((K_plot, len(p.dims))) 755 h.y = [] 756 if E is not None: 757 d.E = RollingArray((K_plot, len(E), len(p.dims))) 758 h.E = [] 759 if P is not None: 760 d.mu = RollingArray((K_plot, len(p.dims))) 761 h.mu = [] 762 if P is not None: 763 d.s = RollingArray((K_plot, 2, len(p.dims))) 764 h.s = [] 765 766 # Plot (invisible coz everything here is nan, for the moment). 767 for ix, ax in zip(p.dims, axs): 768 if True: 769 h.x += ax.plot(d.t, d.x[:, ix], "k") 770 if not_empty(p.obs_inds): 771 h.y += ax.plot(d.t, d.y[:, ix], "g*", ms=10) 772 if "E" in d: 773 h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)] 774 if "mu" in d: 775 h.mu += ax.plot(d.t, d.mu[:, ix], "b") 776 if "s" in d: 777 h.s += [ax.plot(d.t, d.s[:, :, ix], "b--", lw=1)] 778 779 def update(key, E, P): 780 k, ko, faus = key 781 782 EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w) 783 784 # Roll data array 785 ind = k if plot_u else ko 786 for Ens in EE: # If E is duplicated, so must the others be. 787 if "E" in d: 788 d.E.insert(ind, Ens) 789 if "mu" in d: 790 d.mu.insert(ind, mu[key][p.dims]) 791 if "s" in d: 792 d.s.insert(ind, mu[key][p.dims] + [[1], [-1]] * spread[key][p.dims]) 793 if True: 794 d.t.insert(ind, tseq.tt[k]) 795 if not_empty(p.obs_inds): 796 xy = nan * ones(len(p.dims)) 797 if ko is not None: 798 jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds 799 xy[jj] = yy[ko] 800 d.y.insert(ind, xy) 801 if True: 802 d.x.insert(ind, xx[k, p.dims]) 803 804 # Update graphs 805 for ix, ax in zip(p.dims, axs): 806 sliding_xlim(ax, d.t, T_lag, True) 807 if True: 808 h.x[ix].set_data(d.t, d.x[:, ix]) 809 if not_empty(p.obs_inds): 810 h.y[ix].set_data(d.t, d.y[:, ix]) 811 if "mu" in d: 812 h.mu[ix].set_data(d.t, d.mu[:, ix]) 813 if "s" in d: 814 [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]] 815 if "E" in d: 816 [h.E[ix][n].set_data(d.t, d.E[:, n, ix]) for n in range(len(E))] 817 if "E" in d: 818 update_alpha(key, stats, h.E[ix]) 819 820 # TODO 3: fixup. This might be slow? 821 # In any case, it is very far from tested. 822 # Also, relim'iting all of the time is distracting. 823 # Use d_ylim? 824 if "E" in d: 825 lims = d.E 826 elif "mu" in d: 827 lims = d.mu 828 lims = np.array(viz.xtrema(lims[..., ix])) 829 if lims[0] == lims[1]: 830 lims += [-0.5, +0.5] 831 ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy)) 832 833 return 834 835 return update 836 837 return init 838 839 840def phase_particles( 841 is_3d=True, 842 obs_inds=(), 843 dims=(), 844 labels=(), 845 Tplot=None, 846 ens_props=dict(alpha=0.4), # noqa 847 zoom=1.5, 848): 849 # Store parameters 850 params_orig = DotDict(**locals()) 851 852 M = 3 if is_3d else 2 853 854 def init(fignum, stats, key0, plot_u, E, P, **kwargs): 855 xx, yy, mu, _, tseq = stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq 856 857 # Set parameters (kwargs takes precedence over params_orig) 858 p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) 859 860 # Lag settings: 861 has_w = hasattr(stats, "w") 862 if p.Tplot == 0: 863 K_plot = 1 864 else: 865 T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq) 866 K_plot = comp_K_plot(K_lag, a_lag, plot_u) 867 # Extend K_plot forther for adding blanks in resampling (PartFilt): 868 if has_w: 869 K_plot += a_lag 870 871 # Dimension settings 872 if not p.dims: 873 p.dims = arange(M) 874 if not p.labels: 875 p.labels = ["$x_%d$" % d for d in p.dims] 876 assert len(p.dims) == M 877 878 # Set up figure, axes 879 fig, _ = place.freshfig(fignum, figsize=(5, 5)) 880 ax = plt.subplot(111, projection="3d" if is_3d else None) 881 ax.set_facecolor("w") 882 ax.set_title("Phase space trajectories") 883 # Tune plot 884 for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")): 885 viz.set_ilim(ax, ind, *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom)) 886 eval(f"ax.set_{t}label('{s!s}')") 887 888 # Allocate 889 d = DotDict() # data arrays 890 h = DotDict() # plot handles 891 s = DotDict() # scatter handles 892 if E is not None: 893 d.E = RollingArray((K_plot, len(E), M)) 894 h.E = [] 895 if P is not None: 896 d.mu = RollingArray((K_plot, M)) 897 if True: 898 d.x = RollingArray((K_plot, M)) 899 if not_empty(p.obs_inds): 900 d.y = RollingArray((K_plot, M)) 901 902 # Plot tails (invisible coz everything here is nan, for the moment). 903 if "E" in d: 904 h.E += [ 905 ax.plot(*xn, **p.ens_props)[0] for xn in np.transpose(d.E, [1, 2, 0]) 906 ] 907 if "mu" in d: 908 h.mu = ax.plot(*d.mu.T, "b", lw=2)[0] 909 if True: 910 h.x = ax.plot(*d.x.T, "k", lw=3)[0] 911 if "y" in d: 912 h.y = ax.plot(*d.y.T, "g*", ms=14)[0] 913 914 # Scatter. NB: don't init with nan's coz it's buggy 915 # (wrt. get_color() and _offsets3d) since mpl 3.1. 916 if "E" in d: 917 s.E = ax.scatter(*E.T[p.dims], s=3**2, c=[hn.get_color() for hn in h.E]) 918 if "mu" in d: 919 s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()]) 920 if True: 921 s.x = ax.scatter( 922 *ones(M), s=14**2, c=[h.x.get_color()], marker=(5, 1), zorder=99 923 ) 924 925 def update(key, E, P): 926 k, ko, faus = key 927 928 def update_tail(handle, newdata): 929 handle.set_data(newdata[:, 0], newdata[:, 1]) 930 if is_3d: 931 handle.set_3d_properties(newdata[:, 2]) 932 933 def update_sctr(handle, newdata): 934 if is_3d: 935 handle._offsets3d = juggle_axes(*newdata.T, "z") 936 else: 937 handle.set_offsets(newdata) 938 939 EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w) 940 941 # Roll data array 942 ind = k if plot_u else ko 943 for Ens in EE: # If E is duplicated, so must the others be. 944 if "E" in d: 945 d.E.insert(ind, Ens) 946 if True: 947 d.x.insert(ind, xx[k, p.dims]) 948 if "y" in d: 949 xy = nan * ones(len(p.dims)) 950 if ko is not None: 951 jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds 952 jj = list(jj) 953 for i, dim in enumerate(p.dims): 954 try: 955 iobs = jj.index(dim) 956 except ValueError: 957 pass 958 else: 959 xy[i] = yy[ko][iobs] 960 d.y.insert(ind, xy) 961 if "mu" in d: 962 d.mu.insert(ind, mu[key][p.dims]) 963 964 # Update graph 965 update_sctr(s.x, d.x[[-1]]) 966 update_tail(h.x, d.x) 967 if "y" in d: 968 update_tail(h.y, d.y) 969 if "mu" in d: 970 update_sctr(s.mu, d.mu[[-1]]) 971 update_tail(h.mu, d.mu) 972 else: 973 update_sctr(s.E, d.E[-1]) 974 for n in range(len(E)): 975 update_tail(h.E[n], d.E[:, n, :]) 976 update_alpha(key, stats, h.E, s.E) 977 978 return 979 980 return update 981 982 return init 983 984 985def validate_lag(Tplot, tseq): 986 """Return validated `T_lag` such that is is: 987 988 - equal to `Tplot` with fallback: `HMM.tseq.Tplot`. 989 - no longer than `HMM.tseq.T`. 990 991 Also return corresponding `K_lag`, `a_lag`. 992 """ 993 # Defaults 994 if Tplot is None: 995 Tplot = tseq.Tplot 996 997 # Rename 998 T_lag = Tplot 999 1000 assert T_lag >= 0 1001 1002 # Validate T_lag 1003 t2 = tseq.tt[-1] 1004 t1 = max(tseq.tt[0], t2 - T_lag) 1005 T_lag = t2 - t1 1006 1007 K_lag = int(T_lag / tseq.dt) + 1 # Lag in indices 1008 a_lag = K_lag // tseq.dko + 1 # Lag in obs indices 1009 1010 return T_lag, K_lag, a_lag 1011 1012 1013def comp_K_plot(K_lag, a_lag, plot_u): 1014 K_plot = 2 * a_lag # Sum of lags of {f,a} series. 1015 if plot_u: 1016 K_plot += K_lag # Add lag of u series. 1017 return K_plot 1018 1019 1020def update_alpha(key, stats, lines, scatters=None): 1021 """Adjust color alpha (for particle filters).""" 1022 k, ko, faus = key 1023 if ko is None: 1024 return 1025 if faus == "f": 1026 return 1027 if not hasattr(stats, "w"): 1028 return 1029 1030 # Compute alpha values 1031 w = stats.w[key] 1032 alpha = (w / w.max()).clip(0.1, 0.4) 1033 1034 # Set line alpha 1035 for line, a in zip(lines, alpha): 1036 line.set_alpha(a) 1037 1038 # Scatter plot does not have alpha. => Fake it. 1039 if scatters is not None: 1040 colors = scatters.get_facecolor()[:, :3] 1041 if len(colors) == 1: 1042 colors = colors.repeat(len(w), axis=0) 1043 scatters.set_color(np.hstack([colors, alpha[:, None]])) 1044 1045 1046def not_empty(xx): 1047 """Works for non-iterable and iterables (including ndarrays).""" 1048 try: 1049 return len(xx) > 0 1050 except TypeError: 1051 return bool(xx) 1052 1053 1054def duplicate_with_blanks_for_resampled(E, dims, key, has_w): 1055 """Particle filter: insert breaks for resampled particles.""" 1056 if E is None: 1057 return [E] 1058 EE = [] 1059 E = E[:, dims] 1060 if has_w: 1061 k, ko, faus = key 1062 if faus == "f": 1063 pass 1064 elif faus == "a": 1065 _Ea[0] = E[:, 0] # Store (1st dim of) ens. 1066 elif faus == "u" and ko is not None: 1067 # Find resampled particles. Insert duplicate ensemble. Write nans (breaks). 1068 resampled = _Ea[0] != E[:, 0] # Mark as resampled if ens changed. 1069 # Insert current ensemble (copy to avoid overwriting). 1070 EE.append(E.copy()) 1071 EE[0][resampled] = nan # Write breaks 1072 # Always: append current ensemble 1073 EE.append(E) 1074 return EE 1075 1076 1077_Ea = [None] # persistent storage for ens 1078 1079 1080def d_ylim(data, ax=None, cC=0, cE=1, pp=(1, 99), Min=-1e20, Max=+1e20): 1081 """Provide new ylim's intelligently, from percentiles of the data. 1082 1083 - `data`: iterable of arrays for computing percentiles. 1084 - `pp`: percentiles 1085 1086 - `ax`: If present, then the delta_zoom in/out is also considered. 1087 1088 - `cE`: exansion (widenting) rate ∈ [0,1]. 1089 Default: 1, which immediately expands to percentile. 1090 - `cC`: compression (narrowing) rate ∈ [0,1]. 1091 Default: 0, which does not allow compression. 1092 1093 - `Min`/`Max`: bounds 1094 1095 Despite being a little involved, 1096 the cost of this subroutine is typically not substantial 1097 because there's usually not that much data to sort through. 1098 """ 1099 # Find "reasonable" limits (by percentiles), looping over data 1100 maxv = minv = -np.inf # init 1101 for d in data: 1102 d = d[np.isfinite(d)] 1103 if len(d): 1104 perc = np.array([-1, 1]) * np.percentile(d, pp) 1105 minv, maxv = np.maximum([minv, maxv], perc) 1106 minv *= -1 1107 1108 # Pry apart equal values 1109 if np.isclose(minv, maxv): 1110 maxv += 0.5 1111 minv -= 0.5 1112 1113 # Make the zooming transition smooth 1114 if ax is not None: 1115 current = ax.get_ylim() 1116 # Set rate factor as compress or expand factor. 1117 c0 = cC if minv > current[0] else cE 1118 c1 = cC if maxv < current[1] else cE 1119 # Adjust 1120 minv = np.interp(c0, (0, 1), (current[0], minv)) 1121 maxv = np.interp(c1, (0, 1), (current[1], maxv)) 1122 1123 # Bounds 1124 maxv = min(Max, maxv) 1125 minv = max(Min, minv) 1126 1127 # Set (if anything's changed) 1128 def worth_updating(a, b, curr): 1129 # Note: should depend on cC and cE 1130 d = abs(curr[1] - curr[0]) 1131 lower = abs(a - curr[0]) > 0.002 * d 1132 upper = abs(b - curr[1]) > 0.002 * d 1133 return lower and upper 1134 1135 # if worth_updating(minv,maxv,current): 1136 # ax.set_ylim(minv,maxv) 1137 1138 # Some mpl versions don't handle inf limits. 1139 if not np.isfinite(minv): 1140 minv = None 1141 if not np.isfinite(maxv): 1142 maxv = None 1143 1144 return minv, maxv 1145 1146 1147def spatial1d( 1148 obs_inds=(), 1149 periodicity=None, 1150 dims=(), 1151 ens_props={"color": "b", "alpha": 0.1}, # noqa 1152 conf_mult=None, 1153): 1154 # Store parameters 1155 params_orig = DotDict(**locals()) 1156 1157 def init(fignum, stats, key0, plot_u, E, P, **kwargs): 1158 xx, yy, mu = stats.xx, stats.yy, stats.mu 1159 1160 # Set parameters (kwargs takes precedence over params_orig) 1161 p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) 1162 1163 if not p.dims: 1164 M = xx.shape[-1] 1165 p.dims = arange(M) 1166 else: 1167 M = len(p.dims) 1168 1169 # Make periodic wrapper 1170 ii, wrap = viz.setup_wrapping(M, p.periodicity) 1171 1172 # Set up figure, axes 1173 fig, ax = place.freshfig(fignum, figsize=(8, 5)) 1174 fig.suptitle("1d amplitude plot") 1175 1176 # Nans 1177 nan1 = wrap(nan * ones(M)) 1178 1179 if E is None and p.conf_mult is None: 1180 p.conf_mult = 2 1181 1182 # Init plots 1183 if p.conf_mult: 1184 lines_s = ax.plot( 1185 ii, nan1, "b-", lw=1, label=(str(p.conf_mult) + r"$\sigma$ conf") 1186 ) 1187 lines_s += ax.plot(ii, nan1, "b-", lw=1) 1188 (line_mu,) = ax.plot(ii, nan1, "b-", lw=2, label="DA mean") 1189 else: 1190 nanE = nan * ones((stats.xp.N, M)) 1191 lines_E = ax.plot(ii, wrap(nanE[0]), **p.ens_props, lw=1, label="Ensemble") 1192 lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1) 1193 # Truth, Obs 1194 (line_x,) = ax.plot(ii, nan1, "k-", lw=3, label="Truth") 1195 if not_empty(p.obs_inds): 1196 (line_y,) = ax.plot(ii, nan1, "g*", ms=5, label="Obs") 1197 1198 # Tune plot 1199 ax.set_ylim(*viz.xtrema(xx)) 1200 ax.set_xlim(viz.stretch(ii[0], ii[-1], 1)) 1201 # Xticks 1202 xt = ax.get_xticks() 1203 xt = xt[abs(xt % 1) < 0.01].astype(int) # Keep only the integer ticks 1204 xt = xt[xt >= 0] 1205 xt = xt[xt < len(p.dims)] 1206 ax.set_xticks(xt) 1207 ax.set_xticklabels(p.dims[xt]) 1208 1209 ax.set_xlabel("State index") 1210 ax.set_ylabel("Value") 1211 ax.legend(loc="upper right") 1212 1213 text_t = ax.text( 1214 0.01, 1215 0.01, 1216 format_time(None, None, None), 1217 transform=ax.transAxes, 1218 family="monospace", 1219 ha="left", 1220 ) 1221 1222 def update(key, E, P): 1223 k, ko, faus = key 1224 1225 if p.conf_mult: 1226 sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]] 1227 lines_s[0].set_ydata(wrap(sigma[0, p.dims])) 1228 lines_s[1].set_ydata(wrap(sigma[1, p.dims])) 1229 line_mu.set_ydata(wrap(mu[key][p.dims])) 1230 else: 1231 for n, line in enumerate(lines_E): 1232 line.set_ydata(wrap(E[n, p.dims])) 1233 update_alpha(key, stats, lines_E) 1234 1235 line_x.set_ydata(wrap(xx[k, p.dims])) 1236 1237 text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k])) 1238 1239 if "f" in faus: 1240 if not_empty(p.obs_inds): 1241 xy = nan * ones(len(xx[0])) 1242 jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds 1243 xy[jj] = yy[ko] 1244 line_y.set_ydata(wrap(xy[p.dims])) 1245 line_y.set_zorder(5) 1246 line_y.set_visible(True) 1247 1248 if "u" in faus: 1249 if not_empty(p.obs_inds): 1250 line_y.set_visible(False) 1251 1252 return 1253 1254 return update 1255 1256 return init 1257 1258 1259def spatial2d( 1260 square, 1261 ind2sub, 1262 obs_inds=(), 1263 cm=plt.cm.jet, 1264 clims=((-40, 40), (-40, 40), (-10, 10), (-10, 10)), 1265): 1266 def init(fignum, stats, key0, plot_u, E, P, **kwargs): 1267 GS = {"left": 0.125 - 0.04, "right": 0.9 - 0.04} 1268 fig, axs = place.freshfig( 1269 fignum, 1270 figsize=(6, 6), 1271 nrows=2, 1272 ncols=2, 1273 sharex=True, 1274 sharey=True, 1275 gridspec_kw=GS, 1276 ) 1277 1278 for ax in axs.flatten(): 1279 ax.set_aspect("equal", "box") 1280 1281 ((ax_11, ax_12), (ax_21, ax_22)) = axs 1282 1283 ax_11.grid(color="w", linewidth=0.2) 1284 ax_12.grid(color="w", linewidth=0.2) 1285 ax_21.grid(color="k", linewidth=0.1) 1286 ax_22.grid(color="k", linewidth=0.1) 1287 1288 # Upper colorbar -- position relative to ax_12 1289 bb = ax_12.get_position() 1290 dy = 0.1 * bb.height 1291 ax_13 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy]) 1292 # Lower colorbar -- position relative to ax_22 1293 bb = ax_22.get_position() 1294 dy = 0.1 * bb.height 1295 ax_23 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy]) 1296 1297 # Extract data arrays 1298 xx, _, mu, spread, err = stats.xx, stats.yy, stats.mu, stats.spread, stats.err 1299 k = key0[0] 1300 tt = stats.HMM.tseq.tt 1301 1302 # Plot 1303 # - origin='lower' might get overturned by set_ylim() below. 1304 im_11 = ax_11.imshow(square(mu[key0]), cmap=cm) 1305 im_12 = ax_12.imshow(square(xx[k]), cmap=cm) 1306 # hot is better, but needs +1 colorbar 1307 im_21 = ax_21.imshow(square(spread[key0]), cmap=plt.cm.bwr) 1308 im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr) 1309 ims = (im_11, im_12, im_21, im_22) 1310 # Obs init -- a list where item 0 is the handle of something invisible. 1311 lh = list(ax_12.plot(0, 0)[0:1]) 1312 1313 sx = "$\\psi$" 1314 ax_11.set_title("mean " + sx) 1315 ax_12.set_title("true " + sx) 1316 ax_21.set_title("spread. " + sx) 1317 ax_22.set_title("err. " + sx) 1318 1319 # TODO 7 1320 # for ax in axs.flatten(): 1321 # Crop boundries (which should be 0, i.e. yield harsh q gradients): 1322 # lims = (1, nx-2) 1323 # step = (nx - 1)/8 1324 # ticks = arange(step,nx-1,step) 1325 # ax.set_xlim (lims) 1326 # ax.set_ylim (lims[::-1]) 1327 # ax.set_xticks(ticks) 1328 # ax.set_yticks(ticks) 1329 1330 for im, clim in zip(ims, clims): 1331 im.set_clim(clim) 1332 1333 fig.colorbar(im_12, cax=ax_13) 1334 fig.colorbar(im_22, cax=ax_23) 1335 for ax in [ax_13, ax_23]: 1336 ax.yaxis.set_tick_params( 1337 "major", length=2, width=0.5, direction="in", left=True, right=True 1338 ) 1339 ax.set_axisbelow("line") # make ticks appear over colorbar patch 1340 1341 # Title 1342 title = "Streamfunction (" + sx + ")" 1343 fig.suptitle(title) 1344 # Time info 1345 text_t = ax_12.text( 1346 1, 1347 1.1, 1348 format_time(None, None, None), 1349 transform=ax_12.transAxes, 1350 family="monospace", 1351 ha="left", 1352 ) 1353 1354 def update(key, E, P): 1355 k, ko, faus = key 1356 t = tt[k] 1357 1358 im_11.set_data(square(mu[key])) 1359 im_12.set_data(square(xx[k])) 1360 im_21.set_data(square(spread[key])) 1361 im_22.set_data(square(err[key])) 1362 1363 # Remove previous obs 1364 try: 1365 lh[0].remove() 1366 except ValueError: 1367 pass 1368 # Plot current obs. 1369 # - plot() automatically adjusts to direction of y-axis in use. 1370 # - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse. 1371 1372 if ko is not None and not_empty(obs_inds): 1373 lh[0] = ax_12.plot(*ind2sub(obs_inds(ko))[::-1], "k.", ms=1, zorder=5)[ 1374 0 1375 ] 1376 1377 text_t.set_text(format_time(k, ko, t)) 1378 1379 return 1380 1381 return update 1382 1383 return init 1384 1385 1386# List of liveplotters available for all HMMs. 1387# Columns: 1388# - fignum 1389# - show_by_default 1390# - function/class 1391default_liveplotters = [ 1392 (1, sliding_diagnostics), 1393 (1, weight_histogram), 1394]
45class LivePlot: 46 """Live plotting manager. 47 48 Deals with 49 50 - Pause, skip. 51 - Which liveploters to call. 52 - `plot_u` 53 - Figure window (title and number). 54 """ 55 56 def __init__( 57 self, 58 stats, 59 liveplots, 60 key0=(0, None, "u"), 61 E=None, 62 P=None, 63 speed=1.0, 64 replay=False, 65 **kwargs, 66 ): 67 """ 68 Initialize plots. 69 70 - liveplots: figures to plot; alternatives: 71 - `"default"/[]/True`: All default figures for this HMM. 72 - `"all"` : Even more. 73 - non-empty `list` : Only the figures with these numbers 74 (int) or names (str). 75 - `False` : None. 76 - speed: speed of animation. 77 - `>100`: instantaneous 78 - `1` : (default) as quick as possible allowing for 79 plt.draw() to work on a moderately fast computer. 80 - `<1` : slower. 81 """ 82 # Disable if not rc.liveplotting 83 self.any_figs = False 84 if rc.liveplotting: 85 pass 86 elif replay and np.isinf(speed): 87 pass 88 else: 89 return 90 91 # Determine whether all/universal/intermediate stats are plotted 92 self.plot_u = not replay or stats.store_u 93 94 # Set speed/pause params 95 self.params = { 96 "pause_f": 0.05, 97 "pause_a": 0.05, 98 "pause_s": 0.05, 99 "pause_u": 0.001, 100 } 101 # If speed>100: set to inf. Coz pause=1e-99 causes hangup. 102 for pause in ["pause_" + x for x in "faus"]: 103 speed = speed if speed < 100 else np.inf 104 self.params[pause] /= speed 105 106 # Write params 107 self.params.update(getattr(stats.xp, "LP_kwargs", {})) 108 self.params.update(kwargs) 109 110 def get_name(init): 111 """Get name of liveplotter function/class.""" 112 try: 113 return init.__qualname__.split(".")[0] 114 except AttributeError: 115 return init.__class__.__name__ 116 117 # Set up dict of liveplotters 118 potential_LPs = {} 119 for show, init in default_liveplotters: 120 potential_LPs[get_name(init)] = show, init 121 # Add HMM-specific liveplotters 122 for show, init in getattr(stats.HMM, "liveplotters", {}): 123 potential_LPs[get_name(init)] = show, init 124 125 def parse_figlist(lst): 126 """Figures requested for this xp. Convert to list.""" 127 if isinstance(lst, str): 128 fn = lst.lower() 129 if "all" == fn: 130 lst = ["all"] # All potential_LPs 131 elif "default" in fn: 132 lst = ["default"] # All show_by_default 133 elif hasattr(lst, "__len__"): 134 lst = lst # This list (only) 135 elif lst: 136 lst = ["default"] # All show_by_default 137 else: 138 lst = [None] # None 139 return lst 140 141 figlist = parse_figlist(liveplots) 142 143 # Loop over requeted figures 144 self.figures = {} 145 for name, (show_by_default, init) in potential_LPs.items(): 146 if ( 147 (figlist == ["all"]) 148 or (name in figlist) 149 or (figlist == ["default"] and show_by_default) 150 ): 151 # Startup message 152 if not self.any_figs: 153 print("Initializing liveplots...") 154 if is_notebook_or_qt: 155 pauses = [self.params["pause_" + x] for x in "faus"] 156 if any((p > 0) for p in pauses): 157 print( 158 "Note: liveplotting does not work very well" 159 " inside Jupyter notebooks. In particular," 160 " there is no way to stop/skip them except" 161 " to interrupt the kernel (the stop button" 162 " in the toolbar). Consider using instead" 163 " only the replay functionality (with infinite" 164 " playback speed)." 165 ) 166 elif not pb.disable_user_interaction: 167 print("Hit <Space> to pause/step.") 168 print("Hit <Enter> to resume/skip.") 169 print("Hit <i> to enter debug mode.") 170 self.paused = False 171 self.run_ipdb = False 172 self.skipping = False 173 self.any_figs = True 174 175 # Init figure 176 post_title = "" if self.plot_u else "\n(obs times only)" 177 updater = init(name, stats, key0, self.plot_u, E, P, **kwargs) 178 if plt.fignum_exists(name) and getattr(updater, "is_active", 1): 179 self.figures[name] = updater 180 fig = plt.figure(name) 181 win = fig.canvas 182 ax0 = fig.axes[0] 183 win.manager.set_window_title(str(name)) 184 ax0.set_title(ax0.get_title() + post_title) 185 self.update(key0, E, P) # Call initial update 186 if not (replay and np.isinf(speed)): 187 plt.pause(0.01) # Draw 188 189 def update(self, key, E, P): 190 """Update liveplots""" 191 # Check if there are still open figures 192 if self.any_figs: 193 open_figns = plt.get_figlabels() 194 live_figns = set(self.figures.keys()) 195 self.any_figs = bool(live_figns.intersection(open_figns)) 196 else: 197 return 198 199 # Playback control 200 SPACE = b" " 201 CHAR_I = b"i" 202 ENTERs = [b"\n", b"\r"] # Linux + Windows 203 204 def pause(): 205 """Loop until user decision is made.""" 206 ch = read1() 207 while True: 208 # Set state (pause, skipping, ipdb) 209 if ch in ENTERs: 210 self.paused = False 211 elif ch == CHAR_I: 212 self.run_ipdb = True 213 # If keypress valid, resume execution 214 if ch in ENTERs + [SPACE, CHAR_I]: 215 break 216 ch = read1() 217 # Pause to enable zoom, pan, etc. of mpl GUI 218 plot_pause(0.01) # Don't use time.sleep()! 219 220 # Enter pause loop 221 if self.paused: 222 pause() 223 224 else: 225 if key == (0, None, "u"): 226 # Skip read1 for key0 (coz it blocks) 227 pass 228 else: 229 ch = read1() 230 if ch == SPACE: 231 # Pause 232 self.paused = True 233 self.skipping = False 234 pause() 235 elif ch in ENTERs: 236 # Toggle skipping 237 self.skipping = not self.skipping 238 elif ch == CHAR_I: 239 # Schedule debug 240 # Note: The reason we dont set_trace(frame) right here is: 241 # - I could not find the right frame, even doing 242 # > frame = inspect.stack()[0] 243 # > while frame.f_code.co_name != "assimilate": 244 # > frame = frame.f_back 245 # - It just restarts the plot. 246 self.run_ipdb = True 247 248 # Update figures 249 if not self.skipping: 250 faus = key[-1] 251 if faus != "u" or self.plot_u: 252 for name, (updater) in self.figures.items(): 253 if plt.fignum_exists(name) and getattr(updater, "is_active", 1): 254 _ = plt.figure(name) 255 updater(key, E, P) 256 plot_pause(self.params["pause_" + faus]) 257 258 if self.run_ipdb: 259 self.run_ipdb = False 260 import inspect 261 262 import ipdb 263 264 print("Entering debug mode (ipdb).") 265 print("Type '?' (and Enter) for usage help.") 266 print("Type 'c' to continue the assimilation.") 267 ipdb.set_trace(inspect.stack()[2].frame)
Live plotting manager.
Deals with
- Pause, skip.
- Which liveploters to call.
plot_u
- Figure window (title and number).
56 def __init__( 57 self, 58 stats, 59 liveplots, 60 key0=(0, None, "u"), 61 E=None, 62 P=None, 63 speed=1.0, 64 replay=False, 65 **kwargs, 66 ): 67 """ 68 Initialize plots. 69 70 - liveplots: figures to plot; alternatives: 71 - `"default"/[]/True`: All default figures for this HMM. 72 - `"all"` : Even more. 73 - non-empty `list` : Only the figures with these numbers 74 (int) or names (str). 75 - `False` : None. 76 - speed: speed of animation. 77 - `>100`: instantaneous 78 - `1` : (default) as quick as possible allowing for 79 plt.draw() to work on a moderately fast computer. 80 - `<1` : slower. 81 """ 82 # Disable if not rc.liveplotting 83 self.any_figs = False 84 if rc.liveplotting: 85 pass 86 elif replay and np.isinf(speed): 87 pass 88 else: 89 return 90 91 # Determine whether all/universal/intermediate stats are plotted 92 self.plot_u = not replay or stats.store_u 93 94 # Set speed/pause params 95 self.params = { 96 "pause_f": 0.05, 97 "pause_a": 0.05, 98 "pause_s": 0.05, 99 "pause_u": 0.001, 100 } 101 # If speed>100: set to inf. Coz pause=1e-99 causes hangup. 102 for pause in ["pause_" + x for x in "faus"]: 103 speed = speed if speed < 100 else np.inf 104 self.params[pause] /= speed 105 106 # Write params 107 self.params.update(getattr(stats.xp, "LP_kwargs", {})) 108 self.params.update(kwargs) 109 110 def get_name(init): 111 """Get name of liveplotter function/class.""" 112 try: 113 return init.__qualname__.split(".")[0] 114 except AttributeError: 115 return init.__class__.__name__ 116 117 # Set up dict of liveplotters 118 potential_LPs = {} 119 for show, init in default_liveplotters: 120 potential_LPs[get_name(init)] = show, init 121 # Add HMM-specific liveplotters 122 for show, init in getattr(stats.HMM, "liveplotters", {}): 123 potential_LPs[get_name(init)] = show, init 124 125 def parse_figlist(lst): 126 """Figures requested for this xp. Convert to list.""" 127 if isinstance(lst, str): 128 fn = lst.lower() 129 if "all" == fn: 130 lst = ["all"] # All potential_LPs 131 elif "default" in fn: 132 lst = ["default"] # All show_by_default 133 elif hasattr(lst, "__len__"): 134 lst = lst # This list (only) 135 elif lst: 136 lst = ["default"] # All show_by_default 137 else: 138 lst = [None] # None 139 return lst 140 141 figlist = parse_figlist(liveplots) 142 143 # Loop over requeted figures 144 self.figures = {} 145 for name, (show_by_default, init) in potential_LPs.items(): 146 if ( 147 (figlist == ["all"]) 148 or (name in figlist) 149 or (figlist == ["default"] and show_by_default) 150 ): 151 # Startup message 152 if not self.any_figs: 153 print("Initializing liveplots...") 154 if is_notebook_or_qt: 155 pauses = [self.params["pause_" + x] for x in "faus"] 156 if any((p > 0) for p in pauses): 157 print( 158 "Note: liveplotting does not work very well" 159 " inside Jupyter notebooks. In particular," 160 " there is no way to stop/skip them except" 161 " to interrupt the kernel (the stop button" 162 " in the toolbar). Consider using instead" 163 " only the replay functionality (with infinite" 164 " playback speed)." 165 ) 166 elif not pb.disable_user_interaction: 167 print("Hit <Space> to pause/step.") 168 print("Hit <Enter> to resume/skip.") 169 print("Hit <i> to enter debug mode.") 170 self.paused = False 171 self.run_ipdb = False 172 self.skipping = False 173 self.any_figs = True 174 175 # Init figure 176 post_title = "" if self.plot_u else "\n(obs times only)" 177 updater = init(name, stats, key0, self.plot_u, E, P, **kwargs) 178 if plt.fignum_exists(name) and getattr(updater, "is_active", 1): 179 self.figures[name] = updater 180 fig = plt.figure(name) 181 win = fig.canvas 182 ax0 = fig.axes[0] 183 win.manager.set_window_title(str(name)) 184 ax0.set_title(ax0.get_title() + post_title) 185 self.update(key0, E, P) # Call initial update 186 if not (replay and np.isinf(speed)): 187 plt.pause(0.01) # Draw
Initialize plots.
- liveplots: figures to plot; alternatives:
"default"/[]/True
: All default figures for this HMM."all"
: Even more.- non-empty
list
: Only the figures with these numbers (int) or names (str). False
: None.
- speed: speed of animation.
>100
: instantaneous1
: (default) as quick as possible allowing for plt.draw() to work on a moderately fast computer.<1
: slower.
189 def update(self, key, E, P): 190 """Update liveplots""" 191 # Check if there are still open figures 192 if self.any_figs: 193 open_figns = plt.get_figlabels() 194 live_figns = set(self.figures.keys()) 195 self.any_figs = bool(live_figns.intersection(open_figns)) 196 else: 197 return 198 199 # Playback control 200 SPACE = b" " 201 CHAR_I = b"i" 202 ENTERs = [b"\n", b"\r"] # Linux + Windows 203 204 def pause(): 205 """Loop until user decision is made.""" 206 ch = read1() 207 while True: 208 # Set state (pause, skipping, ipdb) 209 if ch in ENTERs: 210 self.paused = False 211 elif ch == CHAR_I: 212 self.run_ipdb = True 213 # If keypress valid, resume execution 214 if ch in ENTERs + [SPACE, CHAR_I]: 215 break 216 ch = read1() 217 # Pause to enable zoom, pan, etc. of mpl GUI 218 plot_pause(0.01) # Don't use time.sleep()! 219 220 # Enter pause loop 221 if self.paused: 222 pause() 223 224 else: 225 if key == (0, None, "u"): 226 # Skip read1 for key0 (coz it blocks) 227 pass 228 else: 229 ch = read1() 230 if ch == SPACE: 231 # Pause 232 self.paused = True 233 self.skipping = False 234 pause() 235 elif ch in ENTERs: 236 # Toggle skipping 237 self.skipping = not self.skipping 238 elif ch == CHAR_I: 239 # Schedule debug 240 # Note: The reason we dont set_trace(frame) right here is: 241 # - I could not find the right frame, even doing 242 # > frame = inspect.stack()[0] 243 # > while frame.f_code.co_name != "assimilate": 244 # > frame = frame.f_back 245 # - It just restarts the plot. 246 self.run_ipdb = True 247 248 # Update figures 249 if not self.skipping: 250 faus = key[-1] 251 if faus != "u" or self.plot_u: 252 for name, (updater) in self.figures.items(): 253 if plt.fignum_exists(name) and getattr(updater, "is_active", 1): 254 _ = plt.figure(name) 255 updater(key, E, P) 256 plot_pause(self.params["pause_" + faus]) 257 258 if self.run_ipdb: 259 self.run_ipdb = False 260 import inspect 261 262 import ipdb 263 264 print("Entering debug mode (ipdb).") 265 print("Type '?' (and Enter) for usage help.") 266 print("Type 'c' to continue the assimilation.") 267 ipdb.set_trace(inspect.stack()[2].frame)
Update liveplots
275class sliding_diagnostics: 276 """Plots a sliding window (like a heart rate monitor) of certain diagnostics.""" 277 278 def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs): 279 # STYLE TABLES - Defines which/how diagnostics get plotted 280 styles = {} 281 282 def lin(a, b): 283 return lambda x: a + b * x 284 285 divN = 1 / getattr(stats.xp, "N", 99) 286 # Columns: transf, shape, plt kwargs 287 styles["RMS"] = { 288 "err.rms": [None, None, dict(c="k", label="Error")], 289 "spread.rms": [None, None, dict(c="b", label="Spread", alpha=0.6)], 290 } 291 styles["Values"] = { 292 "skew": [None, None, dict(c="g", label=star + r"Skew/$\sigma^3$")], 293 "kurt": [None, None, dict(c="r", label=star + r"Kurt$/\sigma^4{-}3$")], 294 "trHK": [None, None, dict(c="k", label=star + "HK")], 295 "infl": [lin(-10, 10), "step", dict(c="c", label="10(infl-1)")], 296 "N_eff": [lin(0, divN), "dirac", dict(c="y", label="N_eff/N", lw=3)], 297 "iters": [lin(0, 0.1), "dirac", dict(c="m", label="iters/10")], 298 "resmpl": [None, "dirac", dict(c="k", label="resampled?")], 299 } 300 301 nAx = len(styles) 302 GS = {"left": 0.125, "right": 0.76} 303 fig, axs = place.freshfig( 304 fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS 305 ) 306 307 axs[0].set_title("Diagnostics") 308 for style, ax in zip(styles, axs): 309 ax.set_ylabel(style) 310 ax.set_xlabel("Time (t)") 311 place_ax.adjust_position(ax, y0=0.03) 312 313 self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq) 314 315 def init_ax(ax, style_table): 316 lines = {} 317 for name in style_table: 318 # SKIP -- if stats[name] is not in existence 319 # Note: The nan check/deletion comes after the first ko. 320 try: 321 stat = deep_getattr(stats, name) 322 except AttributeError: 323 continue 324 # try: val0 = stat[key0[0]] 325 # except KeyError: continue 326 # PS: recall (from series.py) that even if store_u is false, stat[k] is 327 # still present if liveplots=True via the k_tmp functionality. 328 329 # Unpack style 330 ln = {} 331 ln["transf"] = style_table[name][0] or (lambda x: x) 332 ln["shape"] = style_table[name][1] 333 ln["plt"] = style_table[name][2] 334 335 # Create series 336 if isinstance(stat, FAUSt): 337 ln["plot_u"] = plot_u 338 K_plot = comp_K_plot(K_lag, a_lag, ln["plot_u"]) 339 else: 340 ln["plot_u"] = False 341 K_plot = a_lag 342 ln["data"] = RollingArray(K_plot) 343 ln["tt"] = RollingArray(K_plot) 344 345 # Plot (init) 346 (ln["handle"],) = ax.plot(ln["tt"], ln["data"], **ln["plt"]) 347 348 # Plotting only nans yield ugly limits. Revert to defaults. 349 ax.set_xlim(0, 1) 350 ax.set_ylim(0, 1) 351 352 lines[name] = ln 353 return lines 354 355 # Plot 356 self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)] 357 358 # Horizontal line at y=0 359 (self.baseline0,) = ax.plot( 360 ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label="_nolegend_" 361 ) 362 363 # Store 364 self.axs = axs 365 self.stats = stats 366 self.init_incomplete = True 367 368 # Update plot 369 def __call__(self, key, E, P): 370 k, ko, faus = key 371 372 stats = self.stats 373 tseq = stats.HMM.tseq 374 ax0, ax1 = self.axs 375 376 def update_arrays(lines): 377 for name, ln in lines.items(): 378 stat = deep_getattr(stats, name) 379 t = tseq.tt[k] # == tseq.tto[ko] 380 if isinstance(stat, FAUSt): 381 # ln['data'] will contain duplicates for f/a times. 382 if ln["plot_u"]: 383 val = stat[key] 384 ln["tt"].insert(k, t) 385 ln["data"].insert(k, ln["transf"](val)) 386 elif "u" not in faus: 387 val = stat[key] 388 ln["tt"].insert(ko, t) 389 ln["data"].insert(ko, ln["transf"](val)) 390 else: 391 # ln['data'] will not contain duplicates, coz only 'a' is input. 392 if "a" in faus: 393 val = stat[ko] 394 ln["tt"].insert(ko, t) 395 ln["data"].insert(ko, ln["transf"](val)) 396 elif "f" in faus: 397 pass 398 399 def update_plot_data(ax, lines): 400 def bend_into(shape, xx, yy): 401 # Get arrays. Repeat (to use for intermediate nodes). 402 yy = yy.array.repeat(3) 403 xx = xx.array.repeat(3) 404 if len(xx) == 0: 405 pass # shortcircuit any modifications 406 elif shape == "step": 407 yy = np.hstack([yy[1:], nan]) # roll leftward 408 elif shape == "dirac": 409 nonlocal nDirac 410 axW = np.diff(ax.get_xlim()) 411 yy[0::3] = False # set datapoin to 0 412 xx[2::3] = nan # make datapoint disappear 413 xx += nDirac * axW / 100 # offset datapoint horizontally 414 nDirac += 1 415 return xx, yy 416 417 nDirac = 1 418 for _name, ln in lines.items(): 419 ln["handle"].set_data(*bend_into(ln["shape"], ln["tt"], ln["data"])) 420 421 def finalize_init(ax, lines, mm): 422 # Rm lines that only contain NaNs 423 for name in list(lines): 424 ln = lines[name] 425 stat = deep_getattr(stats, name) 426 if not stat.were_changed: 427 ln["handle"].remove() # rm from axes 428 del lines[name] # rm from dict 429 # Add legends 430 if lines: 431 ax.legend(loc="upper left", bbox_to_anchor=(1.01, 1), borderaxespad=0) 432 if mm: 433 ax.annotate( 434 star + ": mean of\nmarginals", 435 xy=(0, -1.5 / len(lines)), 436 xycoords=ax.get_legend().get_frame(), 437 bbox=dict(alpha=0.0), 438 fontsize="small", 439 ) 440 # coz placement of annotate needs flush sometimes: 441 plot_pause(0.01) 442 443 # Insert current stats 444 for lines, ax in zip(self.d, self.axs): 445 update_arrays(lines) 446 update_plot_data(ax, lines) 447 448 # Set x-limits (time) 449 sliding_xlim(ax0, self.d[0]["err.rms"]["tt"], self.T_lag, margin=True) 450 self.baseline0.set_xdata(ax0.get_xlim()) 451 452 # Set y-limits 453 data0 = [ln["data"].array for ln in self.d[0].values()] 454 data1 = [ln["data"].array for ln in self.d[1].values()] 455 ax0.set_ylim(0, d_ylim(data0, ax0, cC=0.2, cE=0.9)[1]) 456 ax1.set_ylim(*d_ylim(data1, ax1, Max=4, Min=-4, cC=0.3, cE=0.9)) 457 458 # Init legend. Rm nan lines. 459 if self.init_incomplete and "a" == faus: 460 self.init_incomplete = False 461 finalize_init(ax0, self.d[0], False) 462 finalize_init(ax1, self.d[1], True)
Plots a sliding window (like a heart rate monitor) of certain diagnostics.
278 def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs): 279 # STYLE TABLES - Defines which/how diagnostics get plotted 280 styles = {} 281 282 def lin(a, b): 283 return lambda x: a + b * x 284 285 divN = 1 / getattr(stats.xp, "N", 99) 286 # Columns: transf, shape, plt kwargs 287 styles["RMS"] = { 288 "err.rms": [None, None, dict(c="k", label="Error")], 289 "spread.rms": [None, None, dict(c="b", label="Spread", alpha=0.6)], 290 } 291 styles["Values"] = { 292 "skew": [None, None, dict(c="g", label=star + r"Skew/$\sigma^3$")], 293 "kurt": [None, None, dict(c="r", label=star + r"Kurt$/\sigma^4{-}3$")], 294 "trHK": [None, None, dict(c="k", label=star + "HK")], 295 "infl": [lin(-10, 10), "step", dict(c="c", label="10(infl-1)")], 296 "N_eff": [lin(0, divN), "dirac", dict(c="y", label="N_eff/N", lw=3)], 297 "iters": [lin(0, 0.1), "dirac", dict(c="m", label="iters/10")], 298 "resmpl": [None, "dirac", dict(c="k", label="resampled?")], 299 } 300 301 nAx = len(styles) 302 GS = {"left": 0.125, "right": 0.76} 303 fig, axs = place.freshfig( 304 fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS 305 ) 306 307 axs[0].set_title("Diagnostics") 308 for style, ax in zip(styles, axs): 309 ax.set_ylabel(style) 310 ax.set_xlabel("Time (t)") 311 place_ax.adjust_position(ax, y0=0.03) 312 313 self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq) 314 315 def init_ax(ax, style_table): 316 lines = {} 317 for name in style_table: 318 # SKIP -- if stats[name] is not in existence 319 # Note: The nan check/deletion comes after the first ko. 320 try: 321 stat = deep_getattr(stats, name) 322 except AttributeError: 323 continue 324 # try: val0 = stat[key0[0]] 325 # except KeyError: continue 326 # PS: recall (from series.py) that even if store_u is false, stat[k] is 327 # still present if liveplots=True via the k_tmp functionality. 328 329 # Unpack style 330 ln = {} 331 ln["transf"] = style_table[name][0] or (lambda x: x) 332 ln["shape"] = style_table[name][1] 333 ln["plt"] = style_table[name][2] 334 335 # Create series 336 if isinstance(stat, FAUSt): 337 ln["plot_u"] = plot_u 338 K_plot = comp_K_plot(K_lag, a_lag, ln["plot_u"]) 339 else: 340 ln["plot_u"] = False 341 K_plot = a_lag 342 ln["data"] = RollingArray(K_plot) 343 ln["tt"] = RollingArray(K_plot) 344 345 # Plot (init) 346 (ln["handle"],) = ax.plot(ln["tt"], ln["data"], **ln["plt"]) 347 348 # Plotting only nans yield ugly limits. Revert to defaults. 349 ax.set_xlim(0, 1) 350 ax.set_ylim(0, 1) 351 352 lines[name] = ln 353 return lines 354 355 # Plot 356 self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)] 357 358 # Horizontal line at y=0 359 (self.baseline0,) = ax.plot( 360 ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label="_nolegend_" 361 ) 362 363 # Store 364 self.axs = axs 365 self.stats = stats 366 self.init_incomplete = True
465def sliding_xlim(ax, tt, lag, margin=False): 466 dt = lag / 20 if margin else 0 467 if tt.nFilled == 0: 468 return # Quit 469 t1, t2 = tt.span() # Get suggested span. 470 s1, s2 = ax.get_xlim() # Get previous lims. 471 # If zero span (eg tt holds single 'f' and 'a'): 472 if t1 == t2: 473 t1 -= 1 # add width 474 t2 += 1 # add width 475 # If user has skipped (too much): 476 elif np.isnan(t1): 477 s2 -= dt # Correct for dt. 478 span = s2 - s1 # Compute previous span 479 # If span<lag: 480 if span < lag: 481 span += t2 - s2 # Grow by "dt". 482 span = min(lag, span) # Bound 483 t1 = t2 - span # Set span. 484 ax.set_xlim(t1, t2 + dt) # Set xlim to span
487class weight_histogram: 488 """Plots histogram of weights. Refreshed each analysis.""" 489 490 def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): 491 if not hasattr(stats, "w"): 492 self.is_active = False 493 return 494 fig, ax = place.freshfig(fignum, figsize=(7, 3), gridspec_kw={"bottom": 0.15}) 495 496 ax.set_xscale("log") 497 ax.set_xlabel("Weigth") 498 ax.set_ylabel("Count") 499 self.stats = stats 500 self.ax = ax 501 self.hist = [] 502 self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31)) 503 504 def __call__(self, key, E, P): 505 k, ko, faus = key 506 if "a" == faus: 507 w = self.stats.w[key] 508 N = len(w) 509 ax = self.ax 510 511 self.is_active = N < 10001 512 if not self.is_active: 513 not_available_text(ax, "Not computed (N > threshold)") 514 return 515 516 counted = w > self.bins[0] 517 _ = [b.remove() for b in self.hist] 518 nn, _, self.hist = ax.hist(w[counted], bins=self.bins, color="b") 519 ax.set_ylim(top=max(nn)) 520 521 ax.set_title( 522 f"N: {N:d}. N_eff: {1/(w@w):.4g}." 523 " Not shown: {N-np.sum(counted):d}. " 524 )
Plots histogram of weights. Refreshed each analysis.
490 def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): 491 if not hasattr(stats, "w"): 492 self.is_active = False 493 return 494 fig, ax = place.freshfig(fignum, figsize=(7, 3), gridspec_kw={"bottom": 0.15}) 495 496 ax.set_xscale("log") 497 ax.set_xlabel("Weigth") 498 ax.set_ylabel("Count") 499 self.stats = stats 500 self.ax = ax 501 self.hist = [] 502 self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
527class spectral_errors: 528 """Plots the (spatial-RMS) error as a functional of the SVD index.""" 529 530 def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): 531 fig, ax = place.freshfig(fignum, figsize=(6, 3)) 532 ax.set_xlabel("Sing. value index") 533 ax.set_yscale("log") 534 self.init_incomplete = True 535 self.ax = ax 536 self.plot_u = plot_u 537 538 try: 539 self.msft = stats.umisf 540 self.sprd = stats.svals 541 except AttributeError: 542 self.is_active = False 543 not_available_text(ax, "Spectral stats not being computed") 544 545 # Update plot 546 def __call__(self, key, E, P): 547 k, ko, faus = key 548 ax = self.ax 549 if self.init_incomplete: 550 if self.plot_u or "f" == faus: 551 self.init_incomplete = False 552 msft = abs(self.msft[key]) 553 sprd = self.sprd[key] 554 if np.any(np.isinf(msft)): 555 not_available_text(ax, "Spectral stats not finite") 556 self.is_active = False 557 else: 558 (self.line_msft,) = ax.plot(msft, "k", lw=2, label="Error") 559 (self.line_sprd,) = ax.plot( 560 sprd, "b", lw=2, label="Spread", alpha=0.9 561 ) 562 ax.get_xaxis().set_major_locator(MaxNLocator(integer=True)) 563 ax.legend() 564 else: 565 msft = abs(self.msft[key]) 566 sprd = self.sprd[key] 567 self.line_sprd.set_ydata(sprd) 568 self.line_msft.set_ydata(msft) 569 # ax.set_ylim(*d_ylim(msft)) 570 # ax.set_ylim(bottom=1e-5) 571 ax.set_ylim([1e-3, 1e1])
Plots the (spatial-RMS) error as a functional of the SVD index.
530 def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): 531 fig, ax = place.freshfig(fignum, figsize=(6, 3)) 532 ax.set_xlabel("Sing. value index") 533 ax.set_yscale("log") 534 self.init_incomplete = True 535 self.ax = ax 536 self.plot_u = plot_u 537 538 try: 539 self.msft = stats.umisf 540 self.sprd = stats.svals 541 except AttributeError: 542 self.is_active = False 543 not_available_text(ax, "Spectral stats not being computed")
574class correlations: 575 """Plots the state (auto-)correlation matrix.""" 576 577 half = True # Whether to show half/full (symmetric) corr matrix. 578 579 def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): 580 GS = {"height_ratios": [4, 1], "hspace": 0.09, "top": 0.95} 581 fig, (ax, ax2) = place.freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS) 582 583 if E is None and np.isnan(P.diag if isinstance(P, CovMat) else P).all(): 584 not_available_text( 585 ax, ("Not available in replays" "\ncoz full Ens/Cov not stored.") 586 ) 587 self.is_active = False 588 return 589 590 Nx = len(stats.mu[key0]) 591 if Nx <= 1003: 592 C = np.eye(Nx) 593 # Mask half 594 mask = np.zeros_like(C, dtype=bool) 595 mask[np.tril_indices_from(mask)] = True 596 cmap = plt.get_cmap("RdBu_r") 597 VM = 1.0 # abs(np.percentile(C,[1,99])).max() 598 im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM) 599 # Colorbar 600 _ = ax.figure.colorbar(im, ax=ax, shrink=0.8) 601 # Tune plot 602 plt.box(False) 603 ax.set_facecolor("w") 604 ax.grid(False) 605 ax.set_title("State correlation matrix:", y=1.07) 606 ax.xaxis.tick_top() 607 608 # ax2 = inset_axes(ax,width="30%",height="60%",loc=3) 609 (line_AC,) = ax2.plot(arange(Nx), ones(Nx), label="Correlation") 610 (line_AA,) = ax2.plot(arange(Nx), ones(Nx), label="Abs. corr.") 611 _ = ax2.hlines(0, 0, Nx - 1, "k", "dotted", lw=1) 612 # Align ax2 with ax 613 bb_AC = ax2.get_position() 614 bb_C = ax.get_position() 615 ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height]) 616 # Tune plot 617 ax2.set_title("Auto-correlation:") 618 ax2.set_ylabel("Mean value") 619 ax2.set_xlabel("Distance (in state indices)") 620 ax2.set_xticklabels([]) 621 ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]])) 622 ax2.set_ylim(top=1) 623 ax2.legend( 624 frameon=True, 625 facecolor="w", 626 bbox_to_anchor=(1, 1), 627 loc="upper left", 628 borderaxespad=0.02, 629 ) 630 631 self.ax = ax 632 self.ax2 = ax2 633 self.im = im 634 self.line_AC = line_AC 635 self.line_AA = line_AA 636 self.mask = mask 637 if hasattr(stats, "w"): 638 self.w = stats.w 639 else: 640 not_available_text(ax) 641 642 # Update plot 643 def __call__(self, key, E, P): 644 # Get cov matrix 645 if E is not None: 646 if hasattr(self, "w"): 647 C = np.cov(E, rowvar=False, aweights=self.w[key]) 648 else: 649 C = np.cov(E, rowvar=False) 650 else: 651 assert P is not None 652 C = P.full if isinstance(P, CovMat) else P 653 C = C.copy() 654 # Compute corr from cov 655 std = np.sqrt(np.diag(C)) 656 C /= std[:, None] 657 C /= std[None, :] 658 # Mask 659 if self.half: 660 C = np.ma.masked_where(self.mask, C) 661 # Plot 662 self.im.set_data(C) 663 # Auto-corr function 664 ACF = circulant_ACF(C) 665 AAF = circulant_ACF(C, do_abs=True) 666 self.line_AC.set_ydata(ACF) 667 self.line_AA.set_ydata(AAF)
Plots the state (auto-)correlation matrix.
579 def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): 580 GS = {"height_ratios": [4, 1], "hspace": 0.09, "top": 0.95} 581 fig, (ax, ax2) = place.freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS) 582 583 if E is None and np.isnan(P.diag if isinstance(P, CovMat) else P).all(): 584 not_available_text( 585 ax, ("Not available in replays" "\ncoz full Ens/Cov not stored.") 586 ) 587 self.is_active = False 588 return 589 590 Nx = len(stats.mu[key0]) 591 if Nx <= 1003: 592 C = np.eye(Nx) 593 # Mask half 594 mask = np.zeros_like(C, dtype=bool) 595 mask[np.tril_indices_from(mask)] = True 596 cmap = plt.get_cmap("RdBu_r") 597 VM = 1.0 # abs(np.percentile(C,[1,99])).max() 598 im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM) 599 # Colorbar 600 _ = ax.figure.colorbar(im, ax=ax, shrink=0.8) 601 # Tune plot 602 plt.box(False) 603 ax.set_facecolor("w") 604 ax.grid(False) 605 ax.set_title("State correlation matrix:", y=1.07) 606 ax.xaxis.tick_top() 607 608 # ax2 = inset_axes(ax,width="30%",height="60%",loc=3) 609 (line_AC,) = ax2.plot(arange(Nx), ones(Nx), label="Correlation") 610 (line_AA,) = ax2.plot(arange(Nx), ones(Nx), label="Abs. corr.") 611 _ = ax2.hlines(0, 0, Nx - 1, "k", "dotted", lw=1) 612 # Align ax2 with ax 613 bb_AC = ax2.get_position() 614 bb_C = ax.get_position() 615 ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height]) 616 # Tune plot 617 ax2.set_title("Auto-correlation:") 618 ax2.set_ylabel("Mean value") 619 ax2.set_xlabel("Distance (in state indices)") 620 ax2.set_xticklabels([]) 621 ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]])) 622 ax2.set_ylim(top=1) 623 ax2.legend( 624 frameon=True, 625 facecolor="w", 626 bbox_to_anchor=(1, 1), 627 loc="upper left", 628 borderaxespad=0.02, 629 ) 630 631 self.ax = ax 632 self.ax2 = ax2 633 self.im = im 634 self.line_AC = line_AC 635 self.line_AA = line_AA 636 self.mask = mask 637 if hasattr(stats, "w"): 638 self.w = stats.w 639 else: 640 not_available_text(ax)
670def circulant_ACF(C, do_abs=False): 671 """Compute the auto-covariance-function corresponding to `C`. 672 673 This assumes it is the cov/corr matrix of a 1D periodic domain. 674 675 Vectorized or FFT implementations are 676 [possible](https://stackoverflow.com/questions/20360675). 677 """ 678 M = len(C) 679 # cols = np.flipud(sla.circulant(np.arange(M)[::-1])) 680 cols = sla.circulant(np.arange(M)) 681 ACF = np.zeros(M) 682 for i in range(M): 683 row = C[i, cols[i]] 684 if do_abs: 685 row = abs(row) 686 ACF += row 687 # Note: this actually also accesses masked values in C. 688 return ACF / M
Compute the auto-covariance-function corresponding to C
.
This assumes it is the cov/corr matrix of a 1D periodic domain.
Vectorized or FFT implementations are possible.
691def sliding_marginals( 692 obs_inds=(), 693 dims=(), 694 labels=(), 695 Tplot=None, 696 ens_props=dict(alpha=0.4), # noqa 697 zoomy=1.0, 698): 699 # Store parameters 700 params_orig = DotDict(**locals()) 701 702 def init(fignum, stats, key0, plot_u, E, P, **kwargs): 703 xx, yy, mu, spread, tseq = ( 704 stats.xx, 705 stats.yy, 706 stats.mu, 707 stats.spread, 708 stats.HMM.tseq, 709 ) 710 711 # Set parameters (kwargs takes precedence over params_orig) 712 p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) 713 714 # Chose marginal dims to plot 715 if not len(p.dims): 716 p.dims = linspace_int(xx.shape[-1], min(10, xx.shape[-1])) 717 718 # Lag settings: 719 T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq) 720 K_plot = comp_K_plot(K_lag, a_lag, plot_u) 721 # Extend K_plot forther for adding blanks in resampling (PartFilt): 722 has_w = hasattr(stats, "w") 723 if has_w: 724 K_plot += a_lag 725 726 # Set up figure, axes 727 fig, axs = place.freshfig( 728 fignum, figsize=(5, 7), squeeze=False, nrows=len(p.dims), sharex=True 729 ) 730 axs = axs.reshape(len(p.dims)) 731 732 # Tune plots 733 axs[0].set_title("Marginal time series") 734 for ix, (m, ax) in enumerate(zip(p.dims, axs)): 735 # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy)) 736 if not p.labels: 737 ax.set_ylabel("$x_{%d}$" % m) 738 else: 739 ax.set_ylabel(p.labels[ix]) 740 axs[-1].set_xlabel("Time (t)") 741 742 plot_pause(0.05) 743 plt.tight_layout() 744 745 # Allocate 746 d = DotDict() # data arrays 747 h = DotDict() # plot handles 748 # Why "if True" ? Just to indent the rest of the line... 749 if True: 750 d.t = RollingArray((K_plot,)) 751 if True: 752 d.x = RollingArray((K_plot, len(p.dims))) 753 h.x = [] 754 if not_empty(p.obs_inds): 755 d.y = RollingArray((K_plot, len(p.dims))) 756 h.y = [] 757 if E is not None: 758 d.E = RollingArray((K_plot, len(E), len(p.dims))) 759 h.E = [] 760 if P is not None: 761 d.mu = RollingArray((K_plot, len(p.dims))) 762 h.mu = [] 763 if P is not None: 764 d.s = RollingArray((K_plot, 2, len(p.dims))) 765 h.s = [] 766 767 # Plot (invisible coz everything here is nan, for the moment). 768 for ix, ax in zip(p.dims, axs): 769 if True: 770 h.x += ax.plot(d.t, d.x[:, ix], "k") 771 if not_empty(p.obs_inds): 772 h.y += ax.plot(d.t, d.y[:, ix], "g*", ms=10) 773 if "E" in d: 774 h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)] 775 if "mu" in d: 776 h.mu += ax.plot(d.t, d.mu[:, ix], "b") 777 if "s" in d: 778 h.s += [ax.plot(d.t, d.s[:, :, ix], "b--", lw=1)] 779 780 def update(key, E, P): 781 k, ko, faus = key 782 783 EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w) 784 785 # Roll data array 786 ind = k if plot_u else ko 787 for Ens in EE: # If E is duplicated, so must the others be. 788 if "E" in d: 789 d.E.insert(ind, Ens) 790 if "mu" in d: 791 d.mu.insert(ind, mu[key][p.dims]) 792 if "s" in d: 793 d.s.insert(ind, mu[key][p.dims] + [[1], [-1]] * spread[key][p.dims]) 794 if True: 795 d.t.insert(ind, tseq.tt[k]) 796 if not_empty(p.obs_inds): 797 xy = nan * ones(len(p.dims)) 798 if ko is not None: 799 jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds 800 xy[jj] = yy[ko] 801 d.y.insert(ind, xy) 802 if True: 803 d.x.insert(ind, xx[k, p.dims]) 804 805 # Update graphs 806 for ix, ax in zip(p.dims, axs): 807 sliding_xlim(ax, d.t, T_lag, True) 808 if True: 809 h.x[ix].set_data(d.t, d.x[:, ix]) 810 if not_empty(p.obs_inds): 811 h.y[ix].set_data(d.t, d.y[:, ix]) 812 if "mu" in d: 813 h.mu[ix].set_data(d.t, d.mu[:, ix]) 814 if "s" in d: 815 [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]] 816 if "E" in d: 817 [h.E[ix][n].set_data(d.t, d.E[:, n, ix]) for n in range(len(E))] 818 if "E" in d: 819 update_alpha(key, stats, h.E[ix]) 820 821 # TODO 3: fixup. This might be slow? 822 # In any case, it is very far from tested. 823 # Also, relim'iting all of the time is distracting. 824 # Use d_ylim? 825 if "E" in d: 826 lims = d.E 827 elif "mu" in d: 828 lims = d.mu 829 lims = np.array(viz.xtrema(lims[..., ix])) 830 if lims[0] == lims[1]: 831 lims += [-0.5, +0.5] 832 ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy)) 833 834 return 835 836 return update 837 838 return init
841def phase_particles( 842 is_3d=True, 843 obs_inds=(), 844 dims=(), 845 labels=(), 846 Tplot=None, 847 ens_props=dict(alpha=0.4), # noqa 848 zoom=1.5, 849): 850 # Store parameters 851 params_orig = DotDict(**locals()) 852 853 M = 3 if is_3d else 2 854 855 def init(fignum, stats, key0, plot_u, E, P, **kwargs): 856 xx, yy, mu, _, tseq = stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq 857 858 # Set parameters (kwargs takes precedence over params_orig) 859 p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) 860 861 # Lag settings: 862 has_w = hasattr(stats, "w") 863 if p.Tplot == 0: 864 K_plot = 1 865 else: 866 T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq) 867 K_plot = comp_K_plot(K_lag, a_lag, plot_u) 868 # Extend K_plot forther for adding blanks in resampling (PartFilt): 869 if has_w: 870 K_plot += a_lag 871 872 # Dimension settings 873 if not p.dims: 874 p.dims = arange(M) 875 if not p.labels: 876 p.labels = ["$x_%d$" % d for d in p.dims] 877 assert len(p.dims) == M 878 879 # Set up figure, axes 880 fig, _ = place.freshfig(fignum, figsize=(5, 5)) 881 ax = plt.subplot(111, projection="3d" if is_3d else None) 882 ax.set_facecolor("w") 883 ax.set_title("Phase space trajectories") 884 # Tune plot 885 for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")): 886 viz.set_ilim(ax, ind, *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom)) 887 eval(f"ax.set_{t}label('{s!s}')") 888 889 # Allocate 890 d = DotDict() # data arrays 891 h = DotDict() # plot handles 892 s = DotDict() # scatter handles 893 if E is not None: 894 d.E = RollingArray((K_plot, len(E), M)) 895 h.E = [] 896 if P is not None: 897 d.mu = RollingArray((K_plot, M)) 898 if True: 899 d.x = RollingArray((K_plot, M)) 900 if not_empty(p.obs_inds): 901 d.y = RollingArray((K_plot, M)) 902 903 # Plot tails (invisible coz everything here is nan, for the moment). 904 if "E" in d: 905 h.E += [ 906 ax.plot(*xn, **p.ens_props)[0] for xn in np.transpose(d.E, [1, 2, 0]) 907 ] 908 if "mu" in d: 909 h.mu = ax.plot(*d.mu.T, "b", lw=2)[0] 910 if True: 911 h.x = ax.plot(*d.x.T, "k", lw=3)[0] 912 if "y" in d: 913 h.y = ax.plot(*d.y.T, "g*", ms=14)[0] 914 915 # Scatter. NB: don't init with nan's coz it's buggy 916 # (wrt. get_color() and _offsets3d) since mpl 3.1. 917 if "E" in d: 918 s.E = ax.scatter(*E.T[p.dims], s=3**2, c=[hn.get_color() for hn in h.E]) 919 if "mu" in d: 920 s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()]) 921 if True: 922 s.x = ax.scatter( 923 *ones(M), s=14**2, c=[h.x.get_color()], marker=(5, 1), zorder=99 924 ) 925 926 def update(key, E, P): 927 k, ko, faus = key 928 929 def update_tail(handle, newdata): 930 handle.set_data(newdata[:, 0], newdata[:, 1]) 931 if is_3d: 932 handle.set_3d_properties(newdata[:, 2]) 933 934 def update_sctr(handle, newdata): 935 if is_3d: 936 handle._offsets3d = juggle_axes(*newdata.T, "z") 937 else: 938 handle.set_offsets(newdata) 939 940 EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w) 941 942 # Roll data array 943 ind = k if plot_u else ko 944 for Ens in EE: # If E is duplicated, so must the others be. 945 if "E" in d: 946 d.E.insert(ind, Ens) 947 if True: 948 d.x.insert(ind, xx[k, p.dims]) 949 if "y" in d: 950 xy = nan * ones(len(p.dims)) 951 if ko is not None: 952 jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds 953 jj = list(jj) 954 for i, dim in enumerate(p.dims): 955 try: 956 iobs = jj.index(dim) 957 except ValueError: 958 pass 959 else: 960 xy[i] = yy[ko][iobs] 961 d.y.insert(ind, xy) 962 if "mu" in d: 963 d.mu.insert(ind, mu[key][p.dims]) 964 965 # Update graph 966 update_sctr(s.x, d.x[[-1]]) 967 update_tail(h.x, d.x) 968 if "y" in d: 969 update_tail(h.y, d.y) 970 if "mu" in d: 971 update_sctr(s.mu, d.mu[[-1]]) 972 update_tail(h.mu, d.mu) 973 else: 974 update_sctr(s.E, d.E[-1]) 975 for n in range(len(E)): 976 update_tail(h.E[n], d.E[:, n, :]) 977 update_alpha(key, stats, h.E, s.E) 978 979 return 980 981 return update 982 983 return init
986def validate_lag(Tplot, tseq): 987 """Return validated `T_lag` such that is is: 988 989 - equal to `Tplot` with fallback: `HMM.tseq.Tplot`. 990 - no longer than `HMM.tseq.T`. 991 992 Also return corresponding `K_lag`, `a_lag`. 993 """ 994 # Defaults 995 if Tplot is None: 996 Tplot = tseq.Tplot 997 998 # Rename 999 T_lag = Tplot 1000 1001 assert T_lag >= 0 1002 1003 # Validate T_lag 1004 t2 = tseq.tt[-1] 1005 t1 = max(tseq.tt[0], t2 - T_lag) 1006 T_lag = t2 - t1 1007 1008 K_lag = int(T_lag / tseq.dt) + 1 # Lag in indices 1009 a_lag = K_lag // tseq.dko + 1 # Lag in obs indices 1010 1011 return T_lag, K_lag, a_lag
Return validated T_lag
such that is is:
- equal to
Tplot
with fallback:HMM.tseq.Tplot
. - no longer than
HMM.tseq.T
.
Also return corresponding K_lag
, a_lag
.
1021def update_alpha(key, stats, lines, scatters=None): 1022 """Adjust color alpha (for particle filters).""" 1023 k, ko, faus = key 1024 if ko is None: 1025 return 1026 if faus == "f": 1027 return 1028 if not hasattr(stats, "w"): 1029 return 1030 1031 # Compute alpha values 1032 w = stats.w[key] 1033 alpha = (w / w.max()).clip(0.1, 0.4) 1034 1035 # Set line alpha 1036 for line, a in zip(lines, alpha): 1037 line.set_alpha(a) 1038 1039 # Scatter plot does not have alpha. => Fake it. 1040 if scatters is not None: 1041 colors = scatters.get_facecolor()[:, :3] 1042 if len(colors) == 1: 1043 colors = colors.repeat(len(w), axis=0) 1044 scatters.set_color(np.hstack([colors, alpha[:, None]]))
Adjust color alpha (for particle filters).
1047def not_empty(xx): 1048 """Works for non-iterable and iterables (including ndarrays).""" 1049 try: 1050 return len(xx) > 0 1051 except TypeError: 1052 return bool(xx)
Works for non-iterable and iterables (including ndarrays).
1055def duplicate_with_blanks_for_resampled(E, dims, key, has_w): 1056 """Particle filter: insert breaks for resampled particles.""" 1057 if E is None: 1058 return [E] 1059 EE = [] 1060 E = E[:, dims] 1061 if has_w: 1062 k, ko, faus = key 1063 if faus == "f": 1064 pass 1065 elif faus == "a": 1066 _Ea[0] = E[:, 0] # Store (1st dim of) ens. 1067 elif faus == "u" and ko is not None: 1068 # Find resampled particles. Insert duplicate ensemble. Write nans (breaks). 1069 resampled = _Ea[0] != E[:, 0] # Mark as resampled if ens changed. 1070 # Insert current ensemble (copy to avoid overwriting). 1071 EE.append(E.copy()) 1072 EE[0][resampled] = nan # Write breaks 1073 # Always: append current ensemble 1074 EE.append(E) 1075 return EE
Particle filter: insert breaks for resampled particles.
1081def d_ylim(data, ax=None, cC=0, cE=1, pp=(1, 99), Min=-1e20, Max=+1e20): 1082 """Provide new ylim's intelligently, from percentiles of the data. 1083 1084 - `data`: iterable of arrays for computing percentiles. 1085 - `pp`: percentiles 1086 1087 - `ax`: If present, then the delta_zoom in/out is also considered. 1088 1089 - `cE`: exansion (widenting) rate ∈ [0,1]. 1090 Default: 1, which immediately expands to percentile. 1091 - `cC`: compression (narrowing) rate ∈ [0,1]. 1092 Default: 0, which does not allow compression. 1093 1094 - `Min`/`Max`: bounds 1095 1096 Despite being a little involved, 1097 the cost of this subroutine is typically not substantial 1098 because there's usually not that much data to sort through. 1099 """ 1100 # Find "reasonable" limits (by percentiles), looping over data 1101 maxv = minv = -np.inf # init 1102 for d in data: 1103 d = d[np.isfinite(d)] 1104 if len(d): 1105 perc = np.array([-1, 1]) * np.percentile(d, pp) 1106 minv, maxv = np.maximum([minv, maxv], perc) 1107 minv *= -1 1108 1109 # Pry apart equal values 1110 if np.isclose(minv, maxv): 1111 maxv += 0.5 1112 minv -= 0.5 1113 1114 # Make the zooming transition smooth 1115 if ax is not None: 1116 current = ax.get_ylim() 1117 # Set rate factor as compress or expand factor. 1118 c0 = cC if minv > current[0] else cE 1119 c1 = cC if maxv < current[1] else cE 1120 # Adjust 1121 minv = np.interp(c0, (0, 1), (current[0], minv)) 1122 maxv = np.interp(c1, (0, 1), (current[1], maxv)) 1123 1124 # Bounds 1125 maxv = min(Max, maxv) 1126 minv = max(Min, minv) 1127 1128 # Set (if anything's changed) 1129 def worth_updating(a, b, curr): 1130 # Note: should depend on cC and cE 1131 d = abs(curr[1] - curr[0]) 1132 lower = abs(a - curr[0]) > 0.002 * d 1133 upper = abs(b - curr[1]) > 0.002 * d 1134 return lower and upper 1135 1136 # if worth_updating(minv,maxv,current): 1137 # ax.set_ylim(minv,maxv) 1138 1139 # Some mpl versions don't handle inf limits. 1140 if not np.isfinite(minv): 1141 minv = None 1142 if not np.isfinite(maxv): 1143 maxv = None 1144 1145 return minv, maxv
Provide new ylim's intelligently, from percentiles of the data.
data
: iterable of arrays for computing percentiles.pp
: percentilesax
: If present, then the delta_zoom in/out is also considered.cE
: exansion (widenting) rate ∈ [0,1]. Default: 1, which immediately expands to percentile.cC
: compression (narrowing) rate ∈ [0,1]. Default: 0, which does not allow compression.
Min
/Max
: bounds
Despite being a little involved, the cost of this subroutine is typically not substantial because there's usually not that much data to sort through.
1148def spatial1d( 1149 obs_inds=(), 1150 periodicity=None, 1151 dims=(), 1152 ens_props={"color": "b", "alpha": 0.1}, # noqa 1153 conf_mult=None, 1154): 1155 # Store parameters 1156 params_orig = DotDict(**locals()) 1157 1158 def init(fignum, stats, key0, plot_u, E, P, **kwargs): 1159 xx, yy, mu = stats.xx, stats.yy, stats.mu 1160 1161 # Set parameters (kwargs takes precedence over params_orig) 1162 p = DotDict(**{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) 1163 1164 if not p.dims: 1165 M = xx.shape[-1] 1166 p.dims = arange(M) 1167 else: 1168 M = len(p.dims) 1169 1170 # Make periodic wrapper 1171 ii, wrap = viz.setup_wrapping(M, p.periodicity) 1172 1173 # Set up figure, axes 1174 fig, ax = place.freshfig(fignum, figsize=(8, 5)) 1175 fig.suptitle("1d amplitude plot") 1176 1177 # Nans 1178 nan1 = wrap(nan * ones(M)) 1179 1180 if E is None and p.conf_mult is None: 1181 p.conf_mult = 2 1182 1183 # Init plots 1184 if p.conf_mult: 1185 lines_s = ax.plot( 1186 ii, nan1, "b-", lw=1, label=(str(p.conf_mult) + r"$\sigma$ conf") 1187 ) 1188 lines_s += ax.plot(ii, nan1, "b-", lw=1) 1189 (line_mu,) = ax.plot(ii, nan1, "b-", lw=2, label="DA mean") 1190 else: 1191 nanE = nan * ones((stats.xp.N, M)) 1192 lines_E = ax.plot(ii, wrap(nanE[0]), **p.ens_props, lw=1, label="Ensemble") 1193 lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1) 1194 # Truth, Obs 1195 (line_x,) = ax.plot(ii, nan1, "k-", lw=3, label="Truth") 1196 if not_empty(p.obs_inds): 1197 (line_y,) = ax.plot(ii, nan1, "g*", ms=5, label="Obs") 1198 1199 # Tune plot 1200 ax.set_ylim(*viz.xtrema(xx)) 1201 ax.set_xlim(viz.stretch(ii[0], ii[-1], 1)) 1202 # Xticks 1203 xt = ax.get_xticks() 1204 xt = xt[abs(xt % 1) < 0.01].astype(int) # Keep only the integer ticks 1205 xt = xt[xt >= 0] 1206 xt = xt[xt < len(p.dims)] 1207 ax.set_xticks(xt) 1208 ax.set_xticklabels(p.dims[xt]) 1209 1210 ax.set_xlabel("State index") 1211 ax.set_ylabel("Value") 1212 ax.legend(loc="upper right") 1213 1214 text_t = ax.text( 1215 0.01, 1216 0.01, 1217 format_time(None, None, None), 1218 transform=ax.transAxes, 1219 family="monospace", 1220 ha="left", 1221 ) 1222 1223 def update(key, E, P): 1224 k, ko, faus = key 1225 1226 if p.conf_mult: 1227 sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]] 1228 lines_s[0].set_ydata(wrap(sigma[0, p.dims])) 1229 lines_s[1].set_ydata(wrap(sigma[1, p.dims])) 1230 line_mu.set_ydata(wrap(mu[key][p.dims])) 1231 else: 1232 for n, line in enumerate(lines_E): 1233 line.set_ydata(wrap(E[n, p.dims])) 1234 update_alpha(key, stats, lines_E) 1235 1236 line_x.set_ydata(wrap(xx[k, p.dims])) 1237 1238 text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k])) 1239 1240 if "f" in faus: 1241 if not_empty(p.obs_inds): 1242 xy = nan * ones(len(xx[0])) 1243 jj = p.obs_inds(ko) if callable(p.obs_inds) else p.obs_inds 1244 xy[jj] = yy[ko] 1245 line_y.set_ydata(wrap(xy[p.dims])) 1246 line_y.set_zorder(5) 1247 line_y.set_visible(True) 1248 1249 if "u" in faus: 1250 if not_empty(p.obs_inds): 1251 line_y.set_visible(False) 1252 1253 return 1254 1255 return update 1256 1257 return init
1260def spatial2d( 1261 square, 1262 ind2sub, 1263 obs_inds=(), 1264 cm=plt.cm.jet, 1265 clims=((-40, 40), (-40, 40), (-10, 10), (-10, 10)), 1266): 1267 def init(fignum, stats, key0, plot_u, E, P, **kwargs): 1268 GS = {"left": 0.125 - 0.04, "right": 0.9 - 0.04} 1269 fig, axs = place.freshfig( 1270 fignum, 1271 figsize=(6, 6), 1272 nrows=2, 1273 ncols=2, 1274 sharex=True, 1275 sharey=True, 1276 gridspec_kw=GS, 1277 ) 1278 1279 for ax in axs.flatten(): 1280 ax.set_aspect("equal", "box") 1281 1282 ((ax_11, ax_12), (ax_21, ax_22)) = axs 1283 1284 ax_11.grid(color="w", linewidth=0.2) 1285 ax_12.grid(color="w", linewidth=0.2) 1286 ax_21.grid(color="k", linewidth=0.1) 1287 ax_22.grid(color="k", linewidth=0.1) 1288 1289 # Upper colorbar -- position relative to ax_12 1290 bb = ax_12.get_position() 1291 dy = 0.1 * bb.height 1292 ax_13 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy]) 1293 # Lower colorbar -- position relative to ax_22 1294 bb = ax_22.get_position() 1295 dy = 0.1 * bb.height 1296 ax_23 = fig.add_axes([bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy]) 1297 1298 # Extract data arrays 1299 xx, _, mu, spread, err = stats.xx, stats.yy, stats.mu, stats.spread, stats.err 1300 k = key0[0] 1301 tt = stats.HMM.tseq.tt 1302 1303 # Plot 1304 # - origin='lower' might get overturned by set_ylim() below. 1305 im_11 = ax_11.imshow(square(mu[key0]), cmap=cm) 1306 im_12 = ax_12.imshow(square(xx[k]), cmap=cm) 1307 # hot is better, but needs +1 colorbar 1308 im_21 = ax_21.imshow(square(spread[key0]), cmap=plt.cm.bwr) 1309 im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr) 1310 ims = (im_11, im_12, im_21, im_22) 1311 # Obs init -- a list where item 0 is the handle of something invisible. 1312 lh = list(ax_12.plot(0, 0)[0:1]) 1313 1314 sx = "$\\psi$" 1315 ax_11.set_title("mean " + sx) 1316 ax_12.set_title("true " + sx) 1317 ax_21.set_title("spread. " + sx) 1318 ax_22.set_title("err. " + sx) 1319 1320 # TODO 7 1321 # for ax in axs.flatten(): 1322 # Crop boundries (which should be 0, i.e. yield harsh q gradients): 1323 # lims = (1, nx-2) 1324 # step = (nx - 1)/8 1325 # ticks = arange(step,nx-1,step) 1326 # ax.set_xlim (lims) 1327 # ax.set_ylim (lims[::-1]) 1328 # ax.set_xticks(ticks) 1329 # ax.set_yticks(ticks) 1330 1331 for im, clim in zip(ims, clims): 1332 im.set_clim(clim) 1333 1334 fig.colorbar(im_12, cax=ax_13) 1335 fig.colorbar(im_22, cax=ax_23) 1336 for ax in [ax_13, ax_23]: 1337 ax.yaxis.set_tick_params( 1338 "major", length=2, width=0.5, direction="in", left=True, right=True 1339 ) 1340 ax.set_axisbelow("line") # make ticks appear over colorbar patch 1341 1342 # Title 1343 title = "Streamfunction (" + sx + ")" 1344 fig.suptitle(title) 1345 # Time info 1346 text_t = ax_12.text( 1347 1, 1348 1.1, 1349 format_time(None, None, None), 1350 transform=ax_12.transAxes, 1351 family="monospace", 1352 ha="left", 1353 ) 1354 1355 def update(key, E, P): 1356 k, ko, faus = key 1357 t = tt[k] 1358 1359 im_11.set_data(square(mu[key])) 1360 im_12.set_data(square(xx[k])) 1361 im_21.set_data(square(spread[key])) 1362 im_22.set_data(square(err[key])) 1363 1364 # Remove previous obs 1365 try: 1366 lh[0].remove() 1367 except ValueError: 1368 pass 1369 # Plot current obs. 1370 # - plot() automatically adjusts to direction of y-axis in use. 1371 # - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse. 1372 1373 if ko is not None and not_empty(obs_inds): 1374 lh[0] = ax_12.plot(*ind2sub(obs_inds(ko))[::-1], "k.", ms=1, zorder=5)[ 1375 0 1376 ] 1377 1378 text_t.set_text(format_time(k, ko, t)) 1379 1380 return 1381 1382 return update 1383 1384 return init