EnergyTerm.h
1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2015 Imperial College London
5  * Copyright 2013-2015 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #ifndef MIRTK_EnergyTerm_H
21 #define MIRTK_EnergyTerm_H
22 
23 #include "mirtk/Configurable.h"
24 
25 #include "mirtk/Indent.h"
26 #include "mirtk/EnergyMeasure.h"
27 #include "mirtk/Transformation.h"
28 #include "mirtk/ObjectFactory.h"
29 
30 
31 namespace mirtk {
32 
33 
34 /**
35  * Base class for one term of an objective function
36  *
37  * In particular, this is the base class for both the data similarity term
38  * and transformation regularization term commonly seen in objective functions
39  * used for image/surface registration.
40  */
41 class EnergyTerm : public Configurable
42 {
43  mirtkAbstractMacro(EnergyTerm);
44 
45  // ---------------------------------------------------------------------------
46  // Attributes
47 
48  /// Weight of energy term
49  mirtkPublicAttributeMacro(double, Weight);
50 
51  /// Transformation with free parameters of energy function
52  mirtkPublicAggregateMacro(class Transformation, Transformation);
53 
54  /// Whether to divide energy term by its initial value
55  mirtkPublicAttributeMacro(bool, DivideByInitialValue);
56 
57  /// Initial unweighted value of energy term
58  double _InitialValue;
59 
60  /// Cached unweighted value of energy term
61  double _Value;
62 
63  /// Copy attributes of this class from another instance
64  void CopyAttributes(const EnergyTerm &);
65 
66  // ---------------------------------------------------------------------------
67  // Construction/Destruction
68 protected:
69 
70  /// Constructor
71  EnergyTerm(const char * = "", double = 1.0);
72 
73  /// Copy constructor
74  EnergyTerm(const EnergyTerm &);
75 
76  /// Assignment operator
78 
79 public:
80 
81  /// Type of energy term factory
83 
84  /// Get global energy term factory instance
85  static FactoryType &Factory();
86 
87  /// Construct new energy term or return nullptr if term not available
88  static EnergyTerm *TryNew(EnergyMeasure, const char * = "", double = 1.0);
89 
90  /// Construct new energy term
91  static EnergyTerm *New(EnergyMeasure, const char * = "", double = 1.0);
92 
93  /// Destructor
94  virtual ~EnergyTerm();
95 
96  /// Energy measure implemented by this term
97  virtual enum EnergyMeasure EnergyMeasure() const = 0;
98 
99  // ---------------------------------------------------------------------------
100  // Parameters
101 
102 protected:
103 
104  /// Set parameter value from string
105  virtual bool SetWithPrefix(const char *, const char *);
106 
107  /// Set parameter value from string
108  virtual bool SetWithoutPrefix(const char *, const char *);
109 
110 public:
111 
112  // Import other overloads
114 
115  /// Get parameter key/value as string map
116  virtual ParameterList Parameter() const;
117 
118  // ---------------------------------------------------------------------------
119  // Evaluation
120 
121  /// Initialize energy term once input and parameters have been set
122  virtual void Initialize();
123 
124  /// Update internal state after change of DoFs
125  ///
126  /// \param[in] gradient Whether to also update internal state for evaluation
127  /// of energy gradient. If \c false, only the internal state
128  /// required for the energy evaluation need to be updated.
129  virtual void Update(bool gradient = true);
130 
131  /// Update energy term after convergence
132  virtual bool Upgrade();
133 
134  /// Reset initial value of energy term
135  void ResetInitialValue();
136 
137  /// Reset cached value of energy term
138  void ResetValue();
139 
140  /// Returns initial value of energy term
141  double InitialValue();
142 
143  /// Evaluate energy term
144  double Value();
145 
146  /// Evaluate gradient of energy term
147  ///
148  /// \param[in,out] gradient Gradient to which the evaluated gradient of this
149  /// energy term is added to with its resp. weight.
150  /// \param[in] step Step length for finite differences.
151  void Gradient(double *gradient, double step);
152 
153  /// Evaluate and normalize gradient of energy term
154  ///
155  /// \param[in,out] gradient Gradient to which the evaluated normalized gradient
156  /// of this energy term is added to with its resp. weight.
157  /// \param[in] step Step length for finite differences.
158  void NormalizedGradient(double *gradient, double step);
159 
160  /// Adjust step length range
161  ///
162  /// \param[in] gradient Gradient of objective function.
163  /// \param[in,out] min Minimum step length.
164  /// \param[in,out] max Maximum step length.
165  virtual void GradientStep(const double *gradient, double &min, double &max) const;
166 
167 protected:
168 
169  /// Evaluate unweighted energy term
170  virtual double Evaluate() = 0;
171 
172  /// Evaluate and add gradient of energy term
173  ///
174  /// \param[in,out] gradient Gradient to which the computed gradient of the
175  /// energy term should be added to.
176  /// \param[in] step Step length for finite differences.
177  /// \param[in] weight Weight to use when adding the gradient.
178  virtual void EvaluateGradient(double *gradient, double step, double weight) = 0;
179 
180  // ---------------------------------------------------------------------------
181  // Debugging
182 
183 public:
184 
185  /// Return unweighted and unnormalized raw energy term value
186  /// \remarks Use for progress reporting only.
187  virtual double RawValue(double) const;
188 
189  /// Return unweighted and unnormalized raw energy term value
190  /// \remarks Use for progress reporting only.
191  double RawValue();
192 
193  /// Print debug information
194  virtual void Print(Indent = 0) const;
195 
196  /// Prefix to be used for debug output files
197  string Prefix(const char * = NULL) const;
198 
199  /// Write input of data fidelity term
200  virtual void WriteDataSets(const char *, const char *, bool = true) const;
201 
202  /// Write gradient of data fidelity term w.r.t each transformed input
203  virtual void WriteGradient(const char *, const char *) const;
204 
205 };
206 
207 ////////////////////////////////////////////////////////////////////////////////
208 // Auxiliary macros for optimizer implementation
209 ////////////////////////////////////////////////////////////////////////////////
210 
211 // -----------------------------------------------------------------------------
212 #define mirtkEnergyTermMacro(name, id) \
213  mirtkObjectMacro(name); \
214 public: \
215  /** Energy measure implemented by this term */ \
216  static mirtk::EnergyMeasure ID() { return id; } \
217  /** Energy measure implemented by this term */ \
218  virtual mirtk::EnergyMeasure EnergyMeasure() const { return id; } \
219 private:
220 
221 // -----------------------------------------------------------------------------
222 /// Register object type with factory singleton
223 #define mirtkRegisterEnergyTermMacro(type) \
224  mirtk::EnergyTerm::Factory().Register(type::ID(), type::NameOfType(), \
225  mirtk::New<mirtk::EnergyTerm, type>)
226 
227 // -----------------------------------------------------------------------------
228 /// Register object type with factory singleton at static initialization time
229 #define mirtkAutoRegisterEnergyTermMacro(type) \
230  mirtkAutoRegisterObjectTypeMacro(mirtk::EnergyTerm::Factory(), \
231  mirtk::EnergyMeasure, type::ID(), \
232  mirtk::EnergyTerm, type)
233 
234 
235 } // namespace mirtk
236 
237 #endif // MIRTK_EnergyTerm_H
double InitialValue()
Returns initial value of energy term.
string Prefix(const char *=NULL) const
Prefix to be used for debug output files.
static EnergyTerm * TryNew(EnergyMeasure, const char *="", double=1.0)
Construct new energy term or return nullptr if term not available.
void ResetValue()
Reset cached value of energy term.
ObjectFactory< enum EnergyMeasure, EnergyTerm > FactoryType
Type of energy term factory.
Definition: EnergyTerm.h:82
virtual bool Upgrade()
Update energy term after convergence.
virtual bool SetWithoutPrefix(const char *, const char *)
Set parameter value from string.
void ResetInitialValue()
Reset initial value of energy term.
double Value()
Evaluate energy term.
virtual void WriteGradient(const char *, const char *) const
Write gradient of data fidelity term w.r.t each transformed input.
virtual void Initialize()
Initialize energy term once input and parameters have been set.
virtual ~EnergyTerm()
Destructor.
Array< Pair< string, string > > ParameterList
Ordered list of parameter name/value pairs.
Definition: Object.h:38
void NormalizedGradient(double *gradient, double step)
virtual ParameterList Parameter() const
Get parameter key/value as string map.
Definition: IOConfig.h:41
virtual enum EnergyMeasure EnergyMeasure() const =0
Energy measure implemented by this term.
EnergyTerm & operator=(const EnergyTerm &)
Assignment operator.
static EnergyTerm * New(EnergyMeasure, const char *="", double=1.0)
Construct new energy term.
virtual bool SetWithPrefix(const char *, const char *)
Set parameter value from string.
static FactoryType & Factory()
Get global energy term factory instance.
virtual void GradientStep(const double *gradient, double &min, double &max) const
virtual double Evaluate()=0
Evaluate unweighted energy term.
virtual void WriteDataSets(const char *, const char *, bool=true) const
Write input of data fidelity term.
virtual void Print(Indent=0) const
Print debug information.
EnergyMeasure
Enumeration of all available energy terms.
Definition: EnergyMeasure.h:31
EnergyTerm(const char *="", double=1.0)
Constructor.
virtual void Update(bool gradient=true)
virtual ParameterList Parameter() const
Get parameter name/value pairs.
Definition: Object.h:139
virtual void EvaluateGradient(double *gradient, double step, double weight)=0
void Gradient(double *gradient, double step)