Source code for brainspace.mesh.mesh_operations

"""
Basic functions on surface meshes.
"""

# Author: Oualid Benkarim <oualid.benkarim@mcgill.ca>
# License: BSD 3 clause


import warnings
from itertools import combinations
import scipy.sparse as ssp
from scipy.spatial.distance import cdist
from scipy.sparse import csgraph as csg
import numpy as np

from vtk import (vtkDataObject, vtkThreshold, vtkGeometryFilter,
                 vtkAppendPolyData, vtkPolyDataConnectivityFilter)

from .mesh_creation import build_polydata
from .mesh_elements import get_immediate_adjacency
from ..vtk_interface import wrap_vtk, serial_connect, get_output
from ..vtk_interface.pipeline import connect
from ..vtk_interface.decorators import wrap_input, append_vtk
from ..utils.parcellation import (relabel_consecutive, map_to_mask,
                                  reduce_by_labels)

ASSOC_CELLS = vtkDataObject.FIELD_ASSOCIATION_CELLS
ASSOC_POINTS = vtkDataObject.FIELD_ASSOCIATION_POINTS


@wrap_input(0)
def _surface_selection(surf, array, low=-np.inf, upp=np.inf, use_cell=False):
    """Selection of points or cells meeting some thresholding criteria.

    Parameters
    ----------
    surf : vtkPolyData or BSPolyData
        Input surface.
    array : str or ndarray
        Array used to perform selection.
    low : float or -np.inf
        Lower threshold. Default is -np.inf.
    upp : float or np.inf
        Upper threshold. Default is +np.inf.
    use_cell : bool, optional
        If True, apply selection to cells. Otherwise, use points.
        Default is False.

    Returns
    -------
    surf_selected : BSPolyData
        Surface after thresholding.

    """

    if low > upp:
        raise ValueError('Threshold not valid: [{},{}]'.format(low, upp))

    at = 'c' if use_cell else 'p'
    if isinstance(array, np.ndarray):
        drop_array = True
        array_name = surf.append_array(array, at=at)
    else:
        drop_array = False
        array_name = array
        array = surf.get_array(name=array, at=at, return_name=False)

    if array.ndim > 1:
        raise ValueError('Arrays has more than one dimension.')

    if not use_cell:
        order_name = surf.append_array(np.arange(surf.n_points), at='p')

    if low == -np.inf:
        low = array.min()
    if upp == np.inf:
        upp = array.max()

    tf = wrap_vtk(vtkThreshold, allScalars=True)
    tf.ThresholdBetween(low, upp)
    if use_cell:
        tf.SetInputArrayToProcess(0, 0, 0, ASSOC_CELLS, array_name)
    else:
        tf.SetInputArrayToProcess(0, 0, 0, ASSOC_POINTS, array_name)

    gf = wrap_vtk(vtkGeometryFilter(), merging=False)
    surf_sel = serial_connect(surf, tf, gf)

    # Check results
    n_exp = np.logical_and(array >= low, array <= upp).sum()
    n_sel = surf_sel.n_cells if use_cell else surf_sel.n_points
    if n_exp != n_sel:
        element = 'cells' if use_cell else 'points'
        warnings.warn('Number of selected {}={}. Expected {}.'
                      'This may be due to the topology after selection.'.
                      format(element, n_exp, n_sel))

    if drop_array:
        surf.remove_array(name=array_name, at=at)
        surf_sel.remove_array(name=array_name, at=at)

    if not use_cell:
        surf_sel = sort_polydata_points(surf_sel, order_name)
        surf_sel.remove_array(name=order_name, at='p')

        surf.remove_array(name=order_name, at='p')

    return surf_sel


@wrap_input(0)
def _surface_mask(surf, mask, use_cell=False):
    """Selection fo points or cells meeting some criteria.

    Parameters
    ----------
    surf : vtkPolyData or BSPolyData
        Input surface.
    mask : str or ndarray
        Binary boolean or integer array. Zero or False elements are
        discarded.
    use_cell : bool, optional
        If True, apply selection to cells. Otherwise, use points.
        Default is False.

    Returns
    -------
    surf_masked : BSPolyData
        PolyData after masking.

    """

    if isinstance(mask, np.ndarray):
        if np.issubdtype(mask.dtype, np.bool_):
            mask = mask.astype(np.uint8)
    else:
        mask = surf.get_array(name=mask, at='c' if use_cell else 'p')

    if np.any(np.unique(mask) > 1):
        raise ValueError('Cannot work with non-binary mask.')

    return _surface_selection(surf, mask, low=1, upp=1, use_cell=use_cell)


