dcsam
Factored inference for discrete-continuous smoothing and mapping
DCEMFactor.h
Go to the documentation of this file.
1 
8 #pragma once
9 
10 #include <math.h>
11 
12 #include <algorithm>
13 #include <limits>
14 #include <utility>
15 #include <vector>
16 
17 #include "dcsam/DCFactor.h"
18 #include "dcsam/DCSAM_utils.h"
19 
20 namespace dcsam {
21 
37 template <class DCFactorType>
38 class DCEMFactor : public DCFactor {
39  private:
40  std::vector<DCFactorType> factors_;
41  std::vector<double> log_weights_;
42  bool normalized_;
43 
44  public:
45  using Base = DCFactor;
46 
47  DCEMFactor() = default;
48 
49  explicit DCEMFactor(const gtsam::KeyVector& continuousKeys,
50  const gtsam::DiscreteKeys& discreteKeys,
51  const std::vector<DCFactorType> factors,
52  const std::vector<double> weights, const bool normalized)
53  : Base(continuousKeys, discreteKeys), normalized_(normalized) {
54  factors_ = factors;
55  for (size_t i = 0; i < weights.size(); i++) {
56  log_weights_.push_back(log(weights[i]));
57  }
58  }
59 
60  explicit DCEMFactor(const gtsam::KeyVector& continuousKeys,
61  const gtsam::DiscreteKeys& discreteKeys,
62  const std::vector<DCFactorType> factors,
63  const bool normalized)
64  : Base(continuousKeys, discreteKeys), normalized_(normalized) {
65  factors_ = factors;
66  for (size_t i = 0; i < factors_.size(); i++) {
67  log_weights_.push_back(0);
68  }
69  }
70 
72  this->factors_ = rhs.factors_;
73  this->log_weights_ = rhs.log_weights_;
74  this->normalized_ = rhs.normalized_;
75  }
76 
77  virtual ~DCEMFactor() = default;
78 
79  double error(const gtsam::Values& continuousVals,
80  const DiscreteValues& discreteVals) const override {
81  // Retrieve the log prob for each component.
82  std::vector<double> logprobs =
83  computeComponentLogProbs(continuousVals, discreteVals);
84 
85  // Weights for each component are obtained by normalizing the errors.
86  std::vector<double> componentWeights = expNormalize(logprobs);
87 
88  // Compute the total error as the weighted sum of component errors.
89  double total_error = 0.0;
90  for (size_t i = 0; i < logprobs.size(); i++) {
91  total_error += componentWeights[i] * (-logprobs[i]);
92  }
93  return total_error;
94  }
95 
96  std::vector<double> computeComponentLogProbs(
97  const gtsam::Values& continuousVals,
98  const DiscreteValues& discreteVals) const {
99  // Container for errors, where:
100  // error_i = error of component factor i - log_weights_i
101  std::vector<double> logprobs;
102  for (size_t i = 0; i < factors_.size(); i++) {
103  double error =
104  factors_[i].error(continuousVals, discreteVals) - log_weights_[i];
105  if (!normalized_)
106  error += factors_[i].logNormalizingConstant(continuousVals);
107  logprobs.push_back(-error);
108  }
109  return logprobs;
110  }
111 
112  size_t getActiveFactorIdx(const gtsam::Values& continuousVals,
113  const DiscreteValues& discreteVals) const {
114  double min_error = std::numeric_limits<double>::infinity();
115  size_t min_error_idx;
116  for (size_t i = 0; i < factors_.size(); i++) {
117  double error =
118  factors_[i].error(continuousVals, discreteVals) - log_weights_[i];
119  if (!normalized_)
120  error += factors_[i].logNormalizingConstant(continuousVals);
121 
122  if (error < min_error) {
123  min_error = error;
124  min_error_idx = i;
125  }
126  }
127  return min_error_idx;
128  }
129 
130  size_t dim() const override {
131  size_t total = 0;
132  // Each component factor `i` requires `factors_[i].dim()` rows in the
133  // overall Jacobian.
134  for (size_t i = 0; i < factors_.size(); i++) {
135  total += factors_[i].dim();
136  }
137  return total;
138  }
139 
140  bool equals(const DCFactor& other, double tol = 1e-9) const override {
141  if (!dynamic_cast<const DCEMFactor*>(&other)) return false;
142  const DCEMFactor& f(static_cast<const DCEMFactor&>(other));
143  if (factors_.size() != f.factors_.size()) return false;
144  for (size_t i = 0; i < factors_.size(); i++) {
145  if (!factors_[i].equals(f.factors_[i])) return false;
146  }
147  return ((log_weights_ == f.log_weights_) && (normalized_ == f.normalized_));
148  }
149 
150  /*
151  * Jacobian magic
152  */
153  boost::shared_ptr<gtsam::GaussianFactor> linearize(
154  const gtsam::Values& continuousVals,
155  const DiscreteValues& discreteVals) const override {
156  std::vector<boost::shared_ptr<gtsam::GaussianFactor>> gfs;
157 
158  // Start by computing all errors, so we can get the component weights.
159  std::vector<double> errors =
160  computeComponentLogProbs(continuousVals, discreteVals);
161 
162  // Weights for each component are obtained by normalizing the errors.
163  std::vector<double> componentWeights = expNormalize(errors);
164 
165  // We want to temporarily build a GaussianFactorGraph to construct the
166  // Jacobian for this whole factor.
167  gtsam::GaussianFactorGraph gfg;
168 
169  for (size_t i = 0; i < factors_.size(); i++) {
170  // std::cout << "i = " << i << std::endl;
171  // First get the GaussianFactor obtained by linearizing `factors_[i]`
172  boost::shared_ptr<gtsam::GaussianFactor> gf =
173  factors_[i].linearize(continuousVals, discreteVals);
174 
175  gtsam::JacobianFactor jf_component(*gf);
176 
177  // Recover the [A b] matrix with Jacobian A and right-hand side vector b,
178  // with noise models "baked in," as a vertical block matrix.
179  gtsam::VerticalBlockMatrix Ab = jf_component.matrixObject();
180 
181  // Copy Ab so we can reweight it appropriately.
182  gtsam::VerticalBlockMatrix Ab_weighted = Ab;
183 
184  // Populate Ab_weighted with weighted Jacobian sqrt(w)*A and right-hand
185  // side vector sqrt(w)*b.
186  double sqrt_weight = sqrt(componentWeights[i]);
187 
188  for (size_t k = 0; k < Ab_weighted.nBlocks(); k++) {
189  Ab_weighted(k) = sqrt_weight * Ab(k);
190  }
191 
192  // Create a `JacobianFactor` from the system [A b] and add it to the
193  // `GaussianFactorGraph`.
194  gtsam::JacobianFactor jf(factors_[i].keys(), Ab_weighted);
195  gfg.add(jf);
196  }
197 
198  // Stack Jacobians to build combined factor.
199 
200  return boost::make_shared<gtsam::JacobianFactor>(gfg);
201  }
202 
203  gtsam::DecisionTreeFactor toDecisionTreeFactor(
204  const gtsam::Values& continuousVals,
205  const DiscreteValues& discreteVals) const override {
206  // Start by computing all log probs, so we can get the component weights.
207  std::vector<double> logprobs =
208  computeComponentLogProbs(continuousVals, discreteVals);
209 
210  // Weights for each component are obtained by normalizing the errors.
211  std::vector<double> componentWeights = expNormalize(logprobs);
212 
213  std::vector<gtsam::DecisionTreeFactor> unary_factors;
214  for (size_t i = 0; i < factors_.size(); i++) {
215  gtsam::DiscreteKeys factor_dkeys = factors_[i].discreteKeys();
216  assert(factor_dkeys.size() == 1);
217  std::vector<double> factor_probs =
218  factors_[i].evalProbs(factor_dkeys[0], continuousVals);
219  std::vector<double> log_weighted_factor_probs;
220  for (size_t k = 0; k < factor_probs.size(); k++) {
221  log_weighted_factor_probs.push_back(componentWeights[i] *
222  log(factor_probs[k]));
223  }
224  std::vector<double> new_probs = expNormalize(log_weighted_factor_probs);
225  gtsam::DecisionTreeFactor unary(factor_dkeys[0], new_probs);
226  unary_factors.push_back(unary);
227  }
228  gtsam::DecisionTreeFactor converted;
229  for (size_t i = 0; i < unary_factors.size(); i++) {
230  converted = converted * unary_factors[i];
231  }
232  return converted;
233  }
234 
235  gtsam::FastVector<gtsam::Key> getAssociationKeys(
236  const gtsam::Values& continuousVals,
237  const DiscreteValues& discreteVals) const {
238  size_t min_error_idx = getActiveFactorIdx(continuousVals, discreteVals);
239  return factors_[min_error_idx].keys();
240  }
241 
242  void updateWeights(const std::vector<double>& weights) {
243  if (weights.size() != log_weights_.size()) {
244  std::cerr << "Attempted to update weights with incorrectly sized vector."
245  << std::endl;
246  return;
247  }
248  for (size_t i = 0; i < weights.size(); i++) {
249  log_weights_[i] = log(weights[i]);
250  }
251  }
252 };
253 } // namespace dcsam
Custom discrete-continuous factor.
Some utilities for DCSAM.
Implementation of a discrete-continuous EM factor.
Definition: DCEMFactor.h:38
size_t getActiveFactorIdx(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCEMFactor.h:112
DCEMFactor(const gtsam::KeyVector &continuousKeys, const gtsam::DiscreteKeys &discreteKeys, const std::vector< DCFactorType > factors, const std::vector< double > weights, const bool normalized)
Definition: DCEMFactor.h:49
gtsam::DecisionTreeFactor toDecisionTreeFactor(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCEMFactor.h:203
bool equals(const DCFactor &other, double tol=1e-9) const override
Definition: DCEMFactor.h:140
gtsam::FastVector< gtsam::Key > getAssociationKeys(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCEMFactor.h:235
size_t dim() const override
Definition: DCEMFactor.h:130
void updateWeights(const std::vector< double > &weights)
Definition: DCEMFactor.h:242
std::vector< double > computeComponentLogProbs(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCEMFactor.h:96
DCEMFactor(const gtsam::KeyVector &continuousKeys, const gtsam::DiscreteKeys &discreteKeys, const std::vector< DCFactorType > factors, const bool normalized)
Definition: DCEMFactor.h:60
virtual ~DCEMFactor()=default
DCEMFactor()=default
boost::shared_ptr< gtsam::GaussianFactor > linearize(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCEMFactor.h:153
DCEMFactor & operator=(const DCEMFactor &rhs)
Definition: DCEMFactor.h:71
double error(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCEMFactor.h:79
Abstract class implementing a discrete-continuous factor.
Definition: DCFactor.h:33
gtsam::DiscreteKeys discreteKeys() const
Definition: DCFactor.h:135
DCFactor()=default
Definition: DCContinuousFactor.h:24
std::vector< double > expNormalize(const std::vector< double > &logProbs)
Definition: DCSAM_utils.h:20
gtsam::DiscreteFactor::Values DiscreteValues
Definition: DCSAM_types.h:19
const double tol
Definition: testDCSAM.cpp:40