1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.orekit.control.indirect.adjoint.cost;
18
19 import org.hipparchus.geometry.euclidean.threed.Vector3D;
20 import org.junit.jupiter.api.Assertions;
21 import org.junit.jupiter.api.Test;
22 import org.junit.jupiter.params.ParameterizedTest;
23 import org.junit.jupiter.params.provider.ValueSource;
24 import org.mockito.Mockito;
25 import org.orekit.TestUtils;
26 import org.orekit.propagation.SpacecraftState;
27 import org.orekit.propagation.integration.AdditionalDerivativesProvider;
28 import org.orekit.propagation.integration.CombinedDerivatives;
29 import org.orekit.time.AbsoluteDate;
30
31 class CartesianCostTest {
32
33 @Test
34 void getEventDetectorsTest() {
35
36 final TestCost cost = new TestCost();
37
38 Assertions.assertEquals(0., cost.getEventDetectors().count());
39 }
40
41 @ParameterizedTest
42 @ValueSource(booleans = {true, false})
43 void getCostDerivativeProviderTest(final boolean yields) {
44
45 final TestCost cost = new TestCost();
46 final String expectedName = "a";
47 final SpacecraftState mockedState = Mockito.mock();
48 final String adjointName = cost.getAdjointName();
49 Mockito.when(mockedState.hasAdditionalData(adjointName)).thenReturn(yields);
50
51 final AdditionalDerivativesProvider costDerivative = cost.getCostDerivativeProvider(expectedName);
52
53 Assertions.assertEquals(expectedName, costDerivative.getName());
54 Assertions.assertEquals(1, costDerivative.getDimension());
55 Assertions.assertNotEquals(yields, costDerivative.yields(mockedState));
56 }
57
58 @Test
59 void getCostDerivativeProviderCombinedDerivativesTest() {
60
61 final CartesianCost cost = new TestCost();
62 final String name = "a";
63 final double[] adjoint = new double[] {1, 2, 3, 4, 5, 6};
64 final SpacecraftState state = new SpacecraftState(TestUtils.getDefaultOrbit(AbsoluteDate.ARBITRARY_EPOCH))
65 .addAdditionalData(cost.getAdjointName(), adjoint);
66
67 final AdditionalDerivativesProvider costDerivative = cost.getCostDerivativeProvider(name);
68
69 final CombinedDerivatives combinedDerivatives = costDerivative.combinedDerivatives(state);
70 Assertions.assertNull(combinedDerivatives.getMainStateDerivativesIncrements());
71 Assertions.assertEquals(1, combinedDerivatives.getAdditionalDerivatives()[0]);
72 }
73
74 private static class TestCost implements CartesianCost {
75
76 @Override
77 public String getAdjointName() {
78 return "adjoint";
79 }
80
81 @Override
82 public int getAdjointDimension() {
83 return 6;
84 }
85
86 @Override
87 public double getMassFlowRateFactor() {
88 return 0;
89 }
90
91 @Override
92 public Vector3D getThrustAccelerationVector(final double[] adjointVariables, final double mass) {
93 return null;
94 }
95
96 @Override
97 public void updateAdjointDerivatives(final double[] adjointVariables, final double mass,
98 final double[] adjointDerivatives) {
99
100 }
101
102 @Override
103 public double getHamiltonianContribution(final double[] adjointVariables, final double mass) {
104 return -1;
105 }
106 }
107
108 }