1   /* Copyright 2022-2025 Romain Serra
2    * Licensed to CS GROUP (CS) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * CS licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *   http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.orekit.control.indirect.adjoint.cost;
18  
19  import org.hipparchus.geometry.euclidean.threed.FieldVector3D;
20  import org.hipparchus.util.Binary64;
21  import org.hipparchus.util.Binary64Field;
22  import org.hipparchus.util.MathArrays;
23  import org.junit.jupiter.api.Assertions;
24  import org.junit.jupiter.api.Test;
25  import org.junit.jupiter.params.ParameterizedTest;
26  import org.junit.jupiter.params.provider.ValueSource;
27  import org.mockito.Mockito;
28  import org.orekit.TestUtils;
29  import org.orekit.propagation.FieldSpacecraftState;
30  import org.orekit.propagation.SpacecraftState;
31  import org.orekit.propagation.integration.FieldAdditionalDerivativesProvider;
32  import org.orekit.propagation.integration.FieldCombinedDerivatives;
33  import org.orekit.time.AbsoluteDate;
34  
35  class FieldCartesianCostTest {
36  
37      @Test
38      void getFieldEventDetectorsTest() {
39          // GIVEN
40          final TestFieldCost fieldCost = new TestFieldCost();
41          // WHEN & THEN
42          Assertions.assertEquals(0., fieldCost.getFieldEventDetectors(Binary64Field.getInstance()).count());
43      }
44  
45      @ParameterizedTest
46      @ValueSource(booleans = {true, false})
47      @SuppressWarnings("unchecked")
48      void getCostDerivativeProviderTest(final boolean yields) {
49          // GIVEN
50          final TestFieldCost fieldCost = new TestFieldCost();
51          final String expectedName = "a";
52          final FieldSpacecraftState<Binary64> mockedState = Mockito.mock();
53          final String adjointName = fieldCost.getAdjointName();
54          Mockito.when(mockedState.hasAdditionalData(adjointName)).thenReturn(yields);
55          // WHEN
56          final FieldAdditionalDerivativesProvider<Binary64> fieldCostDerivative = fieldCost.getCostDerivativeProvider(expectedName);
57          // THEN
58          Assertions.assertEquals(expectedName, fieldCostDerivative.getName());
59          Assertions.assertEquals(1, fieldCostDerivative.getDimension());
60          Assertions.assertNotEquals(yields, fieldCostDerivative.yields(mockedState));
61      }
62  
63      @Test
64      void getCostDerivativeProviderCombinedDerivativesTest() {
65          // GIVEN
66          final FieldCartesianCost<Binary64> cost = new TestCost();
67          final String expectedName = "a";
68          final Binary64Field field = Binary64Field.getInstance();
69          final Binary64[] adjoint = MathArrays.buildArray(field, 6);
70          for (int i = 0; i < adjoint.length; ++i) {
71              adjoint[i] = new Binary64(i);
72          }
73          final FieldSpacecraftState<Binary64> state = new FieldSpacecraftState<>(field,
74                  new SpacecraftState(TestUtils.getDefaultOrbit(AbsoluteDate.ARBITRARY_EPOCH)))
75                  .addAdditionalData(cost.getAdjointName(), adjoint);
76          final Binary64 expectedDerivative = Binary64.ONE;
77          // WHEN
78          final FieldAdditionalDerivativesProvider<Binary64> fieldCostDerivative = cost.getCostDerivativeProvider(expectedName);
79          // THEN
80          final FieldCombinedDerivatives<Binary64> fieldCombinedDerivatives = fieldCostDerivative.combinedDerivatives(state);
81          Assertions.assertNull(fieldCombinedDerivatives.getMainStateDerivativesIncrements());
82          Assertions.assertEquals(expectedDerivative, fieldCombinedDerivatives.getAdditionalDerivatives()[0]);
83      }
84  
85      private static class TestCost implements FieldCartesianCost<Binary64> {
86  
87          @Override
88          public String getAdjointName() {
89              return "adjoint";
90          }
91  
92          @Override
93          public int getAdjointDimension() {
94              return 6;
95          }
96  
97          @Override
98          public Binary64 getMassFlowRateFactor() {
99              return Binary64.ZERO;
100         }
101 
102         @Override
103         public FieldVector3D<Binary64> getFieldThrustAccelerationVector(final Binary64[] adjointVariables, final Binary64 mass) {
104             return null;
105         }
106 
107         @Override
108         public void updateFieldAdjointDerivatives(final Binary64[] adjointVariables, final Binary64 mass,
109                                                   final Binary64[] adjointDerivatives) {
110 
111         }
112 
113         @Override
114         public Binary64 getFieldHamiltonianContribution(final Binary64[] adjointVariables, final Binary64 mass) {
115             return new Binary64(-1);
116         }
117 
118         @Override
119         public CartesianCost toCartesianCost() {
120             return null;
121         }
122     }
123 
124 }