Source code for brainspace.mesh.mesh_cluster

"""
Clustering and sampling of surface mesh points.
"""

# Author: Oualid Benkarim <oualid.benkarim@mcgill.ca>
# License: BSD 3 clause

from warnings import warn

import numpy as np

from sklearn.preprocessing import normalize
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import k_means

# from vtkmodules.vtkFiltersCorePython import vtkDecimatePro
from vtk import vtkDecimatePro

from . import mesh_elements as me
from .mesh_correspondence import find_point_correspondence
from ..vtk_interface.pipeline import serial_connect
from ..utils.parcellation import map_to_mask, reduce_by_labels
from ..gradient import diffusion_mapping


[docs]def cluster_points(surf, n_clusters=100, is_size=False, mask=None, with_centers=True, random_state=None, approach='kmeans', n_init=3, n_jobs=1): """Clustering of surface points. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. n_clusters : int, optional Number of clusters. Default is 100. is_size : bool, optional If True, interpret `n_clusters` as cluster size. Default is False. mask : 1D ndarray, optional Mask for surface points. Points outside the mask (i.e., False) are discarded from clustering. Default is None. with_centers : bool, optional If True, an array of labels with the closest points to the centroid of each cluster is returned. Default is True. random_state : int, RandomState instance or None, optional Random state. Default is None. approach : {'kmeans', 'ward'}, optional Clustering method: k-means or hierarchical with ward linkage. Hierarchical clustering is faster but k-means provides better results. Default is 'kmeans'. n_init : int, optional Number of k-means repetitions. Only used when ``approach == 'kmeans'``. Default is 3. n_jobs : int or None, optional The number of parallel jobs. Only used when ``approach == 'kmeans'``. Default is 1. Returns ------- cluster_labels : 1D ndarray, shape (n_points,) Array of cluster labels. If `mask` is provided, points out of the mask are assigned label 0. center_labels : 1D ndarray, shape (n_points,) Array with centers labeled with their corresponding cluster label. The rest of points is assigned label 0. Returned only if ``with_centers=True``. Notes ----- Valid cluster labels start from 1. If the mask is provided, zeros are assigned to the points outside the mask. """ # Get immediate geodesic distance a = me.get_ring_distance(surf, n_ring=1, metric='geodesic') if mask is not None: a = a[mask, :][:, mask] a.data = np.exp(-a.data/np.median(a.data)) a.tolil().setdiag(1) # Embedding evs, _ = diffusion_mapping(a, n_components=30, alpha=0, diffusion_time=1, random_state=random_state) evs = normalize(evs) # To find spherical clusters if is_size: n_clusters = evs.shape[0] // n_clusters # Find clusters if approach == 'kmeans': if n_jobs != 1: warn('The n_jobs parameter is deprecated and will be removed ' 'in a future version', DeprecationWarning) _, cluster_labs, _ = k_means(evs, n_clusters=n_clusters, random_state=random_state, n_init=n_init) else: conn = me.get_immediate_adjacency(surf, include_self=False) if mask is not None: conn = conn[mask, :][:, mask] hc = AgglomerativeClustering(n_clusters=n_clusters, connectivity=conn, linkage='ward') cluster_labs = hc.fit(evs).labels_ # Valid clusters start from 1 cluster_labs += 1 # Find centers if with_centers: points = surf.Points if mask is None else surf.Points[mask] centroids = reduce_by_labels(points, cluster_labs, red_op='mean', axis=1) centroid_labs = np.zeros_like(cluster_labs) idx_samples = np.arange(points.shape[0]) for i, lab in enumerate(range(1, n_clusters+1)): mask_cl = cluster_labs == lab dif = centroids[i] - points[mask_cl] idx = np.einsum('ij,ij->i', dif, dif).argmin() idx_centroid = idx_samples[mask_cl][idx] centroid_labs[idx_centroid] = lab if mask is not None: centroid_labs = map_to_mask(centroid_labs, mask) if mask is not None: cluster_labs = map_to_mask(cluster_labs, mask) if with_centers: return cluster_labs, centroid_labs return cluster_labs
[docs]def sample_points_clustering(surf, keep=0.1, mask=None, random_state=None, approach='kmeans', n_init=3, n_jobs=1): """Sample equidistant points from surface based on clustering. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. keep : float or int, optional If float, percentage of points to sample. Must be ``0 < keep < 1``. If int, number of points to sample. Default is 0.1. mask : 1D ndarray, optional Mask for surface points. Points outside the mask (i.e., False) are discarded from sampling. Default is None. random_state : int, RandomState instance or None, optional Random state. Default is None. approach : {'kmeans', 'ward'}, optional Clustering approach: k-means or hierarchical with ward linkage. Hierarchical is faster but k-means provides better results. Default is 'kmeans'. n_init : int, optional Number of k-means repetitions. Only used when ``approach == 'kmeans'``. Default is 3. n_jobs : int or None, optional The number of parallel jobs. Only used when ``approach == 'kmeans'``. Default is 1. Returns ------- sampled : 1D ndarray, shape (n_points,) Array with sampled points marked with 1 in their corresponding positions. The rest is 0. See Also -------- :func:`cluster_points` :func:`sample_points_decimation` Notes ----- This method first clusters the surface points and then selects the points closest to the centroids as the sampled points. """ if isinstance(keep, int): n_clusters = keep elif 0 < keep < 1: n_clusters = int(keep * surf.GetNumberOfPoints()) else: ValueError('The value of \'keep\' is not valid.') sampled_points = cluster_points(surf, n_clusters=n_clusters, mask=mask, random_state=random_state, with_centers=True, approach=approach, n_init=n_init, n_jobs=n_jobs)[1] return (sampled_points > 0).astype(np.uint8)
def sample_points_decimation(surf, keep=0.1, mask=None, n_jobs=1): """Sample points from surface based on surface decimation. Parameters ---------- surf : vtkPolyData or BSPolyData Input surface. keep : float or int, optional If float, percentage of points to sample. Must be ``0 < keep < 1``. If int, number of points to sample. Default is 0.1. mask : 1D ndarray, optional Mask for surface points. Points outside the mask (i.e., False) are discarded from sampling. Default is None. n_jobs : int, optional Number of jobs. Default is 1. Returns ------- sampled : 1D ndarray, shape (n_points,) Array with sampled points marked with 1 in their corresponding positions. The rest is 0. Notes ----- In this method, sampling is based on mesh decimation. Number of sampled points will most probably not coincide with the requested number. See Also -------- :func:`sample_points_clustering` """ n_pts = surf.GetNumberOfPoints() # Increase to account for points outside mask if mask is not None: keep *= (1 + np.count_nonzero(~mask) / n_pts) if isinstance(keep, int): factor = (n_pts - keep) / n_pts elif 0 < keep < 1: factor = 1 - keep else: ValueError('The value of \'keep\' is not valid.') cf = vtkDecimatePro() cf.SetTargetReduction(factor) decimated_surf = serial_connect(surf, cf) idx = find_point_correspondence(decimated_surf, surf, eps=0, n_jobs=n_jobs) sampled_points = np.zeros(n_pts, dtype=np.uint8) sampled_points[idx] = 1 # Remove points sampled outside mask if mask is not None: sampled_points[~mask] = 0 return sampled_points