Source code for swiftsimio.visualisation.volume_render

"""
Basic volume render for SPH data.

This takes the 3D positions of the particles and projects them onto a grid.
"""

from typing import Literal
import numpy as np
from swiftsimio import SWIFTDataset, cosmo_array
from swiftsimio.accelerated import jit

from swiftsimio.optional_packages import plt

from swiftsimio.visualisation.smoothing_length import backends_get_hsml
from swiftsimio.visualisation.volume_render_backends import backends, backends_parallel
from swiftsimio.visualisation._vistools import (
    _get_projection_field,
    _get_region_info,
    _get_rotated_and_wrapped_coordinates,
    backend_strip_and_restore_cosmo_and_units,
)


[docs] def render_gas( data: SWIFTDataset, resolution: int, project: str | None = "masses", parallel: bool = False, rotation_matrix: np.ndarray | None = None, rotation_center: cosmo_array | None = None, region: cosmo_array | None = None, periodic: bool = True, ) -> cosmo_array: """ Create a data-field weighted 3D render of a SWIFT dataset as a voxel grid. Parameters ---------- data : SWIFTDataset Dataset from which render is extracted. resolution : int Specifies size of return np.array. project : str, optional Data field to be projected. Default is ``"mass"``. If ``None`` then simply count number of particles. The result is comoving if this is comoving, else it is physical. parallel : bool Used to determine if we will create the image in parallel. This defaults to False, but can speed up the creation of large images significantly at the cost of increased memory usage. rotation_matrix : np.array, optional Rotation matrix (3x3) that describes the rotation of the box around ``rotation_center``. In the default case, this provides a volume render viewed along the z axis. rotation_center : cosmo_array, optional Center of the rotation. If you are trying to rotate around a galaxy, this should be the most bound particle. region : cosmo_array, optional Determines where the image will be created (this corresponds to the left and right-hand edges, and top and bottom edges, and front and back edges) if it is not None. It should have a length of six, and take the form: [x_min, x_max, y_min, y_max, z_min, z_max] Particles outside of this range are still considered if their smoothing lengths overlap with the range. periodic : bool, optional Account for periodic boundaries for the simulation box? Default is ``True``. Returns ------- cosmo_array Voxel grid with units of project / length^3, of size ``resolution`` x ``resolution`` x ``resolution``. Comoving if ``project`` data are comoving, else physical. See Also -------- slice_gas_pixel_grid Creates a 2D slice of a SWIFT dataset. """ data = data.gas m = _get_projection_field(data, project) region_info = _get_region_info(data, region, require_cubic=True, periodic=periodic) hsml = backends_get_hsml["sph"](data) x, y, z = _get_rotated_and_wrapped_coordinates( data, rotation_matrix, rotation_center, periodic ) normed_x = (x - region_info["x_min"]) / region_info["x_range"] normed_y = (y - region_info["y_min"]) / region_info["y_range"] normed_z = (z - region_info["z_min"]) / region_info["z_range"] if periodic: # place everything in the region inside [0, 1], the backend will tile as needed normed_x %= region_info["periodic_box_x"] normed_y %= region_info["periodic_box_y"] normed_z %= region_info["periodic_box_z"] kwargs = dict( x=normed_x, y=normed_y, z=normed_z, m=m, h=hsml / region_info["x_range"], # cubic so x_range == y_range == z_range res=resolution, box_x=region_info["periodic_box_x"], box_y=region_info["periodic_box_y"], box_z=region_info["periodic_box_z"], ) norm = region_info["x_range"] * region_info["y_range"] * region_info["z_range"] backend_func = (backends_parallel if parallel else backends)["scatter"] image = backend_strip_and_restore_cosmo_and_units(backend_func, norm=norm)(**kwargs) return image
[docs] @jit(nopython=True, fastmath=True) def render_voxels_to_array(data: np.array, center: float, width: float) -> np.array: """ Insert voxel values into a 2D image grid. Handles a single render function (call multiple times for multiple render functions). Parameters ---------- data : np.ndarray The 3D voxel array. center : float The center of the rendering function. width : float The width of the rendering function. Returns ------- np.ndarray The 2D image array. """ output = np.zeros((data.shape[0], data.shape[1])) for i in range(data.shape[0]): for j in range(data.shape[1]): out = 0.0 for k in range(data.shape[2]): inner = (center - data[i, j, k]) / width const = 1.0 / (width * np.sqrt(2.0 * np.pi)) fac = np.exp(-0.5 * inner * inner) out += fac * const output[j, i] = out return output
[docs] def visualise_render( render: np.ndarray, centers: list[float], widths: list[float] | float, cmap: str = "viridis", return_type: Literal["all", "lighten", "add"] = "lighten", norm: "list[plt.Normalize] | plt.Normalize | None" = None, ) -> "tuple[list[np.ndarray] | np.ndarray, list[plt.Normalize]]": """ Visualise a render with multiple centers and widths. Parameters ---------- render : np.array The render to visualise. You should scale this appropriately before using this function (e.g. use a logarithmic transform!) and pass in the 'value' np.array, not the original cosmo_array or unyt_array. centers : list[float] The centers of your rendering functions. widths : list[float] | float The widths of your rendering functions. If a single float, all functions will have the same width. cmap : str The colormap to use for the rendering functions. return_type : Literal["all", "lighten", "add"] The type of return. If "all", all images are returned. If "lighten", the maximum of all images is returned. If "add", the sum of all images is returned. norm : list[plt.Normalize] | plt.Normalize | None The normalisation to use for the rendering functions. If a single normalisation, all functions will use the same normalisation. Returns ------- list[np.array] | np.array The images of the rendering functions. If return_type is "all", this will be a list of images. If return_type is "lighten" or "add", this will be a single image. list[plt.Normalize] The normalisations used for the rendering functions. """ if isinstance(widths, float): widths = [widths] * len(centers) if norm is None: norm = [plt.Normalize() for _ in centers] elif not isinstance(norm, list): norm = [norm] * len(centers) colors = plt.get_cmap(cmap)(np.linspace(0, 1, len(centers)))[:, :3] images = [ n(render_voxels_to_array(render, center, width)) for n, center, width in zip(norm, centers, widths) ] images = [ np.array([color[0] * x, color[1] * x, color[2] * x]).T for color, x in zip(colors, images) ] if return_type == "all": return images, norm if return_type == "lighten": return np.max(images, axis=0), norm if return_type == "add": return sum(images), norm
[docs] def visualise_render_options( centers: list[float], widths: list[float] | float, cmap: str = "viridis" ) -> tuple["plt.Figure", "plt.Axes"]: """ Create a figure of your rendering options. The y-axis is the output value of the rendering function. The x-axis is your input quantity. You may wish to plot a histogram on top of this figure; this is why the figure axes and figure are returned. Parameters ---------- centers : list[float] The centers of your rendering functions. widths : list[float] | float The widths of your rendering functions. If a single float, all functions will have the same width. cmap : str The colormap to use for the rendering functions. Returns ------- plt.Figure The matplotlib figure object used for the plot. plt.Axes The matplotlib axes object used for the plot. """ fig, ax = plt.subplots() if isinstance(widths, float): widths = [widths] * len(centers) colors = plt.get_cmap(cmap)(np.linspace(0, 1, len(centers)))[:, :3] for center, width, color in zip(centers, widths, colors): xs = np.linspace(center - 5.0 * width, center + 5.0 * width, 100) ys = [ np.exp(-0.5 * ((center - x) / width) ** 2) / (width * np.sqrt(2.0 * np.pi)) for x in xs ] ax.axvline(center, color=color, linestyle="--") ax.plot(xs, ys, color=color) return fig, ax