Source code for pararealml.plot

from __future__ import annotations

import warnings
from typing import Optional, Callable, Union, List, Tuple

import matplotlib.pyplot as plt
import numpy as np

from matplotlib import cm
from matplotlib.animation import FuncAnimation
from matplotlib.cm import ScalarMappable
from matplotlib.collections import PathCollection
from matplotlib.colors import Colormap
from matplotlib.contour import ContourSet
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from matplotlib.quiver import Quiver
from matplotlib.streamplot import StreamplotSet
from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection

from pararealml.differential_equation import NBodyGravitationalEquation
from pararealml.mesh import Mesh, CoordinateSystem


[docs]class Plot: """ A base class for plots of the solutions of differential equations. """ def __init__(self, figure: Figure): """ :param figure: the figure of the plotted solution """ self._figure = figure
[docs] def show(self) -> Plot: """ Displays the plot. If there are any other instantiated and unclosed plot objects, invoking this method displays those plots as well. Invoking the :func:`~pararealml.plot.Plot.save` method after invoking this one results in undefined behaviour since the plot may get closed as a side effect of this method. :return: the plot object the method is invoked on """ plt.show() return self
[docs] def save(self, file_path: str, extension: str = 'png', **kwargs) -> Plot: """ Saves the plot to the file system. Invoking this method after invoking :func:`~pararealml.plot.Plot.show` results in undefined behaviour. :param file_path: the path to save the image file to excluding any extensions :param extension: the file extension to use :param kwargs: any extra arguments :return: the plot object the method is invoked on """ self._figure.savefig(f'{file_path}.{extension}', **kwargs) return self
[docs] def close(self): """ Closes the plot. """ plt.close(self._figure)
[docs]class AnimatedPlot(Plot): """ A base class for animated plots of the solutions of differential equations. """ def __init__( self, figure: Figure, init_func: Callable[[], None], update_func: Callable[[int], None], n_time_steps: int, n_frames: int, interval: int): """ :param figure: the figure of the plotted solution :param init_func: the animation initialization function :param update_func: the animation update function :param n_time_steps: the total number of time steps included in the solution :param n_frames: the number of frames to display :param interval: the number of milliseconds to pause between each frame """ super(AnimatedPlot, self).__init__(figure) time_steps = np.linspace(0, n_time_steps - 1, n_frames, dtype=int) self._animation = FuncAnimation( figure, func=update_func, init_func=init_func, frames=time_steps, interval=interval)
[docs] def save(self, file_path: str, extension: str = 'gif', **kwargs) -> Plot: self._animation.save(f'{file_path}.{extension}', **kwargs) return self
@staticmethod def _verify_pde_solution_shape_matches_problem( y: np.ndarray, mesh: Mesh, vertex_oriented: bool, expected_x_dims: Union[int, Tuple[int, int]], is_vector_field: bool): """ Verifies that the shape of the input array representing the solution of a partial differential equation over the provided mesh and with the specified vertex orientation matches expectations. :param y: an array representing the solution of the partial differential equation :param mesh: the spatial mesh over which the solution is evaluated :param vertex_oriented: whether the solution is evaluated over the vertices or the cell centers of the mesh :param expected_x_dims: the expected number of spatial dimensions :param is_vector_field: whether the solution is supposed to be a vector field or a scalar field """ if isinstance(expected_x_dims, int): if mesh.dimensions != expected_x_dims: raise ValueError(f'mesh must be {expected_x_dims} dimensional') elif not (expected_x_dims[0] <= mesh.dimensions <= expected_x_dims[1]): raise ValueError( f'mesh must be between {expected_x_dims[0]} and ' f'{expected_x_dims[1]} dimensional') if y.ndim != mesh.dimensions + 2: raise ValueError( f'number of y axes ({y.ndim}) must be two larger than mesh ' f'dimensions ({mesh.dimensions})') if y.shape[1:-1] != mesh.shape(vertex_oriented): raise ValueError( f'y shape {y.shape} must be compatible with mesh shape ' f'{mesh.shape(vertex_oriented)}') if is_vector_field: if y.shape[-1] != mesh.dimensions: raise ValueError( f'number of y components ({y.shape[-1]}) must match ' f'x dimensions {mesh.dimensions}') elif y.shape[-1] != 1: raise ValueError( f'number of y components ({y.shape[-1]}) must be one')
[docs]class TimePlot(Plot): """ A simple y against t plot to visualize the solutions of systems of ordinary differential equations. """ def __init__( self, y: np.ndarray, t: np.ndarray, legend_location: Optional[str] = None, **_): """ :param y: an array representing the solution of the ordinary differential equation system :param t: the time coordinates of the solution :param legend_location: the location of the legend denoting which graph represents which component of the solution :param _: any ignored extra arguments """ if y.ndim != 2: raise ValueError(f'number of y axes ({y.ndim}) must be 2') if t.ndim != 1: raise ValueError(f'number of t axes ({t.ndim}) must be 1') if y.shape[0] != t.shape[0]: raise ValueError( f'first axis of y ({y.shape[0]}) must match length of t ' f'({t.shape[0]})') fig, ax = plt.subplots() for i in range(y.shape[1]): ax.plot(t, y[:, i], label=f'y{i}') ax.set_xlabel('t') ax.set_ylabel('y') if legend_location is not None: ax.legend(loc=legend_location) fig.tight_layout() super(TimePlot, self).__init__(fig)
[docs]class PhaseSpacePlot(Plot): """ A phase-space plot to visualize the solutions of systems of two or three ordinary differential equations. """ def __init__(self, y: np.ndarray, **_): """ :param y: an array representing the solution of the ordinary differential equation system :param _: any ignored extra arguments """ if y.ndim != 2: raise ValueError(f'number of y axes ({y.ndim}) must be 2') if not 2 <= y.shape[1] <= 3: raise ValueError( f'number of y components ({y.shape[1]}) must be either 2 or 3') fig = plt.figure() if y.shape[1] == 2: ax = fig.add_subplot() ax.plot(y[:, 0], y[:, 1]) ax.set_xlabel('y0') ax.set_ylabel('y1') ax.axis('equal') else: ax = fig.add_subplot(projection='3d') ax.plot3D(y[:, 0], y[:, 1], y[:, 2]) ax.set_xlabel('y0') ax.set_ylabel('y1') ax.set_zlabel('y2') ax.set_box_aspect(( np.ptp(y[:, 0]), np.ptp(y[:, 1]), np.ptp(y[:, 2]))) super(PhaseSpacePlot, self).__init__(fig)
[docs]class NBodyPlot(AnimatedPlot): """ A 2D or 3D animated scatter plot to visualize the solutions of 2D or 3D n-body gravitational simulations. """ def __init__( self, y: np.ndarray, diff_eq: NBodyGravitationalEquation, n_frames: int = 100, interval: int = 100, color_map: Colormap = cm.cividis, smallest_marker_size: float = 10., draw_trajectory: bool = True, trajectory_line_style: str = ':', trajectory_line_width: float = .5, span_scaling_factor: float = .25, **_): """ :param y: an array representing the solution of the n-body gravitational differential equation :param diff_eq: the n-body gravitational differential equation solved :param n_frames: the number of frames to display :param interval: the number of milliseconds to pause between each frame :param color_map: the color map to use for coloring the celestial objects :param smallest_marker_size: the size of the marker representing the smallest mass :param draw_trajectory: whether the trajectory of the objects should be plotted as well :param trajectory_line_style: the style of the trajectory line :param trajectory_line_width: the width of the trajectory line :param span_scaling_factor: the fraction of the peak-to-peak value of the object positions along each axis to pad the axis limits with; for example if the lowest and highest x coordinates of any object are -5 and 5 respectively, the limits of the x-axis are set at -5 - 10 * span_scaling_factor and 5 + 10 * span_scaling_factor respectively :param _: any ignored extra arguments """ if y.ndim != 2: raise ValueError(f'number of y axes ({y.ndim}) must be 2') if y.shape[1] != diff_eq.y_dimension: raise ValueError( f'number of y components ({y.ndim}) must match differential ' f'equation y dimension ({diff_eq.y_dimension})') n_obj = diff_eq.n_objects n_obj_by_dims = n_obj * diff_eq.spatial_dimension x_coordinates = y[:, :n_obj_by_dims:diff_eq.spatial_dimension] y_coordinates = y[:, 1:n_obj_by_dims:diff_eq.spatial_dimension] x_max = x_coordinates.max() x_min = x_coordinates.min() y_max = y_coordinates.max() y_min = y_coordinates.min() x_span = x_max - x_min y_span = y_max - y_min x_max += span_scaling_factor * x_span x_min -= span_scaling_factor * x_span y_max += span_scaling_factor * y_span y_min -= span_scaling_factor * y_span masses = np.asarray(diff_eq.masses) scaled_masses = (smallest_marker_size / np.min(masses)) * masses radii = np.power(3. * scaled_masses / (4. * np.pi), 1. / 3.) marker_sizes = np.power(radii, 2) * np.pi colors = color_map(np.linspace(0., 1., n_obj)) self._scatter_plot: Optional[PathCollection] = None self._line_plots: Optional[List[Union[Line2D, Line3D]]] = None style = 'dark_background' with plt.style.context(style): fig = plt.figure() ax = fig.add_subplot( projection='3d' if diff_eq.spatial_dimension == 3 else None) if diff_eq.spatial_dimension == 2: coordinates = np.stack((x_coordinates, y_coordinates), axis=2) def init_plot(): with plt.style.context(style): ax.clear() self._scatter_plot = ax.scatter( x_coordinates[0, :], y_coordinates[0, :], s=marker_sizes, c=colors) if draw_trajectory: self._line_plots = [] for i in range(n_obj): self._line_plots.append(ax.plot( x_coordinates[:1, i], y_coordinates[:1, i], color=colors[i], linestyle=trajectory_line_style, linewidth=trajectory_line_width)[0]) ax.set_xlabel('x') ax.set_ylabel('y') ax.axis('scaled') ax.set_xlim(x_min, x_max) ax.set_ylim(y_min, y_max) def update_plot(time_step: int): self._scatter_plot.set_offsets(coordinates[time_step, ...]) if draw_trajectory: for i in range(n_obj): line_plot = self._line_plots[i] line_plot.set_xdata(x_coordinates[:time_step + 1, i]) line_plot.set_ydata(y_coordinates[:time_step + 1, i]) else: z_coordinates = y[:, 2:n_obj_by_dims:3] z_max = z_coordinates.max() z_min = z_coordinates.min() z_span = z_max - z_min z_max += span_scaling_factor * z_span z_min -= span_scaling_factor * z_span def init_plot(): with plt.style.context(style): ax.clear() self._scatter_plot = ax.scatter( x_coordinates[0, :], y_coordinates[0, :], z_coordinates[0, :], s=marker_sizes, c=colors, depthshade=False) if draw_trajectory: self._line_plots = [] for i in range(n_obj): self._line_plots.append(ax.plot( x_coordinates[:1, i], y_coordinates[:1, i], z_coordinates[:1, i], color=colors[i], linestyle=trajectory_line_style, linewidth=trajectory_line_width)[0]) ax.set_xlabel('x') ax.set_ylabel('y') ax.set_zlabel('z') ax.set_xlim(x_min, x_max) ax.set_ylim(y_min, y_max) ax.set_zlim(z_min, z_max) ax.set_box_aspect(( x_max - x_min, y_max - y_min, z_max - z_min)) ax.set_facecolor('black') ax.xaxis.pane.fill = False ax.yaxis.pane.fill = False ax.zaxis.pane.fill = False ax.grid(False) def update_plot(time_step: int): self._scatter_plot._offsets3d = ( x_coordinates[time_step, ...], y_coordinates[time_step, ...], z_coordinates[time_step, ...] ) if draw_trajectory: for i in range(n_obj): line_plot = self._line_plots[i] line_plot.set_xdata(x_coordinates[:time_step + 1, i]) line_plot.set_ydata(y_coordinates[:time_step + 1, i]) line_plot.set_3d_properties( z_coordinates[:time_step + 1, i]) super(NBodyPlot, self).__init__( fig, init_plot, update_plot, y.shape[0], n_frames, interval)
[docs]class SpaceLinePlot(AnimatedPlot): """ An animated line plot to visualise the solutions of 1D partial differential equations. """ def __init__( self, y: np.ndarray, mesh: Mesh, vertex_oriented: bool, n_frames: int = 100, interval: int = 100, v_min: Optional[float] = None, v_max: Optional[float] = None, equal_scale: bool = False, **_): """ :param y: an array representing the solution of the 1D partial differential equation :param mesh: the spatial mesh over which the solution is evaluated :param vertex_oriented: whether the solution is evaluated over the vertices or the cell centers of the mesh :param n_frames: the number of frames to display :param interval: the number of milliseconds to pause between each frame :param v_min: the lower y-axis limit; if None, the limit is set to the minimum of the solution :param v_max: the upper y-axis limit; if None, the limit is set to the maximum of the solution :param equal_scale: whether the scale of the values of the solution scalar field is the same as the scale of the spatial dimension (i.e. the values represent height) :param _: any ignored extra arguments """ self._verify_pde_solution_shape_matches_problem( y, mesh, vertex_oriented, 1, False) self._line_plot: Optional[Line2D] = None fig, ax = plt.subplots() def init_plot(): ax.clear() self._line_plot, = ax.plot( mesh.coordinate_grids(vertex_oriented)[0], y[0, ..., 0]) ax.set_ylim( np.min(y) if v_min is None else v_min, np.max(y) if v_max is None else v_max) ax.set_xlabel('x') ax.set_ylabel('y') if equal_scale: ax.axis('equal') def update_plot(time_step: int): self._line_plot.set_ydata(y[time_step, ..., 0]) super(SpaceLinePlot, self).__init__( fig, init_plot, update_plot, y.shape[0], n_frames, interval)
[docs]class ContourPlot(AnimatedPlot): """ A contour plot to visualize the solutions of 2D partial differential equations. """ def __init__( self, y: np.ndarray, mesh: Mesh, vertex_oriented: bool, n_frames: int = 100, interval: int = 100, color_map: Colormap = cm.viridis, v_min: Optional[float] = None, v_max: Optional[float] = None, **_): """ :param y: an array representing the solution scalar field of the 2D partial differential equation :param mesh: the spatial mesh over which the solution is evaluated :param vertex_oriented: whether the solution is evaluated over the vertices or the cell centers of the mesh :param n_frames: the number of frames to display :param interval: the number of milliseconds to pause between each frame :param color_map: the color map to use to map the values of the solution scalar field to colors :param v_min: the lower limit of the color map; if None, the limit is set to the minimum of the solution :param v_max: the upper limit of the color map; if None, the limit is set to the maximum of the solution :param _: any ignored extra arguments """ self._verify_pde_solution_shape_matches_problem( y, mesh, vertex_oriented, 2, False) x_cartesian_coordinate_grids = \ mesh.cartesian_coordinate_grids(vertex_oriented) v_min = np.min(y) if v_min is None else v_min v_max = np.max(y) if v_max is None else v_max self._contour_plot: Optional[ContourSet] = None fig = plt.figure() def init_plot(): fig.clear() ax = fig.add_subplot() self._contour_plot = ax.contourf( *x_cartesian_coordinate_grids, y[0, ..., 0], vmin=v_min, vmax=v_max, cmap=color_map) ax.set_xlabel('x0') ax.set_ylabel('x1') ax.axis('scaled') mappable = ScalarMappable(cmap=color_map) mappable.set_clim(v_min, v_max) plt.colorbar(mappable=mappable) def update_plot(time_step: int): for collection in self._contour_plot.collections: collection.remove() self._contour_plot = self._contour_plot.axes.contourf( *x_cartesian_coordinate_grids, y[time_step, ..., 0], vmin=v_min, vmax=v_max, cmap=color_map) super(ContourPlot, self).__init__( fig, init_plot, update_plot, y.shape[0], n_frames, interval)
[docs]class SurfacePlot(AnimatedPlot): """ A 3D surface plot to visualize the solutions of 2D partial differential equations. """ def __init__( self, y: np.ndarray, mesh: Mesh, vertex_oriented: bool, n_frames: int = 100, interval: int = 100, color_map: Colormap = cm.viridis, v_min: Optional[float] = None, v_max: Optional[float] = None, equal_scale: bool = False, **_): """ :param y: an array representing the solution scalar field of the 2D partial differential equation :param mesh: the spatial mesh over which the solution is evaluated :param vertex_oriented: whether the solution is evaluated over the vertices or the cell centers of the mesh :param n_frames: the number of frames to display :param interval: the number of milliseconds to pause between each frame :param color_map: the color map to use to map the values of the solution scalar field to colors :param v_min: the lower z-axis and color map limit; if None, both of these limits are set to the minimum of the solution :param v_max: the upper z-axis and color map limit; if None, both of these limits are set to the maximum of the solution :param equal_scale: whether the scale of the values of the solution scalar field is the same as the scale of the spatial dimensions (i.e. the values represent height) :param _: any ignored extra arguments """ self._verify_pde_solution_shape_matches_problem( y, mesh, vertex_oriented, 2, False) x_cartesian_coordinate_grids = \ mesh.cartesian_coordinate_grids(vertex_oriented) v_min = np.min(y) if v_min is None else v_min v_max = np.max(y) if v_max is None else v_max x_0_ptp = np.ptp(x_cartesian_coordinate_grids[0]) x_1_ptp = np.ptp(x_cartesian_coordinate_grids[1]) x_2_ptp = (v_max - v_min) if equal_scale else min(x_0_ptp, x_1_ptp) surface_plot_args = { 'vmin': v_min, 'vmax': v_max, 'rstride': 1, 'cstride': 1, 'linewidth': 0, 'antialiased': False, 'cmap': color_map } self._surface_plot: Optional[Poly3DCollection] = None fig = plt.figure() ax = fig.add_subplot(projection='3d') def init_plot(): ax.clear() self._surface_plot = ax.plot_surface( *x_cartesian_coordinate_grids, y[0, ..., 0], **surface_plot_args) ax.set_xlabel('x0') ax.set_ylabel('x1') ax.set_zlabel('y') ax.set_zlim(v_min, v_max) ax.set_box_aspect((x_0_ptp, x_1_ptp, x_2_ptp)) def update_plot(time_step: int): self._surface_plot.remove() self._surface_plot = ax.plot_surface( *x_cartesian_coordinate_grids, y[time_step, ..., 0], **surface_plot_args) super(SurfacePlot, self).__init__( fig, init_plot, update_plot, y.shape[0], n_frames, interval)
[docs]class ScatterPlot(AnimatedPlot): """ A 3D scatter plot to visualize the solutions of 3D partial differential equations. """ def __init__( self, y: np.ndarray, mesh: Mesh, vertex_oriented: bool, n_frames: int = 100, interval: int = 100, color_map: Colormap = cm.viridis, v_min: Optional[float] = None, v_max: Optional[float] = None, marker_shape: str = 'o', marker_size: Union[float, np.ndarray] = 20., marker_opacity: float = 1., **_): """ :param y: an array representing the solution of the 3D partial differential equation :param mesh: the spatial mesh over which the solution is evaluated :param vertex_oriented: whether the solution is evaluated over the vertices or the cell centers of the mesh :param n_frames: the number of frames to display :param interval: the number of milliseconds to pause between each frame :param color_map: the color map to use to map the values of the solution scalar field to colors :param v_min: the lower limit of the color map; if None, the limit is set to the minimum of the solution :param v_max: the upper limit of the color map; if None, the limit is set to the maximum of the solution :param marker_shape: the shape of the point markers :param marker_size: the size of the point markers :param marker_opacity: the opacity of the point markers :param _: any ignored extra arguments """ self._verify_pde_solution_shape_matches_problem( y, mesh, vertex_oriented, 3, False) x_cartesian_coordinate_grids = \ mesh.cartesian_coordinate_grids(vertex_oriented) mappable = ScalarMappable(cmap=color_map) mappable.set_clim( np.min(y) if v_min is None else v_min, np.max(y) if v_max is None else v_max) self._scatter_plot: Optional[PathCollection] = None fig = plt.figure() ax = fig.add_subplot(projection='3d') def init_plot(): ax.clear() ax.set_xlabel('x0') ax.set_ylabel('x1') ax.set_zlabel('x2') ax.set_box_aspect(( np.ptp(x_cartesian_coordinate_grids[0]), np.ptp(x_cartesian_coordinate_grids[1]), np.ptp(x_cartesian_coordinate_grids[2]))) self._scatter_plot = ax.scatter( *x_cartesian_coordinate_grids, c=mappable.to_rgba(y[0, ..., 0].flatten()), marker=marker_shape, s=marker_size, alpha=marker_opacity) def update_plot(time_step: int): self._scatter_plot.set_color( mappable.to_rgba(y[time_step, ..., 0].flatten())) super(ScatterPlot, self).__init__( fig, init_plot, update_plot, y.shape[0], n_frames, interval)
[docs]class StreamPlot(AnimatedPlot): """ A 2D stream plot to visualize the solution vector fields of 2D partial differential equation systems. """ def __init__( self, y: np.ndarray, mesh: Mesh, vertex_oriented: bool, n_frames: int = 100, interval: int = 100, color: str = 'black', density: float = 1., **_): """ :param y: an array representing the solution vector field of the 2D partial differential equation system :param mesh: the spatial mesh over which the solution is evaluated :param vertex_oriented: whether the solution is evaluated over the vertices or the cell centers of the mesh :param n_frames: the number of frames to display :param interval: the number of milliseconds to pause between each frame :param color: the color to use for the lines and arrows of the stream plot :param density: the density of the stream lines :param _: any ignored extra arguments """ self._verify_pde_solution_shape_matches_problem( y, mesh, vertex_oriented, 2, True) coordinate_grids = mesh.coordinate_grids(vertex_oriented) self._stream_plot: Optional[StreamplotSet] = None fig = plt.figure() if mesh.coordinate_system_type == CoordinateSystem.POLAR: (x_1_min, x_1_max), (x_0_min, x_0_max) = mesh.x_intervals x_1_min = 0 x_0 = coordinate_grids[1] x_1 = coordinate_grids[0] y_0 = y[..., 1] y_1 = y[..., 0] ax = fig.add_subplot(projection='polar') else: (x_0_min, x_0_max), (x_1_min, x_1_max) = mesh.x_intervals x_0 = coordinate_grids[0].T x_1 = coordinate_grids[1].T y_0 = y[..., 0].transpose([0, 2, 1]) y_1 = y[..., 1].transpose([0, 2, 1]) ax = fig.add_subplot() def init_plot(): ax.clear() self._stream_plot = ax.streamplot( x_0, x_1, y_0[0, ...], y_1[0, ...], color=color, density=density) ax.set_xlim(x_0_min, x_0_max) ax.set_ylim(x_1_min, x_1_max) if mesh.coordinate_system_type == CoordinateSystem.CARTESIAN: ax.axis('scaled') ax.set_xlabel('x') ax.set_ylabel('y') def update_plot(time_step: int): with warnings.catch_warnings(): warnings.simplefilter('ignore') ax.patches.clear() self._stream_plot.lines.remove() self._stream_plot = ax.streamplot( x_0, x_1, y_0[time_step, ...], y_1[time_step, ...], color=color, density=density) super(StreamPlot, self).__init__( fig, init_plot, update_plot, y.shape[0], n_frames, interval)
[docs]class QuiverPlot(AnimatedPlot): """ A 2D or 3D quiver plot to visualize the solution vector fields of 2D or 3D partial differential equation systems. """ def __init__( self, y: np.ndarray, mesh: Mesh, vertex_oriented: bool, n_frames: int = 100, interval: int = 100, normalize: bool = False, pivot: str = 'middle', quiver_scale: float = 10., **_): """ :param y: an array representing the solution vector field of the partial differential equation system :param mesh: the spatial mesh over which the solution is evaluated :param vertex_oriented: whether the solution is evaluated over the vertices or the cell centers of the mesh :param n_frames: the number of frames to display :param interval: the number of milliseconds to pause between each frame :param normalize: Wheter to normalize the lengths of the arrows to one :param pivot: the pivot point of the arrows :param quiver_scale: the scaling factor to apply to the arrow lengths :param _: any ignored extra arguments """ self._verify_pde_solution_shape_matches_problem( y, mesh, vertex_oriented, (2, 3), True) x_cartesian_coordinate_grids = mesh.cartesian_coordinate_grids( vertex_oriented) unit_vector_grids = mesh.unit_vector_grids(vertex_oriented) y_cartesian: np.ndarray = np.asarray(sum([ y[..., i:i + 1] * unit_vector_grids[i][np.newaxis, ...] for i in range(mesh.dimensions) ])) self._quiver_plot: Optional[Quiver] = None fig = plt.figure() if mesh.dimensions == 2: y_0 = y_cartesian[..., 0] y_1 = y_cartesian[..., 1] if normalize: y_magnitude = np.sqrt(np.square(y_0) + np.square(y_1)) y_magnitude_gt_zero = y_magnitude > 0. y_0[y_magnitude_gt_zero] /= y_magnitude[y_magnitude_gt_zero] y_1[y_magnitude_gt_zero] /= y_magnitude[y_magnitude_gt_zero] ax = fig.add_subplot() def init_plot(): ax.clear() ax.set_xlabel('x') ax.set_ylabel('y') self._quiver_plot = ax.quiver( *x_cartesian_coordinate_grids, y_0[0, ...], y_1[0, ...], pivot=pivot, angles='xy', scale_units='xy', scale=1. / quiver_scale) ax.axis('scaled') def update_plot(time_step: int): self._quiver_plot.set_UVC( y_0[time_step, ...], y_1[time_step, ...]) else: y_0 = y_cartesian[..., 0] * quiver_scale y_1 = y_cartesian[..., 1] * quiver_scale y_2 = y_cartesian[..., 2] * quiver_scale ax = fig.add_subplot(projection='3d') def init_plot(): ax.clear() self._quiver_plot = ax.quiver( *x_cartesian_coordinate_grids, y_0[0, ...], y_1[0, ...], y_2[0, ...], pivot=pivot, normalize=normalize) ax.set_xlabel('x') ax.set_ylabel('y') ax.set_zlabel('z') ax.set_box_aspect(( np.ptp(x_cartesian_coordinate_grids[0]), np.ptp(x_cartesian_coordinate_grids[1]), np.ptp(x_cartesian_coordinate_grids[2]))) def update_plot(time_step: int): self._quiver_plot.remove() self._quiver_plot = ax.quiver( *x_cartesian_coordinate_grids, y_0[time_step, ...], y_1[time_step, ...], y_2[time_step, ...], pivot=pivot, normalize=normalize) super(QuiverPlot, self).__init__( fig, init_plot, update_plot, y.shape[0], n_frames, interval)