Source code for swiftsimio.masks

"""Provide functions and objects that use mask information from SWIFT snapshots."""

import warnings
import h5py
import numpy as np
from pathlib import Path

from swiftsimio.metadata.field.attr_reader import (
    load_field_units as _load_field_units,
    load_field_cosmo_factor as _load_field_cosmo_factor,
    load_field_physical as _load_field_physical,
)
from swiftsimio.metadata.objects import SWIFTMetadata
from swiftsimio.objects import InvalidSnapshot, cosmo_array, cosmo_quantity
from swiftsimio.accelerated import ranges_from_array
from swiftsimio._handle_provider import HandleProvider

from typing import Callable, Sequence

_DEFAULT_SAFE_PADDING = 0.1
_GROUPCAT_OUTPUT_TYPES = ["FOF", "SOAP", "FOFSubset", "SOAPSubset"]


def _constraint(method: Callable) -> Callable:
    """
    Decorate a function that constrains a :class:`~swiftsimio.masks.SWIFTMask`.

    Parameters
    ----------
    method : Callable
        A method of the :class:`~swiftsimio.masks.SWIFTMask` class that applies a
        constraint to the mask.

    Returns
    -------
    Callable
        The decorated method.
    """

    def wrapped(
        self: "SWIFTMask",
        *args: tuple,
        **kwargs: dict,
    ) -> None:  # noqa numpydoc ignore=GL08
        # omit docstring so that sphinx picks up docstring of wrapped function
        retval = method(self, *args, **kwargs)
        self.constrained = method.__name__  # only set after function call
        return retval

    return wrapped


