dcsam
Factored inference for discrete-continuous smoothing and mapping
DCMaxMixtureFactor.h
Go to the documentation of this file.
1 
10 #pragma once
11 
12 #include <math.h>
13 
14 #include <algorithm>
15 #include <limits>
16 #include <vector>
17 
18 #include "DCFactor.h"
19 
20 namespace dcsam {
21 
31 template <class DCFactorType>
32 class DCMaxMixtureFactor : public DCFactor {
33  private:
34  std::vector<DCFactorType> factors_;
35  std::vector<double> log_weights_;
36  bool normalized_;
37 
38  public:
39  using Base = DCFactor;
40 
41  DCMaxMixtureFactor() = default;
42 
43  explicit DCMaxMixtureFactor(const gtsam::KeyVector& continuousKeys,
44  const gtsam::DiscreteKeys& discreteKeys,
45  const std::vector<DCFactorType> factors,
46  const std::vector<double> weights,
47  const bool normalized)
48  : Base(continuousKeys, discreteKeys), normalized_(normalized) {
49  factors_ = factors;
50  for (size_t i = 0; i < weights.size(); i++) {
51  log_weights_.push_back(log(weights[i]));
52  }
53  }
54 
55  explicit DCMaxMixtureFactor(const gtsam::KeyVector& continuousKeys,
56  const gtsam::DiscreteKeys& discreteKeys,
57  const std::vector<DCFactorType> factors,
58  const bool normalized)
59  : Base(continuousKeys, discreteKeys), normalized_(normalized) {
60  factors_ = factors;
61  for (size_t i = 0; i < factors_.size(); i++) {
62  log_weights_.push_back(0);
63  }
64  }
65 
67  this->factors_ = rhs.factors_;
68  this->log_weights_ = rhs.log_weights_;
69  this->normalized_ = rhs.normalized_;
70  }
71 
72  virtual ~DCMaxMixtureFactor() = default;
73 
74  double error(const gtsam::Values& continuousVals,
75  const DiscreteValues& discreteVals) const override {
76  size_t min_error_idx = getActiveFactorIdx(continuousVals, discreteVals);
77  assert(0 <= min_error_idx);
78  assert(min_error_idx < factors_.size());
79  double min_error =
80  factors_[min_error_idx].error(continuousVals, discreteVals);
81  if (normalized_) return min_error - log_weights_[min_error_idx];
82  return min_error +
83  factors_[min_error_idx].logNormalizingConstant(continuousVals) -
84  log_weights_[min_error_idx];
85  }
86 
87  size_t getActiveFactorIdx(const gtsam::Values& continuousVals,
88  const DiscreteValues& discreteVals) const {
89  double min_error = std::numeric_limits<double>::infinity();
90  size_t min_error_idx = 0;
91  for (size_t i = 0; i < factors_.size(); i++) {
92  double error =
93  factors_[i].error(continuousVals, discreteVals) - log_weights_[i];
94  if (!normalized_)
95  error += factors_[i].logNormalizingConstant(continuousVals);
96 
97  if (error < min_error) {
98  min_error = error;
99  min_error_idx = i;
100  }
101  }
102  return min_error_idx;
103  }
104 
105  size_t dim() const override {
106  if (factors_.size() > 0) {
107  return factors_[0].dim();
108  } else {
109  return 0;
110  }
111  }
112 
113  bool equals(const DCFactor& other, double tol = 1e-9) const override {
114  if (!dynamic_cast<const DCMaxMixtureFactor*>(&other)) return false;
115  const DCMaxMixtureFactor& f(static_cast<const DCMaxMixtureFactor&>(other));
116  if (factors_.size() != f.factors_.size()) return false;
117  for (size_t i = 0; i < factors_.size(); i++) {
118  if (!factors_[i].equals(f.factors_[i])) return false;
119  }
120  return ((log_weights_ == f.log_weights_) && (normalized_ == f.normalized_));
121  }
122 
123  boost::shared_ptr<gtsam::GaussianFactor> linearize(
124  const gtsam::Values& continuousVals,
125  const DiscreteValues& discreteVals) const override {
126  size_t min_error_idx = getActiveFactorIdx(continuousVals, discreteVals);
127  return factors_[min_error_idx].linearize(continuousVals, discreteVals);
128  }
129 
130  gtsam::DecisionTreeFactor uniformDecisionTreeFactor(
131  const gtsam::DiscreteKey& dk) const {
132  std::vector<double> probs(dk.second, (1.0 / dk.second));
133  gtsam::DecisionTreeFactor uniform(dk, probs);
134  return uniform;
135  }
136 
137  gtsam::DecisionTreeFactor toDecisionTreeFactor(
138  const gtsam::Values& continuousVals,
139  const DiscreteValues& discreteVals) const override {
140  size_t min_error_idx = getActiveFactorIdx(continuousVals, discreteVals);
141  gtsam::DecisionTreeFactor converted;
142  for (size_t i = 0; i < factors_.size(); i++) {
143  if (i == min_error_idx) {
144  converted = converted * factors_[min_error_idx].toDecisionTreeFactor(
145  continuousVals, discreteVals);
146  } else {
147  for (const gtsam::DiscreteKey& dk : factors_[i].discreteKeys()) {
148  converted = converted * uniformDecisionTreeFactor(dk);
149  }
150  }
151  }
152  return converted;
153  }
154 
155  gtsam::FastVector<gtsam::Key> getAssociationKeys(
156  const gtsam::Values& continuousVals,
157  const DiscreteValues& discreteVals) const {
158  size_t min_error_idx = getActiveFactorIdx(continuousVals, discreteVals);
159  return factors_[min_error_idx].keys();
160  }
161 
162  void updateWeights(const std::vector<double>& weights) {
163  if (weights.size() != log_weights_.size()) {
164  std::cerr << "Attempted to update weights with incorrectly sized vector."
165  << std::endl;
166  return;
167  }
168  for (int i = 0; i < weights.size(); i++) {
169  log_weights_[i] = log(weights[i]);
170  }
171  }
172 };
173 } // namespace dcsam
Custom discrete-continuous factor.
Abstract class implementing a discrete-continuous factor.
Definition: DCFactor.h:33
gtsam::DiscreteKeys discreteKeys() const
Definition: DCFactor.h:135
gtsam::Factor Base
Definition: DCFactor.h:39
DCFactor()=default
Implementation of a discrete-continuous max-mixture factor.
Definition: DCMaxMixtureFactor.h:32
void updateWeights(const std::vector< double > &weights)
Definition: DCMaxMixtureFactor.h:162
DCMaxMixtureFactor(const gtsam::KeyVector &continuousKeys, const gtsam::DiscreteKeys &discreteKeys, const std::vector< DCFactorType > factors, const bool normalized)
Definition: DCMaxMixtureFactor.h:55
gtsam::DecisionTreeFactor toDecisionTreeFactor(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCMaxMixtureFactor.h:137
boost::shared_ptr< gtsam::GaussianFactor > linearize(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCMaxMixtureFactor.h:123
size_t dim() const override
Definition: DCMaxMixtureFactor.h:105
gtsam::FastVector< gtsam::Key > getAssociationKeys(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCMaxMixtureFactor.h:155
double error(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCMaxMixtureFactor.h:74
DCMaxMixtureFactor & operator=(const DCMaxMixtureFactor &rhs)
Definition: DCMaxMixtureFactor.h:66
bool equals(const DCFactor &other, double tol=1e-9) const override
Definition: DCMaxMixtureFactor.h:113
virtual ~DCMaxMixtureFactor()=default
gtsam::DecisionTreeFactor uniformDecisionTreeFactor(const gtsam::DiscreteKey &dk) const
Definition: DCMaxMixtureFactor.h:130
DCMaxMixtureFactor(const gtsam::KeyVector &continuousKeys, const gtsam::DiscreteKeys &discreteKeys, const std::vector< DCFactorType > factors, const std::vector< double > weights, const bool normalized)
Definition: DCMaxMixtureFactor.h:43
size_t getActiveFactorIdx(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCMaxMixtureFactor.h:87
Definition: DCContinuousFactor.h:24
gtsam::DiscreteFactor::Values DiscreteValues
Definition: DCSAM_types.h:19
const double tol
Definition: testDCSAM.cpp:40