Optimising Cheetah for speed

One of Cheetah’s standout features is its computational speed. This is achieved through some optimisations under the hood, which the user never needs to worry about. Often, however, there further optimisations that can be made when knowledge on how the model will be used is available. For example, in many cases, one might load a large lattice of an entire facility that has thousands of elements, but then only ever changes a handful of these elements for the experiments at hand. For this case, Cheetah offers some opt-in optimisation features that can help speed up simulations significantly by an order of magnitude or more in some cases.

[1]:
import cheetah
import torch
[2]:
incoming_beam = cheetah.ParameterBeam.from_astra(
    "../../tests/resources/ACHIP_EA1_2021.1351.001"
)

Let’s define a large lattice. With many quadrupole magnets and drift sections in the center and a pair of steerers at each end. We assume that the quadrupole magnets are at their design settings and will never be touched. Only the two steerers at each end are of interest to us, for example because we would like to train a neural network policy to steer the beam using these steerers. Furthermore, as many lattices do, there are a bunch of markers in this lattice. These markers may be helpful to mark certain positions along the beamline, but they don’t actually add anything to the physics of the simulation.

[3]:
original_segment = cheetah.Segment(
    elements=[
        cheetah.HorizontalCorrector(
            length=torch.tensor(0.1), angle=torch.tensor(0.0), name="HCOR_1"
        ),
        cheetah.Drift(length=torch.tensor(0.3)),
        cheetah.VerticalCorrector(
            length=torch.tensor(0.1), angle=torch.tensor(0.0), name="VCOR_1"
        ),
        cheetah.Drift(length=torch.tensor(0.3)),
    ]
    + [
        cheetah.Quadrupole(length=torch.tensor(0.1), k1=torch.tensor(4.2)),
        cheetah.Drift(length=torch.tensor(0.2)),
        cheetah.Quadrupole(length=torch.tensor(0.1), k1=torch.tensor(-4.2)),
        cheetah.Drift(length=torch.tensor(0.2)),
        cheetah.Marker(),
        cheetah.Quadrupole(length=torch.tensor(0.1), k1=torch.tensor(0.0)),
        cheetah.Drift(length=torch.tensor(0.2)),
    ]
    * 150
    + [
        cheetah.HorizontalCorrector(
            length=torch.tensor(0.1), angle=torch.tensor(0.0), name="HCOR_2"
        ),
        cheetah.Drift(length=torch.tensor(0.3)),
        cheetah.VerticalCorrector(
            length=torch.tensor(0.1), angle=torch.tensor(0.0), name="VCOR_2"
        ),
        cheetah.Drift(length=torch.tensor(0.3)),
    ]
)
[4]:
len(original_segment.elements)
[4]:
1058

First, we test how long it takes to track a beam through this segment without any optimisations beyond the ones automatically done under the hood.

[5]:
%%timeit
original_segment.track(incoming_beam)
66.1 ms ± 178 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Just by removing unused markers, we already see a small performance improvement.

[6]:
markers_removed_segment = original_segment.without_inactive_markers()
[7]:
%%timeit
markers_removed_segment.track(incoming_beam)
65.3 ms ± 203 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Drift sections tend to be the cheapest elements to compute. At the same time, many elements in a lattice may be switched off at any given time. When they are switched off, they behave almost exactly like drift sections, but they still require additional computations to arrive at this result. We can however safely replace them by actual Drift elements, which clearly speeds up computations.

[8]:
inactive_to_drifts_segment = original_segment.inactive_elements_as_drifts(
    except_for=["HCOR_1", "VCOR_1", "HCOR_2", "VCOR_2"]
)
len(inactive_to_drifts_segment.elements)
[8]:
1058
[9]:
%%timeit
inactive_to_drifts_segment.track(incoming_beam)
50 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

The most significant improvement can be made by merging elements that are not expected to be changed in the future. For this, Cheetah offers the transfer_maps_merged method. This will by default merge the transfer maps of all elements in the segment. In almost all realistic applications, however, there are some elements the settings of which we wish to change in the future. By passing a list of their names to except_for, we can instruct Cheetah to only merge elements in between the passed elements.

NOTE: Transfer map merging can only be done for a constant incoming beam energy, because the transfer maps need to be computed before they can be merged, and computing them might require the beam energy at the entrance of the element that the transfer map belongs to. If you want to try a different beam energy, you will need to reapply the optimisations to the original lattice while passing a beam with the desired energy.

[10]:
transfer_maps_merged_segment = original_segment.transfer_maps_merged(
    incoming_beam=incoming_beam, except_for=["HCOR_1", "VCOR_1", "HCOR_2", "VCOR_2"]
)
len(transfer_maps_merged_segment.elements)
[10]:
8
[11]:
%%timeit
transfer_maps_merged_segment.track(incoming_beam)
96.2 µs ± 121 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
[12]:
transfer_maps_merged_segment
[12]:
Segment(elements=ModuleList(
  (0): HorizontalCorrector(length=tensor(0.1000), angle=tensor(0.), name='HCOR_1', device='cpu')
  (1): Drift(length=tensor(0.3000), name='unnamed_element_0', device='cpu')
  (2): VerticalCorrector(length=tensor(0.1000), angle=tensor(0.), name='VCOR_1', device='cpu')
  (3): CustomTransferMap(name='unnamed_element_615', device='cpu')
  (4): HorizontalCorrector(length=tensor(0.1000), angle=tensor(0.), name='HCOR_2', device='cpu')
  (5): Drift(length=tensor(0.3000), name='unnamed_element_9', device='cpu')
  (6): VerticalCorrector(length=tensor(0.1000), angle=tensor(0.), name='VCOR_2', device='cpu')
  (7): CustomTransferMap(name='unnamed_element_616', device='cpu')
), name='unnamed', device='cpu')

It is also possible and often advisable to combine optimisations. However, note that this might not always yield as much of an improvement as one may have hoped looking at the improvements delivered by each optimisation on its own. This is usually because these optimisations share some of their effects, i.e. if the first optimisation has already performed a change on the lattice that the second optimisation would have done as well, the second optimisation will not lead to a further speed improvement.

[13]:
fully_optimized_segment = (
    original_segment.without_inactive_markers()
    .inactive_elements_as_drifts(except_for=["HCOR_1", "VCOR_1", "HCOR_2", "VCOR_2"])
    .transfer_maps_merged(
        incoming_beam=incoming_beam, except_for=["HCOR_1", "VCOR_1", "HCOR_2", "VCOR_2"]
    )
)
len(fully_optimized_segment.elements)
[13]:
8
[14]:
fully_optimized_segment
[14]:
Segment(elements=ModuleList(
  (0): HorizontalCorrector(length=tensor(0.1000), angle=tensor(0.), name='HCOR_1', device='cpu')
  (1): Drift(length=tensor(0.3000), name='unnamed_element_617', device='cpu')
  (2): VerticalCorrector(length=tensor(0.1000), angle=tensor(0.), name='VCOR_1', device='cpu')
  (3): CustomTransferMap(name='unnamed_element_1221', device='cpu')
  (4): HorizontalCorrector(length=tensor(0.1000), angle=tensor(0.), name='HCOR_2', device='cpu')
  (5): Drift(length=tensor(0.3000), name='unnamed_element_1219', device='cpu')
  (6): VerticalCorrector(length=tensor(0.1000), angle=tensor(0.), name='VCOR_2', device='cpu')
  (7): CustomTransferMap(name='unnamed_element_1222', device='cpu')
), name='unnamed', device='cpu')
[15]:
%%timeit
fully_optimized_segment.track(incoming_beam)
96.9 µs ± 780 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
[ ]: