ConjugateGradientDescent.h
1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2015 Imperial College London
5  * Copyright 2013-2015 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #ifndef MIRTK_ConjugateGradientDescent_H
21 #define MIRTK_ConjugateGradientDescent_H
22 
23 #include "mirtk/GradientDescent.h"
24 #include "mirtk/Math.h"
25 
26 
27 namespace mirtk {
28 
29 
30 /**
31  * Minimizes objective function using conjugate gradient descent
32  */
34 {
35  mirtkOptimizerMacro(ConjugateGradientDescent, OM_ConjugateGradientDescent);
36 
37  // ---------------------------------------------------------------------------
38  // Attributes
39 
40  /// Enable/disable conjugation of gradient
41  ///
42  /// Can be used to switch to steepest gradient descent once the conjugate
43  /// gradient descent is converged. Call Run after setting ConjugateGradientOff.
44  mirtkPublicAttributeMacro(bool, UseConjugateGradient);
45 
46  /// Whether to compute conjugate gradient of total energy gradient
47  ///
48  /// When disabled, only the gradient of the data fidelity term is conjugated
49  /// before adding the gradient of the model constraint term.
50  mirtkPublicAttributeMacro(bool, ConjugateTotalGradient);
51 
52 protected:
53 
54  double *_g;
55  double *_h;
56 
57  /// Copy attributes of this class from another instance
59 
60  // ---------------------------------------------------------------------------
61  // Construction/Destruction
62 
63 public:
64 
65  /// Constructor
67 
68  /// Copy constructor
70 
71  /// Assignment operator
73 
74  /// Destructor
75  virtual ~ConjugateGradientDescent();
76 
77  // ---------------------------------------------------------------------------
78  // Parameters
80 
81  /// Set parameter value from string
82  virtual bool Set(const char *, const char *);
83 
84  /// Get parameters as key/value as string map
85  virtual ParameterList Parameter() const;
86 
87  // ---------------------------------------------------------------------------
88  // Optimization
89 
90  /// Reset conjugate gradient to objective function gradient
92 
93  /// Enable conjugation of data fidelity gradient
94  void ConjugateGradientOn();
95 
96  /// Enable conjugation of objective function gradient
98 
99  /// Disable conjugation of objective function gradient
100  /// When conjugation was enabled before, the data fidelity gradient is still conjugated
102 
103  /// Disable conjugation of objective function gradient
104  void ConjugateGradientOff();
105 
106 protected:
107 
108  /// Initialize gradient descent
109  virtual void Initialize();
110 
111  /// Finalize gradient descent
112  virtual void Finalize();
113 
114  /// Compute gradient of objective function and make it conjugate
115  virtual void Gradient(double *, double = .0, bool * = NULL);
116 
117  /// Compute conjugate gradient
118  void ConjugateGradient(double *);
119 
120 };
121 
122 ////////////////////////////////////////////////////////////////////////////////
123 // Inline definitions
124 ////////////////////////////////////////////////////////////////////////////////
125 
126 // -----------------------------------------------------------------------------
128 {
129  // Set g[0] to NaN to indicate that vectors _g and _h are uninitialized
130  if (_g) _g[0] = NaN;
131 }
132 
133 // -----------------------------------------------------------------------------
135 {
136  this->UseConjugateGradient(true);
137 }
138 
139 // -----------------------------------------------------------------------------
141 {
142  this->UseConjugateGradient(true);
143  this->ConjugateTotalGradient(true);
144 }
145 
146 // -----------------------------------------------------------------------------
148 {
149  this->ConjugateTotalGradient(false);
150 }
151 
152 // -----------------------------------------------------------------------------
154 {
155  this->UseConjugateGradient(false);
156 }
157 
158 
159 } // namespace mirtk
160 
161 #endif // MIRTK_ConjugateGradientDescent_H
void ConjugateGradientOn()
Enable conjugation of data fidelity gradient.
virtual ~ConjugateGradientDescent()
Destructor.
void ConjugateTotalGradientOn()
Enable conjugation of objective function gradient.
void CopyAttributes(const ConjugateGradientDescent &)
Copy attributes of this class from another instance.
virtual bool Set(const char *, const char *)
Set parameter value from string.
virtual ParameterList Parameter() const
Get parameters as key/value as string map.
Array< Pair< string, string > > ParameterList
Ordered list of parameter name/value pairs.
Definition: Object.h:38
void ResetConjugateGradient()
Reset conjugate gradient to objective function gradient.
ConjugateGradientDescent(ObjectiveFunction *=NULL)
Constructor.
virtual void Initialize()
Initialize gradient descent.
virtual ParameterList Parameter() const
Get parameters as key/value as string map.
Definition: IOConfig.h:41
ConjugateGradientDescent & operator=(const ConjugateGradientDescent &)
Assignment operator.
virtual void Gradient(double *, double=.0, bool *=NULL)
Compute gradient of objective function and make it conjugate.
void ConjugateGradientOff()
Disable conjugation of objective function gradient.
virtual void Finalize()
Finalize gradient descent.
void ConjugateGradient(double *)
Compute conjugate gradient.