AbstractPropagatorConverter.java

/* Copyright 2002-2024 CS GROUP
 * Licensed to CS GROUP (CS) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * CS licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.orekit.propagation.conversion;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.hipparchus.analysis.MultivariateVectorFunction;
import org.hipparchus.exception.MathRuntimeException;
import org.hipparchus.linear.DiagonalMatrix;
import org.hipparchus.optim.ConvergenceChecker;
import org.hipparchus.optim.SimpleVectorValueChecker;
import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresBuilder;
import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresFactory;
import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresOptimizer;
import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
import org.hipparchus.optim.nonlinear.vector.leastsquares.LevenbergMarquardtOptimizer;
import org.hipparchus.optim.nonlinear.vector.leastsquares.MultivariateJacobianFunction;
import org.hipparchus.util.FastMath;
import org.orekit.errors.OrekitException;
import org.orekit.errors.OrekitMessages;
import org.orekit.frames.Frame;
import org.orekit.propagation.Propagator;
import org.orekit.propagation.SpacecraftState;
import org.orekit.propagation.integration.AbstractIntegratedPropagator;
import org.orekit.time.AbsoluteDate;
import org.orekit.utils.PVCoordinates;
import org.orekit.utils.ParameterDriver;

/** Common handling of {@link PropagatorConverter} methods for propagators conversions.
 * <p>
 * This abstract class factors the common code for propagators conversion.
 * Only one method must be implemented by derived classes: {@link #getObjectiveFunction()}.
 * </p>
 * <p>
 * The converter uses the LevenbergMarquardtOptimizer from the <a
 * href="https://hipparchus.org/">Hipparchus</a> library.
 * Different implementations correspond to different methods for computing the Jacobian.
 * </p>
 * @author Pascal Parraud
 * @since 6.0
 */
public abstract class AbstractPropagatorConverter implements PropagatorConverter {

    /** Spacecraft states sample. */
    private List<SpacecraftState> sample;

    /** Target position and velocities at sample points. */
    private double[] target;

    /** Weight for residuals. */
    private double[] weight;

    /** Auxiliary outputData: RMS of solution. */
    private double rms;

    /** Position use indicator. */
    private boolean onlyPosition;

    /** Adapted propagator. */
    private Propagator adapted;

    /** Propagator builder. */
    private final PropagatorBuilder builder;

    /** Frame. */
    private final Frame frame;

    /** Optimizer for fitting. */
    private final LevenbergMarquardtOptimizer optimizer;

    /** Optimum found. */
    private LeastSquaresOptimizer.Optimum optimum;

    /** Convergence checker for optimization algorithm. */
    private final ConvergenceChecker<LeastSquaresProblem.Evaluation> checker;

    /** Maximum number of iterations for optimization. */
    private final int maxIterations;

    /** Build a new instance.
     * @param builder propagator builder
     * @param threshold absolute convergence threshold for optimization algorithm
     * @param maxIterations maximum number of iterations for fitting
     */
    protected AbstractPropagatorConverter(final PropagatorBuilder builder,
                                          final double threshold,
                                          final int maxIterations) {
        this.builder       = builder;
        this.frame         = builder.getFrame();
        this.optimizer     = new LevenbergMarquardtOptimizer();
        this.maxIterations = maxIterations;
        this.sample        = new ArrayList<SpacecraftState>();

        final SimpleVectorValueChecker svvc = new SimpleVectorValueChecker(-1.0, threshold);
        this.checker = LeastSquaresFactory.evaluationChecker(svvc);

    }

    /** Convert a propagator to another.
     * @param source initial propagator (the propagator will be used for sample
     * generation, if it is a numerical propagator, its initial state will
     * be reset unless {@link AbstractIntegratedPropagator#setResetAtEnd(boolean)}
     * has been called beforehand)
     * @param timeSpan time span for fitting
     * @param nbPoints number of fitting points over time span
     * @param freeParameters names of the free parameters
     * @return adapted propagator
          * @exception IllegalArgumentException if one of the parameters cannot be free
     */
    public Propagator convert(final Propagator source,
                              final double timeSpan,
                              final int nbPoints,
                              final List<String> freeParameters)
        throws IllegalArgumentException {
        setFreeParameters(freeParameters);
        final List<SpacecraftState> states = createSample(source, timeSpan, nbPoints);
        return convert(states, false, freeParameters);
    }

