GradientDescent.h
1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2015 Imperial College London
5  * Copyright 2013-2015 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #ifndef MIRTK_GradientDescent_H
21 #define MIRTK_GradientDescent_H
22 
23 #include "mirtk/LocalOptimizer.h"
24 #include "mirtk/LineSearch.h"
25 #include "mirtk/EventDelegate.h"
26 
27 
28 namespace mirtk {
29 
30 
31 
32 /**
33  * Minimizes objective function using gradient descent
34  */
36 {
37  mirtkOptimizerMacro(GradientDescent, OM_GradientDescent);
38 
39  // ---------------------------------------------------------------------------
40  // Attributes
41 
42  /// Maximum number of restarts after upgrade of energy function
43  mirtkPublicAttributeMacro(int, NumberOfRestarts);
44 
45  /// Maximum streak of unsuccessful restarts without improvement
46  mirtkPublicAttributeMacro(int, NumberOfFailedRestarts);
47 
48  /// Line search strategy
49  mirtkPublicAttributeMacro(enum LineSearchStrategy, LineSearchStrategy);
50 
51  /// (Possible) Line search parameter
52  mirtkAttributeMacro(ParameterList, LineSearchParameter);
53 
54  /// Line search optimization method
55  mirtkReadOnlyAggregateMacro(class LineSearch, LineSearch);
56 
57  /// Whether line search object is owned by this optimizer
58  mirtkAttributeMacro(bool, LineSearchOwner);
59 
60  /// Forwards line search event messages to observers of optimization
61  EventDelegate _EventDelegate;
62 
63  /// Copy attributes of this class from another instance
64  void CopyAttributes(const GradientDescent &);
65 
66 protected:
67 
68  /// Allocated memory for line search direction
69  double *_Gradient;
70 
71  /// Whether to allow function parameter sign to change
72  ///
73  /// If \c false for a particular function parameter, the line search sets
74  /// the function parameter to zero whenever the sign of the parameter would
75  /// change when taking a full step along the scaled gradient direction.
77 
78  // ---------------------------------------------------------------------------
79  // Construction/Destruction
80 public:
81 
82  /// Constructor
84 
85  /// Copy constructor
87 
88  /// Assignment operator
90 
91  /// Destructor
92  virtual ~GradientDescent();
93 
94  // Import overloads from base class
95  using LocalOptimizer::Function;
96 
97  /// Set objective function
98  virtual void Function(ObjectiveFunction *);
99 
100  /// Set line search object
101  virtual void LineSearch(class LineSearch *, bool = false);
102 
103  // ---------------------------------------------------------------------------
104  // Parameters
106 
107  /// Set parameter value from string
108  virtual bool Set(const char *, const char *);
109 
110  /// Get parameters as key/value as string map
111  virtual ParameterList Parameter() const;
112 
113  // ---------------------------------------------------------------------------
114  // Execution
115 
116  /// Initialize gradient descent
117  ///
118  /// This member funtion is implicitly called by Run. It can, however, be
119  /// called prior to Run explicitly in order to be able to set up the line
120  /// search instance. Otherwise, use the generic Set member function to
121  /// change the line search parameters and simply have Run call Initialize.
122  virtual void Initialize();
123 
124  /// Optimize objective function using gradient descent
125  virtual double Run();
126 
127 protected:
128 
129  /// Compute descent direction
130  virtual void Gradient(double *, double = .0, bool * = NULL);
131 
132  /// Finalize gradient descent
133  virtual void Finalize();
134 
135 };
136 
137 
138 } // namespace mirtk
139 
140 #endif // MIRTK_GradientDescent_H
virtual void Gradient(double *, double=.0, bool *=NULL)
Compute descent direction.
virtual ParameterList Parameter() const
Get parameters as key/value as string map.
GradientDescent(ObjectiveFunction *=NULL)
Constructor.
virtual void Initialize()
virtual ~GradientDescent()
Destructor.
Array< Pair< string, string > > ParameterList
Ordered list of parameter name/value pairs.
Definition: Object.h:38
virtual ParameterList Parameter() const
Get parameters as key/value as string map.
virtual bool Set(const char *, const char *)
Set parameter value from string.
Definition: IOConfig.h:41
virtual void Function(ObjectiveFunction *)
Set objective function.
virtual void LineSearch(class LineSearch *, bool=false)
Set line search object.
double * _Gradient
Allocated memory for line search direction.
GradientDescent & operator=(const GradientDescent &)
Assignment operator.
virtual void Finalize()
Finalize gradient descent.
LineSearchStrategy
Enumeration of available line search strategies.
Definition: LineSearch.h:32
virtual double Run()
Optimize objective function using gradient descent.