1   /* Copyright 2022-2025 Romain Serra
2    * Licensed to CS GROUP (CS) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * CS licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *   http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.orekit.propagation.numerical;
18  
19  import org.hipparchus.CalculusFieldElement;
20  import org.hipparchus.Field;
21  import org.hipparchus.analysis.differentiation.Gradient;
22  import org.hipparchus.analysis.differentiation.GradientField;
23  import org.hipparchus.geometry.euclidean.threed.FieldVector3D;
24  import org.hipparchus.geometry.euclidean.threed.Rotation;
25  import org.hipparchus.geometry.euclidean.threed.Vector3D;
26  import org.hipparchus.linear.MatrixUtils;
27  import org.hipparchus.linear.QRDecomposition;
28  import org.hipparchus.linear.RealMatrix;
29  import org.hipparchus.ode.events.Action;
30  import org.junit.jupiter.api.Test;
31  import org.junit.jupiter.params.ParameterizedTest;
32  import org.junit.jupiter.params.provider.EnumSource;
33  import org.mockito.Mockito;
34  import org.orekit.attitudes.AttitudeProvider;
35  import org.orekit.attitudes.FrameAlignedProvider;
36  import org.orekit.forces.ForceModel;
37  import org.orekit.forces.ForceModelModifier;
38  import org.orekit.frames.FramesFactory;
39  import org.orekit.orbits.CartesianOrbit;
40  import org.orekit.orbits.OrbitType;
41  import org.orekit.propagation.FieldSpacecraftState;
42  import org.orekit.propagation.SpacecraftState;
43  import org.orekit.propagation.events.DateDetector;
44  import org.orekit.propagation.events.DetectorModifier;
45  import org.orekit.propagation.events.EventDetectionSettings;
46  import org.orekit.propagation.events.EventDetector;
47  import org.orekit.propagation.events.FieldDateDetector;
48  import org.orekit.propagation.events.FieldDetectorModifier;
49  import org.orekit.propagation.events.FieldEventDetectionSettings;
50  import org.orekit.propagation.events.FieldEventDetector;
51  import org.orekit.propagation.events.handlers.ContinueOnEvent;
52  import org.orekit.propagation.events.handlers.EventHandler;
53  import org.orekit.propagation.events.handlers.FieldContinueOnEvent;
54  import org.orekit.propagation.events.handlers.FieldEventHandler;
55  import org.orekit.propagation.events.handlers.ResetDerivativesOnEvent;
56  import org.orekit.time.AbsoluteDate;
57  import org.orekit.time.FieldAbsoluteDate;
58  import org.orekit.utils.AbsolutePVCoordinates;
59  import org.orekit.utils.Constants;
60  import org.orekit.utils.FieldAbsolutePVCoordinates;
61  import org.orekit.utils.FieldPVCoordinates;
62  import org.orekit.utils.PVCoordinates;
63  import org.orekit.utils.ParameterDriver;
64  import org.orekit.utils.TimeStampedPVCoordinates;
65  
66  import java.util.Collections;
67  import java.util.List;
68  import java.util.stream.Collectors;
69  import java.util.stream.Stream;
70  
71  import static org.junit.jupiter.api.Assertions.*;
72  import static org.mockito.Mockito.mock;
73  import static org.mockito.Mockito.when;
74  
75  class SwitchEventHandlerTest {
76  
77      private static final String STM_NAME = "stm";
78  
79      @ParameterizedTest
80      @EnumSource(value = Action.class)
81      void testEventOccurred(final Action action) {
82          // GIVEN
83          final AbsoluteDate date = AbsoluteDate.ARBITRARY_EPOCH;
84          final NumericalPropagationHarvester harvester = mockHarvester();
85          final EventHandler handler = (s, e, i) -> action;
86          final SwitchEventHandler switchEventHandler = buildSwitchEventHandler(mock(ForceModel.class), harvester, handler);
87          final SpacecraftState stateAtSwitch = buildAbsoluteState(date, new double[0]);
88          // WHEN
89          final Action wrappedAction = switchEventHandler.eventOccurred(stateAtSwitch, new DateDetector(), true);
90          // THEN
91          assertEquals(action, wrappedAction);
92      }
93  
94      @Test
95      void testEventOccurredResetDerivativesAndMatchingDetector() {
96          // GIVEN
97          final AbsoluteDate date = AbsoluteDate.ARBITRARY_EPOCH;
98          final Vector3D acceleration = new Vector3D(1, 2, 3);
99          final ForceModel forceModel = new TestForce(acceleration, acceleration, date);
100         final ForceModel forceWithDetectors = getForceModelWithWrappedDateDetectors(forceModel, date);
101         final NumericalPropagationHarvester harvester = mockHarvester();
102         final SwitchEventHandler switchEventHandler = buildSwitchEventHandler(forceWithDetectors, harvester,
103                 new ResetDerivativesOnEvent());
104         final SpacecraftState stateAtSwitch = buildAbsoluteState(date, new double[0]);
105         final DateDetector dateDetector = new DateDetector(date);
106         // WHEN
107         final Action action = switchEventHandler.eventOccurred(stateAtSwitch, dateDetector, true);
108         // THEN
109         assertEquals(Action.RESET_STATE, action);
110     }
111 
112     @Test
113     void testResetStateTrivial() {
114         // GIVEN
115         final NumericalPropagationHarvester harvester = mockHarvester();
116         final SwitchEventHandler switchEventHandler = buildSwitchEventHandler(mock(ForceModel.class), harvester, new ContinueOnEvent());
117         final SpacecraftState stateAtSwitch = buildAbsoluteState(AbsoluteDate.ARBITRARY_EPOCH, new double[0]);
118         // WHEN
119         final SpacecraftState resetState = switchEventHandler.resetState(null, stateAtSwitch);
120         // THEN
121         compareStates(stateAtSwitch, resetState);
122     }
123 
124     private static void compareStatesWithoutAdditionalVariables(final SpacecraftState expectedState,
125                                                                 final SpacecraftState actualState) {
126         assertEquals(expectedState.getDate(), actualState.getDate());
127         assertEquals(expectedState.getPosition(), actualState.getPosition());
128         assertEquals(expectedState.getMass(), actualState.getMass());
129         assertEquals(expectedState.getAttitude(), actualState.getAttitude());
130     }
131 
132     private static void compareStates(final SpacecraftState expectedState, final SpacecraftState actualState) {
133         compareStatesWithoutAdditionalVariables(expectedState, actualState);
134         assertArrayEquals(expectedState.getAdditionalState(STM_NAME), actualState.getAdditionalState(STM_NAME));
135     }
136 
137     @Test
138     void testResetStateDetectorIndependentOfState() {
139         // GIVEN
140         final AbsoluteDate date = AbsoluteDate.ARBITRARY_EPOCH;
141         final Vector3D acceleration = Vector3D.MINUS_I;
142         final ForceModel forceModel = new TestForce(Vector3D.ZERO, acceleration, date);
143         final ForceModel forceWithDetectors = getForceModelWithWrappedDateDetectors(forceModel, date);
144         final NumericalPropagationHarvester harvester = mockHarvester();
145         final SwitchEventHandler switchEventHandler = buildSwitchEventHandler(forceWithDetectors, harvester,
146                 new ResetDerivativesOnEvent());
147         final RealMatrix stm = MatrixUtils.createRealIdentityMatrix(7);
148         final String jacobianName = "param0";
149         when(harvester.getJacobiansColumnsNames()).thenReturn(Collections.singletonList(jacobianName));
150         final double[][] transposedJacobian = new double[7][7];
151         transposedJacobian[0] = new double[] {1, 2, 3, 4, 5, 6, 7};
152         final SpacecraftState stateAtSwitch = buildAbsoluteState(date, harvester.toArray(stm.getData()))
153                 .addAdditionalData(jacobianName, transposedJacobian[0]);
154         when(harvester.getParametersJacobian(stateAtSwitch)).thenReturn(MatrixUtils.createRealMatrix(transposedJacobian).transpose());
155         final EventDetector eventDetector = forceWithDetectors.getEventDetectors().collect(Collectors.toList()).get(0);
156         // WHEN
157         preprocessSwitchHandler(switchEventHandler, stateAtSwitch, eventDetector);
158         final SpacecraftState resetState = switchEventHandler.resetState(eventDetector, stateAtSwitch);
159         // THEN
160         compareStates(stateAtSwitch, resetState);
161         assertArrayEquals(stateAtSwitch.getAdditionalState(jacobianName), resetState.getAdditionalState(jacobianName));
162     }
163 
164     @Test
165     void testResetStateNoSwitch() {
166         // GIVEN
167         final AbsoluteDate date = AbsoluteDate.ARBITRARY_EPOCH;
168         final RealMatrix stm = buildStm();
169         final NumericalPropagationHarvester harvester = mockHarvester();
170         final SpacecraftState stateAtSwitch = buildOrbitState(date, OrbitType.CARTESIAN, harvester.toArray(stm.getData()));
171         final ForceModel forceWithDetectors = getForceModelWithoutSwitch(stateAtSwitch.getPVCoordinates());
172         final SwitchEventHandler switchEventHandler = buildSwitchEventHandler(forceWithDetectors, harvester,
173                 new ResetDerivativesOnEvent());
174         final EventDetector eventDetector = forceWithDetectors.getEventDetectors().collect(Collectors.toList()).get(0);
175         // WHEN
176         preprocessSwitchHandler(switchEventHandler, stateAtSwitch, eventDetector);
177         final SpacecraftState resetState = switchEventHandler.resetState(mock(EventDetector.class), stateAtSwitch);
178         // THEN
179         compareStates(stateAtSwitch, resetState);
180     }
181 
182     private static RealMatrix buildStm() {
183         final RealMatrix matrix = MatrixUtils.createRealIdentityMatrix(7);
184         for (int i = 0; i < 7; i++) {
185             for (int j = 0; j < 7; j++) {
186                 if (i != j) {
187                     matrix.setEntry(i, j, i + j);
188                 }
189             }
190         }
191         return matrix;
192     }
193 
194     @ParameterizedTest
195     @EnumSource(OrbitType.class)
196     void testResetState(final OrbitType orbitType) {
197         // GIVEN
198         final AbsoluteDate date = AbsoluteDate.ARBITRARY_EPOCH;
199         final Vector3D accelerationBefore = Vector3D.MINUS_I;
200         final Vector3D accelerationAfter = new Vector3D(1, 2, 3);
201         final RealMatrix originalStm = buildStm();
202         final NumericalPropagationHarvester harvester = mockHarvester();
203         final SpacecraftState stateAtSwitch = buildOrbitState(date, orbitType, harvester.toArray(originalStm.getData()));
204         final ForceModel forceWithDetectors = getForceModel(accelerationBefore, accelerationAfter, stateAtSwitch.getPVCoordinates());
205         final SwitchEventHandler switchEventHandler = buildSwitchEventHandler(forceWithDetectors, harvester,
206                 new ResetDerivativesOnEvent());
207         final EventDetector eventDetector = new TimeStampedPVDetector(stateAtSwitch.getPVCoordinates());
208         // WHEN
209         preprocessSwitchHandler(switchEventHandler, stateAtSwitch, eventDetector);
210         final SpacecraftState resetState = switchEventHandler.resetState(eventDetector, stateAtSwitch);
211         // THEN
212         compareStatesWithoutAdditionalVariables(stateAtSwitch, resetState);
213         if (orbitType == OrbitType.CARTESIAN) {
214             final GradientField field = GradientField.getField(8);
215             final RealMatrix actualStm = harvester.toSquareMatrix(resetState.getAdditionalState(STM_NAME));
216             final FieldPTimeStampedVDetector<Gradient> fieldDetector = new FieldPTimeStampedVDetector<>(field, stateAtSwitch.getPVCoordinates());
217             final double[] deltaDerivatives = new double[7];
218             deltaDerivatives[3] = accelerationBefore.getX() - accelerationAfter.getX();
219             deltaDerivatives[4] = accelerationBefore.getY() - accelerationAfter.getY();
220             deltaDerivatives[5] = accelerationBefore.getZ() - accelerationAfter.getZ();
221             final RealMatrix updateMatrix = computeUpdateMatrix(stateAtSwitch, deltaDerivatives, fieldDetector);
222             final RealMatrix expectedStm = updateMatrix.multiply(originalStm);
223             for (int i = 0; i < 7; i++) {
224                 assertArrayEquals(expectedStm.getRow(i), actualStm.getRow(i), 1e-6);
225             }
226         }
227     }
228 
229     private RealMatrix computeUpdateMatrix(final SpacecraftState state, final double[] deltaDerivatives,
230                                            final FieldEventDetector<Gradient> switchFieldDetector) {
231         final Gradient g = evaluateG(state, 0., switchFieldDetector);
232         final RealMatrix matrixToInverse = MatrixUtils.createRealIdentityMatrix(8);
233         matrixToInverse.setRow(7, g.getGradient());
234         final RealMatrix inverted = new QRDecomposition(matrixToInverse).getSolver().getInverse();
235         final RealMatrix lhs = MatrixUtils.createRealIdentityMatrix(8);
236         for (int i = 0; i < 7; i++) {
237             lhs.setEntry(i, 7, deltaDerivatives[i]);
238         }
239         final RealMatrix product = lhs.multiply(inverted);
240         return product.getSubMatrix(0, 6, 0, 6);
241     }
242 
243     private Gradient evaluateG(final SpacecraftState state, final double massRate,
244                                final FieldEventDetector<Gradient> fieldEventDetector) {
245         final int freeParameters = 8;
246         final Gradient dt = Gradient.variable(freeParameters, 7, 0);
247         final GradientField field = dt.getField();
248         final Vector3D position = state.getPosition();
249         final Vector3D velocity = state.getVelocity();
250         final FieldVector3D<Gradient> fieldPosition = new FieldVector3D<>(Gradient.variable(freeParameters, 0, position.getX()),
251                 Gradient.variable(freeParameters, 1, position.getY()),
252                 Gradient.variable(freeParameters, 2, position.getZ())).add(new FieldVector3D<>(field,
253                 velocity).scalarMultiply(dt));
254         final FieldVector3D<Gradient> fieldVelocity = new FieldVector3D<>(Gradient.variable(freeParameters, 3, velocity.getX()),
255                 Gradient.variable(freeParameters, 4, velocity.getY()),
256                 Gradient.variable(freeParameters, 5, velocity.getZ())).add(new FieldVector3D<>(field,
257                 state.getPVCoordinates().getAcceleration()).scalarMultiply(dt));
258         final FieldAbsoluteDate<Gradient> fieldDate = new FieldAbsoluteDate<>(dt.getField(), state.getDate()).shiftedBy(dt);
259         final FieldAbsolutePVCoordinates<Gradient> fieldAbsolutePVCoordinates = new FieldAbsolutePVCoordinates<>(
260                 state.getFrame(), fieldDate, fieldPosition, fieldVelocity);
261         final Gradient fieldMass = Gradient.variable(freeParameters, 6, state.getMass()).add(dt.multiply(massRate));
262         final FieldSpacecraftState<Gradient> fieldState = new FieldSpacecraftState<>(fieldAbsolutePVCoordinates)
263                 .withMass(fieldMass);
264         return fieldEventDetector.g(fieldState);
265     }
266 
267     private static ForceModel getForceModelWithoutSwitch(TimeStampedPVCoordinates pvCoordinates) {
268         final Vector3D acceleration = new Vector3D(2, 1);
269         return getForceModel(acceleration, acceleration, pvCoordinates);
270     }
271 
272     private static ForceModel getForceModel(final Vector3D accelerationBefore, final Vector3D accelerationAfter,
273                                             final TimeStampedPVCoordinates pvCoordinates) {
274         final ForceModel forceModel = new TestForce(accelerationBefore, accelerationAfter, pvCoordinates.getDate());
275         final TimeStampedPVCoordinates trivialPV = new TimeStampedPVCoordinates(pvCoordinates.getDate(), new PVCoordinates());
276         return new ForceModelModifier() {
277             @Override
278             public ForceModel getUnderlyingModel() {
279                 return forceModel;
280             }
281 
282             @Override
283             public Stream<EventDetector> getEventDetectors() {
284                 return Stream.of(new TimeStampedPVDetector(trivialPV), new TimeStampedPVDetector(pvCoordinates));
285             }
286 
287             @Override
288             public <T extends CalculusFieldElement<T>> Stream<FieldEventDetector<T>> getFieldEventDetectors(Field<T> field) {
289                 return Stream.of(new FieldPTimeStampedVDetector<>(field, trivialPV),
290                         new FieldPTimeStampedVDetector<>(field,  pvCoordinates));
291             }
292         };
293     }
294 
295     private static void preprocessSwitchHandler(final SwitchEventHandler switchEventHandler, final SpacecraftState stateAtSwitch,
296                                                 final EventDetector eventDetector) {
297         switchEventHandler.init(stateAtSwitch, AbsoluteDate.FUTURE_INFINITY, eventDetector);
298         switchEventHandler.eventOccurred(stateAtSwitch, eventDetector, true);
299     }
300 
301     private static SwitchEventHandler buildSwitchEventHandler(final ForceModel forceModel,
302                                                               final NumericalPropagationHarvester harvester,
303                                                               final EventHandler handler) {
304         final NumericalTimeDerivativesEquations equations = new NumericalTimeDerivativesEquations(null,
305                 null, Collections.singletonList(forceModel));
306         final AttitudeProvider attitudeProvider = new FrameAlignedProvider(Rotation.IDENTITY);
307         return new SwitchEventHandler(handler, harvester, equations, attitudeProvider);
308     }
309 
310     private static SpacecraftState buildAbsoluteState(final AbsoluteDate date, final double[] stmArray) {
311         final PVCoordinates pvCoordinates = buildOrbitState(date, OrbitType.CARTESIAN, stmArray).getPVCoordinates();
312         final AbsolutePVCoordinates absolutePVCoordinates = new AbsolutePVCoordinates(FramesFactory.getEME2000(), date,
313                 pvCoordinates);
314         return new SpacecraftState(absolutePVCoordinates).addAdditionalData(STM_NAME, stmArray);
315     }
316 
317     private static SpacecraftState buildOrbitState(final AbsoluteDate date, final OrbitType orbitType,
318                                                    final double[] stmArray) {
319         final PVCoordinates pvCoordinates = new PVCoordinates(new Vector3D(7e6, 3e3, -1e2), new Vector3D(-1e2, 7e3, 1e1));
320         final CartesianOrbit cartesianOrbit = new CartesianOrbit(new TimeStampedPVCoordinates(date, pvCoordinates),
321                 FramesFactory.getEME2000(), Constants.EGM96_EARTH_MU);
322         return new SpacecraftState(orbitType.convertType(cartesianOrbit)).addAdditionalData(STM_NAME, stmArray);
323     }
324 
325     private static NumericalPropagationHarvester mockHarvester() {
326         final NumericalPropagationHarvester harvester = mock();
327         when(harvester.getStateDimension()).thenReturn(7);
328         when(harvester.toArray(Mockito.any())).thenCallRealMethod();
329         when(harvester.toSquareMatrix(Mockito.any())).thenCallRealMethod();
330         when(harvester.getStmName()).thenReturn(STM_NAME);
331         return harvester;
332     }
333 
334     private static class TimeStampedPVDetector implements EventDetector {
335         private final TimeStampedPVCoordinates pvCoordinates;
336 
337         TimeStampedPVDetector(final TimeStampedPVCoordinates pvCoordinates) {
338             this.pvCoordinates = pvCoordinates;
339         }
340 
341         @Override
342         public double g(SpacecraftState s) {
343             final PVCoordinates relativePV = new PVCoordinates(s.getPVCoordinates(), pvCoordinates);
344             return s.durationFrom(pvCoordinates) + relativePV.getPosition().getX() + relativePV.getPosition().getY()
345                     + relativePV.getPosition().getZ() + relativePV.getVelocity().getX() + relativePV.getVelocity().getY()
346                     + relativePV.getVelocity().getZ();
347         }
348 
349         @Override
350         public EventHandler getHandler() {
351             return new ContinueOnEvent();
352         }
353     }
354 
355     private static class FieldPTimeStampedVDetector<T extends CalculusFieldElement<T>> implements FieldEventDetector<T> {
356         private final TimeStampedPVCoordinates pvCoordinates;
357         private final Field<T> field;
358 
359         FieldPTimeStampedVDetector(final Field<T> field, final TimeStampedPVCoordinates pvCoordinates) {
360             this.field = field;
361             this.pvCoordinates = pvCoordinates;
362         }
363 
364         @Override
365         public T g(FieldSpacecraftState<T> s) {
366             final FieldPVCoordinates<T> relativePV = new FieldPVCoordinates<>(s.getPosition().subtract(pvCoordinates.getPosition()),
367                     s.getPVCoordinates().getVelocity().subtract(pvCoordinates.getVelocity()));
368             final FieldVector3D<T> relativePosition = relativePV.getPosition();
369             final FieldVector3D<T> relativeVelocity = relativePV.getVelocity();
370             return s.durationFrom(new FieldAbsoluteDate<>(s.getDate().getField(), pvCoordinates.getDate()))
371                     .add(relativePosition.getX()).add(relativePosition.getY()).add(relativePosition.getZ())
372                     .add(relativeVelocity.getX()).add(relativeVelocity.getY()).add(relativeVelocity.getZ());
373         }
374 
375         @Override
376         public FieldEventHandler<T> getHandler() {
377             return new FieldContinueOnEvent<>();
378         }
379 
380         @Override
381         public FieldEventDetectionSettings<T> getDetectionSettings() {
382             return new FieldEventDetectionSettings<>(field, EventDetectionSettings.getDefaultEventDetectionSettings());
383         }
384     }
385 
386     private static ForceModel getForceModelWithWrappedDateDetectors(final ForceModel forceModel, final AbsoluteDate date) {
387         return new ForceModelModifier() {
388             @Override
389             public ForceModel getUnderlyingModel() {
390                 return forceModel;
391             }
392 
393             @Override
394             public Stream<EventDetector> getEventDetectors() {
395                 return Stream.of(new DetectorModifier() {
396                     @Override
397                     public boolean dependsOnTimeOnly() {
398                         return false;
399                     }
400 
401                     @Override
402                     public EventDetector getDetector() {
403                         return new DateDetector(date);
404                     }
405                 });
406             }
407 
408             @Override
409             public <T extends CalculusFieldElement<T>> Stream<FieldEventDetector<T>> getFieldEventDetectors(Field<T> field) {
410                 return Stream.of(new FieldDetectorModifier<T>() {
411                     @Override
412                     public boolean dependsOnTimeOnly() {
413                         return false;
414                     }
415 
416                     @Override
417                     public FieldEventDetector<T> getDetector() {
418                         return new FieldDateDetector<>(new FieldAbsoluteDate<>(field, date));
419                     }
420                 });
421             }
422         };
423     }
424 
425     private static class TestForce implements ForceModel {
426 
427         private final Vector3D accelerationBefore;
428         private final Vector3D accelerationAfter;
429         private final AbsoluteDate switchDate;
430 
431         TestForce(final Vector3D accelerationBefore, final Vector3D accelerationAfter,
432                   final AbsoluteDate switchDate) {
433             this.accelerationBefore = accelerationBefore;
434             this.accelerationAfter = accelerationAfter;
435             this.switchDate = switchDate;
436         }
437 
438         @Override
439         public boolean dependsOnPositionOnly() {
440             return false;
441         }
442 
443         @Override
444         public Vector3D acceleration(SpacecraftState s, double[] parameters) {
445             return s.getDate().isBeforeOrEqualTo(switchDate) ? accelerationBefore : accelerationAfter;
446         }
447 
448         @Override
449         public <T extends CalculusFieldElement<T>> FieldVector3D<T> acceleration(FieldSpacecraftState<T> s, T[] parameters) {
450             return null;
451         }
452 
453         @Override
454         public List<ParameterDriver> getParametersDrivers() {
455             return Collections.emptyList();
456         }
457     }
458 }