[docs]def drop_points(surf, array, low=-np.inf, upp=np.inf): """Remove surface points whose values fall within the threshold. Cells corresponding to these points are also removed. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. array : str or 1D ndarray Array used to perform selection. If str, it must be an array in the PointData attributes of the PolyData. low : float or -np.inf Lower threshold. Default is -np.inf. upp : float or np.inf Upper threshold. Default is np.inf. Returns ------- surf_selected : vtkPolyData or BSPolyData PolyData after thresholding. See Also -------- :func:`drop_cells` :func:`select_points` :func:`mask_points` """ if isinstance(array, str): array = surf.get_array(name=array, at='p') mask = np.logical_or(array < low, array > upp) return mask_points(surf, mask)
[docs]def drop_cells(surf, array, low=-np.inf, upp=np.inf): """Remove surface cells whose values fall within the threshold. Points corresponding to these cells are also removed. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. array : str or 1D ndarray Array used to perform selection. If str, it must be an array in the CellData attributes of the PolyData. low : float or -np.inf Lower threshold. Default is -np.inf. upp : float or np.inf Upper threshold. Default is np.inf. Returns ------- surf_selected : vtkPolyData or BSPolyData PolyData after thresholding. See Also -------- :func:`drop_points` :func:`select_cells` :func:`mask_cells` """ if isinstance(array, str): array = surf.get_array(name=array, at='c') mask = np.logical_or(array < low, array > upp) return mask_cells(surf, mask)
[docs]def select_points(surf, array, low=-np.inf, upp=np.inf): """Select surface points whose values fall within the threshold. Cells corresponding to these points are also kept. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. array : str or 1D ndarray Array used to perform selection. If str, it must be an array in the PointData attributes of the PolyData. low : float or -np.inf Lower threshold. Default is -np.inf. upp : float or np.inf Upper threshold. Default is np.inf. Returns ------- surf_selected : vtkPolyData or BSPolyData PolyData after selection. See Also -------- :func:`select_cells` :func:`drop_points` :func:`mask_points` """ return _surface_selection(surf, array, low=low, upp=upp)
[docs]def select_cells(surf, array, low=-np.inf, upp=np.inf): """Select surface cells whose values fall within the threshold. Points corresponding to these cells are also kept. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. array : str or 1D ndarray Array used to perform selection. If str, it must be an array in the CellData attributes of the PolyData. low : float or -np.inf Lower threshold. Default is -np.inf. upp : float or np.inf Upper threshold. Default is np.inf. Returns ------- surf_selected : vtkPolyData or BSPolyData PolyData after selection. See Also -------- :func:`select_points` :func:`drop_cells` :func:`mask_cells` """ return _surface_selection(surf, array, low=low, upp=upp, use_cell=True)
[docs]def mask_points(surf, mask): """Mask surface points. Cells corresponding to these points are also kept. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. mask : 1D ndarray Binary boolean array. Zero elements are discarded. Returns ------- surf_masked : vtkPolyData or BSPolyData PolyData after masking. See Also -------- :func:`mask_cells` :func:`drop_points` :func:`select_points` """ return _surface_mask(surf, mask)
[docs]def mask_cells(surf, mask): """Mask surface cells. Points corresponding to these cells are also kept. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. mask : 1D ndarray Binary boolean array. Zero elements are discarded. Returns ------- surf_masked : vtkPolyData or BSPolyData PolyData after masking. See Also -------- :func:`mask_points` :func:`drop_cells` :func:`select_cells` """ return _surface_mask(surf, mask, use_cell=True)
def combine_surfaces(*surfs): """ Combine surfaces. Parameters ---------- surfs : sequence of vtkPolyData and/or BSPolyData Input surfaces. Returns ------- res : BSPolyData Combination of input surfaces. See Also -------- :func:`split_surface` """ alg = vtkAppendPolyData() for s in surfs: alg = connect(s, alg, add_conn=True) return get_output(alg)
[docs]@append_vtk(to='point') def get_connected_components(surf, labeling=None, mask=None, fill=0, append=False, key='components'): """Get connected components. Connected components are based on connectivity (and same label if `labeling` is provided). Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. labeling : str or 1D ndarray, optional Array with labels. If str, it must be in the point data attributes of `surf`. Default is None. If provided, connectivity is based on neighboring points with the same label. mask : str or 1D ndarray, optional Boolean mask. If str, it must be in the point data attributes of `surf`. Default is None. If specified, only consider points within the mask. fill : int or float, optional Value used for entries out of the mask. Only used if the `target_mask` is provided. Default is 0. append : bool, optional If True, append array to point data attributes of input surface and return surface. Otherwise, only return array. Default is False. key : str, optional Array name to append to surface's point data attributes. Only used if ``append == True``. Default is 'components'. Returns ------- output : vtkPolyData, BSPolyData or ndarray 1D array with different labels for each connected component. Return ndarray if ``append == False``. Otherwise, return input surface with the new array. Notes ----- VTK point data does not accept boolean arrays. If the mask is provided as a string, the mask is built from the corresponding array such that any value larger than 0 is True. """ if isinstance(mask, str): mask = surf.get_array(name=mask, at='p') > 0 if labeling is None: alg = wrap_vtk(vtkPolyDataConnectivityFilter, colorRegions=True, extractionMode='AllRegions') cc = serial_connect(surf, alg).PointData['RegionId'] + 1 if mask is not None: cc[~mask] = 0 return cc if isinstance(labeling, str): labeling = surf.get_array(name=labeling, at='p') mlab = labeling if mask is None else labeling[mask] adj = get_immediate_adjacency(surf, mask=mask) adj = ssp.triu(adj, 1) # Converts to coo # Zero-out neighbors with different labels mask_remove = mlab[adj.row] != mlab[adj.col] adj.data[mask_remove] = 0 adj.eliminate_zeros() nc, cc = csg.connected_components(adj, directed=True, connection='weak') cc += 1 if mask is not None: cc = map_to_mask(cc, mask=mask, fill=fill) return cc
@wrap_input(0) def sort_polydata_points(surf, labeling, append_data=True): if isinstance(labeling, str): labeling = surf.get_array(labeling, at='p') lab_con = relabel_consecutive(labeling) idx_sorted = np.argsort(lab_con) new_pts = surf.Points[idx_sorted] # new_cells = relabel(surf.GetCells2D().ravel(), lab_con).reshape(-1, 3) new_cells = lab_con[surf.GetCells2D()] s = build_polydata(new_pts, cells=new_cells) if append_data is None or append_data is False: return s if append_data is True: append_data = {'p', 'c', 'f'} elif isinstance(append_data, str): append_data = {append_data} for at in append_data: for v, k in zip(*surf.get_array(at=at, return_name=True)): if at in {'p', 'point'}: v = v[idx_sorted] s.append_array(v, name=k, at=at) return s @wrap_input(0) def split_surface(surf, labeling=None): """ Split surface according to the labeling. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. labeling : str, 1D ndarray or None, optional Array used to perform the splitting. If str, it must be an array in the PointData attributes of `surf`. If None, split surface in its connected components. Default is None. Returns ------- res : dict[int, BSPolyData] Dictionary of sub-surfaces for each label. See Also -------- :func:`combine_surfaces` :func:`mask_points` """ if labeling is None: labeling = get_connected_components(surf) elif isinstance(labeling, str): labeling = surf.get_array(labeling, at='p') ulab = np.unique(labeling) if ulab.size == 1: return {1: surf} return {k: mask_points(surf, labeling == k) for k in ulab}
[docs]@wrap_input(0) def downsample_with_parcellation(surf, labeling, name='parcel'): """ Downsample surface according to labeling. Such that, each parcel centroid is used as a point in the new downsampled surface. Connectivity is based on neighboring parcels. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. labeling : str or 1D ndarray Array of labels used to perform the downsampling. If str, it must be an array in the PointData attributes of `surf`. name : str, optional Name of the downsampled parcellation appended to the PointData of the new surface. Default is 'parcel'. Returns ------- res : BSPolyData Downsampled surface. """ if isinstance(labeling, str): labeling = surf.get_array(labeling, at='p') cc = get_connected_components(surf, labeling=labeling) - 1 lab_small = reduce_by_labels(labeling, cc, red_op='min') nlabs = lab_small.size adj = get_immediate_adjacency(surf) adj_neigh = adj.multiply(cc).tocsr() adj_small = np.zeros((nlabs, nlabs), dtype=np.bool) for i in range(nlabs): arow = adj_neigh[cc == i] for j in range(i + 1, nlabs): adj_small[j, i] = adj_small[i, j] = np.any(arow.data == j) points = np.empty((nlabs, 3)) cells = [] for i in range(nlabs): m = cc == i neigh = np.unique(adj_neigh[m].data) neigh = neigh[neigh != i] if neigh.size < 2: continue edges = np.array(list(combinations(neigh, 2))) edges = edges[adj_small[edges[:, 0], edges[:, 1]]] c = np.hstack([np.full(edges.shape[0], i)[:, None], edges]) cells.append(c) p = surf.Points[m] d = cdist(p, p.mean(0, keepdims=True))[:, 0] points[i] = p[np.argmin(d)] cells = np.unique(np.sort(np.vstack(cells), axis=1), axis=0) surf_small = build_polydata(points, cells=cells) surf_small.append_array(lab_small, name=name, at='p') if nlabs == np.unique(labeling).size: return surf_small d = split_surface(surf_small) if len(d) > 1: for k, v in d.items(): d[k] = sort_polydata_points(v, 'parcel') surf_small = combine_surfaces(*list(d.values())) return surf_small