@@ -358,19 +358,20 @@ class TrackFitterUKFBase {
358358 * @param vecY mean of the second set of sigma points
359359 * @return matPxy, the cross-correlation matrix
360360 */
361- template <int32_t SIGMA_DIM>
362- Matrix<DIM_X, DIM_Z> calculateCrossCorrelation (const Matrix<DIM_X, SIGMA_DIM> &sigmaX, const Vector<DIM_X> &vecX,
363- const Matrix<DIM_Z, SIGMA_DIM> &sigmaY, const Vector<DIM_Z> &vecY)
361+ template <int32_t STATE_DIM, int32_t MEAS_DIM, int32_t SIGMA_DIM>
362+ Matrix<STATE_DIM, MEAS_DIM>
363+ calculateCrossCorrelation (const Matrix<STATE_DIM, SIGMA_DIM> &sigmaX, const Vector<STATE_DIM> &vecX,
364+ const Matrix<MEAS_DIM, SIGMA_DIM> &sigmaY, const Vector<MEAS_DIM> &vecY)
364365 {
365- Vector<DIM_X > devXi{util::getColumnAt<DIM_X , SIGMA_DIM>(0 , sigmaX) - vecX}; // X[:, 0] - \bar{ x }
366- Vector<DIM_Z > devYi{util::getColumnAt<DIM_Z , SIGMA_DIM>(0 , sigmaY) - vecY}; // Y[:, 0] - \bar{ y }
366+ Vector<STATE_DIM > devXi{util::getColumnAt<STATE_DIM , SIGMA_DIM>(0 , sigmaX) - vecX}; // X[:, 0] - \bar{ x }
367+ Vector<MEAS_DIM > devYi{util::getColumnAt<MEAS_DIM , SIGMA_DIM>(0 , sigmaY) - vecY}; // Y[:, 0] - \bar{ y }
367368
368369 // P_0 = W[0, 0] (X[:, 0] - \bar{x}) (Y[:, 0] - \bar{y})^T
369- Matrix<DIM_X, DIM_Z > matPxy{m_weight0 * (devXi * devYi.transpose ())};
370+ Matrix<STATE_DIM, MEAS_DIM > matPxy{m_weight0 * (devXi * devYi.transpose ())};
370371
371372 for (int32_t i{1 }; i < SIGMA_DIM; ++i) {
372- devXi = util::getColumnAt<DIM_X , SIGMA_DIM>(i, sigmaX) - vecX; // X[:, i] - \bar{x}
373- devYi = util::getColumnAt<DIM_Z , SIGMA_DIM>(i, sigmaY) - vecY; // Y[:, i] - \bar{y}
373+ devXi = util::getColumnAt<STATE_DIM , SIGMA_DIM>(i, sigmaX) - vecX; // X[:, i] - \bar{x}
374+ devYi = util::getColumnAt<MEAS_DIM , SIGMA_DIM>(i, sigmaY) - vecY; // Y[:, i] - \bar{y}
374375
375376 matPxy += m_weighti * (devXi * devYi.transpose ()); // y += W[0, i] (Y[:, i] -
376377 // \bar{y}) (Y[:, i] - \bar{y})^T
@@ -397,6 +398,19 @@ class TrackFitterUKF : public TrackFitterUKFBase<6, 3, 1, 3> {
397398 AtTools::AtPropagator::StepState fMeanStep ; // / Holds the step information for POCA propagation of mean state
398399 ROOT::Math::Plane3D fMeasurementPlane ; // /< Holds the measurement plane for the track fitter
399400
401+ // vectors to hold the information needed for smoothing the UKF
402+ std::vector<Vector<TF_DIM_X>> m_vecXPredHist; // / @brief History of predicted state vectors at k+1
403+ std::vector<Matrix<TF_DIM_X, TF_DIM_X>> m_matPPredHist; // / @brief History of predicted state covariances at k+1
404+ // / History of cross correlation between filtered state at k and predicted at k+1
405+ std::vector<Matrix<TF_DIM_X, TF_DIM_X>> m_matCPredHist;
406+ // / History of filtered (after correction) state vectors at k
407+ std::vector<Vector<TF_DIM_X>> m_vecXHist;
408+ // / History of filtered (after correction) state covariances at k
409+ std::vector<Matrix<TF_DIM_X, TF_DIM_X>> m_matPHist;
410+
411+ // / The sigma points after propagation for the last prediction step.
412+ Matrix<TF_DIM_X, SIGMA_DIM_A> m_matSigmaXPred{Matrix<TF_DIM_X, SIGMA_DIM_A>::Zero ()};
413+
400414public:
401415 bool fEnableEnStraggling {true }; // /< @brief Flag to enable/disable energy straggling
402416 double fMaxStragglingFactor {1 . / 3 .}; // /< @brief Maximum straggling factor for energy loss
@@ -415,7 +429,7 @@ class TrackFitterUKF : public TrackFitterUKFBase<6, 3, 1, 3> {
415429 : TrackFitterUKFBase(), fPropagator (std::move(propagator)), fStepper (std::move(stepper))
416430 {
417431 }
418-
432+ void Reset ();
419433 void SetInitialState (const ROOT::Math::XYZPoint &initialPosition, const ROOT::Math::XYZVector &initialMomentum,
420434 const TMatrixD &initialCovariance);
421435
@@ -428,56 +442,16 @@ class TrackFitterUKF : public TrackFitterUKFBase<6, 3, 1, 3> {
428442 TMatrixD GetAugStateCovariance () const ;
429443 std::array<double , DIM_A> GetAugStateVector () const ;
430444
431- kf::Vector<TF_DIM_X>
432- funcF (const kf::Vector<TF_DIM_X> &x, const kf::Vector<TF_DIM_V> &v, const kf::Vector<TF_DIM_Z> &z);
433- kf::Vector<TF_DIM_Z> funcH (const kf::Vector<TF_DIM_X> &x);
434-
435- void predictUKF (const ROOT::Math::XYZPoint &z)
436- {
437- using namespace ROOT ::Math;
438-
439- // First we need to propagate the mean state vector to the next measurement point.
440- XYZPoint startingPosition{m_vecX[0 ], m_vecX[1 ], m_vecX[2 ]}; // Get the starting position from the state vector
441- Polar3DVector startingMomentum{m_vecX[3 ], m_vecX[4 ],
442- m_vecX[5 ]}; // Get the starting momentum from the state vector
443-
444- LOG (info) << " Propagating reference state from position: " << startingPosition
445- << " with momentum: " << XYZVector (startingMomentum);
446-
447- fPropagator .SetState (startingPosition, XYZVector (startingMomentum));
448- fPropagator .PropagateToMeasurementSurface (AtTools::AtMeasurementPoint (z), *fStepper );
449- fMeanStep = fPropagator .GetState (); // Get the mean step information from the propagator
450- fMeanStep .fLastPos = startingPosition; // Store the last position
451- fMeanStep .fLastMom = startingMomentum; // Store the last momentum
452-
453- LOG (info) << " Propagated to position: " << fMeanStep .fPos << " with momentum: " << fMeanStep .fMom ;
454-
455- // Now we can construct the reference plane.
456- fMeasurementPlane = Plane3D (fMeanStep .fMom .Unit (),
457- XYZPoint (z)); // Create a plane using the momentum direction and position
458- Vector<TF_DIM_Z> zVec; // Initialize the measurement vector
459- zVec[0 ] = z.X ();
460- zVec[1 ] = z.Y ();
461- zVec[2 ] = z.Z ();
462- auto callback = [this ](const kf::Vector<TF_DIM_X> &x_, const kf::Vector<TF_DIM_V> &v_,
463- const kf::Vector<TF_DIM_Z> &z_) { return funcF (x_, v_, z_); };
464- TrackFitterUKFBase::predictUKF (callback, zVec);
465- }
466-
467- void correctUKF (const ROOT::Math::XYZPoint &z)
468- {
469- Vector<TF_DIM_Z> zVec; // Initialize the measurement vector
470- zVec[0 ] = z.X ();
471- zVec[1 ] = z.Y ();
472- zVec[2 ] = z.Z ();
473- auto callback = [this ](const kf::Vector<TF_DIM_X> &x_) { return funcH (x_); };
474- TrackFitterUKFBase::correctUKF (callback, zVec);
475- }
445+ void predictUKF (const ROOT::Math::XYZPoint &z);
446+ void correctUKF (const ROOT::Math::XYZPoint &z);
476447
477448protected:
478449 std::array<float32_t , TF_DIM_V> calculateProcessNoiseMean () override ;
479-
480450 Matrix<TF_DIM_V, TF_DIM_V> calculateProcessNoiseCovariance () override ;
451+
452+ kf::Vector<TF_DIM_X>
453+ funcF (const kf::Vector<TF_DIM_X> &x, const kf::Vector<TF_DIM_V> &v, const kf::Vector<TF_DIM_Z> &z);
454+ kf::Vector<TF_DIM_Z> funcH (const kf::Vector<TF_DIM_X> &x);
481455};
482456} // namespace kf
483457
0 commit comments