1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.orekit.control.indirect.adjoint;
18
19 import org.hipparchus.CalculusFieldElement;
20 import org.hipparchus.geometry.euclidean.threed.Vector3D;
21 import org.hipparchus.ode.nonstiff.ClassicalRungeKuttaIntegrator;
22 import org.junit.jupiter.api.Assertions;
23 import org.junit.jupiter.api.Test;
24 import org.junit.jupiter.params.ParameterizedTest;
25 import org.junit.jupiter.params.provider.ValueSource;
26 import org.mockito.Mockito;
27 import org.orekit.control.indirect.adjoint.cost.CartesianCost;
28 import org.orekit.control.indirect.adjoint.cost.TestCost;
29 import org.orekit.control.indirect.adjoint.cost.UnboundedCartesianEnergyNeglectingMass;
30 import org.orekit.errors.OrekitException;
31 import org.orekit.errors.OrekitMessages;
32 import org.orekit.frames.Frame;
33 import org.orekit.frames.FramesFactory;
34 import org.orekit.orbits.CartesianOrbit;
35 import org.orekit.orbits.Orbit;
36 import org.orekit.orbits.OrbitType;
37 import org.orekit.propagation.SpacecraftState;
38 import org.orekit.propagation.integration.CombinedDerivatives;
39 import org.orekit.propagation.numerical.NumericalPropagator;
40 import org.orekit.time.AbsoluteDate;
41 import org.orekit.time.FieldAbsoluteDate;
42 import org.orekit.utils.Constants;
43 import org.orekit.utils.PVCoordinates;
44
45 class CartesianAdjointDerivativesProviderTest {
46
47 @Test
48 void testInitException() {
49
50 final String name = "name";
51 final double mu = Constants.EGM96_EARTH_MU;
52 final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(
53 new UnboundedCartesianEnergyNeglectingMass(name), new CartesianAdjointKeplerianTerm(mu));
54 final SpacecraftState mockedState = Mockito.mock(SpacecraftState.class);
55 Mockito.when(mockedState.isOrbitDefined()).thenReturn(true);
56 final Orbit mockedOrbit = Mockito.mock(Orbit.class);
57 Mockito.when(mockedOrbit.getType()).thenReturn(OrbitType.EQUINOCTIAL);
58 Mockito.when(mockedState.getOrbit()).thenReturn(mockedOrbit);
59
60 final Exception exception = Assertions.assertThrows(OrekitException.class,
61 () -> derivativesProvider.init(mockedState, null));
62 Assertions.assertEquals(OrekitMessages.WRONG_COORDINATES_FOR_ADJOINT_EQUATION.getSourceString(),
63 exception.getMessage());
64 }
65
66 @Test
67 void testIntegration() {
68
69 final String name = "name";
70 final double mu = Constants.EGM96_EARTH_MU;
71 final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(
72 new UnboundedCartesianEnergyNeglectingMass(name), new CartesianAdjointKeplerianTerm(mu));
73 final NumericalPropagator propagator = new NumericalPropagator(new ClassicalRungeKuttaIntegrator(100.));
74 final Orbit orbit = new CartesianOrbit(new PVCoordinates(new Vector3D(7e6, 1e3, 0), new Vector3D(10., 7e3, -200)),
75 FramesFactory.getGCRF(), AbsoluteDate.ARBITRARY_EPOCH, mu);
76 propagator.setOrbitType(OrbitType.CARTESIAN);
77 propagator.setInitialState(new SpacecraftState(orbit).addAdditionalData(name, new double[6]));
78 propagator.addAdditionalDerivativesProvider(derivativesProvider);
79
80 final SpacecraftState state = propagator.propagate(orbit.getDate().shiftedBy(1000.));
81
82 Assertions.assertTrue(propagator.isAdditionalDataManaged(name));
83 final double[] finalAdjoint = state.getAdditionalState(name);
84 Assertions.assertEquals(0, finalAdjoint[0]);
85 Assertions.assertEquals(0, finalAdjoint[1]);
86 Assertions.assertEquals(0, finalAdjoint[2]);
87 Assertions.assertEquals(0, finalAdjoint[3]);
88 Assertions.assertEquals(0, finalAdjoint[4]);
89 Assertions.assertEquals(0, finalAdjoint[5]);
90 }
91
92 @Test
93 void testGetCost() {
94
95 final CartesianCost expectedCost = Mockito.mock(CartesianCost.class);
96 final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(expectedCost);
97
98 final CartesianCost actualCost = derivativesProvider.getCost();
99
100 Assertions.assertEquals(expectedCost, actualCost);
101 }
102
103 @ParameterizedTest
104 @ValueSource(booleans = {true, false})
105 void testEvaluateHamiltonian(final boolean withMassAdjoint) {
106
107 final CartesianCost cost = new TestCost();
108 final SpacecraftState state = getState(cost.getAdjointName(), withMassAdjoint);
109 final CartesianAdjointEquationTerm mockedTerm = Mockito.mock(CartesianAdjointEquationTerm.class);
110 final double[] cartesian = new double[6];
111 OrbitType.CARTESIAN.mapOrbitToArray(state.getOrbit(), null, cartesian, null);
112 Mockito.when(mockedTerm.getHamiltonianContribution(state.getDate(), cartesian,
113 state.getAdditionalState(cost.getAdjointName()), state.getFrame())).thenReturn(0.);
114 final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(cost,
115 mockedTerm);
116
117 final double hamiltonian = derivativesProvider.evaluateHamiltonian(state);
118
119 final Vector3D velocity = state.getPVCoordinates().getVelocity();
120 Assertions.assertEquals(velocity.dotProduct(new Vector3D(1, 1, 1)), hamiltonian);
121 }
122
123 @Test
124 void testCombinedDerivatives() {
125
126 final CartesianCost cost = new TestCost();
127 final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(cost);
128 final SpacecraftState state = getState(derivativesProvider.getName(), false);
129
130 final CombinedDerivatives combinedDerivatives = derivativesProvider.combinedDerivatives(state);
131
132 final double[] increment = combinedDerivatives.getMainStateDerivativesIncrements();
133 for (int i = 0; i < 3; i++) {
134 Assertions.assertEquals(0., increment[i]);
135 }
136 Assertions.assertEquals(1., increment[3]);
137 Assertions.assertEquals(2., increment[4]);
138 Assertions.assertEquals(3., increment[5]);
139 Assertions.assertEquals(-10. * state.getMass() * new Vector3D(1., 2., 3).getNorm(), increment[6], 1e-10);
140 }
141
142 @Test
143 void testCombinedDerivativesWithEquationTerm() {
144
145 final CartesianCost cost = new TestCost();
146 final CartesianAdjointEquationTerm equationTerm = new TestAdjointTerm();
147 final CartesianAdjointDerivativesProvider derivativesProvider = new CartesianAdjointDerivativesProvider(cost, equationTerm);
148 final SpacecraftState state = getState(derivativesProvider.getName(), false);
149
150 final CombinedDerivatives combinedDerivatives = derivativesProvider.combinedDerivatives(state);
151
152 final double[] adjointDerivatives = combinedDerivatives.getAdditionalDerivatives();
153 Assertions.assertEquals(1., adjointDerivatives[0]);
154 Assertions.assertEquals(10., adjointDerivatives[1]);
155 Assertions.assertEquals(100., adjointDerivatives[2]);
156 Assertions.assertEquals(-1, adjointDerivatives[3]);
157 Assertions.assertEquals(-1, adjointDerivatives[4]);
158 Assertions.assertEquals(-1, adjointDerivatives[5]);
159 }
160
161 private static SpacecraftState getState(final String name, final boolean withMassAdjoint) {
162 final Orbit orbit = new CartesianOrbit(new PVCoordinates(Vector3D.MINUS_I, Vector3D.PLUS_K),
163 FramesFactory.getGCRF(), AbsoluteDate.ARBITRARY_EPOCH, 1.);
164 final SpacecraftState stateWithoutAdditional = new SpacecraftState(orbit);
165 final double[] adjoint = withMassAdjoint ? new double[7] : new double[6];
166 for (int i = 0; i < 6; i++) {
167 adjoint[i] = 1;
168 }
169 return stateWithoutAdditional.addAdditionalData(name, adjoint);
170 }
171
172 private static class TestAdjointTerm implements CartesianAdjointEquationTerm {
173
174 @Override
175 public double[] getRatesContribution(AbsoluteDate date, double[] stateVariables, double[] adjointVariables, Frame frame) {
176 return new double[] { 1., 10., 100., 0., 0., 0. };
177 }
178
179 @Override
180 public <T extends CalculusFieldElement<T>> T[] getFieldRatesContribution(FieldAbsoluteDate<T> date, T[] stateVariables, T[] adjointVariables, Frame frame) {
181 return null;
182 }
183
184 @Override
185 public double getHamiltonianContribution(AbsoluteDate date, double[] stateVariables, double[] adjointVariables, Frame frame) {
186 return 0;
187 }
188
189 @Override
190 public <T extends CalculusFieldElement<T>> T getFieldHamiltonianContribution(FieldAbsoluteDate<T> date, T[] stateVariables, T[] adjointVariables, Frame frame) {
191 return date.getField().getZero();
192 }
193 }
194
195 }