Skip to content

Commit bd62e90

Browse files
committed
Save the information needed for smoothing
1 parent 796d893 commit bd62e90

File tree

2 files changed

+120
-55
lines changed

2 files changed

+120
-55
lines changed

AtReconstruction/AtFitter/OpenKF/kalman_filter/TrackFitterUKF.cxx

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,31 @@
88
#include <AtPropagator.h>
99

1010
namespace kf {
11+
void TrackFitterUKF::Reset()
12+
{
13+
// Reset the state vector and covariance matrix
14+
m_vecX.setZero();
15+
m_matP.setZero();
16+
m_vecXa.setZero();
17+
m_matPa.setZero();
18+
m_matQ.setZero();
19+
m_matR.setZero();
20+
m_matSigmaXa.setZero();
21+
22+
// Clear the history vectors
23+
m_vecXPredHist.clear();
24+
m_matPPredHist.clear();
25+
m_matCPredHist.clear();
26+
m_vecXHist.clear();
27+
m_matPHist.clear();
28+
fMeanStep = AtTools::AtPropagator::StepState(); // Reset the step state
29+
}
1130

1231
void TrackFitterUKF::SetInitialState(const ROOT::Math::XYZPoint &initialPosition,
1332
const ROOT::Math::XYZVector &initialMomentum, const TMatrixD &initialCovariance)
1433
{
34+
// If we are setting the initial state, then we should clear the history.
35+
Reset();
1536
fPropagator.SetState(initialPosition, initialMomentum); // Set the initial state in the propagator
1637
m_vecX[0] = initialPosition.X(); // X position
1738
m_vecX[1] = initialPosition.Y(); // Y position
@@ -26,6 +47,19 @@ void TrackFitterUKF::SetInitialState(const ROOT::Math::XYZPoint &initialPosition
2647
m_matP(i, j) = initialCovariance(i, j);
2748
}
2849
}
50+
51+
// Save the initial state in our history vectors
52+
m_vecXHist.push_back(m_vecX);
53+
m_matPHist.push_back(m_matP);
54+
m_vecXPredHist.push_back(m_vecX);
55+
m_matPPredHist.push_back(m_matP);
56+
m_matCPredHist.push_back(Matrix<TF_DIM_X, TF_DIM_X>::Zero()); // Cross-correlation is not defined for the first point
57+
58+
// We need to calculate the sigma points for the initial state
59+
updateAugmentedStateAndCovariance(); // Update the augmented state vector and covariance matrix
60+
m_matSigmaXa = calculateSigmaPoints(m_vecXa, m_matPa); // Calculate the sigma points for the initial state
61+
// Now we grab the sigma points only for the state.
62+
m_matSigmaXPred = m_matSigmaXa.block(0, 0, TF_DIM_X, SIGMA_DIM_A); // Extract the state sigma points
2963
}
3064

3165
TMatrixD TrackFitterUKF::GetStateCovariance() const
@@ -156,4 +190,61 @@ Vector<TrackFitterUKF::TF_DIM_Z> TrackFitterUKF::funcH(const Vector<TrackFitterU
156190
return vecZ; // Return the measurement vector
157191
}
158192

