DrawEM.h
1 /*
2  * Developing brain Region Annotation With Expectation-Maximization (Draw-EM)
3  *
4  * Copyright 2013-2016 Imperial College London
5  * Copyright 2013-2016 Christian Ledig
6  * Copyright 2013-2016 Antonios Makropoulos
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef DrawEM_H_
22 #define DrawEM_H_
23 
24 #include "mirtk/EMBase.h"
25 
26 #include "mirtk/PolynomialBiasField.h"
27 #include "mirtk/Image.h"
28 #include "mirtk/HashProbabilisticAtlas.h"
29 #include "mirtk/BiasField.h"
30 #include "mirtk/BiasCorrection.h"
31 #include "mirtk/Image.h"
32 #include "mirtk/GaussianBlurring.h"
33 #include "mirtk/EuclideanDistanceTransform.h"
34 #include "mirtk/ConnectedComponents.h"
35 
36 #include <set>
37 #include <map>
38 #include <vector>
39 #include <utility>
40 
41 namespace mirtk {
42 
43 class DrawEM : public EMBase
44 {
45  mirtkObjectMacro(DrawEM);
46 
47 protected:
48  //
49 
50  /// local weights for MRF field
51  RealImage _MRF_weights;
52 
53  /// Uncorrected image
54  RealImage _uncorrected;
55 
56  /// Bias field correction filter
57  BiasCorrection _biascorrection;
58 
59  /// Bias field
60  BiasField *_biasfield;
61 
62  /// MRF connectivity
63  Matrix _connectivity;
64 
65  /// PV classes
66  map<int,int> pv_classes;
67  vector< pair<int, int> > pv_connections;
68  vector<double> pv_fc;
69 
70  /// use 26-nn MRF
71  bool bignn;
72  /// do PV correction similar to Hui et al.
73  bool huipvcorr;
74  /// weight of the MRF
75  double mrfweight;
76  /// intra MRF
77  double beta,betainter;
78 
79  /// inter MRF
80  bool intermrf;
81  RealImage **_MRF_inter;
82 
83  /// the tissue class of each label
84  int *tissuelabels;
85  int csflabel,wmlabel,gmlabel,outlabel;
86 
87 private:
88  bool isPVclass(int pvclass);
89  double getMRFenergy(int index, int tissue);
90  double getMRFInterEnergy(int index, int tissue);
91 
92 public:
93  /// Constructor
94  DrawEM();
95  template <class ImageType>
96  DrawEM(int noTissues, ImageType **atlas, ImageType *background);
97  template <class ImageType>
98  DrawEM(int noTissues, ImageType **atlas);
99  template <class ImageType>
100  DrawEM(int noTissues, ImageType **atlas, ImageType **initposteriors);
101 
102  /// Initialize parameters
103  void InitialiseParameters();
104 
105  /// Get the bias field
106  void GetBiasField(RealImage &image);
107 
108  /// add partial volume between classes classA and classB
109  int AddPartialVolumeClass(int classA, int classB, int huiclass=0);
110 
111  /// estimate probabilities
112  void EStepMRF(void);
113 
114  /// relax priors
115  void RStep(void);
116  void RStep(double rf);
117 
118  /// Estimates bias field
119  virtual void BStep();
120 
121 protected:
122 
123  using EMBase::SetInput;
124 
125 public:
126 
127  /// Set image
128  virtual void SetInput(const RealImage &, const Matrix &);
129 
130  /// Set bias field
131  virtual void SetBiasField(BiasField *);
132 
133  /// Execute one iteration and return log likelihood
134  virtual double Iterate(int iteration);
135 
136  /// Compute the bias corrected image
137  virtual void GetBiasCorrectedImage(RealImage &);
138 
139 
140  /// set the MRF strength
141  virtual void setMRFstrength(double mrfw);
142  /// set a 26-neighborhood in the MRF
143  virtual void setbignn(bool bnn);
144  /// computes the MRF with the 26-neighborhood
145  double getMRFenergy_diag(int index, int tissue);
146 
147  /// removes the PV classes!
148  //void removePVclasses(double threshold = 0.1);
149 
150  /// sets the tissue class of each label
151  void setTissueLabels(int num,int *atisslabels);
152 
153  /// set hui-style PV correction
154  void setHui(bool hui);
155  /// hui-style PV correction - HACKY!!
156  void huiPVCorrection(bool changePosterior=false);
157  /// construct seg with the tissue class of each label instead of the label itself
158  void ConstructSegmentationHui(IntegerImage &segmentation);
159  /// finds the overall tissue probabilities at the voxel (x,y,z)
160  /// by adding the probability of the different labels belonging to the tissue class
161  void getHuiValues(double &outval,double &csfval,double &gmval,double &wmval,int x,int y,int z,bool atlas);
162  /// modifies the probability of the labels according to overall tissue probabilities at the voxel (x,y,z)
163  /// the overall tissue probability is divided to its labels according to their "contribution" to the tissue
164  void setHuiValues(double &outval,double &csfval,double &gmval,double &wmval,int x,int y,int z,bool atlas);
165 
166  /// set beta_intra
167  virtual void setBeta(double beta);
168  /// set beta_inter
169  virtual void setBetaInter(double betainter);
170  /// set mrf_inter term
171  virtual void setMRFInterAtlas(RealImage **&atlas);
172 
173 };
174 
175 inline void DrawEM::setHui(bool hui){huipvcorr=hui;}
176 inline void DrawEM::setbignn(bool bnn){bignn=bnn;}
177 inline void DrawEM::setMRFstrength(double mrfw){mrfweight=mrfw;}
178 inline void DrawEM::setMRFInterAtlas(RealImage **&atlas){ _MRF_inter=atlas; intermrf=true;}
179 inline void DrawEM::setBeta(double b){beta=b;}
180 inline void DrawEM::setBetaInter(double b){betainter=b;}
181 inline void DrawEM::setTissueLabels(int num,int *atisslabels){
182  tissuelabels=new int[num];
183  for(int i=0;i<num;i++) tissuelabels[i]=atisslabels[i];
184 }
185 
186 }
187 
188 
189 
190 
191 #endif /* DrawEM_H_ */
Definition: IOConfig.h:41