1   /* Copyright 2002-2025 CS GROUP
2    * Licensed to CS GROUP (CS) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * CS licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *   http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.orekit.propagation.conversion.osc2mean;
18  
19  import java.util.function.UnaryOperator;
20  
21  import org.hipparchus.CalculusFieldElement;
22  import org.hipparchus.Field;
23  import org.hipparchus.analysis.differentiation.Gradient;
24  import org.hipparchus.analysis.differentiation.GradientField;
25  import org.hipparchus.geometry.euclidean.threed.FieldVector3D;
26  import org.hipparchus.geometry.euclidean.threed.Vector3D;
27  import org.hipparchus.linear.DiagonalMatrix;
28  import org.hipparchus.linear.FieldVector;
29  import org.hipparchus.linear.MatrixUtils;
30  import org.hipparchus.linear.RealMatrix;
31  import org.hipparchus.linear.RealVector;
32  import org.hipparchus.optim.ConvergenceChecker;
33  import org.hipparchus.optim.SimpleVectorValueChecker;
34  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresBuilder;
35  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresFactory;
36  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresOptimizer;
37  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
38  import org.hipparchus.optim.nonlinear.vector.leastsquares.MultivariateJacobianFunction;
39  import org.hipparchus.util.Pair;
40  import org.orekit.orbits.CartesianOrbit;
41  import org.orekit.orbits.FieldCartesianOrbit;
42  import org.orekit.orbits.FieldKeplerianOrbit;
43  import org.orekit.orbits.FieldOrbit;
44  import org.orekit.orbits.Orbit;
45  import org.orekit.time.FieldAbsoluteDate;
46  import org.orekit.utils.FieldPVCoordinates;
47  import org.orekit.utils.PVCoordinates;
48  
49  /**
50   * Class enabling conversion from osculating to mean orbit
51   * for a given theory using a least-squares algorithm.
52   *
53   * @author Pascal Parraud
54   * @since 13.0
55   */
56  public class LeastSquaresConverter implements OsculatingToMeanConverter {
57  
58      /** Default convergence threshold. */
59      public static final double DEFAULT_THRESHOLD   = 1e-4;
60  
61      /** Default maximum number of iterations. */
62      public static final int DEFAULT_MAX_ITERATIONS = 1000;
63  
64      /** Mean theory used. */
65      private MeanTheory theory;
66  
67      /** Convergence threshold. */
68      private double threshold;
69  
70      /** Maximum number of iterations. */
71      private int maxIterations;
72  
73      /** Optimizer used. */
74      private LeastSquaresOptimizer optimizer;
75  
76      /** Convergence checker for optimization algorithm. */
77      private ConvergenceChecker<LeastSquaresProblem.Evaluation> checker;
78  
79      /** RMS. */
80      private double rms;
81  
82      /** Number of iterations performed. */
83      private int iterationsNb;
84  
85      /**
86       * Default constructor.
87       * <p>
88       * The mean theory and the optimizer must be set before converting.
89       */
90      public LeastSquaresConverter() {
91          this(null, null, DEFAULT_THRESHOLD, DEFAULT_MAX_ITERATIONS);
92      }
93  
94      /**
95       * Constructor.
96       * <p>
97       * The optimizer must be set before converting.
98       *
99       * @param theory mean theory to be used
100      */
101     public LeastSquaresConverter(final MeanTheory theory) {
102         this(theory, null, DEFAULT_THRESHOLD, DEFAULT_MAX_ITERATIONS);
103     }
104 
105     /**
106      * Constructor.
107      * @param theory mean theory to be used
108      * @param optimizer optimizer to be used
109      */
110     public LeastSquaresConverter(final MeanTheory theory,
111                                  final LeastSquaresOptimizer optimizer) {
112         this(theory, optimizer, DEFAULT_THRESHOLD, DEFAULT_MAX_ITERATIONS);
113     }
114 
115     /**
116      * Constructor.
117      * <p>
118      * The mean theory and the optimizer must be set before converting.
119      *
120      * @param threshold     convergence threshold
121      * @param maxIterations maximum number of iterations
122      */
123     public LeastSquaresConverter(final double threshold,
124                                  final int maxIterations) {
125         this(null, null, threshold, maxIterations);
126     }
127 
128     /**
129      * Constructor.
130      * @param theory        mean theory to be used
131      * @param optimizer     optimizer to be used
132      * @param threshold     convergence threshold
133      * @param maxIterations maximum number of iterations
134      */
135     public LeastSquaresConverter(final MeanTheory theory,
136                                  final LeastSquaresOptimizer optimizer,
137                                  final double threshold,
138                                  final int maxIterations) {
139         setMeanTheory(theory);
140         setOptimizer(optimizer);
141         setThreshold(threshold);
142         setMaxIterations(maxIterations);
143     }
144 
145     /** {@inheritDoc} */
146     @Override
147     public MeanTheory getMeanTheory() {
148         return theory;
149     }
150 
151     /** {@inheritDoc} */
152     @Override
153     public void setMeanTheory(final MeanTheory meanTheory) {
154         this.theory = meanTheory;
155     }
156 
157     /**
158      * Gets the optimizer.
159      * @return the optimizer
160      */
161     public LeastSquaresOptimizer getOptimizer() {
162         return optimizer;
163     }
164 
165     /**
166      * Sets the optimizer.
167      * @param optimizer the optimizer
168      */
169     public void setOptimizer(final LeastSquaresOptimizer optimizer) {
170         this.optimizer = optimizer;
171     }
172 
173     /**
174      * Gets the convergence threshold.
175      * @return convergence threshold
176      */
177     public double getThreshold() {
178         return threshold;
179     }
180 
181     /**
182      * Sets the convergence threshold.
183      * @param threshold convergence threshold
184      */
185     public void setThreshold(final double threshold) {
186         this.threshold = threshold;
187         final SimpleVectorValueChecker svvc = new SimpleVectorValueChecker(-1.0, threshold);
188         this.checker = LeastSquaresFactory.evaluationChecker(svvc);
189     }
190 
191     /**
192      * Gets the maximum number of iterations.
193      * @return maximum number of iterations
194      */
195     public int getMaxIterations() {
196         return maxIterations;
197     }
198 
199     /**
200      * Sets maximum number of iterations.
201      * @param maxIterations maximum number of iterations
202      */
203     public void setMaxIterations(final int maxIterations) {
204         this.maxIterations = maxIterations;
205     }
206 
207     /**
208      * Gets the RMS for the last conversion.
209      * @return the RMS
210      */
211     public double getRMS() {
212         return rms;
213     }
214 
215     /**
216      * Gets the number of iterations performed by the last conversion.
217      * @return number of iterations
218      */
219     public int getIterationsNb() {
220         return iterationsNb;
221     }
222 
223     /** {@inheritDoc}
224      *  Uses a least-square algorithm.
225      */
226     @Override
227     public Orbit convertToMean(final Orbit osculating) {
228 
229         // Initialize conversion
230         final Orbit initialized = theory.preprocessing(osculating);
231 
232         // State vector
233         final RealVector stateVector = MatrixUtils.createRealVector(6);
234 
235         // Position/Velocity
236         final Vector3D position = initialized.getPVCoordinates().getPosition();
237         final Vector3D velocity = initialized.getPVCoordinates().getVelocity();
238 
239         // Fill state vector
240         stateVector.setEntry(0, position.getX());
241         stateVector.setEntry(1, position.getY());
242         stateVector.setEntry(2, position.getZ());
243         stateVector.setEntry(3, velocity.getX());
244         stateVector.setEntry(4, velocity.getY());
245         stateVector.setEntry(5, velocity.getZ());
246 
247         // Create the initial guess of the least squares problem
248         final RealVector startState = MatrixUtils.createRealVector(6);
249         startState.setSubVector(0, stateVector.getSubVector(0, 6));
250 
251         // Weights
252         final double[] weights = new double[6];
253         final double velocityWeight = initialized.getPVCoordinates().getVelocity().getNorm() *
254                                       initialized.getPVCoordinates().getPosition().getNormSq() / initialized.getMu();
255         for (int i = 0; i < 3; i++) {
256             weights[i] = 1.0;
257             weights[i + 3] = velocityWeight;
258         }
259 
260         // Constructs the least squares problem
261         final LeastSquaresProblem problem = new LeastSquaresBuilder().
262                                             maxIterations(maxIterations).
263                                             maxEvaluations(Integer.MAX_VALUE).
264                                             checker(checker).
265                                             model(new ModelFunction(initialized)).
266                                             weight(new DiagonalMatrix(weights)).
267                                             target(stateVector).
268                                             start(startState).
269                                             build();
270 
271         // Solve least squares
272         final LeastSquaresOptimizer.Optimum optimum = optimizer.optimize(problem);
273 
274         // Stores some results
275         rms = optimum.getRMS();
276         iterationsNb = optimum.getIterations();
277 
278         // Builds the estimated mean orbit
279         final Vector3D pEstimated = new Vector3D(optimum.getPoint().getSubVector(0, 3).toArray());
280         final Vector3D vEstimated = new Vector3D(optimum.getPoint().getSubVector(3, 3).toArray());
281         final Orbit mean = new CartesianOrbit(new PVCoordinates(pEstimated, vEstimated),
282                                               initialized.getFrame(), initialized.getDate(),
283                                               initialized.getMu());
284 
285         // Returns the mean orbit
286         return theory.postprocessing(osculating, mean);
287     }
288 
289     /** {@inheritDoc}
290      *  Uses a least-square algorithm.
291      */
292     @Override
293     public <T extends CalculusFieldElement<T>> FieldOrbit<T> convertToMean(final FieldOrbit<T> osculating) {
294         throw new UnsupportedOperationException();
295     }
296 
297     /** Model function for the least squares problem.
298      * Provides the Jacobian of the function computing position/velocity at the point.
299      */
300     private class ModelFunction implements MultivariateJacobianFunction {
301 
302         /** Osculating orbit as Cartesian. */
303         private final FieldCartesianOrbit<Gradient> fieldOsc;
304 
305         /**
306          * Constructor.
307          * @param osculating osculating orbit
308          */
309         ModelFunction(final Orbit osculating) {
310             // Conversion to field orbit
311             final Field<Gradient> field = GradientField.getField(6);
312             this.fieldOsc = new FieldCartesianOrbit<>(field, osculating);
313         }
314 
315         /**  {@inheritDoc} */
316         @Override
317         public Pair<RealVector, RealMatrix> value(final RealVector point) {
318             final RealVector objectiveOscState = MatrixUtils.createRealVector(6);
319             final RealMatrix objectiveJacobian = MatrixUtils.createRealMatrix(6, 6);
320             getTransformedAndJacobian(state -> mean2Osc(state), point,
321                                       objectiveOscState, objectiveJacobian);
322             return new Pair<>(objectiveOscState, objectiveJacobian);
323         }
324 
325         /**
326          * Fill model.
327          * @param operator state vector propagation
328          * @param state state vector
329          * @param transformed value to fill
330          * @param jacobian Jacobian to fill
331          */
332         private void getTransformedAndJacobian(final UnaryOperator<FieldVector<Gradient>> operator,
333                                                final RealVector state, final RealVector transformed,
334                                                final RealMatrix jacobian) {
335 
336             // State dimension
337             final int stateDim = state.getDimension();
338 
339             // Initialise the state as field to calculate the gradient
340             final GradientField field = GradientField.getField(stateDim);
341             final FieldVector<Gradient> fieldState = MatrixUtils.createFieldVector(field, stateDim);
342             for (int i = 0; i < stateDim; ++i) {
343                 fieldState.setEntry(i, Gradient.variable(stateDim, i, state.getEntry(i)));
344             }
345 
346             // Call operator
347             final FieldVector<Gradient> fieldTransformed = operator.apply(fieldState);
348 
349             // Output dimension
350             final int outDim = fieldTransformed.getDimension();
351 
352             // Extract transform and Jacobian as real values
353             for (int i = 0; i < outDim; ++i) {
354                 transformed.setEntry(i, fieldTransformed.getEntry(i).getReal());
355                 jacobian.setRow(i, fieldTransformed.getEntry(i).getGradient());
356             }
357 
358         }
359 
360         /**
361          * Operator to compute an osculating state from a mean state.
362          * @param mean mean state vector
363          * @return osculating state vector
364          */
365         private FieldVector<Gradient> mean2Osc(final FieldVector<Gradient> mean) {
366             // Epoch
367             final FieldAbsoluteDate<Gradient> epoch = fieldOsc.getDate();
368 
369             // Field
370             final Field<Gradient> field = epoch.getField();
371 
372             // Extract mean state
373             final FieldVector3D<Gradient> pos = new FieldVector3D<>(mean.getSubVector(0, 3).toArray());
374             final FieldVector3D<Gradient> vel = new FieldVector3D<>(mean.getSubVector(3, 3).toArray());
375             final FieldPVCoordinates<Gradient> pvMean = new FieldPVCoordinates<>(pos, vel);
376             final FieldKeplerianOrbit<Gradient> oMean = new FieldKeplerianOrbit<>(pvMean,
377                                                                                   fieldOsc.getFrame(),
378                                                                                   fieldOsc.getDate(),
379                                                                                   fieldOsc.getMu());
380 
381             // Propagate to epoch
382             final FieldOrbit<Gradient> oOsc = theory.meanToOsculating(oMean);
383             final FieldPVCoordinates<Gradient> pvOsc = oOsc.getPVCoordinates(oMean.getFrame());
384 
385             // Osculating
386             final FieldVector<Gradient> osculating = MatrixUtils.createFieldVector(field, 6);
387             osculating.setEntry(0, pvOsc.getPosition().getX());
388             osculating.setEntry(1, pvOsc.getPosition().getY());
389             osculating.setEntry(2, pvOsc.getPosition().getZ());
390             osculating.setEntry(3, pvOsc.getVelocity().getX());
391             osculating.setEntry(4, pvOsc.getVelocity().getY());
392             osculating.setEntry(5, pvOsc.getVelocity().getZ());
393 
394             // Return
395             return osculating;
396 
397         }
398 
399     }
400 
401 }