Source code for pararealml.operators.ml.pidon.collocation_point_sampler

from abc import ABC, abstractmethod
from typing import Sequence, NamedTuple, Optional, List

import numpy as np

from pararealml.initial_value_problem import TemporalDomainInterval
from pararealml.mesh import SpatialDomainInterval


[docs]class CollocationPoints(NamedTuple): """ Collocation points from a spatio-temporal domain. """ t: np.ndarray x: Optional[np.ndarray]
[docs]class AxialBoundaryPoints(NamedTuple): """ Spatio-temporal collocation points sampled from the lower and upper boundaries of a spatial axis. """ lower_boundary_points: Optional[CollocationPoints] upper_boundary_points: Optional[CollocationPoints]
[docs]class CollocationPointSampler(ABC): """ A base class for collocation point samplers. """
[docs] @abstractmethod def sample_domain_points( self, n_points: int, t_interval: TemporalDomainInterval, x_intervals: Optional[Sequence[SpatialDomainInterval]] ) -> CollocationPoints: """ Samples a set of points from a spatio-temporal domain. If the spatial domain intervals are undefined, it only samples from the temporal domain. :param n_points: the number of points to sample :param t_interval: the bounds of the temporal domain :param x_intervals: a sequence of the bounds of the spatial domain :return: a set of domain points """
[docs] @abstractmethod def sample_boundary_points( self, n_points: int, t_interval: TemporalDomainInterval, x_intervals: Sequence[SpatialDomainInterval] ) -> Sequence[AxialBoundaryPoints]: """ Samples a set of points organized into a sequence of pairs from the boundaries of a spatio-temporal domain. :param n_points: the number of points to sample :param t_interval: the bounds of the temporal domain :param x_intervals: a sequence of the bounds of the spatial domain :return: a set of boundary points organized into a sequence of pairs """
[docs]class UniformRandomCollocationPointSampler(CollocationPointSampler): """ A uniform random collocation point sampler. """
[docs] def sample_domain_points( self, n_points: int, t_interval: TemporalDomainInterval, x_intervals: Optional[Sequence[SpatialDomainInterval]] ) -> CollocationPoints: if n_points <= 0: raise ValueError( f'number of domain points ({n_points}) must be greater than 0') t = np.random.uniform(*t_interval, (n_points, 1)) if x_intervals is not None: x_lower_bounds, x_upper_bounds = zip(*x_intervals) x = np.random.uniform( x_lower_bounds, x_upper_bounds, (n_points, len(x_intervals))) else: x = None return CollocationPoints(t, x)
[docs] def sample_boundary_points( self, n_points: int, t_interval: TemporalDomainInterval, x_intervals: Sequence[SpatialDomainInterval] ) -> Sequence[AxialBoundaryPoints]: if n_points <= 0: raise ValueError( f'number of boundary points ({n_points}) must be greater ' f'than 0') (lower_t_bound, upper_t_bound) = t_interval (lower_x_bounds, upper_x_bounds) = zip(*x_intervals) x_interval_lengths = np.subtract(upper_x_bounds, lower_x_bounds) domain_size = np.prod(x_interval_lengths) boundary_sizes_at_ends_of_axes = np.array([ domain_size / x_interval_length for x_interval_length in x_interval_lengths ]) axial_boundary_pmf = boundary_sizes_at_ends_of_axes / \ boundary_sizes_at_ends_of_axes.sum() n_boundary_points_per_axis = np.random.multinomial( n_points, axial_boundary_pmf) boundary_points = [] for axis, n_boundary_points in enumerate(n_boundary_points_per_axis): n_lower_boundary_points = np.random.binomial(n_boundary_points, .5) n_axial_boundary_points = \ (n_lower_boundary_points, n_boundary_points - n_lower_boundary_points) axial_bounds = (lower_x_bounds[axis], upper_x_bounds[axis]) axial_boundary_points: List[Optional[CollocationPoints]] = [] for axis_end in range(2): n_samples = n_axial_boundary_points[axis_end] if n_samples == 0: axial_boundary_points.append(None) continue t = np.random.uniform( lower_t_bound, upper_t_bound, (n_samples, 1)) x = np.random.uniform( lower_x_bounds, upper_x_bounds, (n_samples, len(x_intervals))) x[:, axis] = axial_bounds[axis_end] axial_boundary_points.append(CollocationPoints(t, x)) boundary_points.append(AxialBoundaryPoints( axial_boundary_points[0], axial_boundary_points[1])) return boundary_points