31 template <
class DCFactorType>
34 std::vector<DCFactorType> factors_;
35 std::vector<double> log_weights_;
45 const std::vector<DCFactorType> factors,
46 const std::vector<double> weights,
47 const bool normalized)
50 for (
size_t i = 0; i < weights.size(); i++) {
51 log_weights_.push_back(log(weights[i]));
57 const std::vector<DCFactorType> factors,
58 const bool normalized)
61 for (
size_t i = 0; i < factors_.size(); i++) {
62 log_weights_.push_back(0);
67 this->factors_ = rhs.factors_;
68 this->log_weights_ = rhs.log_weights_;
69 this->normalized_ = rhs.normalized_;
74 double error(
const gtsam::Values& continuousVals,
77 assert(0 <= min_error_idx);
78 assert(min_error_idx < factors_.size());
80 factors_[min_error_idx].error(continuousVals, discreteVals);
81 if (normalized_)
return min_error - log_weights_[min_error_idx];
83 factors_[min_error_idx].logNormalizingConstant(continuousVals) -
84 log_weights_[min_error_idx];
89 double min_error = std::numeric_limits<double>::infinity();
90 size_t min_error_idx = 0;
91 for (
size_t i = 0; i < factors_.size(); i++) {
93 factors_[i].error(continuousVals, discreteVals) - log_weights_[i];
95 error += factors_[i].logNormalizingConstant(continuousVals);
97 if (
error < min_error) {
102 return min_error_idx;
105 size_t dim()
const override {
106 if (factors_.size() > 0) {
107 return factors_[0].dim();
116 if (factors_.size() != f.factors_.size())
return false;
117 for (
size_t i = 0; i < factors_.size(); i++) {
118 if (!factors_[i].
equals(f.factors_[i]))
return false;
120 return ((log_weights_ == f.log_weights_) && (normalized_ == f.normalized_));
124 const gtsam::Values& continuousVals,
127 return factors_[min_error_idx].linearize(continuousVals, discreteVals);
131 const gtsam::DiscreteKey& dk)
const {
132 std::vector<double> probs(dk.second, (1.0 / dk.second));
133 gtsam::DecisionTreeFactor uniform(dk, probs);
138 const gtsam::Values& continuousVals,
141 gtsam::DecisionTreeFactor converted;
142 for (
size_t i = 0; i < factors_.size(); i++) {
143 if (i == min_error_idx) {
144 converted = converted * factors_[min_error_idx].toDecisionTreeFactor(
145 continuousVals, discreteVals);
147 for (
const gtsam::DiscreteKey& dk : factors_[i].
discreteKeys()) {
156 const gtsam::Values& continuousVals,
159 return factors_[min_error_idx].keys();
163 if (weights.size() != log_weights_.size()) {
164 std::cerr <<
"Attempted to update weights with incorrectly sized vector."
168 for (
int i = 0; i < weights.size(); i++) {
169 log_weights_[i] = log(weights[i]);
Custom discrete-continuous factor.
Abstract class implementing a discrete-continuous factor.
Definition: DCFactor.h:33
gtsam::DiscreteKeys discreteKeys() const
Definition: DCFactor.h:135
gtsam::Factor Base
Definition: DCFactor.h:39
Implementation of a discrete-continuous max-mixture factor.
Definition: DCMaxMixtureFactor.h:32
void updateWeights(const std::vector< double > &weights)
Definition: DCMaxMixtureFactor.h:162
DCMaxMixtureFactor(const gtsam::KeyVector &continuousKeys, const gtsam::DiscreteKeys &discreteKeys, const std::vector< DCFactorType > factors, const bool normalized)
Definition: DCMaxMixtureFactor.h:55
gtsam::DecisionTreeFactor toDecisionTreeFactor(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCMaxMixtureFactor.h:137
boost::shared_ptr< gtsam::GaussianFactor > linearize(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCMaxMixtureFactor.h:123
size_t dim() const override
Definition: DCMaxMixtureFactor.h:105
DCMaxMixtureFactor()=default
gtsam::FastVector< gtsam::Key > getAssociationKeys(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCMaxMixtureFactor.h:155
double error(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const override
Definition: DCMaxMixtureFactor.h:74
DCMaxMixtureFactor & operator=(const DCMaxMixtureFactor &rhs)
Definition: DCMaxMixtureFactor.h:66
bool equals(const DCFactor &other, double tol=1e-9) const override
Definition: DCMaxMixtureFactor.h:113
virtual ~DCMaxMixtureFactor()=default
gtsam::DecisionTreeFactor uniformDecisionTreeFactor(const gtsam::DiscreteKey &dk) const
Definition: DCMaxMixtureFactor.h:130
DCMaxMixtureFactor(const gtsam::KeyVector &continuousKeys, const gtsam::DiscreteKeys &discreteKeys, const std::vector< DCFactorType > factors, const std::vector< double > weights, const bool normalized)
Definition: DCMaxMixtureFactor.h:43
size_t getActiveFactorIdx(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCMaxMixtureFactor.h:87
Definition: DCContinuousFactor.h:24
gtsam::DiscreteFactor::Values DiscreteValues
Definition: DCSAM_types.h:19
const double tol
Definition: testDCSAM.cpp:40