10 #include <gtsam/discrete/DiscreteFactor.h>
11 #include <gtsam/discrete/DiscreteKey.h>
12 #include <gtsam/nonlinear/NonlinearFactor.h>
13 #include <gtsam/nonlinear/Symbol.h>
39 using Base = gtsam::Factor;
50 DCFactor(
const gtsam::KeyVector& continuousKeys,
86 const gtsam::Values& continuousVals,
87 const gtsam::DiscreteFactor::Values& discreteVals)
const = 0;
102 virtual boost::shared_ptr<gtsam::GaussianFactor>
linearize(
103 const gtsam::Values& continuousVals,
130 virtual size_t dim()
const = 0;
157 const gtsam::Values& continuousVals,
159 gtsam::DecisionTreeFactor converted;
161 std::vector<double> probs =
evalProbs(dkey, continuousVals);
163 assert(probs.size() == dkey.second);
164 gtsam::DecisionTreeFactor unary(dkey, probs);
165 converted = converted * unary;
179 throw std::logic_error(
180 "Normalizing constant not implemented."
181 "One or more of the factors in use requires access to the normalization"
182 "constant for a child class of DCFactor, but`logNormalizingConstant` "
183 "has not been overridden.");
192 template <
typename NonlinearFactorType>
194 const NonlinearFactorType& factor,
const gtsam::Values& values)
const {
196 gtsam::Matrix infoMat;
199 boost::shared_ptr<NonlinearFactorType> fPtr =
200 boost::make_shared<NonlinearFactorType>(factor);
201 boost::shared_ptr<NonlinearFactorType> factorPtr(fPtr);
205 boost::shared_ptr<gtsam::NoiseModelFactor> noiseModelFactor =
206 boost::dynamic_pointer_cast<gtsam::NoiseModelFactor>(factorPtr);
207 if (noiseModelFactor) {
210 gtsam::noiseModel::Base::shared_ptr noiseModel =
211 noiseModelFactor->noiseModel();
213 boost::shared_ptr<gtsam::noiseModel::Gaussian> gaussianNoiseModel =
214 boost::dynamic_pointer_cast<gtsam::noiseModel::Gaussian>(noiseModel);
215 if (gaussianNoiseModel) {
217 infoMat = gaussianNoiseModel->information();
223 boost::shared_ptr<gtsam::GaussianFactor> gaussianFactor =
224 factor.linearize(values);
225 infoMat = gaussianFactor->information();
230 return (factor.dim() * log(2.0 * M_PI) / 2.0) -
231 (log(infoMat.determinant()) / 2.0);
254 std::vector<double>
evalProbs(
const gtsam::DiscreteKey& dk,
255 const gtsam::Values& continuousVals)
const {
256 std::vector<double> logProbs;
257 for (
size_t i = 0; i < dk.second; i++) {
259 testDiscreteVals[dk.first] = i;
262 double logProb = -
error(continuousVals, testDiscreteVals);
263 logProbs.push_back(logProb);
281 const gtsam::DecisionTreeFactor& f,
const gtsam::Values& continuousVals,
Some convenient types for DCSAM.
Some utilities for DCSAM.
Abstract class implementing a discrete-continuous factor.
Definition: DCFactor.h:33
virtual boost::shared_ptr< gtsam::GaussianFactor > linearize(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const =0
gtsam::DiscreteKeys discreteKeys_
Definition: DCFactor.h:36
std::vector< double > evalProbs(const gtsam::DiscreteKey &dk, const gtsam::Values &continuousVals) const
Definition: DCFactor.h:254
DCFactor & operator=(const DCFactor &rhs)
Definition: DCFactor.h:58
virtual gtsam::DecisionTreeFactor toDecisionTreeFactor(const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCFactor.h:156
gtsam::DiscreteKeys discreteKeys() const
Definition: DCFactor.h:135
virtual size_t dim() const =0
DCFactor(const gtsam::KeyVector &continuousKeys, const gtsam::DiscreteKeys &discreteKeys)
Definition: DCFactor.h:50
gtsam::DecisionTreeFactor conditionalTimes(const gtsam::DecisionTreeFactor &f, const gtsam::Values &continuousVals, const DiscreteValues &discreteVals) const
Definition: DCFactor.h:280
DCFactor(const gtsam::DiscreteKeys &discreteKeys)
Definition: DCFactor.h:55
virtual double logNormalizingConstant(const gtsam::Values &values) const
Definition: DCFactor.h:178
gtsam::Factor Base
Definition: DCFactor.h:39
double nonlinearFactorLogNormalizingConstant(const NonlinearFactorType &factor, const gtsam::Values &values) const
Definition: DCFactor.h:193
virtual bool equals(const DCFactor &other, double tol=1e-9) const =0
virtual double error(const gtsam::Values &continuousVals, const gtsam::DiscreteFactor::Values &discreteVals) const =0
virtual ~DCFactor()=default
Definition: DCContinuousFactor.h:24
std::vector< double > expNormalize(const std::vector< double > &logProbs)
Definition: DCSAM_utils.h:20
gtsam::DiscreteFactor::Values DiscreteValues
Definition: DCSAM_types.h:19
const double tol
Definition: testDCSAM.cpp:40