"""
Plotting functionality based on VTK.
"""
# Author: Oualid Benkarim <oualid.benkarim@mcgill.ca>
# License: BSD 3 clause
import os
import warnings
from collections import defaultdict
import numpy as np
from numpy.lib.stride_tricks import as_strided
from vtk import vtkCommand
import vtk.qt as vtk_qt
from brainspace import OFF_SCREEN
from ..vtk_interface.pipeline import serial_connect, get_output
from ..vtk_interface.wrappers import (BSWindowToImageFilter, BSPNGWriter,
BSBMPWriter, BSJPEGWriter, BSTIFFWriter,
BSRenderWindow, BSRenderWindowInteractor,
BSGenericRenderWindowInteractor,
BSGL2PSExporter)
# for display bugs due to older intel integrated GPUs (see PyVista)
vtk_qt.QVTKRWIBase = 'QGLWidget'
try:
import IPython
has_ipython = True
except ImportError:
has_ipython = False
try:
import panel as pn
pn.extension('vtk')
has_panel = True
except ImportError:
has_panel = False
try:
from PyQt5 import QtGui
from PyQt5.QtWidgets import QVBoxLayout, QFrame
from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor
from .utils_qt import MainWindow
has_pyqt = True
except ImportError:
has_pyqt = False
def in_ipython():
is_ipy = False
if has_ipython:
try:
ipy = IPython.get_ipython()
if ipy is not None:
is_ipy = True
except:
pass
return is_ipy
def in_notebook():
is_nb = False
if has_ipython:
try:
ipy = IPython.get_ipython()
if ipy is not None:
is_nb = type(ipy).__module__.startswith('ipykernel.')
except:
pass
return is_nb
def _get_qt_app():
app = None
if in_ipython():
from IPython import get_ipython
ipython = get_ipython()
ipython.magic('gui qt')
from IPython.external.qt_for_kernel import QtGui
app = QtGui.QApplication.instance()
if app is None:
from PyQt5.QtWidgets import QApplication
app = QApplication.instance()
if not app:
app = QApplication([''])
return app
def _create_grid(nrow, ncol):
""" Create bounds for vtk rendering
Parameters
----------
nrow : int or array-like
Number of rows. If array-like, must be an array with values in
ascending order between 0 and 1.
ncol : int or array-like
Number of columns. If array-like, must be an array with values in
ascending order between 0 and 1.
Returns
-------
grid: ndarray, shape = (nrow, ncol, 4)
Grid for vtk rendering.
Examples
--------
>>> _create_grid(1, 2)
array([[[0. , 0. , 0.5, 1. ],
[0.5, 0. , 1. , 1. ]]])
>>> _create_grid(1, [0, .5, 1])
array([[[0. , 0. , 0.5, 1. ],
[0.5, 0. , 1. , 1. ]]])
>>> _create_grid(1, [0, .5, .9])
array([[[0. , 0. , 0.5, 1. ],
[0.5, 0. , 0.9, 1. ]]])
>>> _create_grid(1, [0, .5, .9, 1])
array([[[0. , 0. , 0.5, 1. ],
[0.5, 0. , 0.9, 1. ],
[0.9, 0. , 1. , 1. ]]])
>>> _create_grid(2, [.5, .6, .7])
array([[[0.5, 0.5, 0.6, 1. ],
[0.6, 0.5, 0.7, 1. ]],
[[0.5, 0. , 0.6, 0.5],
[0.6, 0. , 0.7, 0.5]]])
"""
if not isinstance(nrow, int):
nrow = np.atleast_1d(nrow)
if nrow.size < 2 or np.any(np.sort(nrow) != nrow) or \
nrow[0] < 0 or nrow[-1] > 1:
raise ValueError('Incorrect row values.')
if not isinstance(ncol, int):
ncol = np.atleast_1d(ncol)
if ncol.size < 2 or np.any(np.sort(ncol) != ncol) or \
ncol[0] < 0 or ncol[-1] > 1:
raise ValueError('Incorrect column values.')
if isinstance(ncol, np.ndarray):
x_min, x_max = ncol[:-1], ncol[1:]
ncol = x_min.size
else:
dx = 1 / ncol
x_min = np.arange(0, 1, dx)
x_max = x_min + dx
if isinstance(nrow, np.ndarray):
y_min, y_max = nrow[:-1], nrow[1:]
nrow = y_min.size
else:
dy = 1 / nrow
y_min = np.arange(0, 1, dy)
y_max = y_min + dy
y_min = np.repeat(y_min, ncol)[::-1]
y_max = np.repeat(y_max, ncol)[::-1]
x_min = np.tile(x_min, nrow)
x_max = np.tile(x_max, nrow)
g = np.column_stack([x_min, y_min, x_max, y_max])
strides = (4 * g.itemsize * ncol, 4 * g.itemsize, g.itemsize)
return as_strided(g, shape=(nrow, ncol, 4), strides=strides)
[docs]class Plotter(object):
DICT_PLOTTERS = dict()
[docs] def __init__(self, nrow=1, ncol=1, offscreen=None, force_close=False,
try_qt=False, **kwargs):
if try_qt:
warnings.warn('Qt rendering is not supported for the moment.')
try_qt = False
self.grid = _create_grid(nrow, ncol)
self.nrow, self.ncol = self.grid.shape[:2]
self.offscreen = OFF_SCREEN if offscreen is None else offscreen
self.force_close = force_close
self.use_qt = has_pyqt and try_qt and not self.offscreen
self.ren_win = BSRenderWindow(**kwargs)
if not self.offscreen:
if self.use_qt:
self.iren = BSGenericRenderWindowInteractor()
else:
self.iren = BSRenderWindowInteractor()
self.iren.renderWindow = self.ren_win
self.iren_interactorStyle = self.iren.interactorStyle
self.iren.AddObserver(vtkCommand.ExitEvent, self.quit)
else:
self.iren = None
self.ren_win.offScreenRendering = True
if self.use_qt:
self.app = _get_qt_app()
self.app_window = MainWindow()
self.app_window.signal_close.connect(self.quit)
self.frame = QFrame()
self.frame.setFrameStyle(QFrame.NoFrame)
self.qt_ren = QVTKRenderWindowInteractor(parent=self.frame,
iren=self.iren.VTKObject,
rw=self.ren_win.VTKObject)
self.vlayout = QVBoxLayout()
self.vlayout.addWidget(self.qt_ren)
self.frame.setLayout(self.vlayout)
self.app_window.setCentralWidget(self.frame)
self.n_renderers = 0
self.renderers = defaultdict(list)
self.populated = -np.ones((self.nrow, self.ncol), dtype=np.int32)
self.panel = None
self._cancel_show = False
self._rendered_once = False
self.DICT_PLOTTERS[id(self)] = self
[docs] @classmethod
def close_all(cls):
for k in list(cls.DICT_PLOTTERS.keys()):
cls.DICT_PLOTTERS.pop(k).close()
[docs] def AddRenderer(self, row=None, col=None, renderer=None, **kwargs):
# row/col = 1, (0, 2), (None, 2), (1, None), (None, None) or None
# bounds in the form :xmins[i], ymins[i], xmaxs[i], ymaxs[i]
if row is None or isinstance(row, tuple):
row = slice(None) if row is None else slice(*row)
else:
row = slice(row, row + 1)
if col is None or isinstance(col, tuple):
col = slice(None) if col is None else slice(*col)
else:
col = slice(col, col + 1)
p = np.unique(self.populated[row, col])
if p.size > 1:
raise ValueError('Subplot overlaps with existing subplots.')
p = p[0]
if p == -1:
self.populated[row, col] = p = self.n_renderers
self.n_renderers += 1
subgrid = self.grid[row, col]
bounds = np.empty(4)
bounds[:2] = subgrid[..., :2].min(axis=(0, 1))
bounds[2:] = subgrid[..., 2:].max(axis=(0, 1))
renderer = self.ren_win.AddRenderer(obj=renderer, **kwargs)
renderer.SetViewport(*bounds)
self.renderers[p].append(renderer)
return renderer
def __getattr__(self, name):
"""Forwards unknown attribute requests to BSRenderWindow."""
return getattr(self.ren_win, name)
def _check_interactive(self, embed_nb, interactive):
if not embed_nb or not interactive:
return interactive
# if embed_nb and not in_notebook():
# raise ValueError("Cannot find notebook.")
if not has_panel:
warnings.warn("Interactive mode requires 'panel'. "
"Setting 'interactive=False'")
return False
if self.nrow > 1 or self.ncol > 1:
warnings.warn("Support for interactive mode is only provided for "
"a single renderer: 'nrow=1' and 'ncol=1'. Setting "
"'interactive=False'")
return False
return interactive
[docs] def show(self, embed_nb=False, interactive=True, transparent_bg=True,
scale=(1, 1)):
if embed_nb:
interactive = self._check_interactive(embed_nb, interactive)
if interactive:
return self.to_panel(scale)
return self.to_notebook(transparent_bg, scale)
else:
self._check_closed()
if self.offscreen:
# self._check_offscreen()
# raise ValueError('Only offscreen rendering is available. '
# 'Please use offscreen=False.')
return None
if self._rendered_once:
raise ValueError('Cannot render multiple times.')
if self._cancel_show:
raise ValueError('Cannot render after offscreen rendering.')
self.iren.Initialize()
if not interactive:
self.iren.interactorStyle = None
self.iren.AddObserver(vtkCommand.KeyPressEvent, self.key_quit)
self.ren_win.Render()
if self.use_qt:
self.app_window.show()
else:
self.iren.Start()
self._rendered_once = True
return None
[docs] def key_quit(self, obj=None, event=None):
if self.iren.keySym.lower() in ['q', 'e']:
self.quit()
[docs] def close(self):
self.ren_win.Finalize()
del self.ren_win
self.ren_win = None
if self.iren:
try:
self.iren.SetDone(True)
except:
pass
self.iren.TerminateApp()
del self.iren
self.iren = None
if self.use_qt:
self.app_window.close()
[docs] def quit(self, *args):
if self.force_close:
self.close()
else:
self.ren_win.Finalize()
if self.iren:
try:
self.iren.SetDone(True)
except:
pass
self.iren.TerminateApp()
if self.use_qt:
self.app_window.close()
def _check_closed(self):
if self.ren_win is None:
raise ValueError('This plotter has been closed.')
def _check_offscreen(self):
if not self.offscreen:
self.ren_win.offScreenRendering = True
self.ren_win.interactor = None
self._cancel_show = True
self.ren_win.Render()
[docs] def to_panel(self, scale=(1, 1)):
if not self._check_interactive(True, True):
return self.to_notebook(scale=scale)
self._check_closed()
self._check_offscreen()
w, h = np.asarray(self.ren_win.size) * scale
w, h = int(w), int(h)
self.panel = pn.pane.VTK(self.ren_win.VTKObject, width=w, height=h)
return self.panel
def _win2img(self, transparent_bg, scale):
self._check_closed()
self._check_offscreen()
wf = BSWindowToImageFilter(input=self.ren_win, readFrontBuffer=False,
shouldRerender=True, fixBoundary=True,
scale=scale)
wf.inputBufferType = 'RGBA' if transparent_bg else 'RGB'
return wf
[docs] def to_notebook(self, transparent_bg=True, scale=(1, 1)):
# if not in_notebook():
# raise ValueError("Cannot find notebook.")
wimg = self._win2img(transparent_bg, scale)
writer = BSPNGWriter(writeToMemory=True)
result = serial_connect(wimg, writer, as_data=False).result
data = memoryview(result).tobytes()
from IPython.display import Image
return Image(data)
[docs] def to_numpy(self, transparent_bg=True, scale=(1, 1)):
wf = self._win2img(transparent_bg, scale)
img = get_output(wf)
shape = img.dimensions[::-1][1:] + (-1,)
img = img.PointData['ImageScalars'].reshape(shape)[::-1]
return img
def _to_image(self, filename, transparent_bg, scale):
pth = os.path.abspath(os.path.expanduser(filename))
pth_no_ext, ext = os.path.splitext(filename)
ext = ext[1:]
fmts1 = {'bmp', 'jpeg', 'jpg', 'png', 'tif', 'tiff'}
fmts2 = {'eps', 'pdf', 'ps', 'svg'}
if ext in fmts1:
wimg = self._win2img(transparent_bg, scale)
if ext == 'bmp':
writer = BSBMPWriter(filename=filename)
elif ext in ['jpg', 'jpeg']:
writer = BSJPEGWriter(filename=filename)
elif ext == 'png':
writer = BSPNGWriter(filename=filename)
else: # if ext in ['tif', 'tiff']:
writer = BSTIFFWriter(filename=filename)
serial_connect(wimg, writer, as_data=False)
elif ext in fmts2:
self._check_closed()
self._check_offscreen()
orig_sz = self.ren_win.size
self.ren_win.size = np.array(scale) * orig_sz
w = BSGL2PSExporter(input=self.ren_win, fileFormat=ext,
compress=False, simpleLineOffset=True,
filePrefix=pth_no_ext,
title='', write3DPropsAsRasterImage=True)
w.UsePainterSettings()
w.Update()
self.ren_win.size = orig_sz
else:
raise ValueError("Format '%s' not supported. Supported formats "
"are: %s" % (ext, fmts1.union(fmts2)))
return pth
[docs] def screenshot(self, filename, transparent_bg=True, scale=(1, 1)):
return self._to_image(filename, transparent_bg, scale)
[docs]class GridPlotter(Plotter):
[docs] def __init__(self, nrow=1, ncol=1, try_qt=True, offscreen=None,
**kwargs):
super().__init__(nrow=nrow, ncol=ncol, try_qt=try_qt,
offscreen=offscreen, **kwargs)
[docs] def AddRenderer(self, row, col, renderer=None, **kwargs):
if not isinstance(row, int) or not isinstance(row, int):
raise ValueError('GridPlotter only supports one renderer '
'for each grid entry')
return super().AddRenderer(row=row, col=col, renderer=renderer,
**kwargs)
[docs] def AddRenderers(self, **kwargs):
ren = np.empty((self.nrow, self.ncol), dtype=np.object)
for i in range(self.nrow):
for j in range(self.ncol):
ren[i, j] = super().AddRenderer(row=i, col=j, **kwargs)
return ren