spatiotemporal.py
1 ##############################################################################
2 # Medical Image Registration ToolKit (MIRTK)
3 #
4 # Copyright 2017 Imperial College London
5 # Copyright 2017 Andreas Schuh
6 #
7 # Licensed under the Apache License, Version 2.0 (the "License");
8 # you may not use this file except in compliance with the License.
9 # You may obtain a copy of the License at
10 #
11 # http://www.apache.org/licenses/LICENSE-2.0
12 #
13 # Unless required by applicable law or agreed to in writing, software
14 # distributed under the License is distributed on an "AS IS" BASIS,
15 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 # See the License for the specific language governing permissions and
17 # limitations under the License.
18 ##############################################################################
19 
20 import os
21 import sys
22 import math
23 import csv
24 import json
25 import re
26 import shutil
27 import subprocess
28 import time
29 from datetime import datetime
30 
31 import mirtk
32 from mirtk.utils import makedirs
33 from mirtk.batch.slurm import submit as sbatch, wait as swait
34 from mirtk.batch.condor import submit as cbatch, wait as cwait
35 
36 
37 ##############################################################################
38 # Spatio-temporal atlas
39 
40 class SpatioTemporalAtlas(object):
41  """Spatio-temporal image and deformation atlas
42 
43  Configuration entries such as file and directory paths, final atlas time
44  points (means) and corresponding temporal kernel width/standard deviation
45  (sigma) are read from a JSON configuration file.
46 
47  The `construct` function performs the group-wise atlas construction.
48  Template images for each atlas time point specified in the configuration
49  are written to the configured `outdir`. At intermediate steps, an average
50  image is created for each time point associated with an image in the
51  dataset from which the atlas is created. These can be found in the `img`
52  subdirectory of each iteration underneath the configured `tmpdir`.
53 
54  After atlas construction, average images and longitudinal deformations
55  for every continuous time (within the atlas precision/max. resolution)
56  can be computed as long as the required deformations are stored.
57 
58  """
59 
60  def _path(self, paths, name, default):
61  """Get absolute path from config dictionary."""
62  path = paths.get(name, default)
63  return os.path.normpath(os.path.join(self.topdir, path)) if path else None
64 
65  def __init__(self, config, root=None, step=-1, verbose=1, threads=-1, exit_on_error=False):
66  """Load spatio-temporal atlas configuration."""
67  self.step = step
68  if isinstance(config, str):
69  config = os.path.abspath(config)
70  with open(config, "rt") as f:
71  self.config = json.load(f)
72  self.root = os.path.dirname(config)
73  else:
74  self.config = config
75  if root:
76  self.root = os.path.abspath(root)
77  else:
78  self.root = os.getcwd()
79  # Paths of image file and output of prior global normalization step
80  paths = self.config.get("paths", {})
81  images = self.config.get("images", {})
82  self.topdir = os.path.normpath(os.path.join(self.root, paths.get("topdir", ".")))
83  self.agecsv = self._path(paths, "agecsv", os.path.join(self.topdir, "config", "ages.csv"))
84  self.imgcsv = self._path(paths, "imgcsv", os.path.join(self.topdir, "config", "subjects.csv"))
85  self.imgage = read_ages(self.agecsv)
86  self.imgids = read_imgids(self.imgcsv)
87  self.tmpdir = self._path(paths, "tmpdir", "cache")
88  self.outdir = self._path(paths, "outdir", os.path.join(self.topdir, "templates"))
89  self.refimg = self._path(images, "ref", os.path.join(self.topdir, "global", "average.nii.gz"))
90  self.channel = images.get("default", "t2w")
91  if len(self.config["images"]) == 1:
92  self.channel = self.config["images"].keys()[0]
93  # Parameters of temporal kernel regressions
94  regression = self.config.get("regression", {})
95  self.means = [float(t) for t in regression['means']]
96  self.sigma = regression.get("sigma", 1.)
97  if isinstance(self.sigma, (tuple, list)):
98  self.sigma = [float(t) for t in self.sigma]
99  if len(self.sigma) != len(self.means):
100  raise ValueError("Number of sigma values must equal number of mean values!")
101  else:
102  self.sigma = [float(self.sigma)] * len(self.means)
103  self.epsilon = float(regression.get("epsilon", .001))
104  self.precision = int(regression.get('precision', 2))
105  # Registration parameters
106  regcfg = self.config.get("registration", {})
107  growth = regcfg.get("growth", {})
108  self.num_bch_terms = growth.get("bchterms", 3) # composition should be symmetric, e.g., 2, 3, or 5 terms
109  self.age_specific_imgdof = growth.get("enabled", True)
110  if self.age_specific_imgdof:
111  self.age_specific_regdof = not growth.get("exclavg", False)
112  else:
113  self.age_specific_regdof = False
114  # Other workflow execution settings
115  envcfg = self.config.get("environment", {})
116  self.verbose = verbose
117  self.queue = {
118  "short": envcfg.get("queue", {}).get("short", "local").lower(),
119  "long": envcfg.get("queue", {}).get("long", "local").lower()
120  }
121  self.threads = envcfg.get("threads", 8) if threads < 0 else threads
122  self.mintasks = envcfg.get("mintasks", 1)
123  self.maxtasks = envcfg.get("maxtasks", 1000)
124  self.exit_on_error = exit_on_error
125  # Discard images not required for any final time point
126  imgids = set()
127  for t in self.means:
128  imgids |= set(self.weights(mean=t).keys())
129  self.imgids = list(imgids)
130  self.imgids.sort()
131 
132  def _run(self, command, args=[], opts={}, step=-1, workdir=None, queue=None, name=None, submit_kwargs={}, wait_kwargs={}):
133  """Execute single MIRTK command."""
134  if step < 0:
135  step = self.step
136  if not name:
137  name = command
138  if queue and queue.lower() == "local":
139  queue = None
140  if command in ("edit-dofs", "em-hard-segmentation", "average-measure"):
141  # These commands do not support the -threads option (yet)
142  threads = 0
143  else:
144  threads = self.threads
145  if "verbose" not in opts:
146  if isinstance(opts, list):
147  opts.append(("verbose", self.verbose - 2))
148  else:
149  opts["verbose"] = self.verbose - 2
150  if "verbose" in opts and command in ("average-measure"):
151  # These commands do not support the -verbose option (yet)
152  del opts["verbose"]
153  if queue:
154  job = self._submit(name, command=command, args=args, opts=opts, step=step, workdir=workdir, script=None, tasks=-1, group=1, queue=queue, **submit_kwargs)
155  self.wait(job, **wait_kwargs)
156  else:
157  prevdir = os.getcwd()
158  if workdir:
159  os.chdir(workdir)
160  try:
161  if self.verbose > 1:
162  sys.stdout.write("\n\n")
163  mirtk.run(command, args=args, opts=opts, showcmd=(self.verbose > 1), threads=threads, onexit='exit' if self.exit_on_error else 'throw')
164  finally:
165  os.chdir(prevdir)
166 
167  def _submit(self, name, command=None, args=[], opts={}, script=None, tasks=-1, group=1, step=-1, queue=None, memory=8 * 1024, workdir=None):
168  """Submit batch script."""
169  if step < 0:
170  step = self.step
171  if not queue:
172  queue = self.queue
173  if tasks == 0:
174  return (queue, 0)
175  groups = tasks
176  threads = self.threads if self.threads > 0 else 1
177  if script:
178  source = 'import sys\nimport socket\n'
179  source += 'sys.stdout.write("Host: " + socket.gethostname() + "\\n\\n")\n'
180  source += 'sys.path.insert(0, "{0}")\n'.format(os.path.dirname(os.path.dirname(mirtk.__file__)))
181  source += 'sys.path.insert(0, "{0}")\n'.format(os.path.dirname(__file__))
182  source += 'from {0} import SpatioTemporalAtlas\n'.format(os.path.splitext(os.path.basename(__file__))[0])
183  source += 'atlas = SpatioTemporalAtlas(root="{root}", config={config}, step={step}, threads={threads}, verbose=3, exit_on_error=True)\n'
184  if tasks > 0:
185  tasks_per_group = max(group, self.mintasks, (tasks + self.maxtasks - 1) // self.maxtasks)
186  if tasks_per_group > 1:
187  groups = range(0, tasks, tasks_per_group)
188  source += 'groupid = int(sys.argv[1])\n'
189  source += 'if groupid < 0 or groupid >= {0}:\n'.format(len(groups))
190  source += ' sys.stderr.write("Invalid group ID\\n")\n'
191  source += 'tasks_per_group = {0}\n'.format(tasks_per_group)
192  source += 'for taskid in range(groupid * tasks_per_group, (groupid + 1) * tasks_per_group):\n'
193  source += ' if taskid >= {0}: break\n'.format(tasks)
194  source += ' ' + '\n '.join(script.splitlines()) + '\n'
195  source += ' else: sys.stderr.write("Invalid task ID\\n")\n'
196  groups = len(groups)
197  else:
198  source += 'taskid = int(sys.argv[1])\n'
199  source += 'if taskid < 0: sys.stderr.write("Invalid task ID\\n")\n'
200  source += script
201  source += 'else: sys.stderr.write("Invalid task ID\\n")\n'
202  else:
203  source += script
204  opts.update({
205  "root": self.root,
206  "config": repr(self.config),
207  "step": step,
208  "threads": threads
209  })
210  script = source
211  elif not command:
212  command = name
213  jobname = "i{0:02d}_{1}".format(step, name) if step >= 0 else name
214  if queue.lower() in ("condor", "htcondor"):
215  if tasks > 0:
216  log = os.path.join(self.subdir(step), "log", name + "_$(Cluster).$(Process).log")
217  else:
218  log = os.path.join(self.subdir(step), "log", name + "_$(Cluster).log")
219  condor_config = self.config.get("environment", {}).get("condor", {})
220  requirements = condor_config.get("requirements", [])
221  environment = condor_config.get("environment", {})
222  jobid = cbatch(name=jobname, command=command, args=args, opts=opts, script=script, tasks=groups,
223  log=log, threads=threads, memory=memory, requirements=requirements, environment=environment,
224  workdir=workdir, verbose=0)
225  else:
226  if tasks > 0:
227  log = os.path.join(self.subdir(step), "log", name + "_%A.%a.log")
228  else:
229  log = os.path.join(self.subdir(step), "log", name + "_%j.log")
230  jobid = sbatch(name=jobname, command=command, args=args, opts=opts, script=script, tasks=groups,
231  log=log, threads=threads, memory=memory, queue=queue,
232  workdir=workdir, verbose=0)
233  if tasks > 0:
234  self.info("Submitted batch '{}' (id={}, #jobs={}, #tasks={})".format(
235  name, jobid[0] if isinstance(jobid, tuple) else jobid, groups, tasks)
236  )
237  else:
238  self.info("Submitted job '{}' (id={})".format(name, jobid[0] if isinstance(jobid, tuple) else jobid))
239  return (queue, jobid)
240 
241  def wait(self, jobs, interval=60, verbose=5):
242  """Wait for batch jobs to complete."""
243  if not isinstance(jobs, list):
244  jobs = [jobs]
245  condor_jobs = []
246  slurm_jobs = []
247  for queue, jobid in jobs:
248  if queue and queue.lower() != "local":
249  if queue.lower() in ("condor", "htcondor"):
250  condor_jobs.append(jobid)
251  else:
252  slurm_jobs.append(jobid)
253  if not cwait(condor_jobs, interval=interval, verbose=verbose):
254  raise Exception("Not all HTCondor jobs finished successfully!")
255  if not swait(slurm_jobs, interval=interval, verbose=verbose):
256  raise Exception("Not all SLURM jobs finished successfully!")
257 
258  def info(self, msg, step=-1):
259  """Print status message."""
260  if self.verbose > 0:
261  sys.stdout.write("{:%Y-%b-%d %H:%M:%S} INFO ".format(datetime.now()))
262  sys.stdout.write(msg)
263  if step >= 0:
264  sys.stdout.write(" (step={})".format(step))
265  sys.stdout.write("\n")
266 
267  def normtime(self, t):
268  """Clamp time to be within atlas domain."""
269  return round(min(max(t, self.means[0]), self.means[-1]), self.precision)
270 
271  def timeindex(self, t):
272  """Get discrete time point index."""
273  if t <= self.means[0]:
274  return 0
275  i = len(self.means) - 1
276  if t >= self.means[-1]:
277  return i
278  for j in range(1, len(self.means)):
279  if self.means[j] > t:
280  i = j - 1
281  break
282  return i
283 
284  def timename(self, t1, t2=None):
285  """Get time point string used in file names."""
286  t1 = self.normtime(t1)
287  if t2 is None:
288  t2 = t1
289  else:
290  t2 = self.normtime(t2)
291  if isclose(t1, t2):
292  w = 2
293  if self.precision > 0:
294  w += self.precision + 1
295  return "t{0:0{1}.{2}f}".format(self.normtime(t1), w, self.precision)
296  return "{0:s}-{1:s}".format(self.timename(t1), self.timename(t2), self.precision)
297 
298  def age(self, imgid):
299  """Get time associated with the specified image, i.e., mean of temporal Gaussian kernel."""
300  return self.normtime(self.imgage[imgid])
301 
302  def ages(self):
303  """Get set of all ages associated with the images from which atlas is constructed."""
304  ages = set()
305  for imgid in self.imgids:
306  ages.add(self.age(imgid))
307  return list(ages)
308 
309  def stdev(self, t):
310  """Get standard deviation of temporal Gaussian kernel centered at time t."""
311  i = self.timeindex(t)
312  if i == len(self.sigma) - 1:
313  return self.sigma[i]
314  else:
315  alpha = (t - self.means[i]) / (self.means[i + 1] - self.means[i])
316  return (1. - alpha) * self.sigma[i] + alpha * self.sigma[i + 1]
317 
318  def weight(self, t, mean, sigma=0, normalize=True, epsilon=None):
319  """Evaluate temporal kernel weight for specified image."""
320  if sigma <= 0:
321  sigma = self.stdev(mean)
322  w = math.exp(- .5 * math.pow((t - mean) / sigma, 2))
323  if normalize:
324  w /= sigma * math.sqrt(2. * math.pi)
325  if epsilon is None:
326  epsilon = self.epsilon
327  return 0. if w < epsilon else w
328 
329  def weights(self, mean=None, sigma=0, zero=False):
330  """Get weights of images within local support of given temporal kernel."""
331  if mean is None:
332  wmean = {}
333  for mean in self.means:
334  wimg = {}
335  for imgid in self.imgids:
336  w = self.weight(self.age(imgid), mean=mean, sigma=sigma)
337  if zero or w > 0.:
338  wimg[imgid] = w
339  wmean[mean] = wimg
340  return wmean
341  else:
342  wimg = {}
343  for imgid in self.imgids:
344  w = self.weight(self.age(imgid), mean=mean, sigma=sigma)
345  if zero or w > 0.:
346  wimg[imgid] = w
347  return wimg
348 
349  def subdir(self, step=-1):
350  """Get absolute path of directory of current iteration."""
351  path = self.tmpdir
352  if step < 0:
353  step = self.step
354  if step >= 0:
355  path = os.path.join(path, "i{0:02d}".format(step))
356  return path
357 
358  def splitchannel(self, channel=None):
359  """Split channel specification into name and label specification."""
360  if not channel:
361  channel = self.channel
362  if "=" in channel:
363  channel, label = channel.split("=")
364  try:
365  l = int(label)
366  label = l
367  except ValueError:
368  label = label.split(",")
369  else:
370  label = 0
371  return (channel, label)
372 
373  def image(self, imgid, channel=None, label=0, force=False, create=True):
374  """Get absolute path of requested input image."""
375  if not channel:
376  channel = self.channel
377  cfg = self.config["images"][channel]
378  prefix = cfg["prefix"].replace("/", os.path.sep)
379  suffix = cfg.get("suffix", ".nii.gz")
380  img = os.path.normpath(os.path.join(self.topdir, prefix + imgid + suffix))
381  if label:
382  if isinstance(label, (tuple, list)):
383  lblstr = ','.join([str(l).strip() for l in label])
384  else:
385  lblstr = str(label).strip()
386  lblstr = lblstr.replace("..", "-")
387  msk = os.path.join(self.topdir, "masks", "{}_{}".format(channel, lblstr), imgid + ".nii.gz")
388  if create and (force or not os.path.exists(msk)):
389  makedirs(os.path.dirname(msk))
390  if not isinstance(label, list):
391  label = [label]
392  self._run("calculate-element-wise", args=[img, "-label"] + label + ["-set", 1, "-pad", 0, "-out", msk, "binary"])
393  img = msk
394  return img
395 
396  def affdof(self, imgid):
397  """Get absolute path of affine transformation of global normalization."""
398  cfg = self.config.get("registration", {}).get("affine", None)
399  if cfg is None:
400  return "identity"
401  prefix = cfg["prefix"].replace("/", os.path.sep)
402  suffix = cfg.get("suffix", ".dof")
403  return os.path.normpath(os.path.join(self.topdir, prefix + imgid + suffix))
404 
405  def regcfg(self, step=-1):
406  """Get registration configuration entries."""
407  configs = self.config.get("registration", {}).get("config", {})
408  if not isinstance(configs, list):
409  configs = [configs]
410  cfg = {}
411  for i in range(min(step, len(configs)) if step > 0 else len(configs)):
412  cfg.update(configs[i])
413  return cfg
414 
415  def parin(self, step=-1, force=False, create=True):
416  """Write registration configuration file and return path."""
417  if step < 0:
418  step = self.step
419  parin = os.path.join(self.subdir(step), "config", "register.cfg")
420  if create and (force or not os.path.exists(parin)):
421  cfg = self.regcfg(step)
422  images = self.config["images"]
423  channels = cfg.get("channels", ["t2w"])
424  if not isinstance(channels, list):
425  channels = [channels]
426  if len(channels) == 0:
427  raise ValueError("Empty list of images/channels specified for registration!")
428  params = ["[default]"]
429  # transformation model
430  model = cfg.get("model", "SVFFD")
431  mffd, model = model.split(":") if ":" in model else "None", model
432  params.append("Transformation model = {}".format(model))
433  params.append("Multi-level transformation = {}".format(mffd))
434  # energy function
435  energy = cfg.get("energy", "sym" if "svffd" in model.lower() else "asym")
436  if energy.lower() in ("asym", "asymmetric"):
437  sim_term_type = 0
438  elif energy.lower() in ("ic", "inverseconsistent", "inverse-consistent"):
439  sim_term_type = 1
440  elif energy.lower() in ("sym", "symmetric"):
441  sim_term_type = 2
442  else:
443  sim_term_type = -1
444  if sim_term_type < 0:
445  formula = energy
446  else:
447  formula = ""
448  measures = cfg.get("measures", {})
449  for c in range(len(channels)):
450  target = 2 * c + 1
451  source = 2 * c + 2
452  if isinstance(measures, list):
453  measure = measures[c] if c < len(measures) else measures[-1]
454  elif isinstance(measures, dict):
455  channel = self.splitchannel(channels[c])[0]
456  measure = measures.get(channel, "NMI")
457  else:
458  measure = measures
459  if sim_term_type == 2:
460  term = "{sim}[{channel} sim](I({tgt}) o T^-0.5, I({src}) o T^0.5)"
461  elif sim_term_type == 1:
462  term = "{sim}[{channel} fwd-sim](I({tgt}) o T^-1, I({src})) + {sim}[{channel} bwd-sim](I({tgt}), I({src}) o T)"
463  else:
464  term = "{sim}[{channel} sim](I({tgt}), I({src}) o T)"
465  if c > 0:
466  formula += " + "
467  formula += term.format(sim=measure, channel=channels[c].capitalize().replace("=", " "), tgt=target, src=source)
468  if "bending" in cfg:
469  formula += " + 0 BE[Bending energy](T)"
470  if "jacobian" in cfg:
471  formula += " + 0 JAC[Jacobian penalty](T)"
472  params.append("Energy function = " + formula)
473  params.append("No. of bins = {}".format(cfg.get("bins", 64)))
474  params.append("Local window size [box] = {}".format(cfg.get("window", "5 vox")))
475  params.append("No. of last function values = 10")
476  if "svffd" in model.lower():
477  params.append("Integration method = {}".format(cfg.get("ffdim", "FastSS")))
478  params.append("No. of integration steps = {}".format(cfg.get("intsteps", 64)))
479  params.append("No. of BCH terms = {}".format(cfg.get("bchterms", 4)))
480  params.append("Use Lie derivative = {}".format(cfg.get("liederiv", "No")))
481  # resolution pyramid
482  spacings = cfg.get("spacing", [])
483  if not isinstance(spacings, list):
484  spacings = [spacings]
485  resolutions = cfg.get("resolution", [])
486  if not isinstance(resolutions, list):
487  resolutions = [resolutions]
488  blurring = cfg.get("blurring", {})
489  if isinstance(blurring, list):
490  blurring = {self.splitchannel(channels[0])[0]: blurring}
491  be_weights = cfg.get("bending", [0])
492  if not isinstance(be_weights, list):
493  be_weights = [be_weights]
494  lj_weights = cfg.get("jacobian", [0])
495  if not isinstance(lj_weights, list):
496  lj_weights = [lj_weights]
497  levels = cfg.get("levels", 0)
498  if levels <= 0:
499  levels = len(resolutions) if isinstance(resolutions, list) else 4
500  resolution = 0.
501  spacing = 0.
502  params.append("No. of resolution levels = {0}".format(levels))
503  params.append("Image interpolation mode = {0}".format(cfg.get("interpolation", "Linear with padding")))
504  params.append("Downsample images with padding = {0}".format("Yes" if cfg.get("padding", True) else "No"))
505  params.append("Image similarity foreground = {0}".format(cfg.get("foreground", "Overlap")))
506  params.append("Strict step length range = No")
507  params.append("Maximum streak of rejected steps = 2")
508  for level in range(1, levels + 1):
509  params.append("")
510  params.append("[level {}]".format(level))
511  resolution = float(resolutions[level - 1]) if level <= len(resolutions) else 2. * resolution
512  if resolution > 0.:
513  params.append("Resolution [mm] = {}".format(resolution))
514  for c in range(len(channels)):
515  target = 2 * c + 1
516  source = 2 * c + 2
517  image = images[self.splitchannel(channels[c])[0]]
518  bkgrnd = float(image.get("bkgrnd", -1))
519  params.append("Background value of image {0} = {1}".format(target, bkgrnd))
520  bkgrnd = -1. if "labels" in image else 0.
521  params.append("Background value of image {0} = {1}".format(source, bkgrnd))
522  for c in range(len(channels)):
523  target = 2 * c + 1
524  source = 2 * c + 2
525  channel = self.splitchannel(channels[c])[0]
526  sigmas = blurring.get(channel, [])
527  if not isinstance(sigmas, list):
528  sigmas = [sigmas]
529  if len(sigmas) == 0 and "labels" in images[channel]:
530  sigmas = [2]
531  if len(sigmas) > 0:
532  sigma = float(sigmas[level - 1] if level <= len(sigmas) else sigmas[-1])
533  params.append("Blurring of image {0} [vox] = {1}".format(target, sigma))
534  spacing = float(spacings[level - 1]) if level <= len(spacings) else 2. * spacing
535  if spacing > 0.:
536  params.append("Control point spacing = {0}".format(spacing))
537  be_weight = float(be_weights[level - 1] if level <= len(be_weights) else be_weights[-1])
538  params.append("Bending energy weight = {0}".format(be_weight))
539  lj_weight = float(lj_weights[level - 1] if level <= len(lj_weights) else lj_weights[-1])
540  params.append("Jacobian penalty weight = {0}".format(lj_weight))
541  # write -parin file for "register" command
542  makedirs(os.path.dirname(parin))
543  with open(parin, "wt") as f:
544  f.write("\n".join(params) + "\n")
545  return parin
546 
547  def regdof(self, imgid, t=None, step=-1, path=None, force=False, create=True, batch=False):
548  """Register atlas to specified image and return path of transformation file."""
549  if step < 0:
550  step = self.step
551  age = self.age(imgid)
552  if t is None or not self.age_specific_regdof:
553  t = age
554  if not path:
555  path = os.path.join(self.subdir(step), "dof", "deformation", self.timename(t), "{0}.dof.gz".format(imgid))
556  dof = path
557  if isclose(t, age):
558  if force or not os.path.exists(path):
559  if step < 1:
560  dof = "identity"
561  elif create:
562  cfg = self.regcfg(step)
563  args = []
564  affdof = self.affdof(imgid)
565  if affdof == "identity":
566  affdof = None
567  channels = cfg.get("channels", self.channel)
568  if not isinstance(channels, list):
569  channels = [channels]
570  if len(channels) == 0:
571  raise ValueError("Empty list of images/channels specified for registration!")
572  for channel in channels:
573  channel, label = self.splitchannel(channel)
574  args.append("-image")
575  args.append(self.image(imgid, channel=channel, label=label))
576  if affdof:
577  args.append("-dof")
578  args.append(affdof)
579  args.append("-image")
580  args.append(self.avgimg(t, channel=channel, label=label, step=step - 1, create=not batch))
581  makedirs(os.path.dirname(dof))
582  self._run("register", args=args, opts={
583  "parin": self.parin(step=step, force=force, create=not batch),
584  "mask": self.refimg,
585  "dofin": "identity",
586  "dofout": dof
587  })
588  else:
589  dof1 = self.regdof(imgid, t=age, step=step, force=force, create=create and not batch)
590  dof2 = self.growth(age, t, step=step - 1, force=force, create=create and not batch)
591  if dof1 == "identity" and dof2 == "identity":
592  dof = "identity"
593  elif dof1 == "identity":
594  dof = dof2
595  elif dof2 == "identity":
596  dof = dof1
597  elif create:
598  makedirs(os.path.dirname(dof))
599  self._run("compose-dofs", args=[dof1, dof2, dof], opts={"bch": self.num_bch_terms, "global": False})
600  return dof
601 
602  def regdofs(self, step=-1, force=False, create=True, queue=None, batchname="register"):
603  """Register all images to their age-specific template."""
604  if step < 0:
605  step = self.step
606  if not queue:
607  queue = self.queue["long"]
608  if queue == "local":
609  queue = None
610  self.info("Register images to template of corresponding age", step=step)
611  script = ""
612  tasks = 0
613  dofs = {}
614  for imgid in self.imgids:
615  dof = self.regdof(imgid, step=step, force=force, create=create and not queue)
616  if dof != "identity" and queue and (force or not os.path.exists(dof)):
617  remove_or_makedirs(dof)
618  script += 'elif taskid == {taskid}: atlas.regdof("{imgid}", batch=True)\n'.format(taskid=tasks, imgid=imgid)
619  tasks += 1
620  dofs[imgid] = dof
621  if create and queue:
622  self.parin(step=step, force=force) # create -parin file if missing
623  job = self._submit(batchname, script=script, tasks=tasks, step=step, queue=queue)
624  else:
625  job = (None, 0)
626  return (job, dofs)
627 
628  def doftable(self, t, step=-1, force=False, create=True, batch=False):
629  """Write table with image to atlas deformations of images with non-zero weight."""
630  dofs = []
631  t = self.normtime(t)
632  weights = self.weights(t)
633  all_dofs_are_identity = True
634  for imgid in self.imgids:
635  if imgid in weights:
636  if not self.age_specific_regdof or isclose(t, self.age(imgid)):
637  dof = self.regdof(imgid, t=t, step=step, force=force, create=create and not batch)
638  else:
639  dof = self.regdof(imgid, t=t, step=step, force=force, create=create, batch=batch)
640  dofs.append((dof, weights[imgid]))
641  if dof != "identity":
642  all_dofs_are_identity = False
643  if all_dofs_are_identity:
644  dofdir = None
645  dofnames = "identity"
646  elif len(dofs) > 0:
647  dofdir = self.topdir
648  dofnames = os.path.join(self.subdir(step), "config", "{}-dofs.tsv".format(self.timename(t)))
649  if create and (force or not os.path.exists(dofnames)):
650  makedirs(os.path.dirname(dofnames))
651  with open(dofnames, "wt") as table:
652  for dof, w in dofs:
653  if dofdir:
654  dof = os.path.relpath(dof, dofdir)
655  table.write("{}\t{}\n".format(dof, w))
656  else:
657  raise ValueError("No image has non-zero weight for time {0}!".format(t))
658  return (dofdir, dofnames)
659 
660  def avgdof(self, t, path=None, step=-1, force=False, create=True, batch=False):
661  """Get mean cross-sectional SV FFD transformation at given time."""
662  t = self.normtime(t)
663  if not path:
664  path = os.path.join(self.subdir(step), "dof", "average", "{0}.dof.gz".format(self.timename(t)))
665  if create and (force or not os.path.exists(path)):
666  dofdir, dofnames = self.doftable(t, step=step, force=force, batch=batch)
667  if dofnames == "identity":
668  path = "identity"
669  else:
670  makedirs(os.path.dirname(path))
671  self._run("average-dofs", args=[path], opts={
672  "dofdir": dofdir,
673  "dofnames": dofnames,
674  "type": "SVFFD",
675  "target": "common",
676  "invert": True,
677  "global": False
678  })
679  return path
680 
681  def avgdofs(self, step=-1, force=False, create=True, queue=None, batchname="avgdofs"):
682  """Compute all average SV FFDs needed for (parallel) atlas construction."""
683  if step < 0:
684  step = self.step
685  if not queue:
686  queue = self.queue["short"]
687  if queue == "local":
688  queue = None
689  self.info("Compute residual average deformations at each age", step=step)
690  script = ""
691  tasks = 0
692  dofs = {}
693  for t in self.ages():
694  dof = self.avgdof(t, step=step, force=force, create=create and not queue)
695  if queue and (force or not os.path.exists(dof)):
696  remove_or_makedirs(dof)
697  script += 'elif taskid == {taskid}: atlas.avgdof({t}, batch=True)\n'.format(taskid=tasks, t=t)
698  tasks += 1
699  dofs[t] = dof
700  if create and queue:
701  job = self._submit(batchname, script=script, tasks=tasks, step=step, queue=queue)
702  else:
703  job = (None, 0)
704  return (job, dofs)
705 
706  def growth(self, t1, t2, step=-1, force=False, create=True, batch=False):
707  """Make composite SV FFD corresponding to longitudinal change from t1 to t2."""
708  if step < 0:
709  step = self.step
710  t1 = self.normtime(t1)
711  t2 = self.normtime(t2)
712  if isclose(t1, t2):
713  return "identity"
714  dof = os.path.join(self.subdir(step), "dof", "growth", "{0}.dof.gz".format(self.timename(t1, t2)))
715  if step < 1:
716  return path if os.path.exists(dof) else "identity"
717  if create and (force or not os.path.exists(dof)):
718  dofs = [
719  self.avgdof(t1, step=step, force=force, create=not batch),
720  self.growth(t1, t2, step=step - 1, force=force, create=not batch),
721  self.avgdof(t2, step=step, force=force, create=not batch)
722  ]
723  if dofs[1] == "identity":
724  del dofs[1]
725  makedirs(os.path.dirname(dof))
726  self._run("compose-dofs", args=dofs + [dof], opts={"scale": -1., "bch": self.num_bch_terms, "global": False})
727  return dof
728 
729  def compose(self, step=-1, ages=[], allpairs=False, force=False, create=True, queue=None, batchname="compose"):
730  """Compose longitudinal deformations with residual average deformations."""
731  if step < 0:
732  step = self.step
733  if not queue:
734  queue = self.queue["short"]
735  if queue == "local":
736  queue = None
737  self.info("Update all pairs of longitudinal deformations", step=step)
738  script = ""
739  tasks = 0
740  dofs = {}
741  if not ages:
742  ages = self.ages()
743  for t1 in ages:
744  dofs[t1] = {}
745  for t2 in ages:
746  if allpairs or (t1 != t2 and self.weight(t1, mean=t2) > 0.):
747  dof = self.growth(t1, t2, step=step, force=force, create=create and not queue)
748  if dof != "identity" and queue and (force or not os.path.exists(dof)):
749  remove_or_makedirs(dof)
750  script += 'elif taskid == {taskid}: atlas.growth({t1}, {t2}, batch=True)\n'.format(taskid=tasks, t1=t1, t2=t2)
751  tasks += 1
752  dofs[t1][t2] = dof
753  if create and queue:
754  job = self._submit(batchname, script=script, tasks=tasks, step=step, queue=queue)
755  else:
756  job = (None, 0)
757  return (job, dofs)
758 
759  def imgdof(self, imgid, t, step=-1, decomposed=False, force=False, create=True, batch=False):
760  """Compute composite image to atlas transformation."""
761  if step < 0:
762  step = self.step
763  if step > 0:
764  dof = os.path.join(self.subdir(step), "dof", "composite", self.timename(t), "{0}.dof.gz".format(imgid))
765  if decomposed or (create and (force or not os.path.exists(dof))):
766  dofs = [
767  self.affdof(imgid),
768  self.regdof(imgid, step=step, force=force, create=create and not batch)
769  ]
770  if self.age_specific_imgdof:
771  growth = self.growth(self.age(imgid), t, step=step - 1, force=force, create=create and not batch)
772  if growth != "identity":
773  dofs.append(growth)
774  dofs.append(self.avgdof(t, step=step, force=force, create=create and not batch))
775  if decomposed:
776  dof = dofs
777  elif create:
778  makedirs(os.path.dirname(dof))
779  self._run("compose-dofs", args=dofs + [dof])
780  else:
781  dof = self.affdof(imgid)
782  if decomposed:
783  dof = [dof] + ['identity'] * 2
784  return dof
785 
786  def imgdofs(self, ages=[], step=-1, force=False, create=True, queue=None, batchname="imgdofs"):
787  """Compute all composite image to atlas transformations."""
788  if step < 0:
789  step = self.step
790  if not queue:
791  queue = self.queue["short"]
792  if queue == "local":
793  queue = None
794  self.info("Update composite image to atlas transformations", step=step)
795  if ages:
796  if not isinstance(ages, (tuple, list)):
797  ages = [ages]
798  ages = [self.normtime(t) for t in ages]
799  else:
800  ages = self.ages()
801  script = ""
802  tasks = 0
803  dofs = {}
804  for t in ages:
805  dofs[t] = {}
806  weights = self.weights(t)
807  for imgid in self.imgids:
808  if imgid in weights:
809  dof = self.imgdof(imgid=imgid, t=t, step=step, force=force, create=create and not queue)
810  if dof != "identity" and queue and (force or not os.path.exists(dof)):
811  remove_or_makedirs(dof)
812  script += 'elif taskid == {taskid}: atlas.imgdof(imgid="{imgid}", t={t}, batch=True)\n'.format(
813  taskid=tasks, imgid=imgid, t=t
814  )
815  tasks += 1
816  dofs[t][imgid] = dof
817  if create and queue:
818  job = self._submit(batchname, script=script, tasks=tasks, step=step, queue=queue)
819  else:
820  job = (None, 0)
821  return (job, dofs)
822 
823  def defimg(self, imgid, t, channel=None, path=None, step=-1, decomposed=True, force=False, create=True, batch=False):
824  """Transform sample image to atlas space at given time point."""
825  if not channel:
826  channel = self.channel
827  if not path:
828  path = os.path.join(self.subdir(step), channel, self.timename(t), imgid + ".nii.gz")
829  if create and (force or not os.path.exists(path)):
830  cfg = self.config["images"][channel]
831  img = self.image(imgid, channel=channel)
832  dof = self.imgdof(imgid=imgid, t=t, step=step, decomposed=decomposed, force=force, create=not batch)
833  opts = {
834  "interp": cfg.get("interp", "linear"),
835  "target": self.refimg,
836  "dofin": dof,
837  "invert": True
838  }
839  if "bkgrnd" in cfg:
840  opts["source-padding"] = cfg["bkgrnd"]
841  if "labels" in cfg:
842  opts["labels"] = cfg["labels"].split(",")
843  if "datatype" in cfg:
844  opts["datatype"] = cfg["datatype"]
845  makedirs(os.path.dirname(path))
846  self._run("transform-image", args=[img, path], opts=opts)
847  return path
848 
849  def defimgs(self, ages=[], channels=[], step=-1, decomposed=True, force=False, create=True, queue=None, batchname="defimgs"):
850  """Transform all images to discrete set of atlas time points."""
851  if step < 0:
852  step = self.step
853  if not queue:
854  queue = self.queue["short"]
855  if queue == "local":
856  queue = None
857  single_channel = channels and isinstance(channels, basestring)
858  if not channels:
859  channels = self.regcfg(step).get("channels", self.channel)
860  if not isinstance(channels, list):
861  channels = [channels]
862  if ages:
863  self.info("Deform images to discrete time points", step=step)
864  if not isinstance(ages, (tuple, list)):
865  ages = [ages]
866  ages = [self.normtime(t) for t in ages]
867  else:
868  self.info("Deform images to observed time points", step=step)
869  ages = self.ages()
870  script = ""
871  tasks = 0
872  imgs = {}
873  for channel in channels:
874  imgs[channel] = {}
875  for t in ages:
876  weights = self.weights(t)
877  if len(weights) == 0:
878  raise Exception("No image has non-zero weight for t={}".format(t))
879  imgs[channel][t] = []
880  for imgid in self.imgids:
881  if imgid in weights:
882  img = self.defimg(imgid=imgid, t=t, channel=channel, step=step, decomposed=decomposed, force=force, create=create and not queue)
883  if queue and (force or not os.path.exists(img)):
884  remove_or_makedirs(img)
885  script += 'elif taskid == {taskid}: atlas.defimg(imgid="{imgid}", t={t}, channel="{channel}", path="{path}", decomposed={decomposed}, batch=True)\n'.format(
886  taskid=tasks, t=t, imgid=imgid, channel=channel, path=img, decomposed=decomposed
887  )
888  tasks += 1
889  imgs[channel][t].append(img)
890  if create and queue:
891  job = self._submit(batchname, script=script, tasks=tasks, step=step, queue=queue)
892  else:
893  job = (None, 0)
894  if single_channel:
895  imgs = imgs[channel]
896  return (job, imgs)
897 
898  def imgtable(self, t, channel=None, step=-1, decomposed=True, force=False, create=True, batch=False):
899  """Write image table with weights for average-images, and deform images to given time point."""
900  t = self.normtime(t)
901  if not channel:
902  channel = self.channel
903  image = self.config["images"][channel]
904  table = os.path.join(self.subdir(step), "config", "{t}-{channel}.tsv".format(t=self.timename(t), channel=channel))
905  if create and (force or not os.path.exists(table)):
906  weights = self.weights(t)
907  if len(weights) == 0:
908  raise Exception("No image has non-zero weight for t={}".format(t))
909  makedirs(os.path.dirname(table))
910  with open(table, "wt") as f:
911  f.write(self.topdir)
912  f.write("\n")
913  for imgid in self.imgids:
914  if imgid in weights:
915  img = self.defimg(t=t, imgid=imgid, channel=channel, step=step, decomposed=decomposed, force=force, create=not batch)
916  f.write(os.path.relpath(img, self.topdir))
917  f.write("\t{0}\n".format(weights[imgid]))
918  return table
919 
920  def avgimg(self, t, channel=None, label=0, path=None, sharpen=True, outdir=None, step=-1, decomposed=True, force=False, create=True, batch=False):
921  """Create average image for a given time point."""
922  if not channel:
923  channel = self.channel
924  cfg = self.config["images"][channel]
925  if isinstance(label, basestring):
926  label = label.split(",")
927  if isinstance(label, list) and len(label) == 1:
928  label = label[0]
929  if not path:
930  if label:
931  if isinstance(label, (tuple, list)):
932  lblstr = ','.join([str(l).strip() for l in label])
933  elif isinstance(label, int):
934  max_label = max(parselabels(cfg["labels"])) if "labels" in cfg else 9
935  lblstr = "{0:0{1}d}".format(label, len(str(max_label)))
936  else:
937  lblstr = str(label).strip()
938  lblstr = lblstr.replace("..", "-")
939  else:
940  lblstr = None
941  if outdir:
942  outdir = os.path.abspath(outdir)
943  if "labels" in cfg:
944  if lblstr:
945  path = os.path.join(outdir, "pbmaps", "_".join([channel, lblstr]))
946  else:
947  path = os.path.join(outdir, "labels", channel)
948  else:
949  path = os.path.join(outdir, "templates", channel)
950  else:
951  path = os.path.join(self.subdir(step), channel)
952  if lblstr:
953  path = os.path.join(path, "prob_" + lblstr)
954  elif sharpen:
955  path = os.path.join(path, "templates")
956  else:
957  path = os.path.join(path, "mean")
958  path = os.path.join(path, self.timename(t) + ".nii.gz")
959  if create and (force or not os.path.exists(path)):
960  makedirs(os.path.dirname(path))
961  if "labels" in cfg and not label:
962  labels = parselabels(cfg["labels"])
963  args = [self.avgimg(t, channel=channel, label=label, step=step, decomposed=decomposed, force=force, batch=batch) for label in labels]
964  self._run("em-hard-segmentation", args=[len(args)] + args + [path])
965  else:
966  table = self.imgtable(t, step=step, channel=channel, decomposed=decomposed, force=force, batch=batch)
967  opts = {
968  "images": table,
969  "reference": self.refimg
970  }
971  if "bkgrnd" in cfg:
972  opts["padding"] = float(cfg["bkgrnd"])
973  opts["datatype"] = cfg.get("datatype", "float")
974  if "labels" in cfg:
975  opts["label"] = label
976  if opts["datatype"] not in ["float", "double"]:
977  opts["rescaling"] = cfg.get("rescaling", [0, 100])
978  elif "rescaling" in cfg:
979  opts["rescaling"] = cfg["rescaling"]
980  else:
981  opts["threshold"] = .5
982  opts["normalization"] = cfg.get("normalization", "zscore")
983  opts["rescaling"] = cfg.get("rescaling", [0, 100])
984  if sharpen:
985  opts["sharpen"] = cfg.get("sharpen", True)
986  self._run("average-images", args=[path], opts=opts)
987  return path
988 
989  def avgimgs(self, step=-1, ages=[], channels=[], labels={}, sharpen=True, outdir=None, decomposed=True, force=False, create=True, queue=None, batchname="avgimgs"):
990  """Create all average images required for (parallel) atlas construction."""
991  if step < 0:
992  step = self.step
993  if not queue:
994  queue = self.queue["short"]
995  if queue == "local":
996  queue = None
997  if ages:
998  self.info("Average images at discrete time points", step=step)
999  ages = [self.normtime(t) for t in ages]
1000  else:
1001  self.info("Average images at observed time points", step=step)
1002  ages = self.ages()
1003  single_channel = channels and isinstance(channels, basestring)
1004  if not channels:
1005  channels = self.regcfg(step).get("channels", self.channel)
1006  if not isinstance(channels, list):
1007  channels = [channels]
1008  if not isinstance(labels, dict):
1009  dlabels = {}
1010  for channel in channels:
1011  dlabels[channel] = labels
1012  labels = dlabels
1013  script = ""
1014  tasks = 0
1015  imgs = {}
1016  for channel in channels:
1017  imgs[channel] = {}
1018  for t in ages:
1019  imgs[channel][t] = []
1020  for channel in channels:
1021  segments = [0]
1022  if "labels" in self.config["images"][channel]:
1023  lbls = labels.get(channel, [])
1024  if lbls:
1025  if isinstance(lbls, basestring):
1026  if lbls.lower() == "all":
1027  segments = parselabels(self.config["images"][channel]["labels"])
1028  else:
1029  segments = parselabels(lbls)
1030  else:
1031  segments = lbls
1032  for segment in segments:
1033  for t in ages:
1034  img = self.avgimg(t, channel=channel, label=segment, sharpen=sharpen, outdir=outdir, step=step,
1035  decomposed=decomposed, force=force, create=create and not queue)
1036  if queue and (force or not os.path.exists(img)):
1037  remove_or_makedirs(img)
1038  script += 'elif taskid == {taskid}: atlas.avgimg(t={t}, channel="{channel}", label={segment}, sharpen={sharpen}, outdir={outdir}, decomposed={decomposed}, batch=True)\n'.format(
1039  taskid=tasks, t=t, channel=channel, segment=repr(segment), sharpen=sharpen, outdir=repr(outdir), decomposed=decomposed
1040  )
1041  tasks += 1
1042  if len(segments) == 1:
1043  imgs[channel][t] = img
1044  else:
1045  imgs[channel][t].append(img)
1046  if create and queue:
1047  job = self._submit(batchname, script=script, tasks=tasks, step=step, queue=queue)
1048  else:
1049  job = (None, 0)
1050  if single_channel:
1051  imgs = imgs[channels[0]]
1052  return (job, imgs)
1053 
1054  def construct(self, start=-1, niter=10, outdir=None, force=False, queue=None):
1055  """Perform atlas construction.
1056 
1057  Args:
1058  start (int): Last completed iteration. (default: step)
1059  niter (int, optional): Number of atlas construction iterations. (default: 10)
1060  outdir (str, optional): Directory for final templates. (default: config.paths.outdir)
1061  force (bool, optional): Force re-creation of already existing files. (default: False)
1062  queue (str, dict, optional): Name of queues of batch queuing system.
1063  When not specified, the atlas construction runs on the local machine
1064  using the number of threads specified during construction. When a single `str`
1065  is given, both short and long running jobs are submitted to the same queue.
1066  Otherwise, separate environments can be specified for "short" or "long" running
1067  jobs using the respective dictionary keys. The supported environments are:
1068  - "local": Multi-threaded execution on host machine
1069  - "condor": Batch execution using HTCondor
1070  - "<other>": Batch execution using named SLURM partition
1071  (default: config.environment.queue)
1072 
1073  """
1074  if start < 0:
1075  start = self.step
1076  if start < 0:
1077  raise ValueError("Atlas to be constructed must have step index >= 0!")
1078  if start > 0:
1079  self.info("Performing {0} iterations starting with step {1}".format(niter, start))
1080  else:
1081  self.info("Performing {0} iterations".format(niter))
1082  self.info("Age-dependent image deformations = {}".format(self.age_specific_imgdof))
1083  self.info("Average age-dependent deformations = {}".format(self.age_specific_regdof))
1084  # Initialize dict of queues used for batch execution
1085  if not queue:
1086  queue = self.queue
1087  if not isinstance(queue, dict):
1088  queue = {"short": queue, "long": queue}
1089  else:
1090  if "short" not in queue:
1091  queue["short"] = "local"
1092  if "long" not in queue:
1093  queue["long"] = "local"
1094  # Save considerable amount of disk memory by not explicitly storing the
1095  # composite image to atlas transformations of type FluidFreeFormTransformation
1096  # (alternatively, approximate composition). The transform-image and/or
1097  # average-images commands can apply a sequence of transforms in order to
1098  # perform the composition quasi on-the-fly.
1099  decomposed_imgdofs = True
1100  rmtemp_regdofs = True
1101  # Iterate atlas construction steps
1102  weights = {}
1103  for t in self.ages():
1104  weights[t] = self.weights(t)
1105  for step in range(start + 1, start + niter + 1):
1106  # Deform images to atlas space
1107  if not decomposed_imgdofs:
1108  job = self.imgdofs(step=step - 1, force=force, queue=queue["short"])[0]
1109  self.wait(job, interval=30, verbose=1)
1110  job = self.defimgs(step=step - 1, decomposed=decomposed_imgdofs, force=force, queue=queue["short"])[0]
1111  self.wait(job, interval=30, verbose=1)
1112  # Average images in atlas space
1113  job = self.avgimgs(step=step - 1, force=force, queue=queue["short"])[0]
1114  self.wait(job, interval=60, verbose=2)
1115  # Register all images to the current template images
1116  job = self.regdofs(step=step, force=force, queue=queue["long"])[0]
1117  self.wait(job, interval=60, verbose=5)
1118  # Compute all required average deformations
1119  job = self.avgdofs(step=step, force=force, queue=queue["short"])[0]
1120  self.wait(job, interval=60, verbose=2)
1121  if rmtemp_regdofs and self.age_specific_regdof:
1122  self.info("Deleting temporary deformation files", step=step)
1123  for t in weights:
1124  for imgid in weights[t]:
1125  dof1 = self.regdof(imgid, step=step, create=False)
1126  dof2 = self.regdof(imgid, t=t, step=step, create=False)
1127  if dof2 != "identity" and dof1 != dof2 and os.path.exists(dof2):
1128  os.remove(dof2)
1129  # Compute all required longitudinal deformations
1130  if self.age_specific_regdof or self.age_specific_imgdof:
1131  job = self.compose(step=step, force=force, queue=queue["short"])[0]
1132  self.wait(job, interval=30, verbose=1)
1133  # Write final template images to specified directory
1134  self.step = start + niter
1135  self.info("Creating final mean shape templates")
1136  if outdir is None:
1137  outdir = self.outdir
1138  if outdir:
1139  ages = self.means
1140  else:
1141  ages = self.ages()
1142  outdir = None
1143  if not decomposed_imgdofs:
1144  job = self.imgdofs(ages=ages, force=force, queue=queue["short"])[0]
1145  self.wait(job, interval=30, verbose=1)
1146  channels = [channel for channel in self.config["images"].keys() if channel not in ("default", "ref")]
1147  job = self.defimgs(channels=channels, ages=ages, decomposed=decomposed_imgdofs, force=force, queue=queue["short"])[0]
1148  self.wait(job, interval=30, verbose=1)
1149  job = self.avgimgs(channels=channels, ages=ages, force=force, queue=queue["short"], outdir=outdir)[0]
1150  self.wait(job, interval=60, verbose=2)
1151  self.info("Finished atlas construction!")
1152 
1153  def evaluate(self, ages=[], step=-1, force=False, queue=None):
1154  """Evaluate atlas sharpness measures."""
1155  if not ages:
1156  ages = self.means
1157  if isinstance(ages, int):
1158  ages = [float(ages)]
1159  elif isinstance(ages, float):
1160  ages = [ages]
1161  if isinstance(step, int):
1162  if step < 0:
1163  if self.step < 0:
1164  raise ValueError("Need to specify which step/iteration of atlas construction to evaluate!")
1165  step = self.step
1166  steps = [step]
1167  else:
1168  steps = step
1169  measures = self.config["evaluation"]["measures"]
1170  rois_spec = self.config["evaluation"].get("rois", {})
1171  roi_paths = {}
1172  roi_label = {}
1173  roi_labels = {}
1174  roi_channels = set()
1175  for roi_name, roi_path in rois_spec.items():
1176  labels = []
1177  if isinstance(roi_path, list):
1178  if len(roi_path) != 2:
1179  raise ValueError("Invalid evaluation ROI value, must be either path (format) string of individual ROI or [<path_format>, <range>]")
1180  labels = roi_path[1]
1181  roi_path = roi_path[0]
1182  if roi_path in self.config["images"] and isinstance(labels, basestring) and labels.lower() == "all":
1183  labels = parselabels(self.config["images"][roi_path].get("labels", ""))
1184  if not labels:
1185  raise ValueError("ROI channel {} must have 'labels' specified to use for 'all'!".format(roi_path))
1186  else:
1187  labels = parselabels(labels)
1188  roi_name_format = roi_name
1189  if roi_name_format.format(l=0) == roi_name_format:
1190  raise ValueError("Invalid evaluation ROI key name, name must include '{l}' format string!")
1191  for label in labels:
1192  roi_name = roi_name_format.format(l=label)
1193  roi_paths[roi_name] = roi_path
1194  roi_label[roi_name] = label
1195  else:
1196  roi_paths[roi_name] = roi_path
1197  if roi_path in self.config["images"]:
1198  roi_channels.add(roi_path)
1199  roi_labels[roi_path] = labels
1200  roi_names = roi_paths.keys()
1201  roi_names.sort()
1202  roi_channels = list(roi_channels)
1203  re_measure = re.compile("^\s*(\w+)\s*\((.*)\)\s*$")
1204  # Evaluate voxel-wise measures
1205  for step in steps:
1206  voxelwise_measures = {}
1207  for t in ages:
1208  voxelwise_measures[t] = []
1209  for channel in measures:
1210  channel_info = self.config["images"][channel]
1211  channel_measures = measures[channel]
1212  if isinstance(channel_measures, basestring):
1213  channel_measures = [channel_measures]
1214  if channel_measures:
1215  # Deform individual images to atlas time point
1216  name = "eval_defimgs_{channel}".format(channel=channel)
1217  job, imgs = self.defimgs(channels=channel, ages=ages, step=step, queue=queue, batchname=name)
1218  self.wait(job, interval=30, verbose=1)
1219  # Evaluate measures for this image channel/modality
1220  for measure in channel_measures:
1221  measure = measure.lower().strip()
1222  match = re_measure.match(measure)
1223  if match:
1224  measure = match.group(1)
1225  args = match.group(2).strip()
1226  else:
1227  args = ""
1228  # Evaluate gradient magnitude of average image
1229  if measure == "grad":
1230  sharpen = False
1231  if "labels" in channel_info:
1232  if not args:
1233  raise ValueError("Gradient magnitude of segmentation can only be computed for probability map of one or more label(s)!")
1234  labels = [args]
1235  else:
1236  if args and args not in ("mean", "avg", "average", "template"):
1237  raise ValueError("Gradient magnitude of intensity images can only be computed from 'mean'/'avg'/'average'/'template' image!")
1238  if args and args == "template":
1239  sharpen = True
1240  labels = []
1241  name = "eval_avgimgs_{channel}".format(channel=channel)
1242  job, avgs = self.avgimgs(channels=channel, labels={channel: labels}, ages=ages, sharpen=sharpen, step=step, queue=queue)
1243  self.wait(job, interval=60, verbose=2)
1244  if args:
1245  measure += "_" + args
1246  for t in ages:
1247  path = os.path.join(self.subdir(step), channel, measure, self.timename(t) + ".nii.gz")
1248  if force or not os.path.exists(path):
1249  makedirs(os.path.dirname(path))
1250  self._run("detect-edges", step=step, queue=queue, wait_kwargs={"interval": 10, "verbose": 3},
1251  name="eval_{channel}_{measure}_{age}".format(channel=channel, measure=measure, age=self.timename(t)),
1252  args=[avgs[t], path], opts={"padding": channel_info.get("bkgrnd", -1), "central": None})
1253  voxelwise_measures[t].append((channel, measure, path))
1254  else:
1255  if args:
1256  raise ValueError("Measure '{}' has no arguments!".format(measure))
1257  opts = {
1258  "bins": channel_info.get("bins", 0),
1259  "normalization": channel_info.get("normalization", "none")
1260  }
1261  if "rescaling" in channel_info:
1262  opts["rescale"] = channel_info["rescaling"]
1263  if "bkgrnd" in channel_info:
1264  opts["padding"] = channel_info["bkgrnd"]
1265  for t in ages:
1266  path = os.path.join(self.subdir(step), channel, measure, self.timename(t) + ".nii.gz")
1267  if force or not os.path.exists(path):
1268  makedirs(os.path.dirname(path))
1269  opts["output"] = os.path.relpath(path, self.topdir)
1270  mask = self.config["evaluation"].get("mask", "").format(t=t)
1271  if mask:
1272  opts["mask"] = mask
1273  self._run("aggregate-images", step=step, queue=queue, workdir=self.topdir, wait_kwargs={"interval": 30, "verbose": 1},
1274  name="eval_{channel}_{measure}_{age}".format(channel=channel, measure=measure, age=self.timename(t)),
1275  args=[measure] + [os.path.relpath(img, self.topdir) for img in imgs[t]], opts=opts)
1276  voxelwise_measures[t].append((channel, measure, path))
1277  # Average voxel-wise measures within each ROI
1278  for channel in roi_channels:
1279  # Deform individual images to atlas time point
1280  name = "eval_defimgs_{channel}".format(channel=channel)
1281  job, imgs = self.defimgs(channels=channel, ages=ages, step=step, queue=queue, batchname=name)
1282  self.wait(job, interval=30, verbose=1)
1283  name = "eval_avgrois_{channel}".format(channel=channel)
1284  job, rois = self.avgimgs(channels=[channel], labels=roi_labels, ages=ages, step=step, queue=queue, batchname=name)
1285  self.wait(job, interval=60, verbose=2)
1286  for t in ages:
1287  if voxelwise_measures[t] and roi_paths:
1288  subdir = self.subdir(step)
1289  mean_table = os.path.join(subdir, "qc-measures", self.timename(t) + "-mean.csv")
1290  sdev_table = os.path.join(subdir, "qc-measures", self.timename(t) + "-sdev.csv")
1291  wsum_table = os.path.join(subdir, "qc-measures", self.timename(t) + "-wsum.csv")
1292  if force or not os.path.exists(mean_table) or not os.path.exists(sdev_table) or not os.path.exists(wsum_table):
1293  outdir = os.path.dirname(mean_table)
1294  tmp_mean_table = os.path.join(outdir, "." + os.path.basename(mean_table))
1295  tmp_sdev_table = os.path.join(outdir, "." + os.path.basename(sdev_table))
1296  tmp_wsum_table = os.path.join(outdir, "." + os.path.basename(wsum_table))
1297  args = [x[2] for x in voxelwise_measures[t]]
1298  opts = {
1299  "name": [],
1300  "roi-name": [],
1301  "roi-path": [],
1302  "header": None,
1303  "preload": None,
1304  "mean": tmp_mean_table,
1305  "sdev": tmp_sdev_table,
1306  "size": tmp_wsum_table,
1307  "digits": self.config["evaluation"].get("digits", 9)
1308  }
1309  for voxelwise_measure in voxelwise_measures[t]:
1310  channel = voxelwise_measure[0]
1311  measure = voxelwise_measure[1]
1312  opts["name"].append("{}/{}".format(channel, measure))
1313  for roi_name in roi_names:
1314  roi_path = roi_paths[roi_name]
1315  if roi_path in roi_channels:
1316  roi_path = self.avgimg(t, channel=roi_path, label=roi_label.get(roi_name, 0), step=step, create=False)
1317  else:
1318  roi_path = roi_path.format(
1319  subdir=subdir, tmpdir=self.tmpdir, topdir=self.topdir,
1320  i=step, t=t, l=roi_label.get(roi_name, 0)
1321  )
1322  opts["roi-name"].append(roi_name)
1323  opts["roi-path"].append(roi_path)
1324  makedirs(outdir)
1325  try:
1326  self._run("average-measure", step=step, queue=queue,
1327  name="eval_average_{age}".format(age=self.timename(t)),
1328  args=args, opts=opts)
1329  except Exception as e:
1330  for tmp_table in [tmp_mean_table, tmp_sdev_table, tmp_wsum_table]:
1331  if os.path.exists(tmp_table):
1332  os.remove(tmp_table)
1333  raise e
1334  cur_wait_time = 0
1335  inc_wait_time = 10
1336  max_wait_time = 6 * inc_wait_time
1337  if queue and queue.lower() != "local":
1338  while True:
1339  missing = []
1340  for tmp_table in [tmp_mean_table, tmp_sdev_table, tmp_wsum_table]:
1341  if not os.path.exists(tmp_table):
1342  missing.append(tmp_table)
1343  if not missing:
1344  break
1345  if cur_wait_time < max_wait_time:
1346  time.sleep(inc_wait_time)
1347  cur_wait_time += inc_wait_time
1348  else:
1349  raise Exception("Job average-measure finished, but output files still missing after {}s: {}".format(cur_wait_time, missing))
1350  for src, dst in zip([tmp_mean_table, tmp_sdev_table, tmp_wsum_table], [mean_table, sdev_table, wsum_table]):
1351  try:
1352  os.rename(src, dst)
1353  except OSError as e:
1354  sys.stderr.write("Failed to rename '{}' to '{}'".format(src, dst))
1355  raise e
1356 
1357  def template(self, i, channel=None):
1358  """Get absolute path of i-th template image."""
1359  if i < 0 or i >= len(self.means):
1360  raise IndexError()
1361  if not channel:
1362  channel = self.channel
1363  img = self.avgimg(self.means[i], channel=channel, step=self.step, create=False, outdir=self.outdir)
1364  if self.step >= 0 and not os.path.exists(img):
1365  avg = self.avgimg(self.means[i], channel=channel, step=self.step, create=False)
1366  if os.path.exists(avg):
1367  img = avg
1368  return img
1369 
1370  def __len__(self):
1371  """Length of the atlas is the number of templates at discrete time points."""
1372  return len(self.means)
1373 
1374  def __getitem__(self, i):
1375  """Absolute file path of i-th atlas template."""
1376  return self.template(i)
1377 
1378  def deformation(self, i, t=None, force=False, create=True):
1379  """Get absolute path of longitudinal deformation from template i."""
1380  if i < 0 or i >= len(self.means):
1381  return "identity"
1382  if t is None:
1383  t = self.means[i + 1]
1384  return self.growth(self.means[i], t, step=self.step, force=force, create=create)
1385 
1386  def deform(self, i, t, path=None, channel=None, force=False, create=True):
1387  """Deform i-th template using longitudinal deformations to time point t."""
1388  if not channel:
1389  channel = self.channel
1390  source = self.template(i, channel=channel)
1391  t = self.normtime(t)
1392  if t == self.means[i]:
1393  if path and os.path.realpath(os.path.abspath(path)) != os.path.realpath(source):
1394  if force and (create or not os.path.exists(path)):
1395  shutil.copyfile(source, path)
1396  return path
1397  else:
1398  return source
1399  if not path:
1400  path = os.path.join(self.subdir(self.step), channel, "temp", "{}-{}.nii.gz".format(self.timename(self.means[i]), self.timename(t)))
1401  if create and (force or not os.path.exists(path)):
1402  dof = self.deformation(i, t)
1403  makedirs(os.path.dirname(path))
1404  self._run("transform-image", args=[source, path], opts={"dofin": dof, "target": self.refimg})
1405  return path
1406 
1407  def interpolate(self, t, path=None, channel=None, interp="default", deform=False, sigma=0, force=False, create=True):
1408  """Interpolate atlas volume for specified time from finite set of templates.
1409 
1410  Unlike avgimg, this function does not evaluate the continuous spatio-temporal function
1411  to construct a template image for the given age. Instead, it uses the specified
1412  interpolation kernel and computes the corresponding weighted average of previously
1413  constructed template images.
1414 
1415  Args:
1416  interp (str): Temporal interpolation kernel (weights). (default: 'gaussian')
1417  sigma (float): Standard deviation used for Gaussian interpolation. (default: adaptive)
1418  deform (bool): Whether to deform each template using the longitudinal deformation.
1419 
1420  """
1421  interp = interp.lower()
1422  if interp in ("kernel", "default"):
1423  interp = "gaussian"
1424  if not channel:
1425  channel = self.channel
1426  t = self.normtime(t)
1427  i = self.timeindex(t)
1428  if t == self.means[i] and interp.startswith("linear"):
1429  return self.template(i, channel=channel)
1430  if not path:
1431  path = os.path.join(self.outdir, self.timename(t) + ".nii.gz")
1432  if create and (force or not os.path.exists(path)):
1433  args = [path]
1434  # Linear interpolation
1435  if interp == "linear":
1436  w = (self.means[i + 1] - t) / (self.means[i + 1] - self.means[i])
1437  args.extend(["-image", self.template(i, channel=channel), w])
1438  if deform:
1439  args.extend(["-dof", self.deformation(i, t)])
1440  w = (t - self.means[i]) / (self.means[i + 1] - self.means[i])
1441  args.extend(["-image", self.template(i + 1, channel=channel), w])
1442  if deform:
1443  args.extend(["-dof", self.deformation(i + 1, t)])
1444  # Gaussian interpolation
1445  elif interp == "gaussian":
1446  for i in range(len(self.means)):
1447  w = self.weight(self.means[i], mean=t, sigma=sigma)
1448  if w > 0.:
1449  args.extend(["-image", self.template(i, channel=channel), w])
1450  if deform:
1451  dof = self.deformation(self.means[i], t)
1452  if dof != "identity":
1453  args.extend(["-dof", dof])
1454  if create:
1455  bkgrnd = self.config["images"][channel].get("bkgrnd", -1)
1456  self._run("average-images", args=args, opts={"reference": self.refimg, "datatype": "uchar", "padding": bkgrnd})
1457  return path
1458 
1459  def view(self, i=None, channel=None):
1460  """View template image at specified time point(s)."""
1461  if i is None:
1462  i = range(len(self.means))
1463  elif not isinstance(i, (list, tuple)):
1464  i = [i]
1465  if not channel:
1466  channel = self.channel
1467  imgs = [self.template(int(idx), channel=channel) for idx in i]
1468  mirtk.run("view", target=imgs)
1469 
1470  def rmtemp(self):
1471  """Delete temporary files.
1472 
1473  This function removes all temporary files which can be recomputed
1474  without the need for performing the more costly registrations.
1475  Intermediate template images, auxiliary CSV files, and composite
1476  transformations are among these temporary files to be deleted.
1477 
1478  """
1479  raise NotImplementedError()
1480 
1481 
1482 ##############################################################################
1483 # Auxiliaries
1484 
1485 # ----------------------------------------------------------------------------
1486 def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
1487  """Compare too floating point numbers."""
1488  return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
1489 
1490 
1491 # ----------------------------------------------------------------------------
1492 def read_imgids(path):
1493  """Read subject/image IDs from text file."""
1494  imgids = []
1495  with open(path, "rt") as f:
1496  for line in f.readlines():
1497  line = line.strip()
1498  if len(line) > 0 and line[0] != '#':
1499  imgids.append(re.split("^[ \t,]+", line)[0])
1500  return imgids
1501 
1502 
1503 # ----------------------------------------------------------------------------
1504 def read_ages(path, delimiter=None):
1505  """Read subject/image age from CSV file."""
1506  if not delimiter:
1507  ext = os.path.splitext(path)[1].lower()
1508  if ext == ".csv":
1509  delimiter = ","
1510  elif ext == ".tsv":
1511  delimiter = "\t"
1512  elif ext == ".txt":
1513  delimiter = " "
1514  else:
1515  raise ValueError("Cannot determine delimiter of {} file! Use file extension .csv, .tsv, or .txt.".format(path))
1516  ages = {}
1517  with open(path, "rt") as csvfile:
1518  reader = csv.reader(csvfile, delimiter=delimiter, quotechar='"')
1519  for line in reader:
1520  ages[line[0]] = float(line[1])
1521  return ages
1522 
1523 
1524 # ----------------------------------------------------------------------------
1525 def basename_without_ext(path):
1526  name = os.path.basename(path)
1527  name, ext = os.path.splitext(name)
1528  if ext.lower() == '.gz':
1529  name = os.path.splitext(name)[0]
1530  return name
1531 
1532 
1533 # ----------------------------------------------------------------------------
1534 def remove_or_makedirs(path):
1535  """Remove file when it exists or make directories otherwise."""
1536  if os.path.exists(path):
1537  os.remove(path)
1538  else:
1539  makedirs(os.path.dirname(path))
1540 
1541 
1542 # ----------------------------------------------------------------------------
1543 def parselabels(labels):
1544  """Parse labels specification."""
1545  values = []
1546  isstr = isinstance(labels, basestring)
1547  if isstr and "," in labels:
1548  labels = [arg.trim() for arg in labels.split(",")]
1549  if isinstance(labels, (tuple, list)):
1550  for label in labels:
1551  values.extend(parselabels(label))
1552  elif isstr:
1553  m = re.match("([0-9]+)..([0-9]+)", labels)
1554  if m:
1555  values.extend(range(int(m.group(1)), int(m.group(2)) + 1))
1556  elif ":" in labels:
1557  range_spec = [int(x) for x in labels.split(":")]
1558  if len(range_spec) == 2:
1559  values.extend(list(range(range_spec[0], range_spec[2] + 1, range_spec[1])))
1560  else:
1561  values.extend(list(range(range_spec[0], range_spec[1] + 1)))
1562  elif labels != "":
1563  values.append(int(labels))
1564  else:
1565  values.append(int(labels))
1566  return values
1567 
Dummy type used to distinguish split constructor from copy constructor.
Definition: Parallel.h:143
def interpolate(self, t, path=None, channel=None, interp="default", deform=False, sigma=0, force=False, create=True)
def parin(self, step=-1, force=False, create=True)
def _submit(self, name, command=None, args=[], opts={}, script=None, tasks=-1, group=1, step=-1, queue=None, memory=8 *1024, workdir=None)
def imgdofs(self, ages=[], step=-1, force=False, create=True, queue=None, batchname="imgdofs")
def compose(self, step=-1, ages=[], allpairs=False, force=False, create=True, queue=None, batchname="compose")
def weight(self, t, mean, sigma=0, normalize=True, epsilon=None)
def deform(self, i, t, path=None, channel=None, force=False, create=True)
def construct(self, start=-1, niter=10, outdir=None, force=False, queue=None)
def avgdofs(self, step=-1, force=False, create=True, queue=None, batchname="avgdofs")
def image(self, imgid, channel=None, label=0, force=False, create=True)
def imgtable(self, t, channel=None, step=-1, decomposed=True, force=False, create=True, batch=False)
def _run(self, command, args=[], opts={}, step=-1, workdir=None, queue=None, name=None, submit_kwargs={}, wait_kwargs={})
def __init__(self, config, root=None, step=-1, verbose=1, threads=-1, exit_on_error=False)
def weights(self, mean=None, sigma=0, zero=False)
def avgimgs(self, step=-1, ages=[], channels=[], labels={}, sharpen=True, outdir=None, decomposed=True, force=False, create=True, queue=None, batchname="avgimgs")
def defimgs(self, ages=[], channels=[], step=-1, decomposed=True, force=False, create=True, queue=None, batchname="defimgs")
def defimg(self, imgid, t, channel=None, path=None, step=-1, decomposed=True, force=False, create=True, batch=False)
def deformation(self, i, t=None, force=False, create=True)
def imgdof(self, imgid, t, step=-1, decomposed=False, force=False, create=True, batch=False)
def wait(self, jobs, interval=60, verbose=5)
def doftable(self, t, step=-1, force=False, create=True, batch=False)
def regdof(self, imgid, t=None, step=-1, path=None, force=False, create=True, batch=False)
def avgimg(self, t, channel=None, label=0, path=None, sharpen=True, outdir=None, step=-1, decomposed=True, force=False, create=True, batch=False)
def evaluate(self, ages=[], step=-1, force=False, queue=None)
def avgdof(self, t, path=None, step=-1, force=False, create=True, batch=False)
def _path(self, paths, name, default)
def regdofs(self, step=-1, force=False, create=True, queue=None, batchname="register")
def growth(self, t1, t2, step=-1, force=False, create=True, batch=False)