import numpy as np
from bouter.utilities import (
extract_segments_above_threshold,
fill_out_segments,
)
[docs]def test_continue_curvature():
n_segments = 6
curvature = np.full((2, n_segments), np.nan)
curvature[0, 0:4] = np.arange(4)
curvature[1, 0:3] = np.arange(3)
continued, n_segments_missing = fill_out_segments(curvature)
np.testing.assert_equal(continued[0, 4:], 3)
np.testing.assert_equal(continued[1, 3:], 2)
curvature = np.full((4, n_segments), np.nan)
curvature[0, 0:4] = np.arange(4)
curvature[1, 0:3] = np.arange(3)
curvature[2, 0:2] = np.arange(2)
curvature[3, 0] = 1
continued, n_segments_missing = fill_out_segments(
curvature, continue_curvature=2
)
np.testing.assert_equal(continued[0, :], np.arange(n_segments))
np.testing.assert_equal(n_segments_missing, [2, 3, 4, 5])