"""Provide functions and objects that use mask information from SWIFT snapshots."""
import warnings
import h5py
import numpy as np
from pathlib import Path
from swiftsimio.metadata.field.attr_reader import (
load_field_units as _load_field_units,
load_field_cosmo_factor as _load_field_cosmo_factor,
load_field_physical as _load_field_physical,
)
from swiftsimio.metadata.objects import SWIFTMetadata
from swiftsimio.objects import InvalidSnapshot, cosmo_array, cosmo_quantity
from swiftsimio.accelerated import ranges_from_array
from swiftsimio._handle_provider import HandleProvider
from typing import Callable, Sequence
_DEFAULT_SAFE_PADDING = 0.1
_GROUPCAT_OUTPUT_TYPES = ["FOF", "SOAP", "FOFSubset", "SOAPSubset"]
def _constraint(method: Callable) -> Callable:
"""
Decorate a function that constrains a :class:`~swiftsimio.masks.SWIFTMask`.
Parameters
----------
method : Callable
A method of the :class:`~swiftsimio.masks.SWIFTMask` class that applies a
constraint to the mask.
Returns
-------
Callable
The decorated method.
"""
def wrapped(
self: "SWIFTMask",
*args: tuple,
**kwargs: dict,
) -> None: # noqa numpydoc ignore=GL08
# omit docstring so that sphinx picks up docstring of wrapped function
retval = method(self, *args, **kwargs)
self.constrained = method.__name__ # only set after function call
return retval
return wrapped
[docs]
class SWIFTMask(HandleProvider):
"""
Main masking object. This can have masks for any present particle type in it.
For catalogues (e.g. SOAP) where all arrays are the same size there can be a "shared"
mask that is stored once and used for all of them.
Takes the SWIFT metadata and enables individual property-by-property masking
when reading from snapshots. When masking like this order-in-file is not preserved,
i.e. the 7th particle may not be the 7th particle in the file.
Parameters
----------
filename : Path
File to read cell metadata from.
metadata : SWIFTMetadata
Metadata loaded from snapshot.
safe_padding : bool or float, optional
If snapshot does not specify bounding box of cell particles (``MinPositions``,
``MaxPositions``), pad the mask to gurantee that *all* particles in requested
spatial region(s) are selected. If the bounding box metadata is present, this
argument is ignored. The default (``0.1``) is to pad by 0.1 times the cell
length. Padding can be disabled (``False``) or set to a different fraction of
the cell length (e.g. ``0.5``). Only entire cells are loaded, but if the
region boundary is more than ``safe_padding`` from a cell boundary the
neighbouring cell is not read. Switching off can reduce I/O load by up to a
factor of 30 in some cases (but a few particles in region could be missing).
See https://swiftsimio.readthedocs.io/en/latest/masking/index.html for further
details.
handle : h5py.File, optional
The file handle to read metadata from.
spatial_only : bool, optional
Deprecated, any necessary conversions handled automatically.
"""
filename: Path
constrained: str | None
def __init__(
self,
filename: Path,
metadata: SWIFTMetadata,
*,
safe_padding: bool | float = _DEFAULT_SAFE_PADDING,
handle: h5py.File | None = None,
spatial_only: bool | None = None, # deprecated
) -> None:
if spatial_only is not None:
warnings.warn(
"`spatial_only` is deprecated, any necessary conversions are now handled"
" automatically.",
DeprecationWarning,
)
super().__init__(filename, handle=handle)
self.metadata = metadata
self.units = metadata.units
self._group_mapping: dict | None = None
self._group_size_mapping: dict | None = None
self._update_list: list | None = None
self.range_mask = True
self.constrained = None
if safe_padding is True:
self.safe_padding = _DEFAULT_SAFE_PADDING
elif safe_padding is False:
self.safe_padding = 0.0
else:
self.safe_padding = safe_padding
if not self.metadata.masking_valid:
raise NotImplementedError(
f"Masking not supported for {self.metadata.output_type} filetype"
)
if self.metadata.partial_snapshot:
raise InvalidSnapshot(
"You cannot use masks on partial snapshots. Please use the virtual "
"file generated by SWIFT (use snapshot.hdf5, not snapshot.0.hdf5)."
)
self._unpack_cell_metadata()
self.cell_mask = self._generate_cell_mask(restrict=None)
self._update_range_masks()
self._close_handle_if_manager()
@property
def group_mapping(self) -> dict[str, str]:
"""
Create mapping between "group names" and their underlying cell metadata names.
Allows for aliases to be used instead of re-creating masks.
Returns
-------
dict[str, str]:
The dictionary of corresponding names.
"""
if self._group_mapping is not None:
return self._group_mapping
if self.metadata.shared_cell_counts is None:
# Each and every particle type has its own cell counts, offsets,
# and hence masks.
self._group_mapping = {
group: f"_{group}" for group in self.metadata.present_group_names
}
else:
# We actually only have _one_ mask!
self._group_mapping = {
group: "_shared" for group in self.metadata.present_group_names
}
return self.group_mapping
@property
def group_size_mapping(self) -> dict[str, str]:
"""
Create mapping between "group names" and their underlying cell metadata names.
Allows for aliases to be used instead of re-creating masks.
Returns
-------
dict[str, str]:
The dictionary of corresponding names.
"""
if self._group_size_mapping is not None:
return self._group_size_mapping
if self.metadata.shared_cell_counts is None:
# Each and every particle type has its own cell counts, offsets,
# and hence masks.
self._group_size_mapping = {
f"{group}_size": f"_{group}_size"
for group in self.metadata.present_group_names
}
else:
# We actually only have _one_ mask!
self._group_size_mapping = {
f"{group}_size": "_shared_size"
for group in self.metadata.present_group_names
}
return self.group_size_mapping
@property
def update_list(self) -> list[str]:
"""
Get list of internal mask variables that need updating when changing spatial mask.
Returns
-------
list[str]
List of the variable names that need updating.
"""
if self._update_list is not None:
return self._update_list
self._update_list = (
# Each and every particle type has its own cell counts, offsets,
# and hence masks:
[f"_{group}" for group in self.metadata.present_group_names]
if self.metadata.shared_cell_counts is None
# Or there is only one shared mask:
else ["_shared"]
)
return self._update_list
def __getattr__(self, name: str) -> np.ndarray:
"""
Overload the ``__getattr__`` method to allow for direct access to the masks.
Parameters
----------
name : str
Name of one of the available particle types.
Returns
-------
np.ndarray[bool]
The requested mask.
Raises
------
AttributeError
If the requested particle type is not found.
"""
mappings = {
**self.group_mapping,
**self.group_size_mapping,
}
underlying_name = mappings.get(name, None)
if underlying_name is None:
raise AttributeError(f"Attribute {name} not found in SWIFTMask")
return getattr(self, underlying_name)
def _generate_empty_masks(
self, fill_value: bool = True
) -> tuple[dict[str, np.ndarray], dict[str, int]]:
"""
Generate empty (all ``True``) masks for all available particle types.
Parameters
----------
fill_value : bool
Value to fill the arrays with, either ``True`` or ``False``.
Returns
-------
dict
Contains the names of the masks as keys and the masks themselves as
corresponding values.
dict
Contains the sizes of the masks labelled with keys matching the masks names
suffixed with ``"_size"``.
"""
mask_func = np.ones if fill_value else np.zeros
empty_masks = {}
sizes = {}
if self.metadata.shared_cell_counts is not None:
size = getattr(
self.metadata, f"n_{self.metadata.shared_cell_counts.lower()}"
)
empty_masks["_shared"] = mask_func(size, dtype=bool)
sizes["_shared_size"] = size
else:
for group_name, data_name in self.group_mapping.items():
size = getattr(self.metadata, f"n_{group_name}")
empty_masks[data_name] = mask_func(size, dtype=bool)
sizes[f"{data_name}_size"] = size
return empty_masks, sizes
def _unpack_cell_metadata(self) -> None:
"""
Unpack the cell metadata into local (to the class) variables.
We do not read in information for empty cells.
"""
# Reset this in case for any reason we have messed them up
self.counts = {}
self.offsets = {}
self.minpositions = {}
self.maxpositions = {}
cell_handle = self.handle["Cells"]
count_handle = cell_handle["Counts"]
metadata_handle = cell_handle["Meta-data"]
centers_values = cell_handle["Centres"][...]
if (
"MinPositions" in cell_handle.keys()
and "MaxPositions" in cell_handle.keys()
):
# Older versions of SWIFT don't have this information
minpos_handle = cell_handle["MinPositions"]
maxpos_handle = cell_handle["MaxPositions"]
else:
minpos_handle, maxpos_handle = None, None
try:
offset_handle = cell_handle["OffsetsInFile"]
except KeyError:
# Previous version of SWIFT did not have distributed
# file i/o implemented
offset_handle = cell_handle["Offsets"]
if self.metadata.shared_cell_counts is not None:
# Single - called _shared.
self.offsets["shared"] = offset_handle[self.metadata.shared_cell_counts][:]
self.counts["shared"] = count_handle[self.metadata.shared_cell_counts][:]
else:
for group, group_name in zip(
self.metadata.present_groups, self.metadata.present_group_names
):
self.offsets[group_name] = offset_handle[group][:]
self.counts[group_name] = count_handle[group][:]
if minpos_handle is not None and maxpos_handle is not None:
for group, group_name in zip(
self.metadata.present_groups, self.metadata.present_group_names
):
minpos_values = minpos_handle[group][:]
maxpos_values = maxpos_handle[group][:]
self.minpositions[group_name] = np.where(
centers_values - 0.5 * metadata_handle.attrs["size"]
< minpos_values,
centers_values - 0.5 * metadata_handle.attrs["size"],
minpos_values,
)
self.maxpositions[group_name] = np.where(
centers_values + 0.5 * metadata_handle.attrs["size"]
> maxpos_values,
centers_values + 0.5 * metadata_handle.attrs["size"],
maxpos_values,
)
else:
# be conservative: pad (default by 0.1 cell) in case particles drifed
# (unless for group catalogues)
pad_cells = (
0
if self.metadata.output_type in _GROUPCAT_OUTPUT_TYPES
else self.safe_padding
)
if self.metadata.output_type not in _GROUPCAT_OUTPUT_TYPES:
warnings.warn(
"Snapshot does not contain Cells/MinPositions and Cells/MaxPositions"
f" metadata. Padding region by {pad_cells} times cell length to"
" account for drifted particles. This behaviour can be"
" configured/disabled with the `safe_padding` parameter when creating"
" the mask. See "
"https://swiftsimio.readthedocs.io/en/latest/masking/index.html"
" for further details."
)
# +/- 0.5 here is the cell size itself:
self.minpositions["shared"] = (
centers_values - (pad_cells + 0.5) * metadata_handle.attrs["size"]
)
self.maxpositions["shared"] = (
centers_values + (pad_cells + 0.5) * metadata_handle.attrs["size"]
)
# Only want to compute this once (even if it is fast, we do not
# have a reliable stable sort in the case where cells do not
# contain at least one of each type of particle).
self.cell_sort = None
# Now perform sort:
for key in self.offsets.keys():
offsets = self.offsets[key]
counts = self.counts[key]
# When using MPI, we cannot assume that these are sorted.
if self.cell_sort is None:
# Only compute once; not stable between particle
# types if some datasets do not have particles in a cell!
self.cell_sort = np.argsort(offsets, stable=True)
self.offsets[key] = offsets[self.cell_sort]
self.counts[key] = counts[self.cell_sort]
# Also need to sort centers in the same way
self.centers = cosmo_array(
centers_values[self.cell_sort],
units=self.units.length,
comoving=True,
scale_factor=self.metadata.scale_factor,
scale_exponent=1,
)
# And sort min & max positions, too.
for k in self.minpositions.keys():
self.minpositions[k] = cosmo_array(
self.minpositions[k][self.cell_sort],
units=self.units.length,
comoving=True,
scale_factor=self.metadata.scale_factor,
scale_exponent=1,
)
for k in self.maxpositions.keys():
self.maxpositions[k] = cosmo_array(
self.maxpositions[k][self.cell_sort],
units=self.units.length,
comoving=True,
scale_factor=self.metadata.scale_factor,
scale_exponent=1,
)
if minpos_handle is None and maxpos_handle is None:
for group_name in self.metadata.present_group_names:
self.minpositions[group_name] = self.minpositions["shared"]
self.maxpositions[group_name] = self.maxpositions["shared"]
# Note that we cannot assume that these are cubic, unfortunately.
self.cell_size = cosmo_array(
metadata_handle.attrs["size"],
units=self.units.length,
comoving=True,
scale_factor=self.metadata.scale_factor,
scale_exponent=1,
)
return
[docs]
def constrain_mask(
self,
group_name: str,
quantity: str,
lower: cosmo_quantity,
upper: cosmo_quantity,
) -> None:
"""
Use :meth:`~swiftsimio.masks.SWIFTMask.constrain_property` instead.
This name is deprecated and will be removed in the future.
Parameters
----------
group_name : str
Particle type (e.g. ``"gas"``).
quantity : str
Quantity being constrained (e.g. ``"temperatures"``).
lower : ~swiftsimio.objects.cosmo_quantity
Constraint lower bound.
upper : ~swiftsimio.objects.cosmo_quantity
Constraint upper bound.
"""
warnings.warn(
"`constrain_mask` is deprecated, use `constrain_property` with the same "
"arguments instead.",
DeprecationWarning,
)
self.constrain_property(group_name, quantity, lower, upper)
[docs]
@_constraint
def constrain_property(
self,
group_name: str,
quantity: str,
lower: cosmo_quantity,
upper: cosmo_quantity,
) -> None:
"""
Constrain the mask further for a given particle type.
Chooses only particles with a property bounded between lower and upper values.
We update the mask such that:
.. code-block:: python
lower < group_name.quantity <= upper
The quantities must be :class:`~swiftsimio.objects.cosmo_quantity`, i.e. must
have units and cosmology information attached.
Parameters
----------
group_name : str
Particle type (e.g. ``"gas"``).
quantity : str
Quantity being constrained (e.g. ``"temperatures"``).
lower : ~swiftsimio.objects.cosmo_quantity
Constraint lower bound.
upper : ~swiftsimio.objects.cosmo_quantity
Constraint upper bound.
See Also
--------
constrain_spatial
Method to generate spatially constrained cell mask.
"""
self.convert_masks_to_bool() # no-op if already bool
data_name = self.group_mapping[group_name]
current_mask = getattr(self, data_name)
group_metadata = getattr(self.metadata, f"{group_name}_properties")
handle_dict = {
k: v for k, v in zip(group_metadata.field_names, group_metadata.field_paths)
}
handle = handle_dict[quantity]
# Load in the relevant data.
h5file: h5py.File
with self.metadata.open_file() as h5file:
field_attributes = h5file[handle].attrs
unit = _load_field_units(field_attributes, self.metadata.units)
physical = _load_field_physical(field_attributes)
cf = _load_field_cosmo_factor(field_attributes, self.metadata)
if isinstance(h5file, (h5py.File, h5py.Group)):
# When reading from a local HDF5 file this is faster than
# just using the boolean indexing because h5py has slow
# indexing routines.
data = np.take(h5file[handle], np.where(current_mask)[0], axis=0)
else:
# Otherwise, assume we can just index the dataset. This
# generates a single http request for remote datasets.
data = h5file[handle][current_mask, ...]
# Wrap result in a cosmo_array
data = cosmo_array(
data,
units=unit,
comoving=not physical,
cosmo_factor=cf,
)
new_mask = np.logical_and.reduce([data > lower, data <= upper])
current_mask[current_mask] = new_mask
setattr(self, data_name, current_mask)
return
def _generate_cell_mask(
self, restrict: cosmo_array | None = None
) -> dict[str, np.ndarray]:
"""
Generate a spatially restricted mask for cells.
Takes the cell metadata and finds the mask for the _cells_ that are
within the spatial region defined by the spatial mask. Not for
user use.
Parameters
----------
restrict : cosmo_array
Restrict is a (3,2) cosmo_array giving the lower and upper bounds for each
axis, e.g.
.. code-block:: python
restrict = cosmo_array(
[
[0.5, 0.7],
[0.1, 0.9],
[0.0, 0.1]
],
u.Mpc,
comoving=True,
scale_factor=1.0,
scale_exponent=1,
)
Returns
-------
np.ndarray[bool]
Mask to indicate which cells are within the specified spatial range.
Raises
------
ValueError
If the mask boundaries are outside the interval [-Lbox/2, 3*Lbox/2].
"""
if self.metadata.output_type in _GROUPCAT_OUTPUT_TYPES:
cell_mask = {"shared": np.ones(len(self.centers), dtype=bool)}
else:
# particles may drift from their cells, mask each type separately
cell_mask = {
group_name: np.ones(len(self.centers), dtype=bool)
for group_name in self.metadata.present_group_names
}
if restrict is None:
return cell_mask
for dimension in range(0, 3):
lower = restrict[dimension][0]
upper = restrict[dimension][1]
boxsize = self.metadata.boxsize[dimension]
if np.logical_or.reduce(
(
lower < -boxsize / 2,
upper < -boxsize / 2,
lower > 3 * boxsize / 2,
upper > 3 * boxsize / 2,
)
):
# because we're only going to make one periodic copy on either side,
# we're in trouble
raise ValueError(
"Mask region boundaries must be in interval [-boxsize/2, 3*boxsize/2]"
f" along {'xyz'[dimension]}-axis."
)
if restrict[dimension] is None or np.abs(upper - lower) > boxsize:
# keep everything along this axis
continue
if upper < lower:
# inverted case, convert to a "normal" case in target window
if lower > boxsize / 2:
lower -= boxsize
elif upper < boxsize / 2: # don't shift both else we get the whole box!
upper += boxsize
group_names = (
["shared"]
if self.metadata.output_type in _GROUPCAT_OUTPUT_TYPES
else self.metadata.present_group_names
)
for group_name in group_names:
# selection intersects one of the 3 periodic copies of a cell:
this_mask = np.logical_or.reduce(
[
np.logical_and(
self.maxpositions[group_name][
cell_mask[group_name], dimension
]
+ shift * boxsize
> lower,
self.minpositions[group_name][
cell_mask[group_name], dimension
]
+ shift * boxsize
< upper,
)
for shift in (-1, 0, 1)
]
)
cell_mask[group_name][cell_mask[group_name]] = this_mask
return cell_mask
def _update_range_masks(self) -> None:
"""
Update the particle masks using the cell masks.
We actually overwrite all non-used cells with ``False``, rather than the
inverse, as we assume initially that we want to read all particles in,
and we want to respect other masks that may have been applied to the data.
"""
for data_name in self.update_list:
count_name = data_name[1:] # Remove the underscore
if self.range_mask:
counts = self.counts[count_name][self.cell_mask[count_name]]
offsets = self.offsets[count_name][self.cell_mask[count_name]]
this_mask = [[o, c + o] for c, o in zip(counts, offsets)]
setattr(self, data_name, np.array(this_mask))
setattr(self, f"{data_name}_size", np.sum(counts))
else:
counts = self.counts[count_name][
np.logical_not(self.cell_mask[count_name])
]
offsets = self.offsets[count_name][
np.logical_not(self.cell_mask[count_name])
]
# We must do the whole boolean mask business.
this_mask = getattr(self, data_name)
for count, offset in zip(counts, offsets):
this_mask[offset : count + offset] = False
return
def _sanitize_region(self, region: Sequence) -> cosmo_array:
"""
Coerce user-provided region to (3, 2) :class:`~swiftsimio.objects.cosmo_array`.
The input region should have length 3 (otherwise error). The rows should have
length 2, but may be ``None``. We check these conditions, then package everything
up into a :class:`~swiftsimio.objects.cosmo_array`.
We have to be careful: :class:`~swiftsimio.objects.cosmo_array` will check the
input for compatibility of units, cosmo_factors, etc. and convert them to be
consistent if needed, but only up to a depth of 1, i.e. in this:
``cosmo_array([[cosmo_quantity(...), cosmo_quantity(...)], None, None])`` the
cosmo-ness of the :class:`~swiftsimio.objects.cosmo_quantity` elements would be
ignored because they are nested two-deep in a list.
Parameters
----------
region : list
The user-provided region.
Returns
-------
~swiftsimio.objects.cosmo_array
The region as a (3, 2) :class:`~swiftsimio.objects.cosmo_array`, with ``None``
rows replaced by the full extent of the box.
"""
if len(region) != 3:
raise ValueError("`restrict` must have length == 3.")
for ax_region in region:
if ax_region is not None and len(ax_region) != 2:
raise ValueError(
"Rows of `restrict` must have length == 2 (or be ``None``)."
)
return cosmo_array(
[
cosmo_array(ax_region)
if ax_region is not None
else np.array([0, 1]) * b
for ax_region, b in zip(region, self.metadata.boxsize)
]
)
[docs]
@_constraint
def constrain_spatial(
self,
restrict: Sequence,
union: bool = False,
intersect: bool | None = None, # deprecated
) -> None:
"""
Use the cell metadata to select particles within a cuboid region.
This mask is necessarily approximate and is coarse-grained to the cell size.
Parameters
----------
restrict : list
Restrict is a shape (3, 2) iterable giving the lower and upper bounds for each
axis. The bounds need to be given as :class:`~swiftsimio.objects.cosmo_array`
or :class:`~swiftsimio.objects.cosmo_quantity`. For example:
.. code-block:: python
restrict = cosmo_array(
[
[0.5, 0.7],
[0.1, 0.9],
[0.0, 0.1]
],
u.Mpc,
comoving=True,
scale_factor=1.0,
scale_exponent=1,
)
If no constraint is desired along an axis, ``None`` can be given instead. For
example, to select a "slab" in the x-y plane:
.. code-block:: python
zmin = cosmo_quantity(
5.0,
u.Mpc,
comoving=True,
scale_factor=1.0,
scale_exponent=1
)
zmax = cosmo_quantity(
6.0,
u.Mpc,
comoving=True,
scale_factor=1.0,
scale_exponent=1
)
restrict = [
None,
None,
[zmin, zmax],
]
union : bool, optional
If ``True``, combine the spatial mask with any existing spatial mask to
select two (or more) regions with repeated calls to ``constrain_spatial``.
intersect : bool, optional
Deprecated due to misleading name, use ``union`` instead.
See Also
--------
constrain_mask
Method to further refine mask, selecting particles by their properties.
"""
if intersect is not None:
warnings.warn(
"`intersect` is deprecated because this term is misleading, use `union` "
"(with the same value) instead. Overriding `union` value if provided.",
DeprecationWarning,
)
union = intersect
sanitized_region = self._sanitize_region(restrict)
if self.constrained == "constrain_spatial" and union:
# we are in union mode and already have a spatial constraint
new_mask = self._generate_cell_mask(sanitized_region)
for group_name in self.update_list:
group_name = group_name[1:] # remove leading underscore
self.cell_mask[group_name] = np.logical_or(
self.cell_mask[group_name], new_mask[group_name]
)
elif self.constrained is None:
# union or not union, doesn't matter, just make a new mask
self.cell_mask = self._generate_cell_mask(sanitized_region)
else:
# union or not union, doesn't matter
# we don't allow combining with current mask
msg = f"Can't `constrain_spatial` after `{self.constrained}`."
if self.constrained == "constrain_spatial":
msg += (
" To combine multiple `constrain_spatial` calls use `union` kwarg."
)
elif union:
msg += (
" `union` kwarg can only be used to combine with"
" `constrain_spatial` constraints."
)
raise RuntimeError(msg)
self._update_range_masks()
return
[docs]
def convert_masks_to_ranges(self) -> None:
"""
Convert the masks to range masks.
These are more compact than boolean masks so they can help save space on highly
constrained machines.
See Also
--------
convert_masks_to_bool
"""
if not self.range_mask:
# We must do the whole boolean mask stuff. To do that, we
# First, convert each boolean mask into an integer mask
# Use the accelerate.ranges_from_array function to convert
# This into a set of ranges.
for mask in self.update_list:
where_array = np.where(getattr(self, mask))[0]
setattr(self, f"{mask}_size", where_array.size)
setattr(self, mask, ranges_from_array(where_array))
self.range_mask = True
return
[docs]
def convert_masks_to_bool(self) -> None:
"""
Convert the masks to boolean masks.
These are sometimes easier to work with than range masks but usually use more
memory.
See Also
--------
convert_masks_to_ranges
"""
if self.range_mask:
empty_masks, sizes = self._generate_empty_masks(fill_value=False)
for mask_key, mask_value in empty_masks.items():
for r in getattr(self, mask_key):
mask_value[slice(*r)] = True
setattr(self, mask_key, mask_value)
for size_key, size_value in sizes.items():
setattr(self, size_key, size_value)
self.range_mask = False
return
[docs]
@_constraint
def constrain_index(self, index: int) -> None:
"""
Constrain the mask to a single row.
Only works for files where all arrays have the same length. Intended for use with
SOAP catalogues, mask to read only a single row.
Parameters
----------
index : int
The index of the row to select.
"""
# constraint decorator will set correctly, overwriting `constrain_indices`
if self.constrained is not None:
# could let constrain_indices check this, but want custom message
raise RuntimeError(f"Can't `constrain_index` after `{self.constrained}`.")
self.constrain_indices([index])
return
[docs]
@_constraint
def constrain_indices(self, indices: list[int]) -> None:
"""
Constrain the mask to a list of rows.
Only works for files where all arrays have the same length.
Parameters
----------
indices : list[int]
An list of the indices of the rows to mask.
"""
if self.constrained is not None:
raise RuntimeError(f"Can't `constrain_indices` after `{self.constrained}`.")
if not self.metadata.homogeneous_arrays:
raise RuntimeError(
"Cannot constrain to specific rows in a non-homogeneous array; you "
f"currently are using a {self.metadata.output_type} file"
)
if not np.all(np.asarray(indices)[:-1] <= np.asarray(indices)[1:]):
# indices list is not sorted
warnings.warn(
"`constrain_indices` selects indices in order, sorting list of indices "
"before masking."
)
sorted_indices = np.sort(indices)
else:
sorted_indices = np.asarray(indices)
for mask in self.update_list:
if self.range_mask:
setattr(self, mask, ranges_from_array(sorted_indices))
setattr(self, f"{mask}_size", len(indices))
else:
comparison_array = np.zeros(getattr(self, mask).size, dtype=bool)
comparison_array[indices] = True
setattr(
self, mask, np.logical_and(getattr(self, mask), comparison_array)
)
return
def _get_masked_counts_offsets(
self,
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""
Return the particle counts and offsets in cells selected by the mask.
Returns
-------
dict[str, np.array], dict[str, np.array]
Dictionaries containing the particle counts and offsets for each particle
type. For example, the particle counts dictionary would be of the form
.. code-block:: python
{"gas": [g_0, g_1, ...],
"dark matter": [bh_0, bh_1, ...], ...}
where the keys would be each of the particle types and values are arrays
of the number of corresponding particles in each cell (in this case there
would be g_0 gas particles in the first cell, g_1 in the second, etc.).
The structure of the dictionaries is the same for the offsets, with the
arrays now storing the offset of the first particle in the cell.
"""
masked_counts = {}
masked_offsets = {}
for part_type in self.counts.keys():
counts = self.counts[part_type]
offsets = self.offsets[part_type]
if self.range_mask:
# figure out what cell each range starts in (cell_ranges[:, 0]) and ends
# in (cell_ranges[:, 1])
# counts and offsets are sorted by offsets when loaded: searchsorted valid
cell_ranges = np.searchsorted(
offsets[1:],
getattr(self, f"_{part_type}"),
side="right",
)
masked_counts[part_type] = np.zeros(counts.shape, dtype=np.int64)
for (start_cell, end_cell), mask_range in zip(
cell_ranges, getattr(self, f"_{part_type}")
):
for cell in range(start_cell, end_cell + 1):
masked_counts[part_type][cell] += min(
mask_range[1], offsets[cell] + counts[cell]
) - max(mask_range[0], offsets[cell])
else:
masked_counts[part_type] = np.array(
[
getattr(self, f"_{part_type}")[slice(*cell_slice)].sum()
for cell_slice in np.vstack((offsets, offsets + counts)).T
]
)
masked_offsets[part_type] = np.r_[
0, np.cumsum(masked_counts[part_type])[:-1]
]
return masked_counts, masked_offsets