    /** Convert a propagator to another.
     * @param source initial propagator (the propagator will be used for sample
     * generation, if it is a numerical propagator, its initial state will
     * be reset unless {@link AbstractIntegratedPropagator#setResetAtEnd(boolean)}
     * has been called beforehand)
     * @param timeSpan time span for fitting
     * @param nbPoints number of fitting points over time span
     * @param freeParameters names of the free parameters
     * @return adapted propagator
          * @exception IllegalArgumentException if one of the parameters cannot be free
     */
    public Propagator convert(final Propagator source,
                              final double timeSpan,
                              final int nbPoints,
                              final String... freeParameters)
        throws IllegalArgumentException {
        setFreeParameters(Arrays.asList(freeParameters));
        final List<SpacecraftState> states = createSample(source, timeSpan, nbPoints);
        return convert(states, false, freeParameters);
    }

    /** Find the propagator that minimize the mean square error for a sample of {@link SpacecraftState states}.
     * @param states spacecraft states sample to fit
     * @param positionOnly if true, consider only position data otherwise both position and velocity are used
     * @param freeParameters names of the free parameters
     * @return adapted propagator
          * @exception IllegalArgumentException if one of the parameters cannot be free
     */
    public Propagator convert(final List<SpacecraftState> states,
                              final boolean positionOnly,
                              final List<String> freeParameters)
        throws IllegalArgumentException {
        setFreeParameters(freeParameters);
        return adapt(states, positionOnly);
    }

    /** Find the propagator that minimize the mean square error for a sample of {@link SpacecraftState states}.
     * @param states spacecraft states sample to fit
     * @param positionOnly if true, consider only position data otherwise both position and velocity are used
     * @param freeParameters names of the free parameters
     * @return adapted propagator
          * @exception IllegalArgumentException if one of the parameters cannot be free
     */
    public Propagator convert(final List<SpacecraftState> states,
                              final boolean positionOnly,
                              final String... freeParameters)
        throws IllegalArgumentException {
        setFreeParameters(Arrays.asList(freeParameters));
        return adapt(states, positionOnly);
    }

    /** Get the adapted propagator.
     * @return adapted propagator
     */
    public Propagator getAdaptedPropagator() {
        return adapted;
    }

    /** Get the Root Mean Square Deviation of the fitting.
     * @return RMSD
     */
    public double getRMS() {
        return rms;
    }

    /** Get the number of objective function evaluations.
     *  @return the number of objective function evaluations.
     */
    public int getEvaluations() {
        return optimum.getEvaluations();
    }

    /** Get the function computing position/velocity at sample points.
     * @return function computing position/velocity at sample points
     */
    protected abstract MultivariateVectorFunction getObjectiveFunction();

    /** Get the Jacobian of the function computing position/velocity at sample points.
     * @return Jacobian of the function computing position/velocity at sample points
     */
    protected abstract MultivariateJacobianFunction getModel();

    /** Check if fitting uses only sample positions.
     * @return true if fitting uses only sample positions
     */
    protected boolean isOnlyPosition() {
        return onlyPosition;
    }

    /** Get the size of the target.
     * @return target size
     */
    protected int getTargetSize() {
        return target.length;
    }

    /** Get the frame of the initial state.
     * @return the orbit frame
     */
    protected Frame getFrame() {
        return frame;
    }

    /** Get the states sample.
     * @return the states sample
     */
    protected List<SpacecraftState> getSample() {
        return sample;
    }

    /** Create a sample of {@link SpacecraftState}.
     * @param source initial propagator
     * @param timeSpan time span for the sample
     * @param nbPoints number of points for the sample over the time span
     * @return a sample of {@link SpacecraftState}
     */
    private List<SpacecraftState> createSample(final Propagator source,
                                               final double timeSpan,
                                               final int nbPoints) {

        final List<SpacecraftState> states = new ArrayList<SpacecraftState>();

        final double stepSize = timeSpan / (nbPoints - 1);
        final AbsoluteDate iniDate = source.getInitialState().getDate();
        for (double dt = 0; dt < timeSpan; dt += stepSize) {
            states.add(source.propagate(iniDate.shiftedBy(dt)));
        }

        return states;
    }

