Source code for ograph.swarm

from typing import Optional, Callable
import numpy as np
from .oplot import contour
from .oplot import Vec1D, Vec2D
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from matplotlib.collections import Collection
from collections import deque
from pathlib import Path


[docs] def plot_positions(data: Vec1D | Vec2D, objective: Optional[Callable[[float, float], float]] = None, *, override_region: tuple[tuple[float, float], tuple[float, float]] = None, override_margin: Optional[float] = None): """Plot a sequence of collections of points as they change over time. Args: data: Either a 2-D tensor or a 3-D tensor. * If :arg:`data` is 2-D: data[t] is a single point. Plot its change over time. * If :arg:`data` is 3-D: data[t] is a collection oof points. Plot the change of these points over time. objective: An objective function to be plotted as background. override_region: Only plot points that fall in this region. Has form ``((x_min, x_max), (y_min, y_max))``. override_margin: A small margin to add to override_region. Might make the plot prettier. """ mort = np.array(data) margin: float = 1 if override_margin is None\ else override_margin x_min: float x_max: float y_min: float y_max: float x_min = np.min(mort.T[0]) x_max = np.max(mort.T[0]) y_min = np.min(mort.T[1]) y_max = np.max(mort.T[1]) if (override_region is not None): x_min = override_region[0][0] x_max = override_region[0][1] y_min = override_region[1][0] y_max = override_region[1][1] mort =\ np.apply_along_axis(func1d=lambda arr: [ arr[0] if x_min < arr[0] < x_max else np.nan, arr[1] if y_min < arr[1] < y_max else np.nan], axis=2 if len(mort.shape) == 3 else 1, arr=mort) # Adjust margins. Matplotlib does not plot ticks that intersect # with borders of the graph; this trick makes plots prettier. x_min -= margin y_min -= margin x_max += margin y_max += margin if objective is not None: contour(fun=lambda x, y: objective(x, y), x_range=(x_min, x_max), y_range=(y_min, y_max)) # type: ignore for i in range(len(mort)): _xs: float _ys: float _xs, _ys = np.array(mort[i]).T _intensity: float = i / len(mort) plt.scatter(_xs, _ys, s=9, color=mpl.colormaps['RdYlGn'](_intensity), alpha=0.5) # type: ignore plt.xlim((x_min + margin, x_max - margin)) plt.ylim((y_min + margin, y_max - margin))
[docs] def plot_fitnesses( data: Vec2D, best_selector: Callable[[Vec2D], np.float64] = np.max, best_label: str = "Best value", all_label: str = "All values",) -> None: """Plot a sequence of numbers against the best one. The "best value" is plotted as a horizontal line. Args: data: A sequence of points to be plotted. Each data[i] should be a 2-tuple. best_selector: A :class:`Callable` that returns the best value in :arg:`data`. To plot a pre-determined, constant best value, let :arg:`best_selector` always return that value. best_label: Label for the best value. all_label: Label for other values. """ mort = data plt.vlines(x=range(len(mort)), ymin=np.repeat(a=np.min(mort), repeats=len(mort)), ymax=mort, zorder=-9, linewidth=0.8) plt.hlines(best_selector(mort), xmin=0, xmax=len(mort), color="tab:green", label=best_label) plt.plot([], [], color="tab:blue", label=all_label) plt.scatter(range(len(mort)), mort, s=12, color="tab:blue", facecolors="white",) plt.legend()
# def plot_fitnesses( # data: Vec2D, # fitnesses: Vec1D,) -> None: # """ # """ # assert len(data) == len(fitnesses),\ # "Fitness vector does not meet data length"\ # f" ({len(data)} != {len(fitnesses)})" # plt.xlabel("Generation") # plt.ylabel("Fitness") # mort = fitnesses # plt.scatter(range(len(mort)), # mort, # s=12, # color="tab:blue", # facecolors="white",) # plt.vlines(x=range(len(mort)), # ymin=np.repeat(a=data, repeats=len(mort)), # ymax=mort, # zorder=-9, # linewidth=0.8) # plt.hlines(fitnesses, # xmin=0, # xmax=len(mort), # color="tab:green", # label="Best fitness ever") # plt.plot([], # [], # color="tab:blue", # label="Best fitness in generation") # plt.legend()
[docs] def animate_positions( data: Vec1D | Vec2D, objective: Optional[Callable[[float, float], float]] = None, *, override_region: Optional[ tuple[tuple[float, float], tuple[float, float]]] = None, override_margin: Optional[float] = None, save_as: Path = "./animation.gif"): if override_region is not None: x_rect_lb: float = override_region[0][0] x_rect_ub: float = override_region[0][1] y_rect_lb: float = override_region[1][0] y_rect_ub: float = override_region[1][1] else: x_rect_lb, y_rect_lb =\ np.min(data, axis=(0, 1)) x_rect_ub, y_rect_ub =\ np.max(data, axis=(0, 1)) margin: float = 1 if override_margin is None\ else override_margin fig, ax = plt.subplots() ax.set_xlim(x_rect_lb, x_rect_ub) ax.set_ylim(y_rect_lb, y_rect_ub) #: Keeping type correctness at all cost. artists_to_remove: deque[Collection] = deque() if objective is not None: contour(objective, # type: ignore x_range=(x_rect_lb - margin, x_rect_ub + margin), y_range=(y_rect_lb - margin, y_rect_ub + margin)) def _animate_frame(i: int): while artists_to_remove: artists_to_remove.pop().remove() ax.set_xlim(x_rect_lb, x_rect_ub) ax.set_ylim(y_rect_lb, y_rect_ub) _xs, _ys = np.array(data[i]).T _intensity = i / len(data) new_artist = plt.scatter( _xs, _ys, s=1, color=mpl.colormaps['RdYlGn'](_intensity)) # type: ignore artists_to_remove.append(new_artist) return [new_artist] ani = FuncAnimation(fig, _animate_frame, frames=len(data), blit=True, # Supposedly "optimises drawing" repeat=False, interval=400, cache_frame_data=False) # See here https://matplotlib.org/stable/api/animation_api.html # for available writers. PillowWriter is chosen here arbitrarily. ani.save(save_as, dpi=300, writer=PillowWriter(fps=25))