1   /* Copyright 2022-2025 Romain Serra
2    * Licensed to CS GROUP (CS) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * CS licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *   http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.orekit.control.indirect.adjoint;
18  
19  import org.hipparchus.CalculusFieldElement;
20  import org.hipparchus.geometry.euclidean.threed.Vector3D;
21  import org.hipparchus.ode.nonstiff.ClassicalRungeKuttaIntegrator;
22  import org.junit.jupiter.api.Assertions;
23  import org.junit.jupiter.api.Test;
24  import org.junit.jupiter.params.ParameterizedTest;
25  import org.junit.jupiter.params.provider.ValueSource;
26  import org.mockito.Mockito;
27  import org.orekit.control.indirect.adjoint.cost.CartesianCost;
28  import org.orekit.control.indirect.adjoint.cost.TestCost;
29  import org.orekit.control.indirect.adjoint.cost.UnboundedCartesianEnergyNeglectingMass;
30  import org.orekit.errors.OrekitException;
31  import org.orekit.errors.OrekitMessages;
32  import org.orekit.frames.Frame;
33  import org.orekit.frames.FramesFactory;
34  import org.orekit.orbits.CartesianOrbit;
35  import org.orekit.orbits.Orbit;
36  import org.orekit.orbits.OrbitType;
37  import org.orekit.propagation.SpacecraftState;
38  import org.orekit.propagation.integration.CombinedDerivatives;
39  import org.orekit.propagation.numerical.NumericalPropagator;
40  import org.orekit.time.AbsoluteDate;
41  import org.orekit.time.FieldAbsoluteDate;
42  import org.orekit.utils.Constants;
43  import org.orekit.utils.PVCoordinates;
44  
45  class CartesianAdjointDerivativesProviderTest {
46  
47      @Test
48      void testInitException() {
49          // GIVEN
50          final String name = "name";
51          final double mu = Constants.EGM96_EARTH_MU;
52          final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(
53                  new UnboundedCartesianEnergyNeglectingMass(name), new CartesianAdjointKeplerianTerm(mu));
54          final SpacecraftState mockedState = Mockito.mock(SpacecraftState.class);
55          Mockito.when(mockedState.isOrbitDefined()).thenReturn(true);
56          final Orbit mockedOrbit = Mockito.mock(Orbit.class);
57          Mockito.when(mockedOrbit.getType()).thenReturn(OrbitType.EQUINOCTIAL);
58          Mockito.when(mockedState.getOrbit()).thenReturn(mockedOrbit);
59          // WHEN
60          final Exception exception = Assertions.assertThrows(OrekitException.class,
61                  () -> derivativesProvider.init(mockedState, null));
62          Assertions.assertEquals(OrekitMessages.WRONG_COORDINATES_FOR_ADJOINT_EQUATION.getSourceString(),
63                  exception.getMessage());
64      }
65  
66      @Test
67      void testIntegration() {
68          // GIVEN
69          final String name = "name";
70          final double mu = Constants.EGM96_EARTH_MU;
71          final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(
72                  new UnboundedCartesianEnergyNeglectingMass(name), new CartesianAdjointKeplerianTerm(mu));
73          final NumericalPropagator propagator = new NumericalPropagator(new ClassicalRungeKuttaIntegrator(100.));
74          final Orbit orbit = new CartesianOrbit(new PVCoordinates(new Vector3D(7e6, 1e3, 0), new Vector3D(10., 7e3, -200)),
75                  FramesFactory.getGCRF(), AbsoluteDate.ARBITRARY_EPOCH, mu);
76          propagator.setOrbitType(OrbitType.CARTESIAN);
77          propagator.setInitialState(new SpacecraftState(orbit).addAdditionalData(name, new double[6]));
78          propagator.addAdditionalDerivativesProvider(derivativesProvider);
79          // WHEN
80          final SpacecraftState state = propagator.propagate(orbit.getDate().shiftedBy(1000.));
81          // THEN
82          Assertions.assertTrue(propagator.isAdditionalDataManaged(name));
83          final double[] finalAdjoint = state.getAdditionalState(name);
84          Assertions.assertEquals(0, finalAdjoint[0]);
85          Assertions.assertEquals(0, finalAdjoint[1]);
86          Assertions.assertEquals(0, finalAdjoint[2]);
87          Assertions.assertEquals(0, finalAdjoint[3]);
88          Assertions.assertEquals(0, finalAdjoint[4]);
89          Assertions.assertEquals(0, finalAdjoint[5]);
90      }
91  
92      @Test
93      void testGetCost() {
94          // GIVEN
95          final CartesianCost expectedCost = Mockito.mock(CartesianCost.class);
96          final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(expectedCost);
97          // WHEN
98          final CartesianCost actualCost = derivativesProvider.getCost();
99          // THEN
100         Assertions.assertEquals(expectedCost, actualCost);
101     }
102 
103     @ParameterizedTest
104     @ValueSource(booleans = {true, false})
105     void testEvaluateHamiltonian(final boolean withMassAdjoint) {
106         // GIVEN
107         final CartesianCost cost = new TestCost();
108         final SpacecraftState state = getState(cost.getAdjointName(), withMassAdjoint);
109         final CartesianAdjointEquationTerm mockedTerm = Mockito.mock(CartesianAdjointEquationTerm.class);
110         final double[] cartesian = new double[6];
111         OrbitType.CARTESIAN.mapOrbitToArray(state.getOrbit(), null, cartesian, null);
112         Mockito.when(mockedTerm.getHamiltonianContribution(state.getDate(), cartesian,
113                         state.getAdditionalState(cost.getAdjointName()), state.getFrame())).thenReturn(0.);
114         final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(cost,
115                 mockedTerm);
116         // WHEN
117         final double hamiltonian = derivativesProvider.evaluateHamiltonian(state);
118         // THEN
119         final Vector3D velocity = state.getPVCoordinates().getVelocity();
120         Assertions.assertEquals(velocity.dotProduct(new Vector3D(1, 1, 1)), hamiltonian);
121     }
122 
123     @Test
124     void testCombinedDerivatives() {
125         // GIVEN
126         final CartesianCost cost = new TestCost();
127         final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(cost);
128         final SpacecraftState state = getState(derivativesProvider.getName(), false);
129         // WHEN
130         final CombinedDerivatives combinedDerivatives = derivativesProvider.combinedDerivatives(state);
131         // THEN
132         final double[] increment = combinedDerivatives.getMainStateDerivativesIncrements();
133         for (int i = 0; i < 3; i++) {
134             Assertions.assertEquals(0., increment[i]);
135         }
136         Assertions.assertEquals(1., increment[3]);
137         Assertions.assertEquals(2., increment[4]);
138         Assertions.assertEquals(3., increment[5]);
139         Assertions.assertEquals(-10. * state.getMass() * new Vector3D(1., 2., 3).getNorm(), increment[6], 1e-10);
140     }
141 
142     @Test
143     void testCombinedDerivativesWithEquationTerm() {
144         // GIVEN
145         final CartesianCost cost = new TestCost();
146         final CartesianAdjointEquationTerm equationTerm = new TestAdjointTerm();
147         final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(cost, equationTerm);
148         final SpacecraftState state = getState(derivativesProvider.getName(), false);
149         // WHEN
150         final CombinedDerivatives combinedDerivatives = derivativesProvider.combinedDerivatives(state);
151         // THEN
152         final double[] adjointDerivatives = combinedDerivatives.getAdditionalDerivatives();
153         Assertions.assertEquals(1., adjointDerivatives[0]);
154         Assertions.assertEquals(10., adjointDerivatives[1]);
155         Assertions.assertEquals(100., adjointDerivatives[2]);
156         Assertions.assertEquals(-1, adjointDerivatives[3]);
157         Assertions.assertEquals(-1, adjointDerivatives[4]);
158         Assertions.assertEquals(-1, adjointDerivatives[5]);
159     }
160 
161     private static SpacecraftState getState(final String name, final boolean withMassAdjoint) {
162         final Orbit orbit = new CartesianOrbit(new PVCoordinates(Vector3D.MINUS_I, Vector3D.PLUS_K),
163                 FramesFactory.getGCRF(), AbsoluteDate.ARBITRARY_EPOCH, 1.);
164         final SpacecraftState stateWithoutAdditional = new SpacecraftState(orbit);
165         final double[] adjoint = withMassAdjoint ? new double[7] : new double[6];
166         for (int i = 0; i < 6; i++) {
167             adjoint[i] = 1;
168         }
169         return stateWithoutAdditional.addAdditionalData(name, adjoint);
170     }
171 
172     private static class TestAdjointTerm implements CartesianAdjointEquationTerm {
173 
174         @Override
175         public double[] getRatesContribution(AbsoluteDate date, double[] stateVariables, double[] adjointVariables, Frame frame) {
176             return new double[] { 1., 10., 100., 0., 0., 0. };
177         }
178 
179         @Override
180         public <T extends CalculusFieldElement<T>> T[] getFieldRatesContribution(FieldAbsoluteDate<T> date, T[] stateVariables, T[] adjointVariables, Frame frame) {
181             return null;
182         }
183 
184         @Override
185         public double getHamiltonianContribution(AbsoluteDate date, double[] stateVariables, double[] adjointVariables, Frame frame) {
186             return 0;
187         }
188 
189         @Override
190         public <T extends CalculusFieldElement<T>> T getFieldHamiltonianContribution(FieldAbsoluteDate<T> date, T[] stateVariables, T[] adjointVariables, Frame frame) {
191             return date.getField().getZero();
192         }
193     }
194 
195 }