Direct Graphical Models  v.1.5.1
TrainNodeCvRF.h
1 // Random Forest (based on OpenCV) training class interface
2 // Written by Sergey G. Kosov in 2012 for Project X
3 #pragma once
4 
5 #include "TrainNode.h"
6 
7 namespace DirectGraphicalModels
8 {
9  class CRForest;
10 
12  typedef struct TrainNodeCvRFParams {
13  int max_depth;
20  int maxCount;
21  double epsilon;
23  int maxSamples;
24 
26  TrainNodeCvRFParams(int _max_depth, int _min_sample_count, float _regression_accuracy, bool _use_surrogates, int _max_categories, bool _calc_var_importance, int _nactive_vars, int _maxCount, double _epsilon, int _term_criteria_type, int _maxSamples) : max_depth(_max_depth), min_sample_count(_min_sample_count), regression_accuracy(_regression_accuracy), use_surrogates(_use_surrogates), max_categories(_max_categories), calc_var_importance(_calc_var_importance), nactive_vars(_nactive_vars), maxCount(_maxCount), epsilon(_epsilon), term_criteria_type(_term_criteria_type), maxSamples(_maxSamples) {}
28 
30  25, // Max depth
31  5, // Min sample count (1% of all data)
32  0, // Regression accuracy (0 means N/A here)
33  false, // Compute surrogate split, no missing data
34  15, // Max number of categories (use sub-optimal algorithm for larger numbers)
35  false, // Calculate variable importance
36  4, // Number of variables randomly selected at node and used to find the best split(s). 0 means sqrt(nFeatures)
37  100, // Max number of trees in the forest (time / accuracy)
38  0.01, // Forest accuracy
39  TermCriteria::MAX_ITER | TermCriteria::EPS, // Termination cirteria (according the the two previous parameters)
40  0 // Maximum number of samples to be used in training. 0 means using all the samples
41  );
42 
43  // =========================== OpenCV RF Train Class ===========================
49  class CTrainNodeCvRF : public CTrainNode
50  {
51  public:
58  DllExport CTrainNodeCvRF(byte nStates, word nFeatures, TrainNodeCvRFParams params = TRAIN_NODE_CV_RF_PARAMS_DEFAULT);
67  DllExport CTrainNodeCvRF(byte nStates, word nFeatures, int maxSamples);
68  DllExport ~CTrainNodeCvRF(void);
69 
70  DllExport void reset(void);
71  DllExport void save(const std::string &path, const std::string &name = std::string(), short idx = -1) const;
72  DllExport void load(const std::string &path, const std::string &name = std::string(), short idx = -1);
73 
74  DllExport void addFeatureVec(const Mat &featureVector, byte gt);
75  DllExport void train(bool doClean = false);
76 
83  DllExport Mat getFeatureImportance(void) const;
84 
85 
86  protected:
87  DllExport void saveFile(FILE *pFile) const { }
88  DllExport void loadFile(FILE *pFile) { }
89  DllExport void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const;
90 
91 
92  protected:
93  Ptr<CRForest> m_pRF;
94 
95 
96  private:
97  void init(TrainNodeCvRFParams params); // This function is called by both constructors
98 
99 
100  private:
101  vec_mat_t m_vSamplesAcc; // = vec_mat_t(nStates); // Samples container for all states
102  vec_int_t m_vNumInputSamples; // = vec_int_t(nStates, 0); // Amount of input samples for all states
103  int m_maxSamples; // = INFINITY; // for optimisation purposes
104  };
105 }
106 
void reset(void)
Resets class variables.
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
TrainNodeCvRFParams(int _max_depth, int _min_sample_count, float _regression_accuracy, bool _use_surrogates, int _max_categories, bool _calc_var_importance, int _nactive_vars, int _maxCount, double _epsilon, int _term_criteria_type, int _maxSamples)
Definition: TrainNodeCvRF.h:26
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
int term_criteria_type
Termination cirteria type (according the the two previous parameters)
Definition: TrainNodeCvRF.h:22
bool use_surrogates
Compute surrogate split, no missing data.
Definition: TrainNodeCvRF.h:16
struct DirectGraphicalModels::TrainNodeCvRFParams TrainNodeCvRFParams
OpenCV Random Forest parameters.
CTrainNodeCvRF(byte nStates, word nFeatures, TrainNodeCvRFParams params=TRAIN_NODE_CV_RF_PARAMS_DEFAULT)
Constructor.
int min_sample_count
Min sample count (1% of all data)
Definition: TrainNodeCvRF.h:14
OpenCV Random Forest parameters.
Definition: TrainNodeCvRF.h:12
int maxCount
Max number of trees in the forest (time / accuracy)
Definition: TrainNodeCvRF.h:20
void train(bool doClean=false)
Random model training.
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
int maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
Definition: TrainNodeCvRF.h:23
void loadFile(FILE *pFile)
Loads the random model from the file.
Definition: TrainNodeCvRF.h:88
const TrainNodeCvRFParams TRAIN_NODE_CV_RF_PARAMS_DEFAULT
Definition: TrainNodeCvRF.h:29
Mat getFeatureImportance(void) const
Returns the feature importance vector.
Ptr< CRForest > m_pRF
Random Forest.
Definition: TrainNodeCvRF.h:93
Base abstract class for node potentials training.
Definition: TrainNode.h:31
int max_categories
Max number of categories (use sub-optimal algorithm for larger numbers)
Definition: TrainNodeCvRF.h:17
float regression_accuracy
Regression accuracy (0 means N/A here)
Definition: TrainNodeCvRF.h:15
OpenCV Random Forest training class.
Definition: TrainNodeCvRF.h:49
void saveFile(FILE *pFile) const
Saves the random model into the file.
Definition: TrainNodeCvRF.h:87
bool calc_var_importance
Calculate variable importance (must be true in order to use CTrainNodeCvRF::getFeatureImportance func...
Definition: TrainNodeCvRF.h:18
int nactive_vars
Number of variables randomly selected at node and used to find the best split(s). (0 means the ) ...
Definition: TrainNodeCvRF.h:19