193+
void TrackFitterUKF::predictUKF(const ROOT::Math::XYZPoint &z)
194+
{
195+
using namespace ROOT::Math;
196+
197+
// First we need to propagate the mean state vector to the next measurement point.
198+
XYZPoint startingPosition{m_vecX[0], m_vecX[1], m_vecX[2]}; // Get the starting position from the state vector
199+
Polar3DVector startingMomentum{m_vecX[3], m_vecX[4], m_vecX[5]}; // Get the starting momentum from the state vector
200+
201+
LOG(info) << "Propagating reference state from position: " << startingPosition
202+
<< " with momentum: " << XYZVector(startingMomentum);
203+
204+
fPropagator.SetState(startingPosition, XYZVector(startingMomentum));
205+
fPropagator.PropagateToMeasurementSurface(AtTools::AtMeasurementPoint(z), *fStepper);
206+
fMeanStep = fPropagator.GetState(); // Get the mean step information from the propagator
207+
fMeanStep.fLastPos = startingPosition; // Store the last position
208+
fMeanStep.fLastMom = startingMomentum; // Store the last momentum
209+
210+
LOG(info) << "Propagated to position: " << fMeanStep.fPos << " with momentum: " << fMeanStep.fMom;
211+
212+
// Now we can construct the reference plane.
213+
fMeasurementPlane = Plane3D(fMeanStep.fMom.Unit(),
214+
XYZPoint(z)); // Create a plane using the momentum direction and position
215+
Vector<TF_DIM_Z> zVec; // Initialize the measurement vector
216+
zVec[0] = z.X();
217+
zVec[1] = z.Y();
218+
zVec[2] = z.Z();
219+
auto callback = [this](const kf::Vector<TF_DIM_X> &x_, const kf::Vector<TF_DIM_V> &v_,
220+
const kf::Vector<TF_DIM_Z> &z_) { return funcF(x_, v_, z_); };
221+
TrackFitterUKFBase::predictUKF(callback, zVec);
222+
223+
// Now we need to store the predicted state and covariance for smoothing later.
224+
m_vecXPredHist.push_back(m_vecX); // Store the predicted state vector
225+
m_matPPredHist.push_back(m_matP); // Store the predicted covariance matrix
226+
227+
// Get the sigma points belonging to the predicted state
228+
Matrix<TF_DIM_X, SIGMA_DIM_A> sigmaXx{m_matSigmaXa.block(0, 0, TF_DIM_X, SIGMA_DIM_A)};
229+
230+
// Calculate the cross-corelation between the filtered state at k and predicted state at k+1
231+
auto matCPred =
232+
calculateCrossCorrelation<TF_DIM_X>(m_matSigmaXPred, m_vecXHist.back(), sigmaXx, m_vecXPredHist.back());
233+
m_matCPredHist.push_back(matCPred); // Store the cross-correlation matrix
234+
}
235+
236+
void TrackFitterUKF::correctUKF(const ROOT::Math::XYZPoint &z)
237+
{
238+
Vector<TF_DIM_Z> zVec; // Initialize the measurement vector
239+
zVec[0] = z.X();
240+
zVec[1] = z.Y();
241+
zVec[2] = z.Z();
242+
auto callback = [this](const kf::Vector<TF_DIM_X> &x_) { return funcH(x_); };
243+
TrackFitterUKFBase::correctUKF(callback, zVec);
244+
245+
// After correction we need to save the filtered state
246+
m_vecXHist.push_back(m_vecX); // Store the filtered state vector
247+
m_matPHist.push_back(m_matP); // Store the filtered covariance matrix
248+
}
249+
159250
} // namespace kf

AtReconstruction/AtFitter/OpenKF/kalman_filter/TrackFitterUKF.h

