1   /* Copyright 2022-2025 Romain Serra
2    * Licensed to CS GROUP (CS) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * CS licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *   http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.orekit.control.indirect.adjoint.cost;
18  
19  import org.hipparchus.geometry.euclidean.threed.FieldVector3D;
20  import org.hipparchus.geometry.euclidean.threed.Vector3D;
21  import org.hipparchus.util.Binary64;
22  import org.hipparchus.util.Binary64Field;
23  import org.hipparchus.util.FastMath;
24  import org.hipparchus.util.MathArrays;
25  import org.junit.jupiter.api.Assertions;
26  import org.junit.jupiter.api.Test;
27  import org.junit.jupiter.params.ParameterizedTest;
28  import org.junit.jupiter.params.provider.ValueSource;
29  import org.orekit.control.indirect.adjoint.CartesianAdjointDerivativesProvider;
30  import org.orekit.control.indirect.adjoint.FieldCartesianAdjointDerivativesProvider;
31  import org.orekit.frames.FramesFactory;
32  import org.orekit.orbits.CartesianOrbit;
33  import org.orekit.propagation.FieldSpacecraftState;
34  import org.orekit.propagation.SpacecraftState;
35  import org.orekit.propagation.events.FieldEventDetector;
36  import org.orekit.propagation.integration.CombinedDerivatives;
37  import org.orekit.propagation.integration.FieldAdditionalDerivativesProvider;
38  import org.orekit.propagation.integration.FieldCombinedDerivatives;
39  import org.orekit.time.AbsoluteDate;
40  import org.orekit.utils.PVCoordinates;
41  
42  import java.util.List;
43  import java.util.stream.Collectors;
44  
45  class FieldQuadraticPenaltyCartesianFuelTest {
46  
47      private static final String ADJOINT_NAME = "adjoint";
48  
49      @ParameterizedTest
50      @ValueSource(booleans = {false, true})
51      void testUpdateFieldAdjointDerivatives(final boolean withMass) {
52          // GIVEN
53          final Binary64 massFlowRateFactor = withMass ? Binary64.ONE : Binary64.ZERO;
54          final FieldQuadraticPenaltyCartesianFuel<Binary64> cost = new FieldQuadraticPenaltyCartesianFuel<>("adjoint", massFlowRateFactor, Binary64.PI, Binary64.ONE);
55          final Binary64[] adjoint = MathArrays.buildArray(Binary64Field.getInstance(), withMass ? 7 : 6);
56          final Binary64[] derivatives = adjoint.clone();
57          adjoint[3] = Binary64.ONE;
58          // WHEN
59          cost.updateFieldAdjointDerivatives(adjoint, Binary64.ONE, derivatives);
60          // THEN
61          final Binary64 zero = Binary64.ZERO;
62          for (int i = 0; i < 6; ++i) {
63              Assertions.assertEquals(zero, derivatives[i]);
64          }
65          if (withMass) {
66              Assertions.assertNotEquals(zero, derivatives[derivatives.length - 1]);
67          } else {
68              Assertions.assertEquals(zero, derivatives[derivatives.length - 1]);
69          }
70      }
71  
72      @ParameterizedTest
73      @ValueSource(doubles = {0, 0.1, 0.5, 0.9})
74      void testEvaluateFieldPenaltyFunction(final double norm) {
75          // GIVEN
76          final Binary64 unitMaximumThrust = Binary64.ONE;
77          final FieldQuadraticPenaltyCartesianFuel<Binary64> penalizedCartesianFuel = new FieldQuadraticPenaltyCartesianFuel<>(
78                  ADJOINT_NAME, Binary64.ONE, unitMaximumThrust, Binary64.ZERO);
79          // WHEN
80          final Binary64 actualPenalty = penalizedCartesianFuel.evaluateFieldPenaltyFunction(Binary64.ONE.newInstance(norm));
81          // THEN
82          Assertions.assertEquals(norm * norm / 2 - norm, actualPenalty.getReal(), 1e-15);
83      }
84  
85      @ParameterizedTest
86      @ValueSource(doubles = {1e-3, 1e-2, 0.5, 0.999})
87      void testGetFieldHamiltonianContribution(final double epsilon) {
88          // GIVEN
89          final FieldQuadraticPenaltyCartesianFuel<Binary64> fieldCost = new FieldQuadraticPenaltyCartesianFuel<>(
90                  ADJOINT_NAME, Binary64.ONE, Binary64.PI, new Binary64(epsilon));
91          final double[] adjoint = new double[] {1, 2, 3, 4, 5, 6, 7};
92          final Binary64[] fieldAdjoint = MathArrays.buildArray(Binary64Field.getInstance(), adjoint.length);
93          for (int i = 0; i < adjoint.length; i++) {
94              fieldAdjoint[i] = fieldCost.getEpsilon().newInstance(adjoint[i]);
95          }
96          final Binary64 mass = new Binary64(100);
97          // WHEN
98          final Binary64 actualPenalty = fieldCost.getFieldHamiltonianContribution(fieldAdjoint, mass);
99          // THEN
100         final QuadraticPenaltyCartesianFuel cost = fieldCost.toCartesianCost();
101         Assertions.assertEquals(cost.getHamiltonianContribution(adjoint, mass.getReal()), actualPenalty.getReal());
102     }
103 
104     @Test
105     void testGetFieldEventDetectors() {
106         // GIVEN
107         final FieldQuadraticPenaltyCartesianFuel<Binary64> penalizedCartesianFuel = new FieldQuadraticPenaltyCartesianFuel<>(
108                 ADJOINT_NAME, Binary64.ONE, Binary64.PI, new Binary64(0.5));
109         // WHEN
110         final List<FieldEventDetector<Binary64>> actualDetectors = penalizedCartesianFuel
111                 .getFieldEventDetectors(Binary64Field.getInstance()).collect(Collectors.toList());
112         // THEN
113         Assertions.assertEquals(2, actualDetectors.size());
114         final SpacecraftState state = buildState(10);
115         final FieldSpacecraftState<Binary64> fieldState = new FieldSpacecraftState<>(Binary64Field.getInstance(), state);
116         final Binary64 g1 = actualDetectors.get(0).g(fieldState);
117         final Binary64 g2 = actualDetectors.get(1).g(fieldState);
118         final Binary64 difference = FastMath.abs(g2.subtract(g1));
119         Assertions.assertEquals(0., penalizedCartesianFuel.getMaximumThrustMagnitude().subtract(difference).getReal(), 1e-12);
120     }
121 
122     @Test
123     void testGetThrustAccelerationVectorEpsilonCloseToZero() {
124         // GIVEN
125         final Binary64 massFlowRateFactor = new Binary64(1);
126         final Binary64 maximumThrustMagnitude = new Binary64(10);
127         final Binary64 epsilon = new Binary64(1e-6);
128         final FieldQuadraticPenaltyCartesianFuel<Binary64> penalizedCartesianFuel = new FieldQuadraticPenaltyCartesianFuel<>(ADJOINT_NAME,
129                 massFlowRateFactor, maximumThrustMagnitude, epsilon);
130         final Binary64 mass = new Binary64(100);
131         final double[] adjoint = new double[] {1, 2, 3, 4, 5, 6, 7};
132         final Binary64[] fieldAdjoint = MathArrays.buildArray(Binary64Field.getInstance(), adjoint.length);
133         for (int i = 0; i < adjoint.length; i++) {
134             fieldAdjoint[i] = epsilon.newInstance(adjoint[i]);
135         }
136         // WHEN
137         final FieldVector3D<Binary64> actualThrustVector = penalizedCartesianFuel.getFieldThrustAccelerationVector(fieldAdjoint, mass);
138         // THEN
139         final FieldCartesianFuelCost<Binary64> fuelCost = new FieldCartesianFuelCost<>(ADJOINT_NAME, massFlowRateFactor, maximumThrustMagnitude);
140         final FieldVector3D<Binary64> expectedThrustVector = fuelCost.getFieldThrustAccelerationVector(fieldAdjoint, mass);
141         Assertions.assertEquals(0., expectedThrustVector.subtract(actualThrustVector).toVector3D().getNorm(), 1e-10);
142     }
143 
144     @Test
145     void testGetThrustAccelerationVectorEpsilonEqualToOne() {
146         // GIVEN
147         final Binary64 massFlowRateFactor = new Binary64(1);
148         final Binary64 maximumThrustMagnitude = new Binary64(10);
149         final Binary64 epsilon = Binary64.ONE;
150         final FieldQuadraticPenaltyCartesianFuel<Binary64> penalizedCartesianFuel = new FieldQuadraticPenaltyCartesianFuel<>(ADJOINT_NAME,
151                 massFlowRateFactor, maximumThrustMagnitude, epsilon);
152         final Binary64 mass = new Binary64(100);
153         final double[] adjoint = new double[] {1, 2, 3, 4, 5, 6, 7};
154         final Binary64[] fieldAdjoint = MathArrays.buildArray(Binary64Field.getInstance(), adjoint.length);
155         for (int i = 0; i < adjoint.length; i++) {
156             fieldAdjoint[i] = epsilon.newInstance(adjoint[i]);
157         }
158         // WHEN
159         final FieldVector3D<Binary64> actualThrustVector = penalizedCartesianFuel.getFieldThrustAccelerationVector(fieldAdjoint, mass);
160         // THEN
161         final FieldBoundedCartesianEnergy<Binary64> fuelCost = new FieldBoundedCartesianEnergy<>(ADJOINT_NAME, massFlowRateFactor, maximumThrustMagnitude);
162         final FieldVector3D<Binary64> expectedThrustVector = fuelCost.getFieldThrustAccelerationVector(fieldAdjoint, mass);
163         Assertions.assertEquals(0., expectedThrustVector.subtract(actualThrustVector).toVector3D().getNorm());
164     }
165 
166     @Test
167     void testToCartesianCost() {
168         // GIVEN
169         final FieldQuadraticPenaltyCartesianFuel<Binary64> fieldCost = new FieldQuadraticPenaltyCartesianFuel<>(
170                 ADJOINT_NAME, Binary64.ONE, Binary64.PI, Binary64.ZERO);
171         // WHEN
172         final QuadraticPenaltyCartesianFuel cost = fieldCost.toCartesianCost();
173         // THEN
174         Assertions.assertEquals(fieldCost.getEpsilon().getReal(), cost.getEpsilon());
175         Assertions.assertEquals(fieldCost.getMaximumThrustMagnitude().getReal(), cost.getMaximumThrustMagnitude());
176         Assertions.assertEquals(fieldCost.getMassFlowRateFactor().getReal(), cost.getMassFlowRateFactor());
177         Assertions.assertEquals(fieldCost.getAdjointName(), cost.getAdjointName());
178     }
179 
180     @ParameterizedTest
181     @ValueSource(doubles = {1, 1e2, 1e4})
182     void testAgainstNonField(final double mass) {
183         // GIVEN
184         final double massFlowRateFactor = 1.e-2;
185         final double maximumThrustMagnitude = 1e-3;
186         final double epsilon = 0.5;
187         final Binary64 zero = Binary64.ZERO;
188         final FieldQuadraticPenaltyCartesianFuel<Binary64> fieldCost = new FieldQuadraticPenaltyCartesianFuel<>(ADJOINT_NAME,
189                 zero.newInstance(massFlowRateFactor), zero.newInstance(maximumThrustMagnitude), zero.newInstance(epsilon));
190         final SpacecraftState state = buildState(mass);
191         final FieldSpacecraftState<Binary64> fieldState = new FieldSpacecraftState<>(Binary64Field.getInstance(), state);
192         final FieldAdditionalDerivativesProvider<Binary64> derivativesProvider = new FieldCartesianAdjointDerivativesProvider<>(fieldCost);
193         // WHEN
194         final FieldCombinedDerivatives<Binary64> actualDerivatives = derivativesProvider.combinedDerivatives(fieldState);
195         // THEN
196         final QuadraticPenaltyCartesianFuel cost = new QuadraticPenaltyCartesianFuel(ADJOINT_NAME,
197                 massFlowRateFactor, maximumThrustMagnitude, epsilon);
198         final CombinedDerivatives expectedDerivatives = new CartesianAdjointDerivativesProvider(cost)
199                 .combinedDerivatives(state);
200         for (int i = 0; i < expectedDerivatives.getMainStateDerivativesIncrements().length; i++) {
201             Assertions.assertEquals(expectedDerivatives.getMainStateDerivativesIncrements()[i],
202                     actualDerivatives.getMainStateDerivativesIncrements()[i].getReal(), 1e-12);
203         }
204         for (int i = 0; i < expectedDerivatives.getAdditionalDerivatives().length; i++) {
205             Assertions.assertEquals(expectedDerivatives.getAdditionalDerivatives()[i],
206                     actualDerivatives.getAdditionalDerivatives()[i].getReal());
207         }
208     }
209 
210     private SpacecraftState buildState(final double mass) {
211         final double[] adjoint = new double[] {1, 2, 3, 4, 5, 6, 7};
212         final CartesianOrbit orbit = new CartesianOrbit(new PVCoordinates(Vector3D.MINUS_I, Vector3D.MINUS_K),
213                 FramesFactory.getEME2000(), AbsoluteDate.ARBITRARY_EPOCH, 1.);
214         return new SpacecraftState(orbit, mass).addAdditionalData(ADJOINT_NAME, adjoint);
215     }
216 }