Skip to content

Commit

Permalink
Change np.average to np.ma.average to properly handle cases with zero…
Browse files Browse the repository at this point in the history
… weights.
  • Loading branch information
SebastienJoly authored and SebastienJoly committed Oct 10, 2024
1 parent 9e6407d commit 93f1608
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions pySC/core/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def bpm_reading(SC: SimulatedCommissioning, bpm_ords: ndarray = None, calculate_
bpm_orbits_4d[:, :, :, shot_num], bpm_sums_4d[:, :, :, shot_num] = _real_bpm_reading(SC, tracking_4d, bpm_inds)

# mean_bpm_orbits_3d is 3D (dim, BPM, turn)
mean_bpm_orbits_3d = np.average(np.ma.array(bpm_orbits_4d, mask=np.isnan(bpm_orbits_4d)),
weights=np.ma.array(bpm_sums_4d, mask=np.isnan(bpm_sums_4d)), axis=3).filled(np.nan)
mean_bpm_orbits_3d = np.ma.average(np.ma.array(bpm_orbits_4d, mask=np.isnan(bpm_orbits_4d)),
weights=np.ma.array(bpm_sums_4d, mask=np.isnan(bpm_sums_4d)), axis=3).filled(np.nan)
# averaging "charge" also when the beam did not reach the location
mean_bpm_sums_3d = np.nansum(bpm_sums_4d, axis=3) / SC.INJ.nShots

Expand All @@ -61,7 +61,7 @@ def bpm_reading(SC: SimulatedCommissioning, bpm_ords: ndarray = None, calculate_
mean_bpm_sums_3d = np.nansum(mean_bpm_sums_3d, axis=2, keepdims=True) / SC.INJ.nTurns
if calculate_errors and SC.INJ.trackMode == TRACK_TBT:
bpm_orbits_4d[np.sum(np.isnan(bpm_orbits_4d), axis=3) > 0, :] = np.nan
squared_orbit_diffs = np.square(bpm_orbits_4d - mean_bpm_orbits_3d[:, :, :, np.newaxis])
squared_orbit_diffs = np.square(bpm_orbits_4d - mean_bpm_orbits_3d)
err_bpm_orbits_3d = np.sqrt(np.average(np.ma.array(squared_orbit_diffs, mask=np.isnan(bpm_orbits_4d)),
weights=np.ma.array(bpm_sums_4d, mask=np.isnan(bpm_orbits_4d)), axis=3)).filled(np.nan)
# Organising the array 2 x (nturns x nbpms) sorted by "arrival time"
Expand Down Expand Up @@ -128,8 +128,7 @@ def beam_transmission(SC: SimulatedCommissioning, nParticles: int = None, nTurns
if nTurns is None:
nTurns = SC.INJ.nTurns
LOGGER.debug(f'Calculating maximum beam transmission for {nParticles} particles and {nTurns} turns: ')
T = at_wrapper.lattice_track(SC.RING, generate_bunches(SC, nParticles=nParticles), nTurns, np.array([len(SC.RING)]),
keep_lattice=False, use_mp=True)
T = at_wrapper.patpass(SC.RING, generate_bunches(SC, nParticles=nParticles), nTurns, np.array([len(SC.RING)]), keep_lattice=False)
fraction_survived = np.mean(~np.isnan(T[0, :, :, :]), axis=(0, 1))
max_turns = np.sum(fraction_survived > 1 - SC.INJ.beamLostAt)
if plot:
Expand Down Expand Up @@ -186,7 +185,7 @@ def _real_bpm_reading(SC, track_mat, bpm_inds=None): # track_mat should be only
bpm_noise = bpm_noise[:, :, np.newaxis] * sc_tools.randnc(2, (2, n_bpms, nTurns))
bpm_offset = np.transpose(at_wrapper.atgetfieldvalues(SC.RING, bpm_ords, 'Offset') + at_wrapper.atgetfieldvalues(SC.RING, bpm_ords, 'SupportOffset'))
bpm_cal_error = np.transpose(at_wrapper.atgetfieldvalues(SC.RING, bpm_ords, 'CalError'))
bpm_roll = np.ravel(at_wrapper.atgetfieldvalues(SC.RING, bpm_ords, 'Roll')) + np.ravel(at_wrapper.atgetfieldvalues(SC.RING, bpm_ords, 'SupportRoll'))
bpm_roll = np.squeeze(at_wrapper.atgetfieldvalues(SC.RING, bpm_ords, 'Roll') + at_wrapper.atgetfieldvalues(SC.RING, bpm_ords, 'SupportRoll'), axis=1)
bpm_sum_error = np.transpose(at_wrapper.atgetfieldvalues(SC.RING, bpm_ords, 'SumError'))[:, np.newaxis] * sc_tools.randnc(2, (n_bpms, nTurns))
# averaging the X and Y positions at BPMs over particles
mean_orbit = np.nanmean(track_mat, axis=1)
Expand All @@ -212,7 +211,7 @@ def _tracking(SC: SimulatedCommissioning, refs: ndarray) -> ndarray:
if SC.INJ.trackMode == TRACK_ORB:
pos = np.transpose(at_wrapper.findorbit6(SC.RING, refs, keep_lattice=False)[1])[[0, 2], :].reshape(2, 1, len(refs), 1)
else:
pos = at_wrapper.lattice_track(SC.RING, generate_bunches(SC), SC.INJ.nTurns, refs, keep_lattice=False)[[0, 2], :, :, :]
pos = at_wrapper.atpass(SC.RING, generate_bunches(SC), SC.INJ.nTurns, refs, keep_lattice=False)[[0, 2], :, :, :]
pos[1, np.isnan(pos[0, :, :, :])] = np.nan
return pos

Expand Down

0 comments on commit 93f1608

Please sign in to comment.