screenshots.py
1 ##############################################################################
2 # Medical Image Registration ToolKit (MIRTK)
3 #
4 # Copyright 2016 Imperial College London
5 # Copyright 2016 Andreas Schuh
6 #
7 # Licensed under the Apache License, Version 2.0 (the "License");
8 # you may not use this file except in compliance with the License.
9 # You may obtain a copy of the License at
10 #
11 # http://www.apache.org/licenses/LICENSE-2.0
12 #
13 # Unless required by applicable law or agreed to in writing, software
14 # distributed under the License is distributed on an "AS IS" BASIS,
15 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 # See the License for the specific language governing permissions and
17 # limitations under the License.
18 ##############################################################################
19 
20 """Python module for the rendering of screenshots."""
21 
22 import os
23 import vtk
24 
25 
26 default_colors = [
27  (1, 0, 0),
28  (0, 0, 1),
29  (0, 1, 0),
30  (0, 1, 1),
31  (1, 0, 1),
32  (1, 1, 0),
33  (1, 1, 1)
34 ]
35 
36 
37 def deep_copy(obj):
38  """Make deep copy of VTK object."""
39  copy = obj.NewInstance()
40  copy.DeepCopy(obj)
41  return copy
42 
43 
44 def iround(x):
45  """Round floating point number and cast to int."""
46  return int(round(x))
47 
48 
49 def nearest_voxel(index):
50  """Get indices of nearest voxel."""
51  i = iround(index[0])
52  j = iround(index[1])
53  k = iround(index[2])
54  return (i, j, k)
55 
56 
57 def invert_matrix(m):
58  """Get inverse of a vtkMatrix4x4."""
59  inv = deep_copy(m)
60  inv.Invert()
61  return inv
62 
63 
64 def index_to_point(index, origin, spacing):
65  """Transform voxel indices to image data point coordinates."""
66  x = origin[0] + index[0] * spacing[0]
67  y = origin[1] + index[1] * spacing[1]
68  z = origin[2] + index[2] * spacing[2]
69  return (x, y, z)
70 
71 
72 def point_to_index(point, origin, spacing):
73  """Transform image data point coordinates to voxel."""
74  i = (point[0] - origin[0]) / spacing[0]
75  j = (point[1] - origin[1]) / spacing[1]
76  k = (point[2] - origin[2]) / spacing[2]
77  return (i, j, k)
78 
79 
80 def matrix_to_affine(matrix):
81  """Convert vtkMatrix4x4 to NiBabel 'affine' 2D array."""
82  return [[matrix.GetElement(0, 0), matrix.GetElement(0, 1),
83  matrix.GetElement(0, 2), matrix.GetElement(0, 3)],
84  [matrix.GetElement(1, 0), matrix.GetElement(1, 1),
85  matrix.GetElement(1, 2), matrix.GetElement(1, 3)],
86  [matrix.GetElement(2, 0), matrix.GetElement(2, 1),
87  matrix.GetElement(2, 2), matrix.GetElement(2, 3)],
88  [matrix.GetElement(3, 0), matrix.GetElement(3, 1),
89  matrix.GetElement(3, 2), matrix.GetElement(3, 3)]]
90 
91 
92 def range_to_level_window(min_value, max_value):
93  """Convert min/max value range to level/window parameters."""
94  window = max_value - min_value
95  level = min_value + .5 * window
96  return (level, window)
97 
98 
99 def auto_image_range(image, percentiles=(1, 99)):
100  """Compute range for color transfer function."""
101  stats = vtk.vtkImageHistogramStatistics()
102  stats.SetInputData(image)
103  stats.AutomaticBinningOn()
104  stats.SetMaximumNumberOfBins(512)
105  stats.SetAutoRangePercentiles(percentiles)
106  stats.UpdateWholeExtent()
107  return tuple(stats.GetAutoRange())
108 
109 
110 def auto_level_window(image, percentiles=(1, 99)):
111  """Compute level/window for color transfer function."""
112  return range_to_level_window(*auto_image_range(image, percentiles))
113 
114 
115 def add_contour(renderer, plane, polydata,
116  transform=None, line_width=3, color=(1, 0, 0)):
117  """Add contour of mesh cut by given image plane to render scene."""
118  if transform:
119  transformer = vtk.vtkTransformPolyDataFilter()
120  transformer.SetInputData(polydata)
121  transformer.SetTransform(transform)
122  transformer.Update()
123  polydata = deep_copy(transformer.GetOutput())
124  transformer = None
125  cutter = vtk.vtkCutter()
126  cutter.SetInputData(polydata)
127  cutter.SetCutFunction(plane)
128  cutter.Update()
129  contour = deep_copy(cutter.GetOutput())
130  cutter = None
131 
132  mapper = vtk.vtkPolyDataMapper()
133  mapper.SetInputData(contour)
134 
135  actor = vtk.vtkActor()
136  actor.SetMapper(mapper)
137  prop = actor.GetProperty()
138  prop.LightingOff()
139  prop.SetRepresentationToWireframe()
140  prop.SetLineWidth(line_width)
141 
142  if color:
143  prop.SetColor(color)
144  mapper.ScalarVisibilityOff()
145  elif polydata.GetPointData().GetScalars():
146  mapper.SetScalarModeToUsePointData()
147  mapper.ScalarVisibilityOn()
148  elif polydata.GetCellData().GetScalars():
149  mapper.SetScalarModeToUseCellData()
150  mapper.ScalarVisibilityOn()
151 
152  renderer.AddActor(actor)
153  return actor
154 
155 
156 def slice_axes(zdir):
157  """Get volume dimensions corresponding to x and z axes of slice."""
158  if zdir == 0:
159  return (1, 2)
160  elif zdir == 1:
161  return (0, 2)
162  elif zdir == 2:
163  return (0, 1)
164  else:
165  raise Exception("Invalid zdir argument: " + zdir)
166 
167 
168 def cropping_region(center, axes, width, height):
169  bounds = [center[0], center[0], center[1], center[1], center[2], center[2]]
170  i = 2 * axes[0]
171  j = 2 * axes[1]
172  bounds[i] -= (width - 1) / 2
173  bounds[i + 1] += (width + 1) / 2
174  bounds[j] -= (height - 1) / 2
175  bounds[j + 1] += (height + 1) / 2
176  return bounds
177 
178 
179 def slice_view(image, index, width, height, zdir=2, polydata=[], colors=default_colors,
180  transform=None, line_width=3, level_window=None, image_lut=None, interpolation="nearest"):
181  """Return vtkRenderer for orthogonal image slice."""
182 
183  # determine orientation of medical volume
184  flip = [False, False, False]
185  if transform:
186  transform.Update()
187  matrix = deep_copy(transform.GetMatrix())
188  try:
189  from nibabel import aff2axcodes
190  codes = aff2axcodes(matrix_to_affine(matrix))
191  except Exception:
192  codes = ('L', 'A', 'S')
193  if matrix.GetElement(0, 0) < 0:
194  codes = 'R'
195  if matrix.GetElement(1, 1) < 0:
196  codes = 'P'
197  if matrix.GetElement(2, 2) < 0:
198  codes = 'I'
199  if codes[0] == 'R':
200  flip[0] = True
201  if codes[1] == 'P':
202  flip[1] = True
203  if codes[2] == 'I':
204  flip[2] = True
205 
206  dims = image.GetDimensions()
207  axes = slice_axes(zdir)
208  if width < 1:
209  width = dims[axes[0]]
210  if height < 1:
211  height = dims[axes[1]]
212  size = [1, 1, 1]
213  size[axes[0]] = width
214  size[axes[1]] = height
215  if zdir == 2:
216  up = (0, 1, 0)
217  else:
218  up = (0, 0, 1)
219  spacing = image.GetSpacing()
220  distance = 10. * spacing[zdir]
221  focal_point = index_to_point(index, image.GetOrigin(), spacing)
222  position = list(focal_point)
223  if flip[zdir]:
224  position[zdir] = position[zdir] - distance
225  else:
226  position[zdir] = position[zdir] + distance
227 
228  margin = 2
229  extent = cropping_region(index, axes, size[axes[0]] + margin, size[axes[1]] + margin)
230 
231  if flip[0] or flip[1] or flip[2]:
232  flip_transform = vtk.vtkTransform()
233  flip_transform.Translate(+focal_point[0], +focal_point[1], +focal_point[2])
234  flip_transform.Scale(-1. if flip[0] else 1.,
235  -1. if flip[1] else 1.,
236  -1. if flip[2] else 1.)
237  flip_transform.Translate(-focal_point[0], -focal_point[1], -focal_point[2])
238  points_transform = vtk.vtkTransform()
239  points_transform.SetMatrix(matrix)
240  points_transform.PostMultiply()
241  points_transform.Concatenate(flip_transform)
242  else:
243  flip_transform = None
244  points_transform = None
245 
246  mapper = vtk.vtkImageSliceMapper()
247  mapper.SetInputData(image)
248  mapper.SetOrientation(zdir)
249  mapper.SetSliceNumber(extent[2 * zdir])
250  mapper.SetCroppingRegion(extent)
251  mapper.CroppingOn()
252  mapper.Update()
253 
254  actor = vtk.vtkImageSlice()
255  actor.SetMapper(mapper)
256  if flip_transform:
257  actor.SetUserTransform(flip_transform)
258  prop = actor.GetProperty()
259  interpolation = interpolation.lower()
260  if interpolation in ("nn", "nearest"):
261  prop.SetInterpolationTypeToNearest()
262  elif interpolation == "linear":
263  prop.SetInterpolationTypeToLinear()
264  elif interpolation == "cubic":
265  prop.SetInterpolationTypeToCubic()
266  else:
267  raise ValueError("Invalid interpolation mode: {}".format(interpolation))
268 
269  if not level_window:
270  level_window = auto_level_window(image)
271  prop.SetColorLevel(level_window[0])
272  prop.SetColorWindow(level_window[1])
273  if image_lut:
274  prop.SetLookupTable(image_lut)
275  prop.UseLookupTableScalarRangeOn()
276 
277  renderer = vtk.vtkRenderer()
278  renderer.AddActor(actor)
279 
280  camera = renderer.GetActiveCamera()
281  camera.SetViewUp(up)
282  camera.SetPosition(position)
283  camera.SetFocalPoint(focal_point)
284  camera.SetParallelScale(.5 * max((size[axes[0]] - 1) * spacing[axes[0]],
285  (size[axes[1]] - 1) * spacing[axes[1]]))
286  camera.SetClippingRange(distance - .5 * spacing[zdir],
287  distance + .5 * spacing[zdir])
288  camera.ParallelProjectionOn()
289 
290  # add contours of polygonal data intersected by slice plane
291  if isinstance(polydata, vtk.vtkPolyData):
292  polydata = [polydata]
293  if isinstance(line_width, int):
294  line_width = [line_width]
295  if isinstance(colors[0], float):
296  colors = [colors]
297  for i in xrange(len(polydata)):
298  if i < len(colors):
299  color = colors[i]
300  else:
301  color = colors[-1]
302  if i < len(line_width):
303  width = line_width[i]
304  else:
305  width = line_width[-1]
306  add_contour(renderer, plane=mapper.GetSlicePlane(),
307  polydata=polydata[i], transform=points_transform,
308  line_width=width, color=color)
309  return renderer
310 
311 
312 def take_screenshot(window, path=None):
313  """Takes vtkRenderWindow instance and writes a screenshot of the rendering.
314 
315  window : vtkRenderWindow
316  The render window from which a screenshot is taken.
317  path : str
318  File name path of output PNG file.
319  A .png file name extension is appended if missing.
320 
321  """
322  _offscreen = window.GetOffScreenRendering()
323  window.OffScreenRenderingOn()
324  window_to_image = vtk.vtkWindowToImageFilter()
325  window_to_image.SetInput(window)
326  window_to_image.Update()
327  writer = vtk.vtkPNGWriter()
328  writer.SetInputConnection(window_to_image.GetOutputPort())
329  if path:
330  if os.path.splitext(path)[1].lower() != '.png':
331  path += '.png'
332  writer.SetFileName(path)
333  else:
334  writer.WriteToMemoryOn()
335  writer.Write()
336  window.SetOffScreenRendering(_offscreen)
337  if writer.GetWriteToMemory():
338  from IPython.display import Image
339  data = str(buffer(writer.GetResult()))
340  return Image(data)
341 
342 
343 def take_orthogonal_screenshots(image, center=None, length=0, offsets=None,
344  size=(512, 512), prefix='screenshot_{n}',
345  suffix=('axial', 'coronal', 'sagittal'),
346  path_format='{prefix}_{suffix}',
347  level_window=None, qform=None,
348  polydata=[], colors=[], line_width=3,
349  trim=False, overwrite=True):
350  """Take three orthogonal screenshots of the given image patch.
351 
352  Arguments
353  ---------
354 
355  image : vtkImageData
356  Volume data.
357  center : (float, float, float) or (int, int, int)
358  vtkImageData coordinates or voxel indices (3-tuple of ints)
359  of patch center point. Be sure to pass tuple with correct type.
360  When not specified, the image center point is used.
361  length : float or int
362  Side length of patch either in mm (float) or number of voxels (int).
363  offsets : list of float or int, optional
364  One or more offsets in either mm (float) or number of voxels (int)
365  from the center point along the orthogonal viewing direction.
366  A screenshot of the volume of interest is taken for each combination
367  of offset and orthgonal viewing direction, i.e., the number of
368  screenshots is `3 * len(offset)`. For example, when the volume of
369  interest is the entire image, i.e., `center=None` and `length=0`,
370  this can be used to render multiple slices with one call of this function.
371  size : (int, int)
372  Either int or 2-tuple/-list of int values specifying the width and height of the screenshot.
373  prefix : str
374  Common output path prefix of screenshot files.
375  suffix : (str, str, str)
376  List or tuple of three strings used as suffix of the screenshot
377  take from the respective orthogonal view. The order is:
378  axial, coronal, sagittal.
379  path_format: str
380  A format string used to construct the file name of each individual screenshot.
381  Make sure to use a suitable format string such that every screenshot has a
382  unique file name. The allowed placeholders are:
383  - `{prefix}`: Substituted by the `prefix` argument.
384  - `{suffix}`: Substituted by the `suffix` argument corresponding to the orthogonal view.
385  - `{n}`: The 1-based index of the screenshot.
386  - `{i}`: The i (int) volume index of the screenshot center point before trimming.
387  - `{j}`: The j (int) volume index of the screenshot center point before trimming.
388  - `{k}`: The k (int) volume index of the screenshot center point before trimming.
389  - `{x}`: The x (float) vtkImageData coordinates of the center point before trimming.
390  - `{y}`: The y (float) vtkImageData coordinates of the center point before trimming.
391  - `{z}`: The z (float) vtkImageData coordinates of the center point before trimming.
392  - `{vi}`: The i (int) volume index of the viewport center point after trimming.
393  - `{vj}`: The j (int) volume index of the viewport center point after trimming.
394  - `{vk}`: The k (int) volume index of the viewport center point after trimming.
395  - `{vx}`: The x (float) vtkImageData coordinates of the viewport center point after trimming.
396  - `{vy}`: The y (float) vtkImageData coordinates of the viewport center point after trimming.
397  - `{vz}`: The z (float) vtkImageData coordinates of the viewport center point after trimming.
398  qform : vtkMatrix4x4, optional
399  Homogeneous vtkImageData coordinates to world transformation matrix.
400  For example, pass the vtkNIFTIImageReader.GetQFormMatrix().
401  level_window : (float, float), optional
402  2-tuple/-list of level and window color transfer function parameters.
403  When not specified, the auto_level_window function with default percentiles
404  is used to compute an intensity range that is robust to outliers.
405  polydata : vtkPolyData, list, optional
406  List of vtkPolyData objects to be cut by each orthogonal
407  image slice plane and the contours rendered over the image.
408  When a `qform` matrix is given, the points are transformed
409  to image coordinates using the inverse of the `qform` matrix.
410  colors : list, optional
411  List of colors (3-tuples of float RGB values in [0, 1]) to use for each `polydata` contour.
412  trim : bool, optional
413  Whether to trim cropping region such that screenshot is contained within image bounds.
414  overwrite : bool, optional
415  Whether to overwrite existing output files. When this option is False,
416  a screenshot is only taken when the output file does not exist.
417 
418  Returns
419  -------
420 
421  paths: list
422  List of 6-tuples with parameters used to render each screenshot and the
423  absolute file path of the written PNG image files. The number of screenshots
424  is divisable by three, where the first `len(paths)/3` screenshots are taken from
425  axial image slices, the next `len(paths)/3` screenshots are taken from coronal
426  slices, and the last `len(paths)/3` screenshots are taken from sagittal slices.
427 
428  For each screenshot, the returned 6-tuple contains the following values:
429  - path: Absolute path of PNG file.
430  - zdir: Volume dimension corresponding to viewing direction.
431  - center: 3-tuple of screenshot center voxel indices before trimming.
432  - origin: 3-tuple of viewport center voxel indices after trimming.
433  - size: 2-tuple of extracted slice width and height.
434  - isnew: Whether image file was newly written or existed already.
435  This value is always True when `overwrite=True`.
436 
437  """
438  if isinstance(size, int):
439  size = (size, size)
440  if not level_window:
441  level_window = auto_level_window(image)
442  if not suffix or len(suffix) != 3:
443  raise Exception("suffix must be a tuple/list of three elements")
444  suffix = list(suffix)
445  suffix.reverse()
446  origin = image.GetOrigin()
447  spacing = image.GetSpacing()
448  dims = image.GetDimensions()
449  if os.path.splitext(path_format)[1] != '.png':
450  path_format += '.png'
451  if center and (isinstance(center[0], int) and
452  isinstance(center[1], int) and
453  isinstance(center[2], int)):
454  index = center
455  else:
456  if not center:
457  center = image.GetCenter()
458  index = nearest_voxel(point_to_index(center, origin, spacing))
459  if not offsets:
460  offsets = [0]
461  elif isinstance(offsets, (int, float)):
462  offsets = [offsets]
463  if qform:
464  linear_transform = vtk.vtkMatrixToLinearTransform()
465  linear_transform.SetInput(invert_matrix(qform))
466  linear_transform.Update()
467  else:
468  linear_transform = None
469  args = dict(
470  polydata=polydata,
471  transform=linear_transform,
472  level_window=level_window,
473  colors=colors,
474  line_width=line_width
475  )
476  n = 0
477  screenshots = []
478  for zdir in (2, 1, 0):
479  xdir, ydir = slice_axes(zdir)
480  if isinstance(length, int):
481  width = length
482  height = length
483  else:
484  width = iround(length / spacing[xdir])
485  height = iround(length / spacing[ydir])
486  for offset in offsets:
487  ijk = list(index)
488  if isinstance(offset, int):
489  ijk[zdir] += offset
490  else:
491  ijk[zdir] += iround(offset / spacing[zdir])
492  if ijk[zdir] < 0 or ijk[zdir] >= dims[zdir]:
493  continue
494  vox = list(ijk)
495  if trim:
496  bounds = cropping_region(vox, (xdir, ydir, zdir), width, height)
497  if bounds[0] < 0:
498  bounds[0] = 0
499  if bounds[1] >= dims[0]:
500  bounds[1] = dims[0] - 1
501  if bounds[2] < 0:
502  bounds[2] = 0
503  if bounds[3] >= dims[1]:
504  bounds[3] = dims[1] - 1
505  if bounds[4] < 0:
506  bounds[4] = 0
507  if bounds[5] >= dims[2]:
508  bounds[5] = dims[2] - 1
509  dim = 2 * xdir
510  width = (bounds[dim + 1] - bounds[dim])
511  vox[xdir] = bounds[dim] + width / 2
512  dim = 2 * ydir
513  height = (bounds[dim + 1] - bounds[dim])
514  vox[ydir] = bounds[dim] + height / 2
515  renderer = slice_view(image, zdir=zdir, index=vox, width=width, height=height, **args)
516  window = vtk.vtkRenderWindow()
517  window.SetSize(size)
518  window.AddRenderer(renderer)
519  n += 1
520  xyz = index_to_point(ijk, origin, spacing)
521  pos = index_to_point(vox, origin, spacing)
522  path = os.path.abspath(path_format.format(
523  n=n, prefix=prefix, suffix=suffix[zdir],
524  vi=vox[0], vj=vox[1], vk=vox[2],
525  vx=pos[0], vy=pos[1], vz=pos[2],
526  i=ijk[0], j=ijk[1], k=ijk[2],
527  x=xyz[0], y=xyz[1], z=xyz[2],
528  ))
529  if overwrite or not os.path.isfile(path):
530  directory = os.path.dirname(path)
531  if not os.path.isdir(directory):
532  os.makedirs(directory)
533  take_screenshot(window, path=path)
534  isnew = True
535  else:
536  isnew = False
537  screenshots.append((path, zdir, ijk, vox, (width, height), isnew))
538  renderer = None
539  window = None
540  return screenshots