37 template <
class DCFactorType>
40 std::vector<DCFactorType> factors_;
41 std::vector<double> log_weights_;
49 explicit DCEMFactor(
const gtsam::KeyVector& continuousKeys,
51 const std::vector<DCFactorType> factors,
52 const std::vector<double> weights,
const bool normalized)
55 for (
size_t i = 0; i < weights.size(); i++) {
56 log_weights_.push_back(log(weights[i]));
60 explicit DCEMFactor(
const gtsam::KeyVector& continuousKeys,
62 const std::vector<DCFactorType> factors,
63 const bool normalized)
66 for (
size_t i = 0; i < factors_.size(); i++) {
67 log_weights_.push_back(0);
72 this->factors_ = rhs.factors_;
73 this->log_weights_ = rhs.log_weights_;
74 this->normalized_ = rhs.normalized_;
79 double error(
const gtsam::Values& continuousVals,
82 std::vector<double> logprobs =
86 std::vector<double> componentWeights =
expNormalize(logprobs);
89 double total_error = 0.0;
90 for (
size_t i = 0; i < logprobs.size(); i++) {
91 total_error += componentWeights[i] * (-logprobs[i]);
97 const gtsam::Values& continuousVals,
101 std::vector<double> logprobs;
102 for (
size_t i = 0; i < factors_.size(); i++) {
104 factors_[i].error(continuousVals, discreteVals) - log_weights_[i];
106 error += factors_[i].logNormalizingConstant(continuousVals);
107 logprobs.push_back(-
error);
114 double min_error = std::numeric_limits<double>::infinity();
115 size_t min_error_idx;
116 for (
size_t i = 0; i < factors_.size(); i++) {
118 factors_[i].error(continuousVals, discreteVals) - log_weights_[i];
120 error += factors_[i].logNormalizingConstant(continuousVals);
122 if (
error < min_error) {
127 return min_error_idx;
130 size_t dim()
const override {
134 for (
size_t i = 0; i < factors_.size(); i++) {
135 total += factors_[i].dim();
141 if (!
dynamic_cast<const DCEMFactor*
>(&other))
return false;
143 if (factors_.size() != f.factors_.size())
return false;
144 for (
size_t i = 0; i < factors_.size(); i++) {
145 if (!factors_[i].
equals(f.factors_[i]))
return false;
147 return ((log_weights_ == f.log_weights_) && (normalized_ == f.normalized_));
154 const gtsam::Values& continuousVals,
156 std::vector<boost::shared_ptr<gtsam::GaussianFactor>> gfs;
159 std::vector<double> errors =
163 std::vector<double> componentWeights =
expNormalize(errors);
167 gtsam::GaussianFactorGraph gfg;
169 for (
size_t i = 0; i < factors_.size(); i++) {
172 boost::shared_ptr<gtsam::GaussianFactor> gf =
173 factors_[i].linearize(continuousVals, discreteVals);
175 gtsam::JacobianFactor jf_component(*gf);
179 gtsam::VerticalBlockMatrix Ab = jf_component.matrixObject();
182 gtsam::VerticalBlockMatrix Ab_weighted = Ab;
186 double sqrt_weight = sqrt(componentWeights[i]);
188 for (
size_t k = 0; k < Ab_weighted.nBlocks(); k++) {
189 Ab_weighted(k) = sqrt_weight * Ab(k);
194 gtsam::JacobianFactor jf(factors_[i].keys(), Ab_weighted);
200 return boost::make_shared<gtsam::JacobianFactor>(gfg);
204 const gtsam::Values& continuousVals,
207 std::vector<double> logprobs =
211 std::vector<double> componentWeights =
expNormalize(logprobs);
213 std::vector<gtsam::DecisionTreeFactor> unary_factors;
214 for (
size_t i = 0; i < factors_.size(); i++) {
215 gtsam::DiscreteKeys factor_dkeys = factors_[i].discreteKeys();
216 assert(factor_dkeys.size() == 1);
217 std::vector<double> factor_probs =
218 factors_[i].evalProbs(factor_dkeys[0], continuousVals);
219 std::vector<double> log_weighted_factor_probs;
220 for (
size_t k = 0; k < factor_probs.size(); k++) {
221 log_weighted_factor_probs.push_back(componentWeights[i] *
222 log(factor_probs[k]));
224 std::vector<double> new_probs =
expNormalize(log_weighted_factor_probs);
225 gtsam::DecisionTreeFactor unary(factor_dkeys[0], new_probs);
226 unary_factors.push_back(unary);
228 gtsam::DecisionTreeFactor converted;
229 for (
size_t i = 0; i < unary_factors.size(); i++) {
230 converted = converted * unary_factors[i];
236 const gtsam::Values& continuousVals,
239 return factors_[min_error_idx].keys();
243 if (weights.size() != log_weights_.size()) {
244 std::cerr <<
"Attempted to update weights with incorrectly sized vector."
248 for (
size_t i = 0; i < weights.size(); i++) {
249 log_weights_[i] = log(weights[i]);
Custom discrete-continuous factor.
Some utilities for DCSAM.
Implementation of a discrete-continuous EM factor.
Definition: DCEMFactor.h:38
size_t getActiveFactorIdx(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCEMFactor.h:112
DCEMFactor(const gtsam::KeyVector &continuousKeys, const gtsam::DiscreteKeys &discreteKeys, const std::vector< DCFactorType > factors, const std::vector< double > weights, const bool normalized)
Definition: DCEMFactor.h:49
gtsam::DecisionTreeFactor toDecisionTreeFactor(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCEMFactor.h:203
bool equals(const DCFactor &other, double tol=1e-9) const override
Definition: DCEMFactor.h:140
gtsam::FastVector< gtsam::Key > getAssociationKeys(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCEMFactor.h:235
size_t dim() const override
Definition: DCEMFactor.h:130
void updateWeights(const std::vector< double > &weights)
Definition: DCEMFactor.h:242
std::vector< double > computeComponentLogProbs(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCEMFactor.h:96
DCEMFactor(const gtsam::KeyVector &continuousKeys, const gtsam::DiscreteKeys &discreteKeys, const std::vector< DCFactorType > factors, const bool normalized)
Definition: DCEMFactor.h:60
virtual ~DCEMFactor()=default
boost::shared_ptr< gtsam::GaussianFactor > linearize(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCEMFactor.h:153
DCEMFactor & operator=(const DCEMFactor &rhs)
Definition: DCEMFactor.h:71
double error(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCEMFactor.h:79
Abstract class implementing a discrete-continuous factor.
Definition: DCFactor.h:33
gtsam::DiscreteKeys discreteKeys() const
Definition: DCFactor.h:135
Definition: DCContinuousFactor.h:24
std::vector< double > expNormalize(const std::vector< double > &logProbs)
Definition: DCSAM_utils.h:20
gtsam::DiscreteFactor::Values DiscreteValues
Definition: DCSAM_types.h:19
const double tol
Definition: testDCSAM.cpp:40