    /** Free some parameters.
     * @param freeParameters names of the free parameters
     */
    private void setFreeParameters(final Iterable<String> freeParameters) {

        // start by setting all parameters as not estimated
        for (final ParameterDriver driver : builder.getPropagationParametersDrivers().getDrivers()) {
            driver.setSelected(false);
        }

        // set only the selected parameters as estimated
        for (final String parameter : freeParameters) {
            boolean found = false;
            for (final ParameterDriver driver : builder.getPropagationParametersDrivers().getDrivers()) {
                if (driver.getName().equals(parameter)) {
                    found = true;
                    driver.setSelected(true);
                    break;
                }
            }
            if (!found) {
                // build the list of supported parameters
                final StringBuilder sBuilder = new StringBuilder();
                for (final ParameterDriver driver : builder.getPropagationParametersDrivers().getDrivers()) {
                    if (sBuilder.length() > 0) {
                        sBuilder.append(", ");
                    }
                    sBuilder.append(driver.getName());
                }
                throw new OrekitException(OrekitMessages.UNSUPPORTED_PARAMETER_NAME,
                                          parameter, sBuilder.toString());
            }
        }
    }

    /** Adapt a propagator to minimize the mean square error for a set of {@link SpacecraftState states}.
     * @param states set of spacecraft states to fit
     * @param positionOnly if true, consider only position data otherwise both position and velocity are used
     * @return adapted propagator
     */
    private Propagator adapt(final List<SpacecraftState> states,
                             final boolean positionOnly) {

        this.onlyPosition = positionOnly;

        // very rough first guess using osculating parameters of first sample point
        final double[] initial = builder.getSelectedNormalizedParameters();

        // warm-up iterations, using only a few points
        setSample(states.subList(0, onlyPosition ? 2 : 1));
        final double[] intermediate = fit(initial);

        // final search using all points
        setSample(states);
        final double[] result = fit(intermediate);

        rms = getRMS(result);
        adapted = buildAdaptedPropagator(result);

        return adapted;
    }

    /** Find the propagator that minimize the mean square error for a sample of {@link SpacecraftState states}.
     * @param initial initial estimation parameters (position, velocity, free parameters)
     * @return fitted parameters
          * @exception MathRuntimeException if maximal number of iterations is exceeded
     */
    private double[] fit(final double[] initial)
        throws MathRuntimeException {

        final LeastSquaresProblem problem = new LeastSquaresBuilder().
                                            maxIterations(maxIterations).
                                            maxEvaluations(Integer.MAX_VALUE).
                                            model(getModel()).
                                            target(target).
                                            weight(new DiagonalMatrix(weight)).
                                            start(initial).
                                            checker(checker).
                                            build();

        optimum = optimizer.optimize(problem);
        return optimum.getPoint().toArray();

    }

    /** Get the Root Mean Square Deviation for a given parameters set.
     * @param parameterSet position/velocity parameters set
     * @return RMSD
     */
    private double getRMS(final double[] parameterSet) {
        final double[] residuals = getObjectiveFunction().value(parameterSet);
        for (int i = 0; i < residuals.length; ++i) {
            residuals[i] = target[i] - residuals[i];
        }
        double sum2 = 0;
        for (final double residual : residuals) {
            sum2 += residual * residual;
        }
        return FastMath.sqrt(sum2 / residuals.length);
    }

    /** Build the adpated propagator for a given position/velocity(/free) parameters set.
     * @param parameterSet position/velocity(/free) parameters set
     * @return adapted propagator
     */
    private Propagator buildAdaptedPropagator(final double[] parameterSet) {
        return builder.buildPropagator(parameterSet);
    }

    /** Set the states sample.
     * @param states spacecraft states sample
     */
    private void setSample(final List<SpacecraftState> states) {

        this.sample = states;

        if (onlyPosition) {
            target = new double[states.size() * 3];
            weight = new double[states.size() * 3];
        } else {
            target = new double[states.size() * 6];
            weight = new double[states.size() * 6];
        }

        int k = 0;
        for (int i = 0; i < states.size(); i++) {

            final PVCoordinates pv = states.get(i).getPVCoordinates(frame);

            // position
            target[k]   = pv.getPosition().getX();
            weight[k++] = 1;
            target[k]   = pv.getPosition().getY();
            weight[k++] = 1;
            target[k]   = pv.getPosition().getZ();
            weight[k++] = 1;

            // velocity
            if (!onlyPosition) {
                // velocity weight relative to position
                final double r2 = pv.getPosition().getNormSq();
                final double v  = pv.getVelocity().getNorm();
                final double vWeight = v * r2 / states.get(i).getMu();

                target[k]   = pv.getVelocity().getX();
                weight[k++] = vWeight;
                target[k]   = pv.getVelocity().getY();
                weight[k++] = vWeight;
                target[k]   = pv.getVelocity().getZ();
                weight[k++] = vWeight;
            }

        }

    }

}