1   /* Copyright 2022-2025 Romain Serra
2    * Licensed to CS GROUP (CS) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * CS licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *   http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.orekit.control.indirect.adjoint;
18  
19  import org.hipparchus.geometry.euclidean.threed.FieldVector3D;
20  import org.hipparchus.geometry.euclidean.threed.Vector3D;
21  import org.hipparchus.ode.nonstiff.ClassicalRungeKuttaFieldIntegrator;
22  import org.hipparchus.util.Binary64;
23  import org.hipparchus.util.Binary64Field;
24  import org.hipparchus.util.MathArrays;
25  import org.junit.jupiter.api.Assertions;
26  import org.junit.jupiter.api.Test;
27  import org.junit.jupiter.params.ParameterizedTest;
28  import org.junit.jupiter.params.provider.ValueSource;
29  import org.mockito.Mockito;
30  import org.orekit.control.indirect.adjoint.cost.FieldUnboundedCartesianEnergyNeglectingMass;
31  import org.orekit.control.indirect.adjoint.cost.TestFieldCost;
32  import org.orekit.errors.OrekitException;
33  import org.orekit.frames.FramesFactory;
34  import org.orekit.orbits.*;
35  import org.orekit.propagation.FieldSpacecraftState;
36  import org.orekit.propagation.SpacecraftState;
37  import org.orekit.propagation.integration.FieldCombinedDerivatives;
38  import org.orekit.propagation.numerical.FieldNumericalPropagator;
39  import org.orekit.time.AbsoluteDate;
40  import org.orekit.utils.Constants;
41  import org.orekit.utils.PVCoordinates;
42  
43  class FieldCartesianAdjointDerivativesProviderTest {
44  
45      @Test
46      @SuppressWarnings("unchecked")
47      void testInitException() {
48          // GIVEN
49          final String name = "name";
50          final double mu = Constants.EGM96_EARTH_MU;
51          final FieldCartesianAdjointDerivativesProvider<Binary64> derivativesProvider = new FieldCartesianAdjointDerivativesProvider<>(
52                  new FieldUnboundedCartesianEnergyNeglectingMass<>(name, Binary64Field.getInstance()), new CartesianAdjointKeplerianTerm(mu));
53          final FieldSpacecraftState<Binary64> mockedState = Mockito.mock(FieldSpacecraftState.class);
54          Mockito.when(mockedState.isOrbitDefined()).thenReturn(true);
55          final FieldOrbit<Binary64> mockedOrbit = Mockito.mock(FieldOrbit.class);
56          Mockito.when(mockedOrbit.getType()).thenReturn(OrbitType.EQUINOCTIAL);
57          Mockito.when(mockedState.getOrbit()).thenReturn(mockedOrbit);
58          // WHEN
59          Assertions.assertThrows(OrekitException.class, () -> derivativesProvider.init(mockedState, null));
60      }
61  
62      @Test
63      void testIntegration() {
64          // GIVEN
65          final String name = "name";
66          final double mu = Constants.EGM96_EARTH_MU;
67          final Binary64Field field = Binary64Field.getInstance();
68          final FieldCartesianAdjointDerivativesProvider<Binary64> derivativesProvider = new FieldCartesianAdjointDerivativesProvider<>(
69                  new FieldUnboundedCartesianEnergyNeglectingMass<>(name, field), new CartesianAdjointKeplerianTerm(mu));
70          final ClassicalRungeKuttaFieldIntegrator<Binary64> integrator = new ClassicalRungeKuttaFieldIntegrator<>(field,
71                  Binary64.ONE.multiply(100.));
72          final FieldNumericalPropagator<Binary64> propagator = new FieldNumericalPropagator<>(field, integrator);
73          final Orbit orbit = new CartesianOrbit(new PVCoordinates(new Vector3D(7e6, 1e3, 0), new Vector3D(10., 7e3, -200)),
74                  FramesFactory.getGCRF(), AbsoluteDate.ARBITRARY_EPOCH, mu);
75          final FieldSpacecraftState<Binary64> initialState = new FieldSpacecraftState<>(field, new SpacecraftState(orbit));
76          propagator.setOrbitType(OrbitType.CARTESIAN);
77          propagator.setInitialState(initialState.addAdditionalData(name, MathArrays.buildArray(field, 6)));
78          propagator.addAdditionalDerivativesProvider(derivativesProvider);
79          // WHEN
80          final FieldSpacecraftState<Binary64> terminalState = propagator.propagate(initialState.getDate().shiftedBy(1000.));
81          // THEN
82          Assertions.assertTrue(propagator.isAdditionalDataManaged(name));
83          final Binary64[] adjoint = terminalState.getAdditionalState(name);
84          Assertions.assertEquals(0., adjoint[0].getReal());
85          Assertions.assertEquals(0., adjoint[1].getReal());
86          Assertions.assertEquals(0., adjoint[2].getReal());
87          Assertions.assertEquals(0., adjoint[3].getReal());
88          Assertions.assertEquals(0., adjoint[4].getReal());
89          Assertions.assertEquals(0., adjoint[5].getReal());
90      }
91  
92      @ParameterizedTest
93      @ValueSource(booleans = {true, false})
94      void testEvaluateHamiltonian(final boolean withMassAdjoint) {
95          // GIVEN
96          final TestFieldCost cost = new TestFieldCost();
97          final double mu = 1e-3;
98          final FieldCartesianAdjointDerivativesProvider<Binary64> derivativesProvider = new FieldCartesianAdjointDerivativesProvider<>(cost,
99                  new CartesianAdjointKeplerianTerm(mu));
100         final FieldSpacecraftState<Binary64> state = getState(derivativesProvider.getName(), withMassAdjoint);
101         // WHEN
102         final Binary64 hamiltonian = derivativesProvider.evaluateHamiltonian(state);
103         // THEN
104         final FieldVector3D<Binary64> velocity = state.getPVCoordinates().getVelocity();
105         final FieldVector3D<Binary64> vector = new FieldVector3D<>(Binary64.ONE, Binary64.ONE, Binary64.ONE);
106         Assertions.assertEquals(velocity.dotProduct(vector).add(mu), hamiltonian);
107     }
108 
109     @Test
110     void testCombinedDerivatives() {
111         // GIVEN
112         final TestFieldCost cost = new TestFieldCost();
113         final FieldCartesianAdjointDerivativesProvider<Binary64> derivativesProvider = new FieldCartesianAdjointDerivativesProvider<>(
114                 cost);
115         final FieldSpacecraftState<Binary64> state = getState(derivativesProvider.getName(), false);
116         // WHEN
117         final FieldCombinedDerivatives<Binary64> combinedDerivatives = derivativesProvider.combinedDerivatives(state);
118         // THEN
119         final Binary64[] increment = combinedDerivatives.getMainStateDerivativesIncrements();
120         for (int i = 0; i < 3; i++) {
121             Assertions.assertEquals(0., increment[i].getReal());
122         }
123         Assertions.assertEquals(1., increment[3].getReal());
124         Assertions.assertEquals(2., increment[4].getReal());
125         Assertions.assertEquals(3., increment[5].getReal());
126         Assertions.assertEquals(-10. * state.getMass().getReal() * new Vector3D(1., 2., 3).getNorm(),
127                 increment[6].getReal(), 1e-10);
128     }
129 
130     private static FieldSpacecraftState<Binary64> getState(final String name, final boolean withMassAdjoint) {
131         final Orbit orbit = new CartesianOrbit(new PVCoordinates(Vector3D.MINUS_I, Vector3D.PLUS_K),
132                 FramesFactory.getGCRF(), AbsoluteDate.ARBITRARY_EPOCH, 1.);
133         final FieldSpacecraftState<Binary64> stateWithoutAdditional = new FieldSpacecraftState<>(new FieldCartesianOrbit<>(Binary64Field.getInstance(), orbit));
134         final Binary64[] adjoint = MathArrays.buildArray(stateWithoutAdditional.getDate().getField(),
135                 withMassAdjoint ? 7 : 6);
136         for (int i = 0; i < 6; i++) {
137             adjoint[i] = Binary64.ONE;
138         }
139         return stateWithoutAdditional.addAdditionalData(name, adjoint);
140     }
141 }