1   /* Copyright 2002-2026 CS GROUP
2    * Licensed to CS GROUP (CS) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * CS licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *   http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.orekit.propagation.numerical;
18  
19  import org.hipparchus.analysis.differentiation.Gradient;
20  import org.hipparchus.exception.LocalizedCoreFormats;
21  import org.hipparchus.linear.DecompositionSolver;
22  import org.hipparchus.linear.MatrixUtils;
23  import org.hipparchus.linear.QRDecomposition;
24  import org.hipparchus.linear.RealMatrix;
25  import org.hipparchus.util.Precision;
26  import org.orekit.attitudes.AttitudeProvider;
27  import org.orekit.attitudes.AttitudeProviderModifier;
28  import org.orekit.errors.OrekitException;
29  import org.orekit.forces.ForceModel;
30  import org.orekit.orbits.Orbit;
31  import org.orekit.orbits.OrbitType;
32  import org.orekit.orbits.PositionAngleType;
33  import org.orekit.propagation.FieldSpacecraftState;
34  import org.orekit.propagation.SpacecraftState;
35  import org.orekit.propagation.integration.AdditionalDerivativesProvider;
36  import org.orekit.propagation.integration.CombinedDerivatives;
37  import org.orekit.utils.DataDictionary;
38  import org.orekit.utils.ParameterDriver;
39  import org.orekit.utils.TimeSpanMap;
40  
41  import java.io.IOException;
42  import java.io.ObjectInputStream;
43  import java.util.HashMap;
44  import java.util.List;
45  import java.util.Map;
46  
47  /** Abstract generator for numerical State Transition Matrix.
48   * @author Luc Maisonobe
49   * @author Melina Vanel
50   * @author Romain Serra
51   * @since 13.1
52   */
53  abstract class AbstractStateTransitionMatrixGenerator implements AdditionalDerivativesProvider {
54  
55      /** Space dimension. */
56      protected static final int SPACE_DIMENSION = 3;
57  
58      /** Threshold for matrix solving. */
59      private static final double THRESHOLD = Precision.SAFE_MIN;
60  
61      /** Name of the Cartesian STM additional state. */
62      private final String stmName;
63  
64      /** Force models used in propagation. */
65      private final List<ForceModel> forceModels;
66  
67      /** Attitude provider used in propagation. */
68      private final AttitudeProvider attitudeProvider;
69  
70      /** Observers for partial derivatives. */
71      private final Map<String, PartialsObserver> partialsObservers;
72  
73      /** Number of state variables. */
74      private final int stateDimension;
75  
76      /** Dimension of flatten STM. */
77      private final int dimension;
78  
79      /** Simple constructor.
80       * @param stmName name of the Cartesian STM additional state
81       * @param forceModels force models used in propagation
82       * @param attitudeProvider attitude provider used in propagation
83       * @param stateDimension dimension of state vector
84       */
85      AbstractStateTransitionMatrixGenerator(final String stmName, final List<ForceModel> forceModels,
86                                             final AttitudeProvider attitudeProvider, final int stateDimension) {
87          this.stmName           = stmName;
88          this.forceModels       = forceModels;
89          this.attitudeProvider  = attitudeProvider;
90          this.stateDimension    = stateDimension;
91          this.dimension         = stateDimension * stateDimension;
92          this.partialsObservers = new HashMap<>();
93      }
94  
95      /** Register an observer for partial derivatives.
96       * <p>
97       * The observer {@link PartialsObserver#partialsComputed(SpacecraftState, double[], double[])} partialsComputed}
98       * method will be called when partial derivatives are computed, as a side effect of
99       * calling {@link #computePartials(SpacecraftState)} (SpacecraftState)}
100      * </p>
101      * @param name name of the parameter driver this observer is interested in (may be null)
102      * @param observer observer to register
103      */
104     void addObserver(final String name, final PartialsObserver observer) {
105         partialsObservers.put(name, observer);
106     }
107 
108     /** {@inheritDoc} */
109     @Override
110     public String getName() {
111         return stmName;
112     }
113 
114     /** {@inheritDoc} */
115     @Override
116     public int getDimension() {
117         return dimension;
118     }
119 
120     /**
121      * Getter for the number of state variables.
122      * @return state vector dimension
123      */
124     public int getStateDimension() {
125         return stateDimension;
126     }
127 
128     /**
129      * Protected getter for the force models.
130      * @return forces
131      */
132     protected List<ForceModel> getForceModels() {
133         return forceModels;
134     }
135 
136     /**
137      * Protected getter for the partials observers map.
138      * @return map
139      */
140     protected Map<String, PartialsObserver> getPartialsObservers() {
141         return partialsObservers;
142     }
143 
144     /**
145      * Method to build a linear system solver.
146      * @param matrix equations matrix
147      * @return solver
148      */
149     private DecompositionSolver getDecompositionSolver(final RealMatrix matrix) {
150         return new QRDecomposition(matrix, THRESHOLD).getSolver();
151     }
152 
153     /** Set the initial value of the State Transition Matrix.
154      * <p>
155      * The returned state must be added to the propagator.
156      * </p>
157      * @param state initial state
158      * @param dYdY0 initial State Transition Matrix ∂Y/∂Y₀,
159      * if null (which is the most frequent case), assumed to be 6x6 identity
160      * @param orbitType orbit type used for states Y and Y₀ in {@code dYdY0}
161      * @param positionAngleType position angle used states Y and Y₀ in {@code dYdY0}
162      * @return state with initial STM (converted to Cartesian ∂C/∂Y₀) added
163      */
164     SpacecraftState setInitialStateTransitionMatrix(final SpacecraftState state, final RealMatrix dYdY0,
165                                                     final OrbitType orbitType,
166                                                     final PositionAngleType positionAngleType) {
167 
168         final RealMatrix nonNullDYdY0;
169         if (dYdY0 == null) {
170             nonNullDYdY0 = MatrixUtils.createRealIdentityMatrix(getStateDimension());
171         } else {
172             if (dYdY0.getRowDimension() != getStateDimension() ||
173                     dYdY0.getColumnDimension() != getStateDimension()) {
174                 throw new OrekitException(LocalizedCoreFormats.DIMENSIONS_MISMATCH_2x2,
175                         dYdY0.getRowDimension(), dYdY0.getColumnDimension(),
176                         getStateDimension(), getStateDimension());
177             }
178             nonNullDYdY0 = dYdY0;
179         }
180 
181         // convert to Cartesian STM
182         final RealMatrix dCdY0;
183         if (state.isOrbitDefined()) {
184             final RealMatrix dYdC = MatrixUtils.createRealIdentityMatrix(getStateDimension());
185             final Orbit orbit = orbitType.convertType(state.getOrbit());
186             final double[][] jacobian = new double[6][6];
187             orbit.getJacobianWrtCartesian(positionAngleType, jacobian);
188             dYdC.setSubMatrix(jacobian, 0, 0);
189             final DecompositionSolver decomposition = getDecompositionSolver(dYdC);
190             dCdY0 = decomposition.solve(nonNullDYdY0);
191         } else {
192             dCdY0 = nonNullDYdY0;
193         }
194 
195         // set additional state
196         return state.addAdditionalData(getName(), flatten(dCdY0));
197 
198     }
199 
200     /**
201      * Flattens a matrix into an 1-D array.
202      * @param matrix matrix to be flatten
203      * @return array
204      */
205     double[] flatten(final RealMatrix matrix) {
206         final double[] flat = new double[getDimension()];
207         int k = 0;
208         for (int i = 0; i < getStateDimension(); ++i) {
209             for (int j = 0; j < getStateDimension(); ++j) {
210                 flat[k++] = matrix.getEntry(i, j);
211             }
212         }
213         return flat;
214     }
215 
216     /** {@inheritDoc} */
217     @Override
218     public boolean yields(final SpacecraftState state) {
219         return !state.hasAdditionalData(getName());
220     }
221 
222     /** {@inheritDoc} */
223     public CombinedDerivatives combinedDerivatives(final SpacecraftState state) {
224         final double[] factor = computePartials(state);
225 
226         // retrieve current State Transition Matrix
227         final double[] p    = state.getAdditionalState(getName());
228         final double[] pDot = new double[p.length];
229 
230         // perform multiplication
231         multiplyMatrix(factor, p, pDot, getStateDimension());
232 
233         return new CombinedDerivatives(pDot, null);
234 
235     }
236 
237     /** Compute evolution matrix product.
238      * @param factor factor matrix
239      * @param x right factor of the multiplication, as a flatten array in row major order
240      * @param y placeholder where to put the result, as a flatten array in row major order
241      * @param columns number of columns of both x and y (so their dimensions are the state one times the columns)
242      */
243     abstract void multiplyMatrix(double[] factor, double[] x, double[] y, int columns);
244 
245     /** Compute the various partial derivatives.
246      * @param state current spacecraft state
247      * @return factor matrix
248      */
249     double[] computePartials(final SpacecraftState state) {
250 
251         // set up containers for partial derivatives
252         final double[]              factor               = new double[(stateDimension - SPACE_DIMENSION) * stateDimension];
253         final Map<String, double[]> partialsDictionary = new HashMap<>();
254 
255         // evaluate contribution of all force models
256         final AttitudeProvider equivalentAttitudeProvider = wrapAttitudeProviderIfPossible();
257         final boolean isThereAnyForceNotDependingOnlyOnPosition = getForceModels().stream().anyMatch(force -> !force.dependsOnPositionOnly());
258         final NumericalGradientConverter posOnlyConverter = new NumericalGradientConverter(state, SPACE_DIMENSION, equivalentAttitudeProvider);
259         final NumericalGradientConverter fullConverter = isThereAnyForceNotDependingOnlyOnPosition ?
260                 new NumericalGradientConverter(state, getStateDimension(), equivalentAttitudeProvider) : posOnlyConverter;
261         final SpacecraftState stateForParameters = state.withAdditionalData(new LocalDoubleArrayDictionary(state.getAdditionalDataValues()));
262 
263         for (final ForceModel forceModel : getForceModels()) {
264 
265             final NumericalGradientConverter     converter    = forceModel.dependsOnPositionOnly() ? posOnlyConverter : fullConverter;
266             final FieldSpacecraftState<Gradient> dsState      = converter.getState(forceModel);
267             final Gradient[]                     parameters   = converter.getParametersAtStateDate(dsState, forceModel);
268 
269             // update partial derivatives w.r.t. state variables
270             final Gradient[] ratesPartials = computeRatesPartialsAndUpdateFactor(forceModel, dsState, parameters, factor);
271 
272             // partials derivatives with respect to parameters
273             updateFactorForParameters(forceModel, converter, ratesPartials, partialsDictionary, stateForParameters, factor);
274 
275         }
276 
277         return factor;
278 
279     }
280 
281     /**
282      * Compute with automatic differentiation the partial derivatives of state variables' rate
283      * that are not part of the position vector.
284      * @param forceModel force model
285      * @param fieldState state in Taylor differential algebra
286      * @param parameters force parameters in Taylor differential algebra
287      * @param factor factor matrix to update
288      * @return array of rates in Taylor differential algebra
289      */
290     abstract Gradient[] computeRatesPartialsAndUpdateFactor(ForceModel forceModel,
291                                                             FieldSpacecraftState<Gradient> fieldState,
292                                                             Gradient[] parameters, double[] factor);
293 
294     /**
295      * Update factor regarding partials of force model parameters.
296      * @param forceModel force
297      * @param converter gradient converter
298      * @param ratesPartials state variables' rates evaluated in the Taylor differential algebra
299      * @param partialsDictionary dictionary storing the partials
300      * @param state spacecraft state
301      * @param factor factor matrix (flattened)
302      */
303     private void updateFactorForParameters(final ForceModel forceModel, final NumericalGradientConverter converter,
304                                            final Gradient[] ratesPartials, final Map<String, double[]> partialsDictionary,
305                                            final SpacecraftState state, final double[] factor) {
306         int paramsIndex = converter.getFreeStateParameters();
307         for (ParameterDriver driver : forceModel.getParametersDrivers()) {
308             if (driver.isSelected()) {
309 
310                 // for each span (for each estimated value) corresponding name is added
311                 for (TimeSpanMap.Span<String> span = driver.getNamesSpanMap().getFirstSpan(); span != null; span = span.next()) {
312                     updateDictionaryEntry(partialsDictionary, span, ratesPartials, paramsIndex);
313                     ++paramsIndex;
314                 }
315             }
316         }
317 
318         // notify observers
319         for (Map.Entry<String, PartialsObserver> observersEntry : getPartialsObservers().entrySet()) {
320             observersEntry.getValue().partialsComputed(state, factor,
321                     partialsDictionary.getOrDefault(observersEntry.getKey(), new double[ratesPartials.length]));
322         }
323     }
324 
325     /**
326      * Update entry of dictionary with derivative information.
327      * @param partialsDictionary dictionary
328      * @param span time span
329      * @param ratesPartials state variables' rates evaluated in the Taylor differential algebra
330      * @param paramsIndex index of parameter as an independent variable of the differential algebra
331      */
332     private void updateDictionaryEntry(final Map<String, double[]> partialsDictionary, final TimeSpanMap.Span<String> span,
333                                        final Gradient[] ratesPartials, final int paramsIndex) {
334         // get the partials derivatives for this driver
335         partialsDictionary.putIfAbsent(span.getData(), new double[ratesPartials.length]);
336 
337         // add the contribution of the current force model
338         final double[] increment = partialsDictionary.get(span.getData());
339         for (int i = 0; i < ratesPartials.length; ++i) {
340             increment[i] += ratesPartials[i].getGradient()[paramsIndex];
341         }
342         partialsDictionary.replace(span.getData(), increment);
343     }
344 
345     /**
346      * Method that first checks if it is possible to replace the attitude provider with a computationally cheaper one
347      * to evaluate. If applicable, the new provider only computes the rotation and uses dummy rate and acceleration,
348      * since they should not be used later on.
349      * @return same provider if at least one forces used attitude derivatives, otherwise one wrapping the old one for
350      * the rotation
351      */
352     AttitudeProvider wrapAttitudeProviderIfPossible() {
353         if (forceModels.stream().anyMatch(ForceModel::dependsOnAttitudeRate)) {
354             // at least one force uses an attitude rate, need to keep the original provider
355             return attitudeProvider;
356         } else {
357             // the original provider can be replaced by a lighter one for performance
358             return AttitudeProviderModifier.getFrozenAttitudeProvider(attitudeProvider);
359         }
360     }
361 
362     /** Interface for observing partials derivatives. */
363     @FunctionalInterface
364     public interface PartialsObserver {
365 
366         /** Callback called when partial derivatives have been computed.
367          * @param state current spacecraft state
368          * @param factor factor matrix, flattened along rows
369          * @param partials partials derivatives of all state variables' rates (except from position) w.r.t. the parameter driver
370          * that was registered (zero if no parameters were not selected or parameter is unknown)
371          */
372         void partialsComputed(SpacecraftState state, double[] factor, double[] partials);
373 
374     }
375 
376     /**
377      * Local override of data dictionary using HashMap for performance.
378      */
379     private static class LocalDoubleArrayDictionary extends DataDictionary {
380 
381         /** Serialization UID. */
382         private static final long serialVersionUID = 1L;
383 
384         /** Map for quick access. */
385         private transient Map<String, Object> objectMap;
386 
387         /**
388          * Constructor.
389          * @param inputDictionary dictionary whose content is to reproduce
390          */
391         LocalDoubleArrayDictionary(final DataDictionary inputDictionary) {
392             super(inputDictionary);
393             objectMap = toMap();
394         }
395 
396         /**
397          * Deserializes the object from a stream and restores the transient fields.
398          *
399          * @param ois the input stream from which the object is read
400          * @throws IOException if an I/O error occurs during deserialization
401          * @throws ClassNotFoundException if the class of a serialized object cannot be found
402          */
403         private void readObject(final ObjectInputStream ois) throws IOException, ClassNotFoundException {
404             ois.defaultReadObject();
405             objectMap = toMap();
406         }
407 
408         @Override
409         public Object get(final String key) {
410             return objectMap.get(key);
411         }
412     }
413 }
414