AbstractStateTransitionMatrixGenerator.java

  1. /* Copyright 2002-2025 CS GROUP
  2.  * Licensed to CS GROUP (CS) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * CS licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *   http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.orekit.propagation.numerical;

  18. import org.hipparchus.analysis.differentiation.Gradient;
  19. import org.hipparchus.exception.LocalizedCoreFormats;
  20. import org.hipparchus.linear.DecompositionSolver;
  21. import org.hipparchus.linear.MatrixUtils;
  22. import org.hipparchus.linear.QRDecomposition;
  23. import org.hipparchus.linear.RealMatrix;
  24. import org.hipparchus.util.Precision;
  25. import org.orekit.attitudes.AttitudeProvider;
  26. import org.orekit.attitudes.AttitudeProviderModifier;
  27. import org.orekit.errors.OrekitException;
  28. import org.orekit.forces.ForceModel;
  29. import org.orekit.orbits.Orbit;
  30. import org.orekit.orbits.OrbitType;
  31. import org.orekit.orbits.PositionAngleType;
  32. import org.orekit.propagation.FieldSpacecraftState;
  33. import org.orekit.propagation.SpacecraftState;
  34. import org.orekit.propagation.integration.AdditionalDerivativesProvider;
  35. import org.orekit.propagation.integration.CombinedDerivatives;
  36. import org.orekit.utils.DoubleArrayDictionary;
  37. import org.orekit.utils.ParameterDriver;
  38. import org.orekit.utils.TimeSpanMap;

  39. import java.util.HashMap;
  40. import java.util.List;
  41. import java.util.Map;

  42. /** Abstract generator for numerical State Transition Matrix.
  43.  * @author Luc Maisonobe
  44.  * @author Melina Vanel
  45.  * @author Romain Serra
  46.  * @since 13.1
  47.  */
  48. abstract class AbstractStateTransitionMatrixGenerator implements AdditionalDerivativesProvider {

  49.     /** Space dimension. */
  50.     protected static final int SPACE_DIMENSION = 3;

  51.     /** Threshold for matrix solving. */
  52.     private static final double THRESHOLD = Precision.SAFE_MIN;

  53.     /** Name of the Cartesian STM additional state. */
  54.     private final String stmName;

  55.     /** Force models used in propagation. */
  56.     private final List<ForceModel> forceModels;

  57.     /** Attitude provider used in propagation. */
  58.     private final AttitudeProvider attitudeProvider;

  59.     /** Observers for partial derivatives. */
  60.     private final Map<String, PartialsObserver> partialsObservers;

  61.     /** Number of state variables. */
  62.     private final int stateDimension;

  63.     /** Dimension of flatten STM. */
  64.     private final int dimension;

  65.     /** Simple constructor.
  66.      * @param stmName name of the Cartesian STM additional state
  67.      * @param forceModels force models used in propagation
  68.      * @param attitudeProvider attitude provider used in propagation
  69.      * @param stateDimension dimension of state vector
  70.      */
  71.     AbstractStateTransitionMatrixGenerator(final String stmName, final List<ForceModel> forceModels,
  72.                                            final AttitudeProvider attitudeProvider, final int stateDimension) {
  73.         this.stmName           = stmName;
  74.         this.forceModels       = forceModels;
  75.         this.attitudeProvider  = attitudeProvider;
  76.         this.stateDimension    = stateDimension;
  77.         this.dimension         = stateDimension * stateDimension;
  78.         this.partialsObservers = new HashMap<>();
  79.     }

  80.     /** Register an observer for partial derivatives.
  81.      * <p>
  82.      * The observer {@link PartialsObserver#partialsComputed(SpacecraftState, double[], double[])} partialsComputed}
  83.      * method will be called when partial derivatives are computed, as a side effect of
  84.      * calling {@link #computePartials(SpacecraftState)} (SpacecraftState)}
  85.      * </p>
  86.      * @param name name of the parameter driver this observer is interested in (may be null)
  87.      * @param observer observer to register
  88.      */
  89.     void addObserver(final String name, final PartialsObserver observer) {
  90.         partialsObservers.put(name, observer);
  91.     }

  92.     /** {@inheritDoc} */
  93.     @Override
  94.     public String getName() {
  95.         return stmName;
  96.     }

  97.     /** {@inheritDoc} */
  98.     @Override
  99.     public int getDimension() {
  100.         return dimension;
  101.     }

  102.     /**
  103.      * Getter for the number of state variables.
  104.      * @return state vector dimension
  105.      */
  106.     public int getStateDimension() {
  107.         return stateDimension;
  108.     }

  109.     /**
  110.      * Protected getter for the force models.
  111.      * @return forces
  112.      */
  113.     protected List<ForceModel> getForceModels() {
  114.         return forceModels;
  115.     }

  116.     /**
  117.      * Protected getter for the partials observers map.
  118.      * @return map
  119.      */
  120.     protected Map<String, PartialsObserver> getPartialsObservers() {
  121.         return partialsObservers;
  122.     }

  123.     /**
  124.      * Method to build a linear system solver.
  125.      * @param matrix equations matrix
  126.      * @return solver
  127.      */
  128.     private DecompositionSolver getDecompositionSolver(final RealMatrix matrix) {
  129.         return new QRDecomposition(matrix, THRESHOLD).getSolver();
  130.     }

  131.     /** Set the initial value of the State Transition Matrix.
  132.      * <p>
  133.      * The returned state must be added to the propagator.
  134.      * </p>
  135.      * @param state initial state
  136.      * @param dYdY0 initial State Transition Matrix ∂Y/∂Y₀,
  137.      * if null (which is the most frequent case), assumed to be 6x6 identity
  138.      * @param orbitType orbit type used for states Y and Y₀ in {@code dYdY0}
  139.      * @param positionAngleType position angle used states Y and Y₀ in {@code dYdY0}
  140.      * @return state with initial STM (converted to Cartesian ∂C/∂Y₀) added
  141.      */
  142.     SpacecraftState setInitialStateTransitionMatrix(final SpacecraftState state, final RealMatrix dYdY0,
  143.                                                     final OrbitType orbitType,
  144.                                                     final PositionAngleType positionAngleType) {

  145.         final RealMatrix nonNullDYdY0;
  146.         if (dYdY0 == null) {
  147.             nonNullDYdY0 = MatrixUtils.createRealIdentityMatrix(getStateDimension());
  148.         } else {
  149.             if (dYdY0.getRowDimension() != getStateDimension() ||
  150.                     dYdY0.getColumnDimension() != getStateDimension()) {
  151.                 throw new OrekitException(LocalizedCoreFormats.DIMENSIONS_MISMATCH_2x2,
  152.                         dYdY0.getRowDimension(), dYdY0.getColumnDimension(),
  153.                         getStateDimension(), getStateDimension());
  154.             }
  155.             nonNullDYdY0 = dYdY0;
  156.         }

  157.         // convert to Cartesian STM
  158.         final RealMatrix dCdY0;
  159.         if (state.isOrbitDefined()) {
  160.             final RealMatrix dYdC = MatrixUtils.createRealIdentityMatrix(getStateDimension());
  161.             final Orbit orbit = orbitType.convertType(state.getOrbit());
  162.             final double[][] jacobian = new double[6][6];
  163.             orbit.getJacobianWrtCartesian(positionAngleType, jacobian);
  164.             dYdC.setSubMatrix(jacobian, 0, 0);
  165.             final DecompositionSolver decomposition = getDecompositionSolver(dYdC);
  166.             dCdY0 = decomposition.solve(nonNullDYdY0);
  167.         } else {
  168.             dCdY0 = nonNullDYdY0;
  169.         }

  170.         // set additional state
  171.         return state.addAdditionalData(getName(), flatten(dCdY0));

  172.     }

  173.     /**
  174.      * Flattens a matrix into an 1-D array.
  175.      * @param matrix matrix to be flatten
  176.      * @return array
  177.      */
  178.     double[] flatten(final RealMatrix matrix) {
  179.         final double[] flat = new double[getDimension()];
  180.         int k = 0;
  181.         for (int i = 0; i < getStateDimension(); ++i) {
  182.             for (int j = 0; j < getStateDimension(); ++j) {
  183.                 flat[k++] = matrix.getEntry(i, j);
  184.             }
  185.         }
  186.         return flat;
  187.     }

  188.     /** {@inheritDoc} */
  189.     @Override
  190.     public boolean yields(final SpacecraftState state) {
  191.         return !state.hasAdditionalData(getName());
  192.     }

  193.     /** {@inheritDoc} */
  194.     public CombinedDerivatives combinedDerivatives(final SpacecraftState state) {
  195.         final double[] factor = computePartials(state);

  196.         // retrieve current State Transition Matrix
  197.         final double[] p    = state.getAdditionalState(getName());
  198.         final double[] pDot = new double[p.length];

  199.         // perform multiplication
  200.         multiplyMatrix(factor, p, pDot, getStateDimension());

  201.         return new CombinedDerivatives(pDot, null);

  202.     }

  203.     /** Compute evolution matrix product.
  204.      * @param factor factor matrix
  205.      * @param x right factor of the multiplication, as a flatten array in row major order
  206.      * @param y placeholder where to put the result, as a flatten array in row major order
  207.      * @param columns number of columns of both x and y (so their dimensions are the state one times the columns)
  208.      */
  209.     abstract void multiplyMatrix(double[] factor, double[] x, double[] y, int columns);

  210.     /** Compute the various partial derivatives.
  211.      * @param state current spacecraft state
  212.      * @return factor matrix
  213.      */
  214.     double[] computePartials(final SpacecraftState state) {

  215.         // set up containers for partial derivatives
  216.         final double[]              factor               = new double[(stateDimension - SPACE_DIMENSION) * stateDimension];
  217.         final DoubleArrayDictionary partialsDictionary = new DoubleArrayDictionary();

  218.         // evaluate contribution of all force models
  219.         final AttitudeProvider equivalentAttitudeProvider = wrapAttitudeProviderIfPossible();
  220.         final boolean isThereAnyForceNotDependingOnlyOnPosition = getForceModels().stream().anyMatch(force -> !force.dependsOnPositionOnly());
  221.         final NumericalGradientConverter posOnlyConverter = new NumericalGradientConverter(state, SPACE_DIMENSION, equivalentAttitudeProvider);
  222.         final NumericalGradientConverter fullConverter = isThereAnyForceNotDependingOnlyOnPosition ?
  223.                 new NumericalGradientConverter(state, getStateDimension(), equivalentAttitudeProvider) : posOnlyConverter;

  224.         for (final ForceModel forceModel : getForceModels()) {

  225.             final NumericalGradientConverter     converter    = forceModel.dependsOnPositionOnly() ? posOnlyConverter : fullConverter;
  226.             final FieldSpacecraftState<Gradient> dsState      = converter.getState(forceModel);
  227.             final Gradient[]                     parameters   = converter.getParametersAtStateDate(dsState, forceModel);

  228.             // update partial derivatives w.r.t. state variables
  229.             final Gradient[] ratesPartials = computeRatesPartialsAndUpdateFactor(forceModel, dsState, parameters, factor);

  230.             // partials derivatives with respect to parameters
  231.             updateFactorForParameters(forceModel, converter, ratesPartials, partialsDictionary, state, factor);

  232.         }

  233.         return factor;

  234.     }

  235.     /**
  236.      * Compute with automatic differentiation the partial derivatives of state variables' rate
  237.      * that are not part of the position vector.
  238.      * @param forceModel force model
  239.      * @param fieldState state in Taylor differential algebra
  240.      * @param parameters force parameters in Taylor differential algebra
  241.      * @param factor factor matrix to update
  242.      * @return array of rates in Taylor differential algebra
  243.      */
  244.     abstract Gradient[] computeRatesPartialsAndUpdateFactor(ForceModel forceModel,
  245.                                                             FieldSpacecraftState<Gradient> fieldState,
  246.                                                             Gradient[] parameters, double[] factor);

  247.     /**
  248.      * Update factor regarding partials of force model parameters.
  249.      * @param forceModel force
  250.      * @param converter gradient converter
  251.      * @param ratesPartials state variables' rates evaluated in the Taylor differential algebra
  252.      * @param partialsDictionary dictionary storing the partials
  253.      * @param state spacecraft state
  254.      * @param factor factor matrix (flattened)
  255.      */
  256.     private void updateFactorForParameters(final ForceModel forceModel, final NumericalGradientConverter converter,
  257.                                            final Gradient[] ratesPartials, final DoubleArrayDictionary partialsDictionary,
  258.                                            final SpacecraftState state, final double[] factor) {
  259.         int paramsIndex = converter.getFreeStateParameters();
  260.         for (ParameterDriver driver : forceModel.getParametersDrivers()) {
  261.             if (driver.isSelected()) {

  262.                 // for each span (for each estimated value) corresponding name is added
  263.                 for (TimeSpanMap.Span<String> span = driver.getNamesSpanMap().getFirstSpan(); span != null; span = span.next()) {
  264.                     updateDictionaryEntry(partialsDictionary, span, ratesPartials, paramsIndex);
  265.                     ++paramsIndex;
  266.                 }
  267.             }
  268.         }

  269.         // notify observers
  270.         for (Map.Entry<String, PartialsObserver> observersEntry : getPartialsObservers().entrySet()) {
  271.             final DoubleArrayDictionary.Entry entry = partialsDictionary.getEntry(observersEntry.getKey());
  272.             observersEntry.getValue().partialsComputed(state, factor, entry == null ? new double[ratesPartials.length] : entry.getValue());
  273.         }
  274.     }

  275.     /**
  276.      * Update entry of dictionary with derivative information.
  277.      * @param partialsDictionary dictionary
  278.      * @param span time span
  279.      * @param ratesPartials state variables' rates evaluated in the Taylor differential algebra
  280.      * @param paramsIndex index of parameter as an independent variable of the differential algebra
  281.      */
  282.     private void updateDictionaryEntry(final DoubleArrayDictionary partialsDictionary, final TimeSpanMap.Span<String> span,
  283.                                        final Gradient[] ratesPartials, final int paramsIndex) {
  284.         // get the partials derivatives for this driver
  285.         DoubleArrayDictionary.Entry entry = partialsDictionary.getEntry(span.getData());
  286.         if (entry == null) {
  287.             // create an entry filled with zeroes
  288.             partialsDictionary.put(span.getData(), new double[ratesPartials.length]);
  289.             entry = partialsDictionary.getEntry(span.getData());
  290.         }

  291.         // add the contribution of the current force model
  292.         final double[] increment = new double[ratesPartials.length];
  293.         for (int i = 0; i < ratesPartials.length; ++i) {
  294.             increment[i] = ratesPartials[i].getGradient()[paramsIndex];
  295.         }
  296.         entry.increment(increment);
  297.     }

  298.     /**
  299.      * Method that first checks if it is possible to replace the attitude provider with a computationally cheaper one
  300.      * to evaluate. If applicable, the new provider only computes the rotation and uses dummy rate and acceleration,
  301.      * since they should not be used later on.
  302.      * @return same provider if at least one forces used attitude derivatives, otherwise one wrapping the old one for
  303.      * the rotation
  304.      */
  305.     AttitudeProvider wrapAttitudeProviderIfPossible() {
  306.         if (forceModels.stream().anyMatch(ForceModel::dependsOnAttitudeRate)) {
  307.             // at least one force uses an attitude rate, need to keep the original provider
  308.             return attitudeProvider;
  309.         } else {
  310.             // the original provider can be replaced by a lighter one for performance
  311.             return AttitudeProviderModifier.getFrozenAttitudeProvider(attitudeProvider);
  312.         }
  313.     }

  314.     /** Interface for observing partials derivatives. */
  315.     @FunctionalInterface
  316.     public interface PartialsObserver {

  317.         /** Callback called when partial derivatives have been computed.
  318.          * @param state current spacecraft state
  319.          * @param factor factor matrix, flattened along rows
  320.          * @param partials partials derivatives of all state variables' rates (except from position) w.r.t. the parameter driver
  321.          * that was registered (zero if no parameters were not selected or parameter is unknown)
  322.          */
  323.         void partialsComputed(SpacecraftState state, double[] factor, double[] partials);

  324.     }

  325. }