[docs] class SWIFTMask(HandleProvider): """ Main masking object. This can have masks for any present particle type in it. For catalogues (e.g. SOAP) where all arrays are the same size there can be a "shared" mask that is stored once and used for all of them. Takes the SWIFT metadata and enables individual property-by-property masking when reading from snapshots. When masking like this order-in-file is not preserved, i.e. the 7th particle may not be the 7th particle in the file. Parameters ---------- filename : Path File to read cell metadata from. metadata : SWIFTMetadata Metadata loaded from snapshot. safe_padding : bool or float, optional If snapshot does not specify bounding box of cell particles (``MinPositions``, ``MaxPositions``), pad the mask to gurantee that *all* particles in requested spatial region(s) are selected. If the bounding box metadata is present, this argument is ignored. The default (``0.1``) is to pad by 0.1 times the cell length. Padding can be disabled (``False``) or set to a different fraction of the cell length (e.g. ``0.5``). Only entire cells are loaded, but if the region boundary is more than ``safe_padding`` from a cell boundary the neighbouring cell is not read. Switching off can reduce I/O load by up to a factor of 30 in some cases (but a few particles in region could be missing). See https://swiftsimio.readthedocs.io/en/latest/masking/index.html for further details. handle : h5py.File, optional The file handle to read metadata from. spatial_only : bool, optional Deprecated, any necessary conversions handled automatically. """ filename: Path constrained: str | None def __init__( self, filename: Path, metadata: SWIFTMetadata, *, safe_padding: bool | float = _DEFAULT_SAFE_PADDING, handle: h5py.File | None = None, spatial_only: bool | None = None, # deprecated ) -> None: if spatial_only is not None: warnings.warn( "`spatial_only` is deprecated, any necessary conversions are now handled" " automatically.", DeprecationWarning, ) super().__init__(filename, handle=handle) self.metadata = metadata self.units = metadata.units self._group_mapping: dict | None = None self._group_size_mapping: dict | None = None self._update_list: list | None = None self.range_mask = True self.constrained = None if safe_padding is True: self.safe_padding = _DEFAULT_SAFE_PADDING elif safe_padding is False: self.safe_padding = 0.0 else: self.safe_padding = safe_padding if not self.metadata.masking_valid: raise NotImplementedError( f"Masking not supported for {self.metadata.output_type} filetype" ) if self.metadata.partial_snapshot: raise InvalidSnapshot( "You cannot use masks on partial snapshots. Please use the virtual " "file generated by SWIFT (use snapshot.hdf5, not snapshot.0.hdf5)." ) self._unpack_cell_metadata() self.cell_mask = self._generate_cell_mask(restrict=None) self._update_range_masks() self._close_handle_if_manager() @property def group_mapping(self) -> dict[str, str]: """ Create mapping between "group names" and their underlying cell metadata names. Allows for aliases to be used instead of re-creating masks. Returns ------- dict[str, str]: The dictionary of corresponding names. """ if self._group_mapping is not None: return self._group_mapping if self.metadata.shared_cell_counts is None: # Each and every particle type has its own cell counts, offsets, # and hence masks. self._group_mapping = { group: f"_{group}" for group in self.metadata.present_group_names } else: # We actually only have _one_ mask! self._group_mapping = { group: "_shared" for group in self.metadata.present_group_names } return self.group_mapping @property def group_size_mapping(self) -> dict[str, str]: """ Create mapping between "group names" and their underlying cell metadata names. Allows for aliases to be used instead of re-creating masks. Returns ------- dict[str, str]: The dictionary of corresponding names. """ if self._group_size_mapping is not None: return self._group_size_mapping if self.metadata.shared_cell_counts is None: # Each and every particle type has its own cell counts, offsets, # and hence masks. self._group_size_mapping = { f"{group}_size": f"_{group}_size" for group in self.metadata.present_group_names } else: # We actually only have _one_ mask! self._group_size_mapping = { f"{group}_size": "_shared_size" for group in self.metadata.present_group_names } return self.group_size_mapping @property def update_list(self) -> list[str]: """ Get list of internal mask variables that need updating when changing spatial mask. Returns ------- list[str] List of the variable names that need updating. """ if self._update_list is not None: return self._update_list self._update_list = ( # Each and every particle type has its own cell counts, offsets, # and hence masks: [f"_{group}" for group in self.metadata.present_group_names] if self.metadata.shared_cell_counts is None # Or there is only one shared mask: else ["_shared"] ) return self._update_list def __getattr__(self, name: str) -> np.ndarray: """ Overload the ``__getattr__`` method to allow for direct access to the masks. Parameters ---------- name : str Name of one of the available particle types. Returns ------- np.ndarray[bool] The requested mask. Raises ------ AttributeError If the requested particle type is not found. """ mappings = { **self.group_mapping, **self.group_size_mapping, } underlying_name = mappings.get(name, None) if underlying_name is None: raise AttributeError(f"Attribute {name} not found in SWIFTMask") return getattr(self, underlying_name) def _generate_empty_masks( self, fill_value: bool = True ) -> tuple[dict[str, np.ndarray], dict[str, int]]: """ Generate empty (all ``True``) masks for all available particle types. Parameters ---------- fill_value : bool Value to fill the arrays with, either ``True`` or ``False``. Returns ------- dict Contains the names of the masks as keys and the masks themselves as corresponding values. dict Contains the sizes of the masks labelled with keys matching the masks names suffixed with ``"_size"``. """ mask_func = np.ones if fill_value else np.zeros empty_masks = {} sizes = {} if self.metadata.shared_cell_counts is not None: size = getattr( self.metadata, f"n_{self.metadata.shared_cell_counts.lower()}" ) empty_masks["_shared"] = mask_func(size, dtype=bool) sizes["_shared_size"] = size else: for group_name, data_name in self.group_mapping.items(): size = getattr(self.metadata, f"n_{group_name}") empty_masks[data_name] = mask_func(size, dtype=bool) sizes[f"{data_name}_size"] = size return empty_masks, sizes def _unpack_cell_metadata(self) -> None: """ Unpack the cell metadata into local (to the class) variables. We do not read in information for empty cells. """ # Reset this in case for any reason we have messed them up self.counts = {} self.offsets = {} self.minpositions = {} self.maxpositions = {} cell_handle = self.handle["Cells"] count_handle = cell_handle["Counts"] metadata_handle = cell_handle["Meta-data"] centers_values = cell_handle["Centres"][...] if ( "MinPositions" in cell_handle.keys() and "MaxPositions" in cell_handle.keys() ): # Older versions of SWIFT don't have this information minpos_handle = cell_handle["MinPositions"] maxpos_handle = cell_handle["MaxPositions"] else: minpos_handle, maxpos_handle = None, None try: offset_handle = cell_handle["OffsetsInFile"] except KeyError: # Previous version of SWIFT did not have distributed # file i/o implemented offset_handle = cell_handle["Offsets"] if self.metadata.shared_cell_counts is not None: # Single - called _shared. self.offsets["shared"] = offset_handle[self.metadata.shared_cell_counts][:] self.counts["shared"] = count_handle[self.metadata.shared_cell_counts][:] else: for group, group_name in zip( self.metadata.present_groups, self.metadata.present_group_names ): self.offsets[group_name] = offset_handle[group][:] self.counts[group_name] = count_handle[group][:] if minpos_handle is not None and maxpos_handle is not None: for group, group_name in zip( self.metadata.present_groups, self.metadata.present_group_names ): minpos_values = minpos_handle[group][:] maxpos_values = maxpos_handle[group][:] self.minpositions[group_name] = np.where( centers_values - 0.5 * metadata_handle.attrs["size"] < minpos_values, centers_values - 0.5 * metadata_handle.attrs["size"], minpos_values, ) self.maxpositions[group_name] = np.where( centers_values + 0.5 * metadata_handle.attrs["size"] > maxpos_values, centers_values + 0.5 * metadata_handle.attrs["size"], maxpos_values, ) else: # be conservative: pad (default by 0.1 cell) in case particles drifed # (unless for group catalogues) pad_cells = ( 0 if self.metadata.output_type in _GROUPCAT_OUTPUT_TYPES else self.safe_padding ) if self.metadata.output_type not in _GROUPCAT_OUTPUT_TYPES: warnings.warn( "Snapshot does not contain Cells/MinPositions and Cells/MaxPositions" f" metadata. Padding region by {pad_cells} times cell length to" " account for drifted particles. This behaviour can be" " configured/disabled with the `safe_padding` parameter when creating" " the mask. See " "https://swiftsimio.readthedocs.io/en/latest/masking/index.html" " for further details." ) # +/- 0.5 here is the cell size itself: self.minpositions["shared"] = ( centers_values - (pad_cells + 0.5) * metadata_handle.attrs["size"] ) self.maxpositions["shared"] = ( centers_values + (pad_cells + 0.5) * metadata_handle.attrs["size"] ) # Only want to compute this once (even if it is fast, we do not # have a reliable stable sort in the case where cells do not # contain at least one of each type of particle). self.cell_sort = None # Now perform sort: for key in self.offsets.keys(): offsets = self.offsets[key] counts = self.counts[key] # When using MPI, we cannot assume that these are sorted. if self.cell_sort is None: # Only compute once; not stable between particle # types if some datasets do not have particles in a cell! self.cell_sort = np.argsort(offsets, stable=True) self.offsets[key] = offsets[self.cell_sort] self.counts[key] = counts[self.cell_sort] # Also need to sort centers in the same way self.centers = cosmo_array( centers_values[self.cell_sort], units=self.units.length, comoving=True, scale_factor=self.metadata.scale_factor, scale_exponent=1, ) # And sort min & max positions, too. for k in self.minpositions.keys(): self.minpositions[k] = cosmo_array( self.minpositions[k][self.cell_sort], units=self.units.length, comoving=True, scale_factor=self.metadata.scale_factor, scale_exponent=1, ) for k in self.maxpositions.keys(): self.maxpositions[k] = cosmo_array( self.maxpositions[k][self.cell_sort], units=self.units.length, comoving=True, scale_factor=self.metadata.scale_factor, scale_exponent=1, ) if minpos_handle is None and maxpos_handle is None: for group_name in self.metadata.present_group_names: self.minpositions[group_name] = self.minpositions["shared"] self.maxpositions[group_name] = self.maxpositions["shared"] # Note that we cannot assume that these are cubic, unfortunately. self.cell_size = cosmo_array( metadata_handle.attrs["size"], units=self.units.length, comoving=True, scale_factor=self.metadata.scale_factor, scale_exponent=1, ) return
[docs] def constrain_mask( self, group_name: str, quantity: str, lower: cosmo_quantity, upper: cosmo_quantity, ) -> None: """ Use :meth:`~swiftsimio.masks.SWIFTMask.constrain_property` instead. This name is deprecated and will be removed in the future. Parameters ---------- group_name : str Particle type (e.g. ``"gas"``). quantity : str Quantity being constrained (e.g. ``"temperatures"``). lower : ~swiftsimio.objects.cosmo_quantity Constraint lower bound. upper : ~swiftsimio.objects.cosmo_quantity Constraint upper bound. """ warnings.warn( "`constrain_mask` is deprecated, use `constrain_property` with the same " "arguments instead.", DeprecationWarning, ) self.constrain_property(group_name, quantity, lower, upper)
[docs] @_constraint def constrain_property( self, group_name: str, quantity: str, lower: cosmo_quantity, upper: cosmo_quantity, ) -> None: """ Constrain the mask further for a given particle type. Chooses only particles with a property bounded between lower and upper values. We update the mask such that: .. code-block:: python lower < group_name.quantity <= upper The quantities must be :class:`~swiftsimio.objects.cosmo_quantity`, i.e. must have units and cosmology information attached. Parameters ---------- group_name : str Particle type (e.g. ``"gas"``). quantity : str Quantity being constrained (e.g. ``"temperatures"``). lower : ~swiftsimio.objects.cosmo_quantity Constraint lower bound. upper : ~swiftsimio.objects.cosmo_quantity Constraint upper bound. See Also -------- constrain_spatial Method to generate spatially constrained cell mask. """ self.convert_masks_to_bool() # no-op if already bool data_name = self.group_mapping[group_name] current_mask = getattr(self, data_name) group_metadata = getattr(self.metadata, f"{group_name}_properties") handle_dict = { k: v for k, v in zip(group_metadata.field_names, group_metadata.field_paths) } handle = handle_dict[quantity] # Load in the relevant data. h5file: h5py.File with self.metadata.open_file() as h5file: field_attributes = h5file[handle].attrs unit = _load_field_units(field_attributes, self.metadata.units) physical = _load_field_physical(field_attributes) cf = _load_field_cosmo_factor(field_attributes, self.metadata) if isinstance(h5file, (h5py.File, h5py.Group)): # When reading from a local HDF5 file this is faster than # just using the boolean indexing because h5py has slow # indexing routines. data = np.take(h5file[handle], np.where(current_mask)[0], axis=0) else: # Otherwise, assume we can just index the dataset. This # generates a single http request for remote datasets. data = h5file[handle][current_mask, ...] # Wrap result in a cosmo_array data = cosmo_array( data, units=unit, comoving=not physical, cosmo_factor=cf, ) new_mask = np.logical_and.reduce([data > lower, data <= upper]) current_mask[current_mask] = new_mask setattr(self, data_name, current_mask) return
def _generate_cell_mask( self, restrict: cosmo_array | None = None ) -> dict[str, np.ndarray]: """ Generate a spatially restricted mask for cells. Takes the cell metadata and finds the mask for the _cells_ that are within the spatial region defined by the spatial mask. Not for user use. Parameters ---------- restrict : cosmo_array Restrict is a (3,2) cosmo_array giving the lower and upper bounds for each axis, e.g. .. code-block:: python restrict = cosmo_array( [ [0.5, 0.7], [0.1, 0.9], [0.0, 0.1] ], u.Mpc, comoving=True, scale_factor=1.0, scale_exponent=1, ) Returns ------- np.ndarray[bool] Mask to indicate which cells are within the specified spatial range. Raises ------ ValueError If the mask boundaries are outside the interval [-Lbox/2, 3*Lbox/2]. """ if self.metadata.output_type in _GROUPCAT_OUTPUT_TYPES: cell_mask = {"shared": np.ones(len(self.centers), dtype=bool)} else: # particles may drift from their cells, mask each type separately cell_mask = { group_name: np.ones(len(self.centers), dtype=bool) for group_name in self.metadata.present_group_names } if restrict is None: return cell_mask for dimension in range(0, 3): lower = restrict[dimension][0] upper = restrict[dimension][1] boxsize = self.metadata.boxsize[dimension] if np.logical_or.reduce( ( lower < -boxsize / 2, upper < -boxsize / 2, lower > 3 * boxsize / 2, upper > 3 * boxsize / 2, ) ): # because we're only going to make one periodic copy on either side, # we're in trouble raise ValueError( "Mask region boundaries must be in interval [-boxsize/2, 3*boxsize/2]" f" along {'xyz'[dimension]}-axis." ) if restrict[dimension] is None or np.abs(upper - lower) > boxsize: # keep everything along this axis continue if upper < lower: # inverted case, convert to a "normal" case in target window if lower > boxsize / 2: lower -= boxsize elif upper < boxsize / 2: # don't shift both else we get the whole box! upper += boxsize group_names = ( ["shared"] if self.metadata.output_type in _GROUPCAT_OUTPUT_TYPES else self.metadata.present_group_names ) for group_name in group_names: # selection intersects one of the 3 periodic copies of a cell: this_mask = np.logical_or.reduce( [ np.logical_and( self.maxpositions[group_name][ cell_mask[group_name], dimension ] + shift * boxsize > lower, self.minpositions[group_name][ cell_mask[group_name], dimension ] + shift * boxsize < upper, ) for shift in (-1, 0, 1) ] ) cell_mask[group_name][cell_mask[group_name]] = this_mask return cell_mask def _update_range_masks(self) -> None: """ Update the particle masks using the cell masks. We actually overwrite all non-used cells with ``False``, rather than the inverse, as we assume initially that we want to read all particles in, and we want to respect other masks that may have been applied to the data. """ for data_name in self.update_list: count_name = data_name[1:] # Remove the underscore if self.range_mask: counts = self.counts[count_name][self.cell_mask[count_name]] offsets = self.offsets[count_name][self.cell_mask[count_name]] this_mask = [[o, c + o] for c, o in zip(counts, offsets)] setattr(self, data_name, np.array(this_mask)) setattr(self, f"{data_name}_size", np.sum(counts)) else: counts = self.counts[count_name][ np.logical_not(self.cell_mask[count_name]) ] offsets = self.offsets[count_name][ np.logical_not(self.cell_mask[count_name]) ] # We must do the whole boolean mask business. this_mask = getattr(self, data_name) for count, offset in zip(counts, offsets): this_mask[offset : count + offset] = False return def _sanitize_region(self, region: Sequence) -> cosmo_array: """ Coerce user-provided region to (3, 2) :class:`~swiftsimio.objects.cosmo_array`. The input region should have length 3 (otherwise error). The rows should have length 2, but may be ``None``. We check these conditions, then package everything up into a :class:`~swiftsimio.objects.cosmo_array`. We have to be careful: :class:`~swiftsimio.objects.cosmo_array` will check the input for compatibility of units, cosmo_factors, etc. and convert them to be consistent if needed, but only up to a depth of 1, i.e. in this: ``cosmo_array([[cosmo_quantity(...), cosmo_quantity(...)], None, None])`` the cosmo-ness of the :class:`~swiftsimio.objects.cosmo_quantity` elements would be ignored because they are nested two-deep in a list. Parameters ---------- region : list The user-provided region. Returns ------- ~swiftsimio.objects.cosmo_array The region as a (3, 2) :class:`~swiftsimio.objects.cosmo_array`, with ``None`` rows replaced by the full extent of the box. """ if len(region) != 3: raise ValueError("`restrict` must have length == 3.") for ax_region in region: if ax_region is not None and len(ax_region) != 2: raise ValueError( "Rows of `restrict` must have length == 2 (or be ``None``)." ) return cosmo_array( [ cosmo_array(ax_region) if ax_region is not None else np.array([0, 1]) * b for ax_region, b in zip(region, self.metadata.boxsize) ] )
[docs] @_constraint def constrain_spatial( self, restrict: Sequence, union: bool = False, intersect: bool | None = None, # deprecated ) -> None: """ Use the cell metadata to select particles within a cuboid region. This mask is necessarily approximate and is coarse-grained to the cell size. Parameters ---------- restrict : list Restrict is a shape (3, 2) iterable giving the lower and upper bounds for each axis. The bounds need to be given as :class:`~swiftsimio.objects.cosmo_array` or :class:`~swiftsimio.objects.cosmo_quantity`. For example: .. code-block:: python restrict = cosmo_array( [ [0.5, 0.7], [0.1, 0.9], [0.0, 0.1] ], u.Mpc, comoving=True, scale_factor=1.0, scale_exponent=1, ) If no constraint is desired along an axis, ``None`` can be given instead. For example, to select a "slab" in the x-y plane: .. code-block:: python zmin = cosmo_quantity( 5.0, u.Mpc, comoving=True, scale_factor=1.0, scale_exponent=1 ) zmax = cosmo_quantity( 6.0, u.Mpc, comoving=True, scale_factor=1.0, scale_exponent=1 ) restrict = [ None, None, [zmin, zmax], ] union : bool, optional If ``True``, combine the spatial mask with any existing spatial mask to select two (or more) regions with repeated calls to ``constrain_spatial``. intersect : bool, optional Deprecated due to misleading name, use ``union`` instead. See Also -------- constrain_mask Method to further refine mask, selecting particles by their properties. """ if intersect is not None: warnings.warn( "`intersect` is deprecated because this term is misleading, use `union` " "(with the same value) instead. Overriding `union` value if provided.", DeprecationWarning, ) union = intersect sanitized_region = self._sanitize_region(restrict) if self.constrained == "constrain_spatial" and union: # we are in union mode and already have a spatial constraint new_mask = self._generate_cell_mask(sanitized_region) for group_name in self.update_list: group_name = group_name[1:] # remove leading underscore self.cell_mask[group_name] = np.logical_or( self.cell_mask[group_name], new_mask[group_name] ) elif self.constrained is None: # union or not union, doesn't matter, just make a new mask self.cell_mask = self._generate_cell_mask(sanitized_region) else: # union or not union, doesn't matter # we don't allow combining with current mask msg = f"Can't `constrain_spatial` after `{self.constrained}`." if self.constrained == "constrain_spatial": msg += ( " To combine multiple `constrain_spatial` calls use `union` kwarg." ) elif union: msg += ( " `union` kwarg can only be used to combine with" " `constrain_spatial` constraints." ) raise RuntimeError(msg) self._update_range_masks() return
[docs] def convert_masks_to_ranges(self) -> None: """ Convert the masks to range masks. These are more compact than boolean masks so they can help save space on highly constrained machines. See Also -------- convert_masks_to_bool """ if not self.range_mask: # We must do the whole boolean mask stuff. To do that, we # First, convert each boolean mask into an integer mask # Use the accelerate.ranges_from_array function to convert # This into a set of ranges. for mask in self.update_list: where_array = np.where(getattr(self, mask))[0] setattr(self, f"{mask}_size", where_array.size) setattr(self, mask, ranges_from_array(where_array)) self.range_mask = True return
[docs] def convert_masks_to_bool(self) -> None: """ Convert the masks to boolean masks. These are sometimes easier to work with than range masks but usually use more memory. See Also -------- convert_masks_to_ranges """ if self.range_mask: empty_masks, sizes = self._generate_empty_masks(fill_value=False) for mask_key, mask_value in empty_masks.items(): for r in getattr(self, mask_key): mask_value[slice(*r)] = True setattr(self, mask_key, mask_value) for size_key, size_value in sizes.items(): setattr(self, size_key, size_value) self.range_mask = False return
[docs] @_constraint def constrain_index(self, index: int) -> None: """ Constrain the mask to a single row. Only works for files where all arrays have the same length. Intended for use with SOAP catalogues, mask to read only a single row. Parameters ---------- index : int The index of the row to select. """ # constraint decorator will set correctly, overwriting `constrain_indices` if self.constrained is not None: # could let constrain_indices check this, but want custom message raise RuntimeError(f"Can't `constrain_index` after `{self.constrained}`.") self.constrain_indices([index]) return
[docs] @_constraint def constrain_indices(self, indices: list[int]) -> None: """ Constrain the mask to a list of rows. Only works for files where all arrays have the same length. Parameters ---------- indices : list[int] An list of the indices of the rows to mask. """ if self.constrained is not None: raise RuntimeError(f"Can't `constrain_indices` after `{self.constrained}`.") if not self.metadata.homogeneous_arrays: raise RuntimeError( "Cannot constrain to specific rows in a non-homogeneous array; you " f"currently are using a {self.metadata.output_type} file" ) if not np.all(np.asarray(indices)[:-1] <= np.asarray(indices)[1:]): # indices list is not sorted warnings.warn( "`constrain_indices` selects indices in order, sorting list of indices " "before masking." ) sorted_indices = np.sort(indices) else: sorted_indices = np.asarray(indices) for mask in self.update_list: if self.range_mask: setattr(self, mask, ranges_from_array(sorted_indices)) setattr(self, f"{mask}_size", len(indices)) else: comparison_array = np.zeros(getattr(self, mask).size, dtype=bool) comparison_array[indices] = True setattr( self, mask, np.logical_and(getattr(self, mask), comparison_array) ) return
def _get_masked_counts_offsets( self, ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: """ Return the particle counts and offsets in cells selected by the mask. Returns ------- dict[str, np.array], dict[str, np.array] Dictionaries containing the particle counts and offsets for each particle type. For example, the particle counts dictionary would be of the form .. code-block:: python {"gas": [g_0, g_1, ...], "dark matter": [bh_0, bh_1, ...], ...} where the keys would be each of the particle types and values are arrays of the number of corresponding particles in each cell (in this case there would be g_0 gas particles in the first cell, g_1 in the second, etc.). The structure of the dictionaries is the same for the offsets, with the arrays now storing the offset of the first particle in the cell. """ masked_counts = {} masked_offsets = {} for part_type in self.counts.keys(): counts = self.counts[part_type] offsets = self.offsets[part_type] if self.range_mask: # figure out what cell each range starts in (cell_ranges[:, 0]) and ends # in (cell_ranges[:, 1]) # counts and offsets are sorted by offsets when loaded: searchsorted valid cell_ranges = np.searchsorted( offsets[1:], getattr(self, f"_{part_type}"), side="right", ) masked_counts[part_type] = np.zeros(counts.shape, dtype=np.int64) for (start_cell, end_cell), mask_range in zip( cell_ranges, getattr(self, f"_{part_type}") ): for cell in range(start_cell, end_cell + 1): masked_counts[part_type][cell] += min( mask_range[1], offsets[cell] + counts[cell] ) - max(mask_range[0], offsets[cell]) else: masked_counts[part_type] = np.array( [ getattr(self, f"_{part_type}")[slice(*cell_slice)].sum() for cell_slice in np.vstack((offsets, offsets + counts)).T ] ) masked_offsets[part_type] = np.r_[ 0, np.cumsum(masked_counts[part_type])[:-1] ] return masked_counts, masked_offsets