Lines changed: 29 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -358,19 +358,20 @@ class TrackFitterUKFBase {
358358
* @param vecY mean of the second set of sigma points
359359
* @return matPxy, the cross-correlation matrix
360360
*/
361-
template <int32_t SIGMA_DIM>
362-
Matrix<DIM_X, DIM_Z> calculateCrossCorrelation(const Matrix<DIM_X, SIGMA_DIM> &sigmaX, const Vector<DIM_X> &vecX,
363-
const Matrix<DIM_Z, SIGMA_DIM> &sigmaY, const Vector<DIM_Z> &vecY)
361+
template <int32_t STATE_DIM, int32_t MEAS_DIM, int32_t SIGMA_DIM>
362+
Matrix<STATE_DIM, MEAS_DIM>
363+
calculateCrossCorrelation(const Matrix<STATE_DIM, SIGMA_DIM> &sigmaX, const Vector<STATE_DIM> &vecX,
364+
const Matrix<MEAS_DIM, SIGMA_DIM> &sigmaY, const Vector<MEAS_DIM> &vecY)
364365
{
365-
Vector<DIM_X> devXi{util::getColumnAt<DIM_X, SIGMA_DIM>(0, sigmaX) - vecX}; // X[:, 0] - \bar{ x }
366-
Vector<DIM_Z> devYi{util::getColumnAt<DIM_Z, SIGMA_DIM>(0, sigmaY) - vecY}; // Y[:, 0] - \bar{ y }
366+
Vector<STATE_DIM> devXi{util::getColumnAt<STATE_DIM, SIGMA_DIM>(0, sigmaX) - vecX}; // X[:, 0] - \bar{ x }
367+
Vector<MEAS_DIM> devYi{util::getColumnAt<MEAS_DIM, SIGMA_DIM>(0, sigmaY) - vecY}; // Y[:, 0] - \bar{ y }
367368

368369
// P_0 = W[0, 0] (X[:, 0] - \bar{x}) (Y[:, 0] - \bar{y})^T
369-
Matrix<DIM_X, DIM_Z> matPxy{m_weight0 * (devXi * devYi.transpose())};
370+
Matrix<STATE_DIM, MEAS_DIM> matPxy{m_weight0 * (devXi * devYi.transpose())};
370371

371372
for (int32_t i{1}; i < SIGMA_DIM; ++i) {
372-
devXi = util::getColumnAt<DIM_X, SIGMA_DIM>(i, sigmaX) - vecX; // X[:, i] - \bar{x}
373-
devYi = util::getColumnAt<DIM_Z, SIGMA_DIM>(i, sigmaY) - vecY; // Y[:, i] - \bar{y}
373+
devXi = util::getColumnAt<STATE_DIM, SIGMA_DIM>(i, sigmaX) - vecX; // X[:, i] - \bar{x}
374+
devYi = util::getColumnAt<MEAS_DIM, SIGMA_DIM>(i, sigmaY) - vecY; // Y[:, i] - \bar{y}
374375

375376
matPxy += m_weighti * (devXi * devYi.transpose()); // y += W[0, i] (Y[:, i] -
376377
// \bar{y}) (Y[:, i] - \bar{y})^T
@@ -397,6 +398,19 @@ class TrackFitterUKF : public TrackFitterUKFBase<6, 3, 1, 3> {
397398
AtTools::AtPropagator::StepState fMeanStep; /// Holds the step information for POCA propagation of mean state
398399
ROOT::Math::Plane3D fMeasurementPlane; ///< Holds the measurement plane for the track fitter
399400

401+
// vectors to hold the information needed for smoothing the UKF
402+
std::vector<Vector<TF_DIM_X>> m_vecXPredHist; /// @brief History of predicted state vectors at k+1
403+
std::vector<Matrix<TF_DIM_X, TF_DIM_X>> m_matPPredHist; /// @brief History of predicted state covariances at k+1
404+
/// History of cross correlation between filtered state at k and predicted at k+1
405+
std::vector<Matrix<TF_DIM_X, TF_DIM_X>> m_matCPredHist;
406+
/// History of filtered (after correction) state vectors at k
407+
std::vector<Vector<TF_DIM_X>> m_vecXHist;
408+
/// History of filtered (after correction) state covariances at k
409+
std::vector<Matrix<TF_DIM_X, TF_DIM_X>> m_matPHist;
410+
411+
/// The sigma points after propagation for the last prediction step.
412+
Matrix<TF_DIM_X, SIGMA_DIM_A> m_matSigmaXPred{Matrix<TF_DIM_X, SIGMA_DIM_A>::Zero()};
413+
400414
public:
401415
bool fEnableEnStraggling{true}; ///< @brief Flag to enable/disable energy straggling
402416
double fMaxStragglingFactor{1. / 3.}; ///< @brief Maximum straggling factor for energy loss
@@ -415,7 +429,7 @@ class TrackFitterUKF : public TrackFitterUKFBase<6, 3, 1, 3> {
415429
: TrackFitterUKFBase(), fPropagator(std::move(propagator)), fStepper(std::move(stepper))
416430
{
417431
}
418-
432+
void Reset();
419433
void SetInitialState(const ROOT::Math::XYZPoint &initialPosition, const ROOT::Math::XYZVector &initialMomentum,
420434
const TMatrixD &initialCovariance);
421435

@@ -428,56 +442,16 @@ class TrackFitterUKF : public TrackFitterUKFBase<6, 3, 1, 3> {
428442
TMatrixD GetAugStateCovariance() const;
429443
std::array<double, DIM_A> GetAugStateVector() const;
430444

431-
kf::Vector<TF_DIM_X>
432-
funcF(const kf::Vector<TF_DIM_X> &x, const kf::Vector<TF_DIM_V> &v, const kf::Vector<TF_DIM_Z> &z);
433-
kf::Vector<TF_DIM_Z> funcH(const kf::Vector<TF_DIM_X> &x);
434-
435-
void predictUKF(const ROOT::Math::XYZPoint &z)
436-
{
437-
using namespace ROOT::Math;
438-
439-
// First we need to propagate the mean state vector to the next measurement point.
440-
XYZPoint startingPosition{m_vecX[0], m_vecX[1], m_vecX[2]}; // Get the starting position from the state vector
441-
Polar3DVector startingMomentum{m_vecX[3], m_vecX[4],
442-
m_vecX[5]}; // Get the starting momentum from the state vector
443-
444-
LOG(info) << "Propagating reference state from position: " << startingPosition
445-
<< " with momentum: " << XYZVector(startingMomentum);
446-
447-
fPropagator.SetState(startingPosition, XYZVector(startingMomentum));
448-
fPropagator.PropagateToMeasurementSurface(AtTools::AtMeasurementPoint(z), *fStepper);
449-
fMeanStep = fPropagator.GetState(); // Get the mean step information from the propagator
450-
fMeanStep.fLastPos = startingPosition; // Store the last position
451-
fMeanStep.fLastMom = startingMomentum; // Store the last momentum
452-
453-
LOG(info) << "Propagated to position: " << fMeanStep.fPos << " with momentum: " << fMeanStep.fMom;
454-
455-
// Now we can construct the reference plane.
456-
fMeasurementPlane = Plane3D(fMeanStep.fMom.Unit(),
457-
XYZPoint(z)); // Create a plane using the momentum direction and position
458-
Vector<TF_DIM_Z> zVec; // Initialize the measurement vector
459-
zVec[0] = z.X();
460-
zVec[1] = z.Y();
461-
zVec[2] = z.Z();
462-
auto callback = [this](const kf::Vector<TF_DIM_X> &x_, const kf::Vector<TF_DIM_V> &v_,
463-
const kf::Vector<TF_DIM_Z> &z_) { return funcF(x_, v_, z_); };
464-
TrackFitterUKFBase::predictUKF(callback, zVec);
465-
}
466-
467-
void correctUKF(const ROOT::Math::XYZPoint &z)
468-
{
469-
Vector<TF_DIM_Z> zVec; // Initialize the measurement vector
470-
zVec[0] = z.X();
471-
zVec[1] = z.Y();
472-
zVec[2] = z.Z();
473-
auto callback = [this](const kf::Vector<TF_DIM_X> &x_) { return funcH(x_); };
474-
TrackFitterUKFBase::correctUKF(callback, zVec);
475-
}
445+
void predictUKF(const ROOT::Math::XYZPoint &z);
446+
void correctUKF(const ROOT::Math::XYZPoint &z);
476447

477448
protected:
478449
std::array<float32_t, TF_DIM_V> calculateProcessNoiseMean() override;
479-
480450
Matrix<TF_DIM_V, TF_DIM_V> calculateProcessNoiseCovariance() override;
451+
452+
kf::Vector<TF_DIM_X>
453+
funcF(const kf::Vector<TF_DIM_X> &x, const kf::Vector<TF_DIM_V> &v, const kf::Vector<TF_DIM_Z> &z);
454+
kf::Vector<TF_DIM_Z> funcH(const kf::Vector<TF_DIM_X> &x);
481455
};
482456
} // namespace kf
483457

0 commit comments

Comments
 (0)