TransformationApproximationError.h
1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2015 Imperial College London
5  * Copyright 2013-2015 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #ifndef MIRTK_TransformationApproximationError_H
21 #define MIRTK_TransformationApproximationError_H
22 
23 #include "mirtk/ObjectiveFunction.h"
24 
25 #include "mirtk/Array.h"
26 #include "mirtk/Vector3D.h"
27 #include "mirtk/Transformation.h"
28 
29 
30 namespace mirtk {
31 
32 
33 /**
34  * Mean-squared-error of transformation approximation
35  *
36  * This objective function is minimized by Transformation::ApproximateDOFs
37  * to find the transformation parameters which minimize the mean squared error
38  * of the approximation.
39  */
41 {
42  mirtkObjectMacro(TransformationApproximationError);
43 
44  // ---------------------------------------------------------------------------
45  // Attributes
46 
47  /// Transformation
48  mirtkReadOnlyAggregateMacro(class Transformation, Transformation);
49 
50  /// Number of distinct time points
51  mirtkReadOnlyAttributeMacro(int, NumberOfTimePoints);
52 
53  /// Number of points
54  mirtkReadOnlyAttributeMacro(int, NumberOfPoints);
55 
56  /// Centroid of target points
57  mirtkReadOnlyAttributeMacro(Point, TargetCenter);
58 
59  /// Centroid of source points
60  mirtkReadOnlyAttributeMacro(Point, SourceCenter);
61 
62 protected:
63 
64  Array<PointSet> _Target;
65  Array<double> _TargetTime;
66  Array<PointSet> _Current;
67  Array<PointSet> _Source;
68  Vector3D<double> *_Gradient;
69 
70 public:
71 
72  // ---------------------------------------------------------------------------
73  // Construction/destruction
74 
75  /// Constructor
77  const double *, const double *, const double *, const double *,
78  const double *, const double *, const double *, int);
79 
80  /// Destructor
82 
83  /// Subtract centroid from each point set
84  void CenterPoints();
85 
86  // ---------------------------------------------------------------------------
87  // Function parameters (DoFs)
88 
89  /// Get number of DoFs
90  ///
91  /// \returns Number of free function parameters.
92  virtual int NumberOfDOFs() const;
93 
94  /// Set function parameter values
95  ///
96  /// This is function can be used to set the parameters of the objective function
97  /// to particular values. In particular, it can be used to restore the function
98  /// parameters after a failed incremental update which did not result in the
99  /// desired improvement.
100  ///
101  /// \param[in] x Function parameter (DoF) values.
102  virtual void Put(const double *x);
103 
104  /// Get function parameter value
105  ///
106  /// \param[in] i Function parameter (DoF) index.
107  ///
108  /// \returns Value of specified function parameter (DoF).
109  virtual double Get(int i) const;
110 
111  /// Get function parameter values
112  ///
113  /// This function can be used to store a backup of the current funtion parameter
114  /// values before an update such that these can be restored using the Put
115  /// member function if the update did not result in the desired change of the
116  /// overall objective function value.
117  ///
118  /// \param[in] x Function parameter (DoF) values.
119  virtual void Get(double *x) const;
120 
121  /// Add change (i.e., scaled gradient) to each parameter value
122  ///
123  /// This function updates each DoF of the objective function given a vector
124  /// of corresponding changes, i.e., the computed gradient of the objective
125  /// function w.r.t. these parameters or a desired change computed otherwise.
126  ///
127  /// \param[in] dx Change of each function parameter (DoF) as computed by the
128  /// Gradient member function and scaled by a chosen step length.
129  ///
130  /// \returns Maximum change of function parameter.
131  virtual double Step(double *dx);
132 
133  /// Update internal state after change of parameters
134  ///
135  /// \param[in] gradient Update also internal state required for evaluation of
136  /// gradient of objective function.
137  virtual void Update(bool gradient = true);
138 
139  // ---------------------------------------------------------------------------
140  // Evaluation
141 
142  /// Evaluate objective function value
143  virtual double Value();
144 
145  /// Evaluate gradient of objective function w.r.t its DoFs
146  ///
147  /// \param[in] step Step length for finite differences.
148  /// \param[out] dx Gradient of objective function.
149  /// \param[out] sgn_chg Whether function parameter value is allowed to
150  /// change sign when stepping along the computed gradient.
151  virtual void Gradient(double *dx, double step = .0, bool *sgn_chg = NULL);
152 
153  /// Compute norm of gradient of objective function
154  ///
155  /// This norm is used to define a unit for the step length used by gradient
156  /// descent methods. It is, for example, the maximum absolute value norm for
157  /// linear transformations and the maximum control point displacement for FFDs.
158  /// The computation of the norm may be done after conjugating the gradient
159  /// vector obtained using the Gradient member function.
160  ///
161  /// \param[in] dx Gradient of objective function.
162  virtual double GradientNorm(const double *dx) const;
163 
164 };
165 
166 
167 } // namespace mirtk
168 
169 #endif // MIRTK_TransformationApproximationError_H
void CenterPoints()
Subtract centroid from each point set.
virtual void Update(bool gradient=true)
virtual void Put(const double *x)
TransformationApproximationError(class Transformation *, const double *, const double *, const double *, const double *, const double *, const double *, const double *, int)
Constructor.
virtual double Value()
Evaluate objective function value.
virtual void Gradient(double *dx, double step=.0, bool *sgn_chg=NULL)
virtual double GradientNorm(const double *dx) const
Definition: IOConfig.h:41
virtual double Step(double *dx)
virtual ~TransformationApproximationError()
Destructor.
virtual double Get(int i) const
int NumberOfPoints(vtkDataSet *)
Number of points.