1   /* Copyright 2022-2025 Romain Serra
2    * Licensed to CS GROUP (CS) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * CS licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *   http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.orekit.control.indirect.adjoint.cost;
18  
19  import org.hipparchus.complex.Complex;
20  import org.hipparchus.geometry.euclidean.threed.FieldVector3D;
21  import org.hipparchus.geometry.euclidean.threed.Vector3D;
22  import org.hipparchus.util.Binary64;
23  import org.hipparchus.util.Binary64Field;
24  import org.hipparchus.util.MathArrays;
25  import org.junit.jupiter.api.Assertions;
26  import org.junit.jupiter.api.Test;
27  import org.junit.jupiter.params.ParameterizedTest;
28  import org.junit.jupiter.params.provider.ValueSource;
29  import org.mockito.Mockito;
30  import org.orekit.frames.FramesFactory;
31  import org.orekit.orbits.CartesianOrbit;
32  import org.orekit.orbits.Orbit;
33  import org.orekit.propagation.FieldSpacecraftState;
34  import org.orekit.propagation.SpacecraftState;
35  import org.orekit.propagation.events.EventDetectionSettings;
36  import org.orekit.propagation.events.EventDetector;
37  import org.orekit.propagation.events.FieldEventDetectionSettings;
38  import org.orekit.propagation.events.FieldEventDetector;
39  import org.orekit.propagation.events.handlers.FieldResetDerivativesOnEvent;
40  import org.orekit.time.AbsoluteDate;
41  import org.orekit.utils.TimeStampedPVCoordinates;
42  
43  import java.util.List;
44  import java.util.stream.Collectors;
45  import java.util.stream.Stream;
46  
47  class FieldCartesianFuelCostTest {
48  
49      @Test
50      void testConstructor() {
51          // GIVEN
52          final EventDetectionSettings expectedDetectionSettings = EventDetectionSettings.getDefaultEventDetectionSettings();
53          // WHEN
54          final FieldCartesianFuelCost<Binary64> cartesianFuel = new FieldCartesianFuelCost<>("", new Binary64(1),
55                  new Binary64(2));
56          // THEN
57          final FieldEventDetectionSettings<Binary64> actualDetectionSettings = cartesianFuel.getEventDetectionSettings();
58          Assertions.assertEquals(expectedDetectionSettings.getMaxIterationCount(),
59                  actualDetectionSettings.getMaxIterationCount());
60          Assertions.assertEquals(expectedDetectionSettings.getThreshold(), actualDetectionSettings.getThreshold().getReal());
61      }
62  
63      @ParameterizedTest
64      @ValueSource(doubles = {-1e1, 1, 1e1})
65      void testUpdateFieldAdjointDerivatives(final double adjointMass) {
66          // GIVEN
67          final Binary64 rateFactor = Binary64.ONE;
68          final Binary64 maximumThrust = new Binary64(2);
69          final FieldCartesianFuelCost<Binary64> cartesianFuel = new FieldCartesianFuelCost<>("", rateFactor, maximumThrust);
70          final double[] adjoint = new double[] {0, 0, 0, 0, 1, 0, adjointMass};
71          final Binary64[] fieldAdjoint = MathArrays.buildArray(Binary64Field.getInstance(), adjoint.length);
72          for (int i = 0; i < adjoint.length; i++) {
73              fieldAdjoint[i] = new Binary64(adjoint[i]);
74          }
75          final Binary64 mass = new Binary64(100);
76          final Binary64[] fieldDerivatives = MathArrays.buildArray(Binary64Field.getInstance(), adjoint.length);
77          final double[] derivatives = new double[adjoint.length];
78          // WHEN
79          cartesianFuel.updateFieldAdjointDerivatives(fieldAdjoint, mass, fieldDerivatives);
80          // THEN
81          cartesianFuel.toCartesianCost().updateAdjointDerivatives(adjoint, mass.getReal(), derivatives);
82          Assertions.assertEquals(derivatives[6], fieldDerivatives[6].getReal());
83      }
84  
85      @ParameterizedTest
86      @ValueSource(doubles = {-1e1, 1, 1e1})
87      void testGetFieldThrustAccelerationVector(final double adjointMass) {
88          // GIVEN
89          final Binary64 rateFactor = Binary64.ONE;
90          final Binary64 maximumThrust = new Binary64(2);
91          final FieldCartesianFuelCost<Binary64> cartesianFuel = new FieldCartesianFuelCost<>("", rateFactor, maximumThrust);
92          final double[] adjoint = new double[] {0, 0, 0, 0, 1, 0, adjointMass};
93          final Binary64[] fieldAdjoint = MathArrays.buildArray(Binary64Field.getInstance(), adjoint.length);
94          for (int i = 0; i < adjoint.length; i++) {
95              fieldAdjoint[i] = new Binary64(adjoint[i]);
96          }
97          final Binary64 mass = new Binary64(100);
98          // WHEN
99          final FieldVector3D<Binary64> actual = cartesianFuel.getFieldThrustAccelerationVector(fieldAdjoint, mass);
100         // THEN
101         Assertions.assertEquals(cartesianFuel.toCartesianCost().getThrustAccelerationVector(adjoint, mass.getReal()), actual.toVector3D());
102     }
103 
104     @Test
105     void testGetFieldHamiltonianContribution() {
106         // GIVEN
107         final FieldCartesianFuelCost<Binary64> cartesianFuel = Mockito.mock(FieldCartesianFuelCost.class);
108         final Binary64Field field = Binary64Field.getInstance();
109         final Binary64[] adjoint = MathArrays.buildArray(field, 0);
110         final Binary64 mass = Binary64.ONE;
111         final FieldVector3D<Binary64> accelerationVector = new FieldVector3D<>(field, new Vector3D(1, 2, 3));
112         Mockito.when(cartesianFuel.getFieldThrustAccelerationVector(adjoint, mass)).thenReturn(accelerationVector);
113         Mockito.when(cartesianFuel.getFieldHamiltonianContribution(adjoint, mass)).thenCallRealMethod();
114         // WHEN
115         final Binary64 actual = cartesianFuel.getFieldHamiltonianContribution(adjoint, mass);
116         // THEN
117         Assertions.assertEquals(accelerationVector.scalarMultiply(mass).getNorm(), actual.negate());
118     }
119 
120     @ParameterizedTest
121     @ValueSource(doubles = {0., 1.})
122     void testGetFieldEventDetectors(final double massFlowRateFactor) {
123         // GIVEN
124         final String adjointName = "1";
125         final FieldCartesianFuelCost<Binary64> cartesianFuel = new FieldCartesianFuelCost<>(adjointName,
126                 new Binary64(massFlowRateFactor), new Binary64(2));
127         final Binary64Field field = Binary64Field.getInstance();
128         final double mass = 3;
129         // WHEN
130         final Stream<FieldEventDetector<Binary64>> detectorStream = cartesianFuel.getFieldEventDetectors(field);
131         // THEN
132         final List<FieldEventDetector<Binary64>> fieldDetectorList = detectorStream.collect(Collectors.toList());
133         final List<EventDetector> detectorList = cartesianFuel.toCartesianCost().getEventDetectors().collect(Collectors.toList());
134         Assertions.assertEquals(detectorList.size(), fieldDetectorList.size());
135         final EventDetector detector = detectorList.get(0);
136         final FieldEventDetector<Binary64> fieldEventDetector = fieldDetectorList.get(0);
137         Assertions.assertInstanceOf(FieldResetDerivativesOnEvent.class, fieldEventDetector.getHandler());
138         final FieldSpacecraftState<Binary64> fieldState = buildFieldState(mass, adjointName);
139         Assertions.assertEquals(detector.g(fieldState.toSpacecraftState()), fieldEventDetector.g(fieldState).getReal());
140         Assertions.assertEquals(detector.getDetectionSettings().getThreshold(),
141                 fieldEventDetector.getDetectionSettings().getThreshold().getReal());
142         Assertions.assertEquals(detector.getDetectionSettings().getMaxIterationCount(),
143                 fieldEventDetector.getDetectionSettings().getMaxIterationCount());
144     }
145 
146     private static FieldSpacecraftState<Binary64> buildFieldState(final double mass, final String adjointName) {
147         final Orbit orbit = new CartesianOrbit(new TimeStampedPVCoordinates(AbsoluteDate.ARBITRARY_EPOCH,
148                 new Vector3D(4, 5, 6), Vector3D.MINUS_K), FramesFactory.getEME2000(), 1);
149         final Binary64Field field = Binary64Field.getInstance();
150         final Binary64[] adjoint = MathArrays.buildArray(field, 7);
151         for (int i = 0; i < adjoint.length; i++) {
152             adjoint[i] = new Binary64(i + 1);
153         }
154         return new FieldSpacecraftState<>(field, new SpacecraftState(orbit, mass))
155                 .addAdditionalData(adjointName, adjoint);
156     }
157 
158     @Test
159     void testToCartesianCost() {
160         // GIVEN
161         final Complex massRateFactor = Complex.ONE;
162         final FieldCartesianFuelCost<Complex> fieldCartesianEnergy = new FieldCartesianFuelCost<>("",
163                 massRateFactor, new Complex(2));
164         // WHEN
165         final CartesianFuelCost cartesianEnergy = fieldCartesianEnergy.toCartesianCost();
166         // THEN
167         Assertions.assertEquals(cartesianEnergy.getAdjointName(), fieldCartesianEnergy.getAdjointName());
168         Assertions.assertEquals(cartesianEnergy.getMaximumThrustMagnitude(),
169                 fieldCartesianEnergy.getMaximumThrustMagnitude().getReal());
170     }
171 }
172