From e976b5eb3960beec14faab5a57b1261c7ee78964 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Sun, 28 Aug 2022 16:07:53 -0500 Subject: [PATCH 01/11] Create data module --- src/time_series.rs | 1 + 1 file changed, 1 insertion(+) create mode 100644 src/time_series.rs diff --git a/src/time_series.rs b/src/time_series.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/time_series.rs @@ -0,0 +1 @@ + From c94f009823c41a2fd4e8d219d22c7bdb2b2236ff Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Sun, 25 Dec 2022 16:12:44 -0600 Subject: [PATCH 02/11] Move PeriodogramPeaks to a new module --- src/feature.rs | 2 +- src/features/_periodogram_peaks.rs | 151 ++++++++++++++++++++++++++ src/features/mod.rs | 6 +- src/features/periodogram.rs | 167 +++-------------------------- src/lib.rs | 3 + src/number_ending.rs | 12 +++ 6 files changed, 184 insertions(+), 157 deletions(-) create mode 100644 src/features/_periodogram_peaks.rs create mode 100644 src/number_ending.rs diff --git a/src/feature.rs b/src/feature.rs index 7377269..ac284b0 100644 --- a/src/feature.rs +++ b/src/feature.rs @@ -50,7 +50,7 @@ where PercentAmplitude, PercentDifferenceMagnitudePercentile, Periodogram(Periodogram), - _PeriodogramPeaks, + _PeriodogramPeaks(PeriodogramPeaks), ReducedChi2, Skew, StandardDeviation, diff --git a/src/features/_periodogram_peaks.rs b/src/features/_periodogram_peaks.rs new file mode 100644 index 0000000..172e877 --- /dev/null +++ b/src/features/_periodogram_peaks.rs @@ -0,0 +1,151 @@ +use crate::evaluator::*; +use crate::evaluator::{Deserialize, EvaluatorInfo, EvaluatorProperties, Serialize}; +use crate::peak_indices::peak_indices_reverse_sorted; +use crate::{ + number_ending, EvaluatorError, EvaluatorInfoTrait, FeatureEvaluator, + FeatureNamesDescriptionsTrait, Float, TimeSeries, +}; + +use schemars::JsonSchema; +use std::iter; + +macro_const! { + const PERIODOGRAM_PEAKS_DOC: &'static str = r#" +Peak evaluator for [Periodogram] +"#; +} + +#[doc(hidden)] +#[doc = PERIODOGRAM_PEAKS_DOC!()] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde( + from = "PeriodogramPeaksParameters", + into = "PeriodogramPeaksParameters" +)] +pub struct PeriodogramPeaks { + peaks: usize, + properties: Box, +} + +impl PeriodogramPeaks { + pub fn new(peaks: usize) -> Self { + assert!(peaks > 0, "Number of peaks should be at least one"); + let info = EvaluatorInfo { + size: 2 * peaks, + min_ts_length: 1, + t_required: true, + m_required: true, + w_required: false, + sorting_required: true, + variability_required: false, + }; + let names = (0..peaks) + .flat_map(|i| vec![format!("period_{}", i), format!("period_s_to_n_{}", i)]) + .collect(); + let descriptions = (0..peaks) + .flat_map(|i| { + vec![ + format!( + "period of the {}{} highest peak of periodogram", + i + 1, + number_ending(i + 1), + ), + format!( + "Spectral density to spectral density standard deviation ratio of \ + the {}{} highest peak of periodogram", + i + 1, + number_ending(i + 1) + ), + ] + }) + .collect(); + Self { + properties: EvaluatorProperties { + info, + names, + descriptions, + } + .into(), + peaks, + } + } + + pub fn get_peaks(&self) -> usize { + self.peaks + } + + #[inline] + pub fn default_peaks() -> usize { + 1 + } + + pub const fn doc() -> &'static str { + PERIODOGRAM_PEAKS_DOC + } +} + +impl Default for PeriodogramPeaks { + fn default() -> Self { + Self::new(Self::default_peaks()) + } +} + +impl EvaluatorInfoTrait for PeriodogramPeaks { + fn get_info(&self) -> &EvaluatorInfo { + &self.properties.info + } +} + +impl FeatureNamesDescriptionsTrait for PeriodogramPeaks { + fn get_names(&self) -> Vec<&str> { + self.properties.names.iter().map(String::as_str).collect() + } + + fn get_descriptions(&self) -> Vec<&str> { + self.properties + .descriptions + .iter() + .map(String::as_str) + .collect() + } +} + +impl FeatureEvaluator for PeriodogramPeaks +where + T: Float, +{ + fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { + let peak_indices = peak_indices_reverse_sorted(&ts.m.sample); + Ok(peak_indices + .iter() + .flat_map(|&i| { + iter::once(T::two() * T::PI() / ts.t.sample[i]) + .chain(iter::once(ts.m.signal_to_noise(ts.m.sample[i]))) + }) + .chain(iter::repeat(T::zero())) + .take(2 * self.peaks) + .collect()) + } +} + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename = "PeriodogramPeaks")] +struct PeriodogramPeaksParameters { + peaks: usize, +} + +impl From for PeriodogramPeaksParameters { + fn from(f: PeriodogramPeaks) -> Self { + Self { peaks: f.peaks } + } +} + +impl From for PeriodogramPeaks { + fn from(p: PeriodogramPeaksParameters) -> Self { + Self::new(p.peaks) + } +} + +impl JsonSchema for PeriodogramPeaks { + json_schema!(PeriodogramPeaksParameters, false); +} diff --git a/src/features/mod.rs b/src/features/mod.rs index d5dea9e..37a71e4 100644 --- a/src/features/mod.rs +++ b/src/features/mod.rs @@ -1,5 +1,8 @@ //! Feature sctructs implements [crate::FeatureEvaluator] trait +mod _periodogram_peaks; +pub(crate) use _periodogram_peaks::PeriodogramPeaks; + mod amplitude; pub use amplitude::Amplitude; @@ -82,8 +85,8 @@ mod percent_difference_magnitude_percentile; pub use percent_difference_magnitude_percentile::PercentDifferenceMagnitudePercentile; mod periodogram; +pub use _periodogram_peaks::PeriodogramPeaks as _PeriodogramPeaks; pub use periodogram::Periodogram; -pub use periodogram::PeriodogramPeaks as _PeriodogramPeaks; mod reduced_chi2; pub use reduced_chi2::ReducedChi2; @@ -110,4 +113,5 @@ mod villar_fit; pub use villar_fit::{VillarFit, VillarInitsBounds, VillarLnPrior}; mod weighted_mean; + pub use weighted_mean::WeightedMean; diff --git a/src/features/periodogram.rs b/src/features/periodogram.rs index 407dd93..1f5f61a 100644 --- a/src/features/periodogram.rs +++ b/src/features/periodogram.rs @@ -1,162 +1,11 @@ use crate::evaluator::*; use crate::extractor::FeatureExtractor; -use crate::peak_indices::peak_indices_reverse_sorted; +use crate::features::_periodogram_peaks::PeriodogramPeaks; use crate::periodogram; use crate::periodogram::{AverageNyquistFreq, NyquistFreq, PeriodogramPower, PeriodogramPowerFft}; use std::convert::TryInto; use std::fmt::Debug; -use std::iter; - -fn number_ending(i: usize) -> &'static str { - #[allow(clippy::match_same_arms)] - match (i % 10, i % 100) { - (1, 11) => "th", - (1, _) => "st", - (2, 12) => "th", - (2, _) => "nd", - (3, 13) => "th", - (3, _) => "rd", - (_, _) => "th", - } -} - -macro_const! { - const PERIODOGRAM_PEAK_DOC: &'static str = r#" -Peak evaluator for [Periodogram] -"#; -} - -#[doc(hidden)] -#[doc = PERIODOGRAM_PEAK_DOC!()] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde( - from = "PeriodogramPeaksParameters", - into = "PeriodogramPeaksParameters" -)] -pub struct PeriodogramPeaks { - peaks: usize, - properties: Box, -} - -impl PeriodogramPeaks { - pub fn new(peaks: usize) -> Self { - assert!(peaks > 0, "Number of peaks should be at least one"); - let info = EvaluatorInfo { - size: 2 * peaks, - min_ts_length: 1, - t_required: true, - m_required: true, - w_required: false, - sorting_required: true, - variability_required: false, - }; - let names = (0..peaks) - .flat_map(|i| vec![format!("period_{}", i), format!("period_s_to_n_{}", i)]) - .collect(); - let descriptions = (0..peaks) - .flat_map(|i| { - vec![ - format!( - "period of the {}{} highest peak of periodogram", - i + 1, - number_ending(i + 1), - ), - format!( - "Spectral density to spectral density standard deviation ratio of \ - the {}{} highest peak of periodogram", - i + 1, - number_ending(i + 1) - ), - ] - }) - .collect(); - Self { - properties: EvaluatorProperties { - info, - names, - descriptions, - } - .into(), - peaks, - } - } - - #[inline] - pub fn default_peaks() -> usize { - 1 - } - - pub const fn doc() -> &'static str { - PERIODOGRAM_PEAK_DOC - } -} - -impl Default for PeriodogramPeaks { - fn default() -> Self { - Self::new(Self::default_peaks()) - } -} - -impl EvaluatorInfoTrait for PeriodogramPeaks { - fn get_info(&self) -> &EvaluatorInfo { - &self.properties.info - } -} - -impl FeatureNamesDescriptionsTrait for PeriodogramPeaks { - fn get_names(&self) -> Vec<&str> { - self.properties.names.iter().map(String::as_str).collect() - } - - fn get_descriptions(&self) -> Vec<&str> { - self.properties - .descriptions - .iter() - .map(String::as_str) - .collect() - } -} - -impl FeatureEvaluator for PeriodogramPeaks -where - T: Float, -{ - fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - let peak_indices = peak_indices_reverse_sorted(&ts.m.sample); - Ok(peak_indices - .iter() - .flat_map(|&i| { - iter::once(T::two() * T::PI() / ts.t.sample[i]) - .chain(iter::once(ts.m.signal_to_noise(ts.m.sample[i]))) - }) - .chain(iter::repeat(T::zero())) - .take(2 * self.peaks) - .collect()) - } -} - -#[derive(Serialize, Deserialize, JsonSchema)] -#[serde(rename = "PeriodogramPeaks")] -struct PeriodogramPeaksParameters { - peaks: usize, -} - -impl From for PeriodogramPeaksParameters { - fn from(f: PeriodogramPeaks) -> Self { - Self { peaks: f.peaks } - } -} - -impl From for PeriodogramPeaks { - fn from(p: PeriodogramPeaksParameters) -> Self { - Self::new(p.peaks) - } -} - -impl JsonSchema for PeriodogramPeaks { - json_schema!(PeriodogramPeaksParameters, false); -} macro_const! { const DOC: &str = r#" @@ -300,8 +149,16 @@ where /// New [Periodogram] that finds given number of peaks pub fn new(peaks: usize) -> Self { let peaks = PeriodogramPeaks::new(peaks); - let peak_names = peaks.properties.names.clone(); - let peak_descriptions = peaks.properties.descriptions.clone(); + let peak_names = peaks + .get_names() + .into_iter() + .map(ToOwned::to_owned) + .collect(); + let peak_descriptions = peaks + .get_descriptions() + .into_iter() + .map(ToOwned::to_owned) + .collect(); let peaks_size_hint = peaks.size_hint(); let peaks_min_ts_length = peaks.min_ts_length(); let info = EvaluatorInfo { @@ -437,7 +294,7 @@ where let mut features = feature_extractor.into_vec(); let rest_of_features = features.split_off(1); let periodogram_peaks: PeriodogramPeaks = features.pop().unwrap().try_into().unwrap(); - let peaks = periodogram_peaks.peaks; + let peaks = periodogram_peaks.get_peaks(); Self { resolution, diff --git a/src/lib.rs b/src/lib.rs index 36c6c79..7bf202d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,6 +44,9 @@ pub use nl_fit::LmsderCurveFit; pub use nl_fit::{prior, LnPrior, LnPrior1D}; pub use nl_fit::{CurveFitAlgorithm, McmcCurveFit}; +mod number_ending; +pub(crate) use number_ending::number_ending; + #[doc(hidden)] pub mod periodogram; pub use periodogram::recurrent_sin_cos::RecurrentSinCos; diff --git a/src/number_ending.rs b/src/number_ending.rs new file mode 100644 index 0000000..7af7c04 --- /dev/null +++ b/src/number_ending.rs @@ -0,0 +1,12 @@ +pub(crate) fn number_ending(i: usize) -> &'static str { + #[allow(clippy::match_same_arms)] + match (i % 10, i % 100) { + (1, 11) => "th", + (1, _) => "st", + (2, 12) => "th", + (2, _) => "nd", + (3, 13) => "th", + (3, _) => "rd", + (_, _) => "th", + } +} From d1d2b6ff7af8ca33e99ee6144ab2ba046f38a621 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Thu, 29 Dec 2022 12:02:07 -0600 Subject: [PATCH 03/11] Initial impl of MiltiColorPeriodogram --- Cargo.toml | 1 + src/data/mod.rs | 3 +- src/data/multi_color_time_series.rs | 44 ++- src/error.rs | 25 ++ src/features/periodogram.rs | 36 ++- src/lib.rs | 2 +- src/multicolor/features/mod.rs | 3 + .../features/multi_color_periodogram.rs | 255 ++++++++++++++++++ src/multicolor/mod.rs | 2 +- src/periodogram/mod.rs | 73 ++--- src/periodogram/power_direct.rs | 3 +- src/periodogram/power_fft.rs | 5 +- src/periodogram/power_trait.rs | 3 +- 13 files changed, 393 insertions(+), 62 deletions(-) create mode 100644 src/multicolor/features/multi_color_periodogram.rs diff --git a/Cargo.toml b/Cargo.toml index 5d60028..c3d4718 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,7 @@ clap = { version = "3.2.6", features = ["std", "color", "suggestions", "derive", criterion = "0.4" hyperdual = "1.1" light-curve-common = "0.1.0" +ndarray = { version = "^0.15", features = ["approx-0_5"] } plotters = { version = "0.3.5", default-features = false, features = ["errorbar", "line_series", "ttf"] } plotters-bitmap = "0.3.3" rand = "0.7" diff --git a/src/data/mod.rs b/src/data/mod.rs index aafaeb9..469dee0 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -1,12 +1,11 @@ mod data_sample; pub use data_sample::DataSample; -mod multi_color_time_series; +pub(crate) mod multi_color_time_series; pub use multi_color_time_series::MultiColorTimeSeries; mod sorted_array; pub use sorted_array::SortedArray; mod time_series; - pub use time_series::TimeSeries; diff --git a/src/data/multi_color_time_series.rs b/src/data/multi_color_time_series.rs index 535543a..fa1bd4d 100644 --- a/src/data/multi_color_time_series.rs +++ b/src/data/multi_color_time_series.rs @@ -3,6 +3,7 @@ use crate::float_trait::Float; use crate::multicolor::PassbandTrait; use crate::{DataSample, PassbandSet}; +use conv::prelude::*; use itertools::Either; use itertools::EitherOrBoth; use itertools::Itertools; @@ -23,6 +24,22 @@ where P: PassbandTrait + 'p, T: Float, { + pub fn total_lenu(&self) -> usize { + match self { + Self::Mapping(mapping) => mapping.total_lenu(), + Self::Flat(flat) => flat.total_lenu(), + Self::MappingFlat { flat, .. } => flat.total_lenu(), + } + } + + pub fn total_lenf(&self) -> T { + match self { + Self::Mapping(mapping) => mapping.total_lenf(), + Self::Flat(flat) => flat.total_lenf(), + Self::MappingFlat { flat, .. } => flat.total_lenf(), + } + } + pub fn from_map(map: impl Into>>) -> Self { Self::Mapping(MappedMultiColorTimeSeries::new(map)) } @@ -148,13 +165,30 @@ where ) } + pub fn total_lenu(&self) -> usize { + self.0.values().map(|ts| ts.lenu()).sum() + } + + pub fn total_lenf(&self) -> T { + self.total_lenu().value_as::().unwrap() + } + pub fn passbands<'slf>( &'slf self, ) -> std::collections::btree_map::Keys<'slf, P, TimeSeries<'a, T>> where 'a: 'slf, { - self.keys() + self.0.keys() + } + + pub fn iter_ts<'slf>( + &'slf self, + ) -> std::collections::btree_map::Values<'slf, P, TimeSeries<'a, T>> + where + 'a: 'slf, + { + self.0.values() } pub fn iter_passband_set<'slf, 'ps>( @@ -305,4 +339,12 @@ where passband_set: mapping.keys().cloned().collect(), } } + + pub fn total_lenu(&self) -> usize { + self.t.sample.len() + } + + pub fn total_lenf(&self) -> T { + self.t.sample.len().value_as::().unwrap() + } } diff --git a/src/error.rs b/src/error.rs index 15c6f0a..6cd5234 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,5 @@ +use crate::data::multi_color_time_series::MappedMultiColorTimeSeries; +use crate::float_trait::Float; use crate::PassbandTrait; use std::collections::BTreeSet; @@ -28,6 +30,15 @@ pub enum MultiColorEvaluatorError { actual: BTreeSet, desired: BTreeSet, }, + + #[error("No time-series long enough: maximum length found is {maximum_actual}, while minimum required is {minimum_required}")] + AllTimeSeriesAreShort { + maximum_actual: usize, + minimum_required: usize, + }, + + #[error(r#"Underlying feature caused an error: "{0:?}""#)] + UnderlyingEvaluatorError(#[from] EvaluatorError), } impl MultiColorEvaluatorError { @@ -43,6 +54,20 @@ impl MultiColorEvaluatorError { desired: desired.map(|p| p.name().into()).collect(), } } + + pub fn all_time_series_short( + mapped: &MappedMultiColorTimeSeries, + minimum_required: usize, + ) -> Self + where + P: PassbandTrait, + T: Float, + { + Self::AllTimeSeriesAreShort { + maximum_actual: mapped.iter_ts().map(|ts| ts.lenu()).max().unwrap_or(0), + minimum_required, + } + } } #[derive(Debug, thiserror::Error, PartialEq, Eq)] diff --git a/src/features/periodogram.rs b/src/features/periodogram.rs index 1f5f61a..27506cb 100644 --- a/src/features/periodogram.rs +++ b/src/features/periodogram.rs @@ -4,6 +4,7 @@ use crate::features::_periodogram_peaks::PeriodogramPeaks; use crate::periodogram; use crate::periodogram::{AverageNyquistFreq, NyquistFreq, PeriodogramPower, PeriodogramPowerFft}; +use ndarray::Array1; use std::convert::TryInto; use std::fmt::Debug; @@ -32,7 +33,7 @@ series without observation errors (unity weights are used if required). You can #[doc = DOC!()] #[derive(Clone, Debug, Deserialize, Serialize)] #[serde( - bound = "T: Float, F: FeatureEvaluator + From + TryInto, >::Error: Debug,", + bound = "T: Float, F: FeatureEvaluator + From + TryInto, >::Error: Debug,", from = "PeriodogramParameters", into = "PeriodogramParameters" )] @@ -43,7 +44,7 @@ where resolution: f32, max_freq_factor: f32, nyquist: NyquistFreq, - feature_extractor: FeatureExtractor, + pub(crate) feature_extractor: FeatureExtractor, periodogram_algorithm: PeriodogramPower, properties: Box, } @@ -119,24 +120,24 @@ where self } - fn periodogram(&self, ts: &mut TimeSeries) -> periodogram::Periodogram { + pub(crate) fn periodogram(&self, t: &[T]) -> periodogram::Periodogram { periodogram::Periodogram::from_t( self.periodogram_algorithm.clone(), - ts.t.as_slice(), + t, self.resolution, self.max_freq_factor, self.nyquist.clone(), ) } - pub fn power(&self, ts: &mut TimeSeries) -> Vec { - self.periodogram(ts).power(ts) + pub fn power(&self, ts: &mut TimeSeries) -> Array1 { + self.periodogram(ts.t.as_slice()).power(ts) } - pub fn freq_power(&self, ts: &mut TimeSeries) -> (Vec, Vec) { - let p = self.periodogram(ts); + pub fn freq_power(&self, ts: &mut TimeSeries) -> (Array1, Array1) { + let p = self.periodogram(ts.t.as_slice()); let power = p.power(ts); - let freq = (0..power.len()).map(|i| p.freq(i)).collect::>(); + let freq = (0..power.len()).map(|i| p.freq_by_index(i)).collect(); (freq, power) } } @@ -190,15 +191,12 @@ impl Periodogram where T: Float, F: FeatureEvaluator + From + TryInto, - >::Error: Debug, + >::Error: Debug, { fn transform_ts(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - self.check_ts_length(ts)?; + self.check_ts(ts)?; let (freq, power) = self.freq_power(ts); - Ok(TmArrays { - t: freq.into(), - m: power.into(), - }) + Ok(TmArrays { t: freq, m: power }) } } @@ -225,7 +223,7 @@ impl EvaluatorInfoTrait for Periodogram where T: Float, F: FeatureEvaluator + From + TryInto, - >::Error: Debug, + >::Error: Debug, { fn get_info(&self) -> &EvaluatorInfo { &self.properties.info @@ -236,7 +234,7 @@ impl FeatureNamesDescriptionsTrait for Periodogram where T: Float, F: FeatureEvaluator + From + TryInto, - >::Error: Debug, + >::Error: Debug, { fn get_names(&self) -> Vec<&str> { self.properties.names.iter().map(String::as_str).collect() @@ -255,7 +253,7 @@ impl FeatureEvaluator for Periodogram where T: Float, F: FeatureEvaluator + From + TryInto, - >::Error: Debug, + >::Error: Debug, { transformer_eval!(); } @@ -279,7 +277,7 @@ impl From> for PeriodogramParameters where T: Float, F: FeatureEvaluator + From + TryInto, - >::Error: Debug, + >::Error: Debug, { fn from(f: Periodogram) -> Self { let Periodogram { diff --git a/src/lib.rs b/src/lib.rs index 7bf202d..8dcbc91 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,7 @@ pub use float_trait::Float; mod lnerfc; -mod multicolor; +pub mod multicolor; pub use multicolor::*; mod nl_fit; diff --git a/src/multicolor/features/mod.rs b/src/multicolor/features/mod.rs index 54b5a4c..4afb029 100644 --- a/src/multicolor/features/mod.rs +++ b/src/multicolor/features/mod.rs @@ -6,3 +6,6 @@ pub use color_of_median::ColorOfMedian; mod color_of_minimum; pub use color_of_minimum::ColorOfMinimum; + +mod multi_color_periodogram; +pub use multi_color_periodogram::MultiColorPeriodogram; diff --git a/src/multicolor/features/multi_color_periodogram.rs b/src/multicolor/features/multi_color_periodogram.rs new file mode 100644 index 0000000..f601973 --- /dev/null +++ b/src/multicolor/features/multi_color_periodogram.rs @@ -0,0 +1,255 @@ +use crate::data::MultiColorTimeSeries; +use crate::error::MultiColorEvaluatorError; +use crate::evaluator::TmArrays; +use crate::evaluator::{ + EvaluatorInfo, EvaluatorInfoTrait, FeatureEvaluator, FeatureNamesDescriptionsTrait, OwnedArrays, +}; + +use crate::features::{Periodogram, PeriodogramPeaks}; +use crate::float_trait::Float; +use crate::multicolor::multicolor_evaluator::*; +use crate::multicolor::{PassbandSet, PassbandTrait}; +use crate::periodogram::{self, NyquistFreq, PeriodogramPower}; + +use ndarray::Array1; + +use std::fmt::Debug; + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde( + bound = "T: Float, F: FeatureEvaluator + From + TryInto, >::Error: Debug," +)] +pub struct MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator, +{ + // We use it to not reimplement some internals + monochrome: Periodogram, + properties: Box, +} + +impl MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator, +{ + #[inline] + pub fn default_peaks() -> usize { + PeriodogramPeaks::default_peaks() + } + + #[inline] + pub fn default_resolution() -> f32 { + Periodogram::::default_resolution() + } + + #[inline] + pub fn default_max_freq_factor() -> f32 { + Periodogram::::default_max_freq_factor() + } + + /// Set frequency resolution + /// + /// The larger frequency resolution allows to find peak period with better precision + pub fn set_freq_resolution(&mut self, resolution: f32) -> &mut Self { + self.monochrome.set_freq_resolution(resolution); + self + } + + /// Multiply maximum (Nyquist) frequency + /// + /// Maximum frequency is Nyquist frequncy multiplied by this factor. The larger factor allows + /// to find larger frequency and makes [PeriodogramPowerFft] more precise. However large + /// frequencies can show false peaks + pub fn set_max_freq_factor(&mut self, max_freq_factor: f32) -> &mut Self { + self.monochrome.set_max_freq_factor(max_freq_factor); + self + } + + /// Define Nyquist frequency + pub fn set_nyquist(&mut self, nyquist: NyquistFreq) -> &mut Self { + self.monochrome.set_nyquist(nyquist); + self + } + + /// Extend a feature to extract from periodogram + pub fn add_feature(&mut self, feature: F) -> &mut Self { + self.monochrome.add_feature(feature); + self + } + + pub fn set_periodogram_algorithm( + &mut self, + periodogram_power: PeriodogramPower, + ) -> &mut Self { + self.monochrome.set_periodogram_algorithm(periodogram_power); + self + } +} + +impl MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator + From + TryInto, + >::Error: Debug, +{ + fn power_from_periodogram<'slf, 'a, 'mcts, P>( + &self, + p: &periodogram::Periodogram, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: PassbandTrait, + { + let unnormed_power = mcts + .mapping_mut() + .values_mut() + .filter(|ts| self.monochrome.check_ts_length(ts).is_ok()) + .map(|ts| p.power(ts) * ts.lenf()) + .reduce(|acc, x| acc + x) + .ok_or_else(|| { + MultiColorEvaluatorError::all_time_series_short( + mcts.mapping_mut(), + self.monochrome.min_ts_length(), + ) + })?; + Ok(unnormed_power / mcts.total_lenf()) + } + + pub fn power<'slf, 'a, 'mcts, P>( + &self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: PassbandTrait, + { + self.power_from_periodogram( + &self.monochrome.periodogram(mcts.flat_mut().t.as_slice()), + mcts, + ) + } + + pub fn freq_power<'slf, 'a, 'mcts, P>( + &self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result<(Array1, Array1), MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + P: PassbandTrait, + { + let p = self.monochrome.periodogram(mcts.flat_mut().t.as_slice()); + let power = self.power_from_periodogram(&p, mcts)?; + let freq = (0..power.len()).map(|i| p.freq_by_index(i)).collect(); + Ok((freq, power)) + } +} + +impl MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator + From + TryInto, + >::Error: Debug, +{ + fn transform_mcts_to_ts

( + &self, + mcts: &mut MultiColorTimeSeries, + ) -> Result, MultiColorEvaluatorError> + where + P: PassbandTrait, + { + let (freq, power) = self.freq_power(mcts)?; + Ok(TmArrays { t: freq, m: power }) + } +} + +impl EvaluatorInfoTrait for MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator + From + TryInto, + >::Error: Debug, +{ + fn get_info(&self) -> &EvaluatorInfo { + &self.properties.info + } +} + +impl FeatureNamesDescriptionsTrait for MultiColorPeriodogram +where + T: Float, + F: FeatureEvaluator + From + TryInto, + >::Error: Debug, +{ + fn get_names(&self) -> Vec<&str> { + self.properties.names.iter().map(String::as_str).collect() + } + + fn get_descriptions(&self) -> Vec<&str> { + self.properties + .descriptions + .iter() + .map(String::as_str) + .collect() + } +} + +impl MultiColorPassbandSetTrait

for MultiColorPeriodogram +where + T: Float, + P: PassbandTrait, + F: FeatureEvaluator, +{ + fn get_passband_set(&self) -> &PassbandSet

{ + &PassbandSet::AllAvailable + } +} + +impl MultiColorEvaluator for MultiColorPeriodogram +where + P: PassbandTrait, + T: Float, + F: FeatureEvaluator + From + TryInto, + >::Error: Debug, +{ + fn eval_multicolor_no_mcts_check<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + let arrays = self.transform_mcts_to_ts(mcts)?; + let mut ts = arrays.ts(); + self.monochrome + .feature_extractor + .eval(&mut ts) + .map_err(From::from) + } + + /// Returns vector of feature values and fill invalid components with given value + fn eval_or_fill_multicolor<'slf, 'a, 'mcts>( + &'slf self, + mcts: &'mcts mut MultiColorTimeSeries<'a, P, T>, + fill_value: T, + ) -> Result, MultiColorEvaluatorError> + where + 'slf: 'a, + 'a: 'mcts, + { + let arrays = match self.transform_mcts_to_ts(mcts) { + Ok(arrays) => arrays, + Err(_) => return Ok(vec![fill_value; self.size_hint()]), + }; + let mut ts = arrays.ts(); + Ok(self + .monochrome + .feature_extractor + .eval_or_fill(&mut ts, fill_value)) + } +} diff --git a/src/multicolor/mod.rs b/src/multicolor/mod.rs index fb09f71..bcb785b 100644 --- a/src/multicolor/mod.rs +++ b/src/multicolor/mod.rs @@ -1,4 +1,4 @@ -mod features; +pub mod features; mod monochrome_feature; pub use monochrome_feature::MonochromeFeature; diff --git a/src/periodogram/mod.rs b/src/periodogram/mod.rs index 6bc477e..b83407e 100644 --- a/src/periodogram/mod.rs +++ b/src/periodogram/mod.rs @@ -5,6 +5,7 @@ use crate::float_trait::Float; use conv::ConvAsUtil; use enum_dispatch::enum_dispatch; +use ndarray::Array1; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -92,11 +93,11 @@ where ) } - pub fn freq(&self, i: usize) -> T { + pub fn freq_by_index(&self, i: usize) -> T { self.freq_grid.step * (i + 1).approx().unwrap() } - pub fn power(&self, ts: &mut TimeSeries) -> Vec { + pub fn power(&self, ts: &mut TimeSeries) -> Array1 { self.periodogram_power.power(&self.freq_grid, ts) } } @@ -110,16 +111,17 @@ mod tests { use crate::data::SortedArray; use crate::peak_indices::peak_indices_reverse_sorted; - use light_curve_common::{all_close, linspace}; + use approx::assert_relative_eq; + use ndarray::{arr1, s}; use rand::prelude::*; #[test] fn compr_direct_with_scipy() { const OMEGA_SIN: f64 = 0.07; const N: usize = 100; - let t = linspace(0.0, 99.0, N); - let m: Vec<_> = t.iter().map(|&x| f64::sin(OMEGA_SIN * x)).collect(); - let mut ts = TimeSeries::new_without_weight(&t, &m); + let t = Array1::linspace(0.0, 99.0, N); + let m = t.mapv(|x| f64::sin(OMEGA_SIN * x)); + let mut ts = TimeSeries::new_without_weight(t, m); let periodogram = Periodogram::new( PeriodogramPowerDirect.into(), FreqGrid { @@ -127,10 +129,10 @@ mod tests { size: 1, }, ); - all_close( - &[periodogram.power(&mut ts)[0] * 2.0 / (N as f64 - 1.0)], - &[1.0], - 1.0 / (N as f64), + assert_relative_eq!( + periodogram.power(&mut ts)[0] * 2.0 / (N as f64 - 1.0), + 1.0, + epsilon = 1.0 / (N as f64), ); // import numpy as np @@ -147,26 +149,28 @@ mod tests { size: 5, }; let periodogram = Periodogram::new(PeriodogramPowerDirect.into(), freq_grid.clone()); - all_close( - &linspace( + assert_relative_eq!( + Array1::linspace( freq_grid.step, freq_grid.step * freq_grid.size as f64, freq_grid.size, - ), - &(0..freq_grid.size) - .map(|i| periodogram.freq(i)) - .collect::>(), - 1e-12, + ) + .view(), + (0..freq_grid.size) + .map(|i| periodogram.freq_by_index(i)) + .collect::>() + .view(), + epsilon = 1e-12, ); - let desired = [ + let desired = arr1(&[ 16.99018018, 18.57722516, 21.96049738, 28.15056806, 36.66519435, - ]; + ]); let actual = periodogram.power(&mut ts); - all_close(&actual[..], &desired[..], 1e-6); + assert_relative_eq!(actual, desired, epsilon = 1e-6); } #[test] @@ -176,14 +180,14 @@ mod tests { const RESOLUTION: f32 = 1.0; const MAX_FREQ_FACTOR: f32 = 1.0; - let t = linspace(0.0, (N - 1) as f64, N); - let m: Vec<_> = t.iter().map(|&x| f64::sin(OMEGA * x)).collect(); - let mut ts = TimeSeries::new_without_weight(&t, &m); + let t = Array1::linspace(0.0, (N - 1) as f64, N); + let m = t.mapv(|x| f64::sin(OMEGA * x)); + let mut ts = TimeSeries::new_without_weight(t, m); let nyquist: NyquistFreq = AverageNyquistFreq.into(); let direct = Periodogram::from_t( PeriodogramPowerDirect.into(), - &t, + ts.t.as_slice(), RESOLUTION, MAX_FREQ_FACTOR, nyquist.clone(), @@ -191,13 +195,17 @@ mod tests { .power(&mut ts); let fft = Periodogram::from_t( PeriodogramPowerFft::new().into(), - &t, + ts.t.as_slice(), RESOLUTION, MAX_FREQ_FACTOR, nyquist, ) .power(&mut ts); - all_close(&fft[..direct.len() - 1], &direct[..direct.len() - 1], 1e-8); + assert_relative_eq!( + fft.slice(s![..direct.len() - 1]), + direct.slice(s![..direct.len() - 1]), + epsilon = 1e-8 + ); } #[test] @@ -209,17 +217,14 @@ mod tests { const RESOLUTION: f32 = 4.0; const MAX_FREQ_FACTOR: f32 = 1.0; - let t = linspace(0.0, (N - 1) as f64, N); - let m: Vec<_> = t - .iter() - .map(|&x| f64::sin(OMEGA1 * x) + AMPLITUDE2 * f64::cos(OMEGA2 * x)) - .collect(); - let mut ts = TimeSeries::new_without_weight(&t, &m); + let t = Array1::linspace(0.0, (N - 1) as f64, N); + let m = t.mapv(|x| f64::sin(OMEGA1 * x) + AMPLITUDE2 * f64::cos(OMEGA2 * x)); + let mut ts = TimeSeries::new_without_weight(t, m); let nyquist: NyquistFreq = AverageNyquistFreq.into(); let direct = Periodogram::from_t( PeriodogramPowerDirect.into(), - &t, + ts.t.as_slice(), RESOLUTION, MAX_FREQ_FACTOR, nyquist.clone(), @@ -227,7 +232,7 @@ mod tests { .power(&mut ts); let fft = Periodogram::from_t( PeriodogramPowerFft::new().into(), - &t, + ts.t.as_slice(), RESOLUTION, MAX_FREQ_FACTOR, nyquist, diff --git a/src/periodogram/power_direct.rs b/src/periodogram/power_direct.rs index a802031..7009314 100644 --- a/src/periodogram/power_direct.rs +++ b/src/periodogram/power_direct.rs @@ -4,6 +4,7 @@ use crate::periodogram::freq::FreqGrid; use crate::periodogram::power_trait::*; use crate::periodogram::recurrent_sin_cos::*; +use ndarray::Array1; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -22,7 +23,7 @@ impl PeriodogramPowerTrait for PeriodogramPowerDirect where T: Float, { - fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Vec { + fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Array1 { let m_mean = ts.m.get_mean(); let sin_cos_omega_tau = SinCosOmegaTau::new(freq.step, ts.t.as_slice().iter()); diff --git a/src/periodogram/power_fft.rs b/src/periodogram/power_fft.rs index 40526df..dadd6b0 100644 --- a/src/periodogram/power_fft.rs +++ b/src/periodogram/power_fft.rs @@ -5,6 +5,7 @@ use crate::periodogram::freq::FreqGrid; use crate::periodogram::power_trait::*; use conv::{ConvAsUtil, RoundToNearest}; +use ndarray::Array1; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::cell::RefCell; @@ -73,11 +74,11 @@ impl PeriodogramPowerTrait for PeriodogramPowerFft where T: Float, { - fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Vec { + fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Array1 { let m_std2 = ts.m.get_std2(); if m_std2.is_zero() { - return vec![T::zero(); freq.size.next_power_of_two()]; + return Array1::zeros(freq.size.next_power_of_two()); } let grid = TimeGrid::from_freq_grid(freq); diff --git a/src/periodogram/power_trait.rs b/src/periodogram/power_trait.rs index 46c271b..92e0361 100644 --- a/src/periodogram/power_trait.rs +++ b/src/periodogram/power_trait.rs @@ -3,6 +3,7 @@ use crate::float_trait::Float; use crate::periodogram::freq::FreqGrid; use enum_dispatch::enum_dispatch; +use ndarray::Array1; use std::fmt::Debug; /// Periodogram execution algorithm @@ -11,5 +12,5 @@ pub trait PeriodogramPowerTrait: Debug + Clone + Send where T: Float, { - fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Vec; + fn power(&self, freq: &FreqGrid, ts: &mut TimeSeries) -> Array1; } From bfff11d7b14d47b75194940dcf47196bf5924b3a Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Tue, 25 Apr 2023 12:58:03 -0500 Subject: [PATCH 04/11] Test for number_ending --- src/number_ending.rs | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/number_ending.rs b/src/number_ending.rs index 7af7c04..f994e03 100644 --- a/src/number_ending.rs +++ b/src/number_ending.rs @@ -1,3 +1,4 @@ +/// Return a suffix for a number, like "st", "nd", or "th". pub(crate) fn number_ending(i: usize) -> &'static str { #[allow(clippy::match_same_arms)] match (i % 10, i % 100) { @@ -10,3 +11,41 @@ pub(crate) fn number_ending(i: usize) -> &'static str { (_, _) => "th", } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test() { + assert_eq!(number_ending(0), "th"); + assert_eq!(number_ending(1), "st"); + assert_eq!(number_ending(2), "nd"); + assert_eq!(number_ending(3), "rd"); + assert_eq!(number_ending(4), "th"); + assert_eq!(number_ending(5), "th"); + assert_eq!(number_ending(6), "th"); + assert_eq!(number_ending(7), "th"); + assert_eq!(number_ending(8), "th"); + assert_eq!(number_ending(9), "th"); + assert_eq!(number_ending(10), "th"); + assert_eq!(number_ending(11), "th"); + assert_eq!(number_ending(12), "th"); + assert_eq!(number_ending(13), "th"); + assert_eq!(number_ending(14), "th"); + assert_eq!(number_ending(15), "th"); + assert_eq!(number_ending(16), "th"); + assert_eq!(number_ending(17), "th"); + assert_eq!(number_ending(18), "th"); + assert_eq!(number_ending(19), "th"); + assert_eq!(number_ending(20), "th"); + assert_eq!(number_ending(21), "st"); + assert_eq!(number_ending(22), "nd"); + assert_eq!(number_ending(23), "rd"); + assert_eq!(number_ending(24), "th"); + assert_eq!(number_ending(25), "th"); + assert_eq!(number_ending(100), "th"); + assert_eq!(number_ending(101), "st"); + assert_eq!(number_ending(102), "nd"); + } +} From 3bc26fda8a7c7efe685d7607c1f930b4130f11fc Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Tue, 25 Apr 2023 12:59:01 -0500 Subject: [PATCH 05/11] Test for PeriodogramPeaks --- src/features/_periodogram_peaks.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/features/_periodogram_peaks.rs b/src/features/_periodogram_peaks.rs index 172e877..39c8829 100644 --- a/src/features/_periodogram_peaks.rs +++ b/src/features/_periodogram_peaks.rs @@ -12,6 +12,10 @@ use std::iter; macro_const! { const PERIODOGRAM_PEAKS_DOC: &'static str = r#" Peak evaluator for [Periodogram] + +- Depends on: **time**, **magnitude** (which have meaning of frequency and spectral density) +- Minimum number of observations: **1** +- Number of features: **2 * npeaks** "#; } @@ -149,3 +153,11 @@ impl From for PeriodogramPeaks { impl JsonSchema for PeriodogramPeaks { json_schema!(PeriodogramPeaksParameters, false); } + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::*; + + check_feature!(PeriodogramPeaks); +} From ff4b6d4f60fc8aab4e81105045ed026780d91185 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Tue, 25 Apr 2023 15:14:24 -0500 Subject: [PATCH 06/11] Docs for multi-color stuff --- src/multicolor/features/color_of_maximum.rs | 8 +++++ src/multicolor/features/color_of_median.rs | 3 ++ src/multicolor/features/color_of_minimum.rs | 8 +++++ .../features/multi_color_periodogram.rs | 3 +- src/multicolor/monochrome_feature.rs | 35 +++++++++++++++++++ src/multicolor/multicolor_evaluator.rs | 10 ++++++ src/multicolor/multicolor_extractor.rs | 5 +++ src/multicolor/passband/dump_passband.rs | 12 +++++++ .../passband/monochrome_passband.rs | 18 ++++++++++ 9 files changed, 100 insertions(+), 2 deletions(-) diff --git a/src/multicolor/features/color_of_maximum.rs b/src/multicolor/features/color_of_maximum.rs index 9639f87..cd20ee0 100644 --- a/src/multicolor/features/color_of_maximum.rs +++ b/src/multicolor/features/color_of_maximum.rs @@ -11,6 +11,10 @@ pub use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt::Debug; +/// Difference maximum value magnitudes of two passbands +/// +/// Note that maximum is calculated for each passband separately, and maximum has mathematical +/// meaning, not "magnitudial" (astronomical) one. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] pub struct ColorOfMaximum

@@ -27,6 +31,10 @@ impl

ColorOfMaximum

where P: PassbandTrait, { + /// Create new [ColorOfMaximum] evaluator + /// + /// # Arguments + /// - `passbands` - two passbands pub fn new(passbands: [P; 2]) -> Self { let set: BTreeSet<_> = passbands.clone().into(); Self { diff --git a/src/multicolor/features/color_of_median.rs b/src/multicolor/features/color_of_median.rs index a6779e4..83a0e78 100644 --- a/src/multicolor/features/color_of_median.rs +++ b/src/multicolor/features/color_of_median.rs @@ -14,6 +14,9 @@ pub use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt::Debug; +/// Difference of median magnitudes in two passbands +/// +/// Note that median is calculated for each passband separately #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] pub struct ColorOfMedian

diff --git a/src/multicolor/features/color_of_minimum.rs b/src/multicolor/features/color_of_minimum.rs index 72de5c7..3129544 100644 --- a/src/multicolor/features/color_of_minimum.rs +++ b/src/multicolor/features/color_of_minimum.rs @@ -11,6 +11,10 @@ pub use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt::Debug; +/// Difference of minimum magnitudes of two passbands +/// +/// Note that minimum is calculated for each passband separately, and maximum has mathematical +/// meaning, not "magnitudial" (astronomical) one. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] pub struct ColorOfMinimum

@@ -27,6 +31,10 @@ impl

ColorOfMinimum

where P: PassbandTrait, { + /// Create new [ColorOfMinimum] evaluator + /// + /// # Arguments + /// - `passbands` - two passbands pub fn new(passbands: [P; 2]) -> Self { let set: BTreeSet<_> = passbands.clone().into(); Self { diff --git a/src/multicolor/features/multi_color_periodogram.rs b/src/multicolor/features/multi_color_periodogram.rs index f601973..8f646d4 100644 --- a/src/multicolor/features/multi_color_periodogram.rs +++ b/src/multicolor/features/multi_color_periodogram.rs @@ -4,7 +4,6 @@ use crate::evaluator::TmArrays; use crate::evaluator::{ EvaluatorInfo, EvaluatorInfoTrait, FeatureEvaluator, FeatureNamesDescriptionsTrait, OwnedArrays, }; - use crate::features::{Periodogram, PeriodogramPeaks}; use crate::float_trait::Float; use crate::multicolor::multicolor_evaluator::*; @@ -12,9 +11,9 @@ use crate::multicolor::{PassbandSet, PassbandTrait}; use crate::periodogram::{self, NyquistFreq, PeriodogramPower}; use ndarray::Array1; - use std::fmt::Debug; +/// Multi-passband periodogram #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde( bound = "T: Float, F: FeatureEvaluator + From + TryInto, >::Error: Debug," diff --git a/src/multicolor/monochrome_feature.rs b/src/multicolor/monochrome_feature.rs index 2b483d2..7e9e411 100644 --- a/src/multicolor/monochrome_feature.rs +++ b/src/multicolor/monochrome_feature.rs @@ -15,6 +15,7 @@ use std::collections::BTreeSet; use std::fmt::Debug; use std::marker::PhantomData; +/// Multi-color feature which evaluates non-color dependent feature for each passband. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(bound( deserialize = "P: PassbandTrait + Deserialize<'de>, T: Float, F: FeatureEvaluator" @@ -35,6 +36,11 @@ where T: Float, F: FeatureEvaluator, { + /// Creates a new instance of `MonochromeFeature`. + /// + /// # Arguments + /// - `feature` - non-multi-color feature to evaluate for each passband. + /// - `passband_set` - set of passbands to evaluate the feature for. pub fn new(feature: F, passband_set: BTreeSet

) -> Self { let names = passband_set .iter() @@ -130,3 +136,32 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + + use crate::features::Mean; + use crate::multicolor::passband::MonochromePassband; + use crate::Feature; + + #[test] + fn test_monochrome_feature() { + let feature: MonochromeFeature, f64, Feature<_>> = + MonochromeFeature::new( + Mean::default().into(), + [ + MonochromePassband::new(4700e-8, "g"), + MonochromePassband::new(6200e-8, "r"), + ] + .into_iter() + .collect(), + ); + assert_eq!(feature.get_names(), vec!["mean_g", "mean_r"]); + assert_eq!( + feature.get_descriptions(), + vec!["mean magnitude, passband g", "mean magnitude, passband r"] + ); + assert_eq!(feature.get_info().size, 2); + } +} diff --git a/src/multicolor/multicolor_evaluator.rs b/src/multicolor/multicolor_evaluator.rs index c60dddf..6d67b4b 100644 --- a/src/multicolor/multicolor_evaluator.rs +++ b/src/multicolor/multicolor_evaluator.rs @@ -16,14 +16,20 @@ pub use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; use std::fmt::Debug; +/// Trait for getting alphabetically sorted passbands #[enum_dispatch] pub trait MultiColorPassbandSetTrait

where P: PassbandTrait, { + /// Get passband set for this evaluator fn get_passband_set(&self) -> &PassbandSet

; } +/// Enum for passband set, which can be either fixed set or all available passbands. +/// This is used for [MultiColorEvaluator]s, which can be evaluated on all available passbands +/// (for example [MultiColorPeriodogram](super::features::MultiColorPeriodogram)) or on fixed set of +/// passbands (for example [ColorOfMaximum](super::ColorOfMaximum)). #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde(bound(deserialize = "P: PassbandTrait + Deserialize<'de>"))] #[non_exhaustive] @@ -31,7 +37,9 @@ pub enum PassbandSet

where P: Ord, { + /// Fixed set of passbands FixedSet(BTreeSet

), + /// All available passbands AllAvailable, } @@ -44,6 +52,7 @@ where } } +/// Helper error for [MultiColorEvaluator] enum InternalMctsError { MultiColorEvaluatorError(MultiColorEvaluatorError), InternalWrongPassbandSet, @@ -78,6 +87,7 @@ impl InternalMctsError { } } +/// Trait for multi-color feature evaluators #[enum_dispatch] pub trait MultiColorEvaluator: FeatureNamesDescriptionsTrait diff --git a/src/multicolor/multicolor_extractor.rs b/src/multicolor/multicolor_extractor.rs index 5fe4d70..3ef1e49 100644 --- a/src/multicolor/multicolor_extractor.rs +++ b/src/multicolor/multicolor_extractor.rs @@ -12,6 +12,7 @@ use std::collections::BTreeSet; use std::fmt::Debug; use std::marker::PhantomData; +/// Bulk feature evaluator. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde( into = "MultiColorExtractorParameters", @@ -37,6 +38,10 @@ where T: Float, MCF: MultiColorEvaluator, { + /// Create a new [MultiColorExtractor] + /// + /// # Arguments + /// `features` - A vector of multi-color features to be evaluated pub fn new(features: Vec) -> Self { let passband_set = { let set: BTreeSet<_> = features diff --git a/src/multicolor/passband/dump_passband.rs b/src/multicolor/passband/dump_passband.rs index ffd73cd..675d7e6 100644 --- a/src/multicolor/passband/dump_passband.rs +++ b/src/multicolor/passband/dump_passband.rs @@ -4,6 +4,7 @@ pub use schemars::JsonSchema; pub use serde::{Deserialize, Serialize}; use std::fmt::Debug; +/// A passband for the cases where we don't care about the actual passband. #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] pub struct DumpPassband {} @@ -12,3 +13,14 @@ impl PassbandTrait for DumpPassband { "" } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dump_passband() { + let passband = DumpPassband {}; + assert_eq!(passband.name(), ""); + } +} diff --git a/src/multicolor/passband/monochrome_passband.rs b/src/multicolor/passband/monochrome_passband.rs index 987bc79..09f6f29 100644 --- a/src/multicolor/passband/monochrome_passband.rs +++ b/src/multicolor/passband/monochrome_passband.rs @@ -7,6 +7,7 @@ pub use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::fmt::Debug; +/// A passband specified by a single wavelength. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct MonochromePassband<'a, T> { pub name: &'a str, @@ -17,6 +18,11 @@ impl<'a, T> MonochromePassband<'a, T> where T: Float, { + /// Create a new `MonochromePassband`. + /// + /// # Arguments + /// - `wavelength`: The wavelength of the passband, panic if it is not a positive normal number. + /// - `name`: The name of the passband. pub fn new(wavelength: T, name: &'a str) -> Self { assert!( wavelength.is_normal(), @@ -67,3 +73,15 @@ where self.name } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_monochrome_passband() { + let passband = MonochromePassband::new(1.0, "test"); + assert_eq!(passband.name(), "test"); + assert_eq!(passband.wavelength, 1.0); + } +} From 318ef0291afdb65d13835926fad8f816511f109b Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Tue, 25 Apr 2023 15:30:53 -0500 Subject: [PATCH 07/11] TimeSeries::m_chi2 --- CHANGELOG.md | 2 +- src/data/time_series.rs | 22 ++++++++++++++-------- src/features/stetson_k.rs | 3 +-- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ab65c4..743cade 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added --- +- `m_chi2` attribute and `get_m_chi2` method for `TimeSeries` ### Changed diff --git a/src/data/time_series.rs b/src/data/time_series.rs index 855e781..473ad27 100644 --- a/src/data/time_series.rs +++ b/src/data/time_series.rs @@ -21,6 +21,7 @@ where pub m: DataSample<'a, T>, pub w: DataSample<'a, T>, m_weighted_mean: Option, + m_chi2: Option, m_reduced_chi2: Option, t_max_m: Option, t_min_m: Option, @@ -84,6 +85,7 @@ where m, w, m_weighted_mean: None, + m_chi2: None, m_reduced_chi2: None, t_max_m: None, t_min_m: None, @@ -116,6 +118,7 @@ where m, w, m_weighted_mean: None, + m_chi2: None, m_reduced_chi2: None, t_max_m: None, t_min_m: None, @@ -140,20 +143,23 @@ where |ts: &mut TimeSeries| { ts.m.sample.weighted_mean(&ts.w.sample).unwrap() } ); - time_series_getter!(m_reduced_chi2, get_m_reduced_chi2, |ts: &mut TimeSeries< - T, - >| { + time_series_getter!(m_chi2, get_m_chi2, |ts: &mut TimeSeries| { let m_weighed_mean = ts.get_m_weighted_mean(); - let m_reduced_chi2 = Zip::from(&ts.m.sample) + let m_chi2 = Zip::from(&ts.m.sample) .and(&ts.w.sample) .fold(T::zero(), |chi2, &m, &w| { chi2 + (m - m_weighed_mean).powi(2) * w - }) - / (ts.lenf() - T::one()); - if m_reduced_chi2.is_zero() { + }); + if m_chi2.is_zero() { ts.plateau = Some(true); } - m_reduced_chi2 + m_chi2 + }); + + time_series_getter!(m_reduced_chi2, get_m_reduced_chi2, |ts: &mut TimeSeries< + T, + >| { + ts.get_m_chi2() / (ts.lenf() - T::one()) }); time_series_getter!(bool, plateau, is_plateau, |ts: &mut TimeSeries| { diff --git a/src/features/stetson_k.rs b/src/features/stetson_k.rs index a28d3c0..7110e23 100644 --- a/src/features/stetson_k.rs +++ b/src/features/stetson_k.rs @@ -62,12 +62,11 @@ where T: Float, { fn eval_no_ts_check(&self, ts: &mut TimeSeries) -> Result, EvaluatorError> { - let chi2 = ts.get_m_reduced_chi2() * (ts.lenf() - T::one()); let mean = ts.get_m_weighted_mean(); let value = Zip::from(&ts.m.sample) .and(&ts.w.sample) .fold(T::zero(), |acc, &y, &w| acc + T::abs(y - mean) * T::sqrt(w)) - / T::sqrt(ts.lenf() * chi2); + / T::sqrt(ts.lenf() * ts.get_m_chi2()); Ok(vec![value]) } } From a2c483d76d23b96092d58692a8ac5f698b54aefa Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Thu, 27 Apr 2023 16:22:59 -0500 Subject: [PATCH 08/11] PartialEq for DataSample and TimeSeries --- src/data/data_sample.rs | 9 +++++++++ src/data/time_series.rs | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/src/data/data_sample.rs b/src/data/data_sample.rs index 44e5073..f7753a5 100644 --- a/src/data/data_sample.rs +++ b/src/data/data_sample.rs @@ -21,6 +21,15 @@ where std2: Option, } +impl<'a, T> PartialEq for DataSample<'a, T> +where + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.sample == other.sample + } +} + macro_rules! data_sample_getter { ($attr: ident, $getter: ident, $func: expr, $method_sorted: ident) => { // This lint is false-positive in macros diff --git a/src/data/time_series.rs b/src/data/time_series.rs index 473ad27..aaf8230 100644 --- a/src/data/time_series.rs +++ b/src/data/time_series.rs @@ -28,6 +28,15 @@ where plateau: Option, } +impl<'a, T> PartialEq for TimeSeries<'a, T> +where + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.t == other.t && self.m == other.m && self.w == other.w + } +} + macro_rules! time_series_getter { ($t: ty, $attr: ident, $getter: ident, $func: expr) => { // This lint is false-positive in macros From 8abf92df676020afaaa465cdc2350362de50de4a Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Thu, 27 Apr 2023 16:24:01 -0500 Subject: [PATCH 09/11] More variants converters for MCTS --- CHANGELOG.md | 1 + Cargo.toml | 1 + src/data/multi_color_time_series.rs | 237 ++++++++++++++++++++++++++-- src/error.rs | 3 + 4 files changed, 229 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 743cade..9089e8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `m_chi2` attribute and `get_m_chi2` method for `TimeSeries` +- `take_mut` dependency ### Changed diff --git a/Cargo.toml b/Cargo.toml index c3d4718..9c73545 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,7 @@ num-traits = "^0.2" paste = "1" schemars = "^0.8" serde = { version = "1", features = ["derive"] } +take_mut = "0.2.2" thiserror = "1" thread_local = "1.1" unzip3 = "1" diff --git a/src/data/multi_color_time_series.rs b/src/data/multi_color_time_series.rs index fa1bd4d..787d3bc 100644 --- a/src/data/multi_color_time_series.rs +++ b/src/data/multi_color_time_series.rs @@ -10,6 +10,7 @@ use itertools::Itertools; use std::collections::{BTreeMap, BTreeSet}; use std::ops::{Deref, DerefMut}; +#[derive(Clone, Debug)] pub enum MultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { Mapping(MappedMultiColorTimeSeries<'a, P, T>), Flat(FlatMultiColorTimeSeries<'a, P, T>), @@ -40,6 +41,15 @@ where } } + pub fn passband_count(&self) -> usize { + match self { + Self::Mapping(mapping) => mapping.passband_count(), + Self::Flat(flat) => flat.passband_count(), + // Both flat and mapping have the same number of passbands and should be equally fast + Self::MappingFlat { flat, .. } => flat.passband_count(), + } + } + pub fn from_map(map: impl Into>>) -> Self { Self::Mapping(MappedMultiColorTimeSeries::new(map)) } @@ -53,21 +63,43 @@ where Self::Flat(FlatMultiColorTimeSeries::new(t, m, w, passbands)) } - pub fn mapping_mut(&mut self) -> &mut MappedMultiColorTimeSeries<'a, P, T> { + fn ensure_mapping(&mut self) -> &mut Self { if matches!(self, MultiColorTimeSeries::Flat(_)) { - let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new())); - *self = match std::mem::replace(self, dummy_self) { + take_mut::take(self, |slf| match slf { Self::Flat(mut flat) => { let mapping = MappedMultiColorTimeSeries::from_flat(&mut flat); Self::MappingFlat { mapping, flat } } - _ => unreachable!(), - } + _ => unreachable!("We just checked that we are in ::Flat variant"), + }); } + self + } + + fn enforce_mapping(&mut self) -> &mut Self { match self { + Self::Mapping(_) => {} + Self::Flat(_flat) => take_mut::take(self, |slf| match slf { + Self::Flat(flat) => Self::Mapping(flat.into()), + _ => unreachable!("We just checked that we are in ::Flat variant"), + }), + Self::MappingFlat { .. } => { + take_mut::take(self, |slf| match slf { + Self::MappingFlat { mapping, .. } => Self::Mapping(mapping), + _ => unreachable!("We just checked that we are in ::MappingFlat variant"), + }); + } + } + self + } + + pub fn mapping_mut(&mut self) -> &mut MappedMultiColorTimeSeries<'a, P, T> { + match self.ensure_mapping() { Self::Mapping(mapping) => mapping, Self::Flat(_flat) => { - unreachable!("::Flat variant is already transofrmed to ::MappingFlat") + unreachable!( + "::Flat variant is already transformed to ::MappingFlat in ensure_mapping" + ) } Self::MappingFlat { mapping, .. } => mapping, } @@ -81,20 +113,25 @@ where } } - pub fn flat_mut(&mut self) -> &mut FlatMultiColorTimeSeries<'a, P, T> { + fn ensure_flat(&mut self) -> &mut Self { if matches!(self, MultiColorTimeSeries::Mapping(_)) { - let dummy_self = Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new())); - *self = match std::mem::replace(self, dummy_self) { + take_mut::take(self, |slf| match slf { Self::Mapping(mut mapping) => { let flat = FlatMultiColorTimeSeries::from_mapping(&mut mapping); Self::MappingFlat { mapping, flat } } - _ => unreachable!(), - } + _ => unreachable!("We just checked that we are in ::Mapping variant"), + }); } - match self { + self + } + + pub fn flat_mut(&mut self) -> &mut FlatMultiColorTimeSeries<'a, P, T> { + match self.ensure_flat() { Self::Mapping(_mapping) => { - unreachable!("::Mapping veriant is already transformed to ::MappingFlat") + unreachable!( + "::Mapping variant is already transformed to ::MappingFlat in ensure_flat" + ) } Self::Flat(flat) => flat, Self::MappingFlat { flat, .. } => flat, @@ -124,12 +161,45 @@ where Self::MappingFlat { mapping, .. } => Either::Left(mapping.passbands()), } } + + /// Inserts new pair of passband and time series into the multicolor time series. + /// + /// It always converts [MultiColorTimeSeries] to [MultiColorTimeSeries::Mapping] variant. + /// Also it replaces existing time series if passband is already present, and returns old time + /// series. + pub fn insert(&mut self, passband: P, ts: TimeSeries<'a, T>) -> Option> { + match self.enforce_mapping() { + Self::Mapping(mapping) => mapping.0.insert(passband, ts), + _ => unreachable!("We just converted self to ::Mapping variant"), + } + } } +impl<'a, P, T> Default for MultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn default() -> Self { + Self::Mapping(MappedMultiColorTimeSeries::new(BTreeMap::new())) + } +} + +#[derive(Debug, Clone)] pub struct MappedMultiColorTimeSeries<'a, P: PassbandTrait, T: Float>( BTreeMap>, ); +impl<'a, P, T> PartialEq for MappedMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + impl<'a, 'p, P, T> MappedMultiColorTimeSeries<'a, P, T> where P: PassbandTrait + 'p, @@ -173,6 +243,10 @@ where self.total_lenu().value_as::().unwrap() } + pub fn passband_count(&self) -> usize { + self.0.len() + } + pub fn passbands<'slf>( &'slf self, ) -> std::collections::btree_map::Keys<'slf, P, TimeSeries<'a, T>> @@ -267,6 +341,7 @@ impl<'a, P: PassbandTrait, T: Float> DerefMut for MappedMultiColorTimeSeries<'a, } } +#[derive(Debug, Clone)] pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { pub t: DataSample<'a, T>, pub m: DataSample<'a, T>, @@ -275,6 +350,19 @@ pub struct FlatMultiColorTimeSeries<'a, P: PassbandTrait, T: Float> { passband_set: BTreeSet

, } +impl<'a, P, T> PartialEq for FlatMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn eq(&self, other: &Self) -> bool { + self.t == other.t + && self.m == other.m + && self.w == other.w + && self.passbands == other.passbands + } +} + impl<'a, P, T> FlatMultiColorTimeSeries<'a, P, T> where P: PassbandTrait, @@ -347,4 +435,127 @@ where pub fn total_lenf(&self) -> T { self.t.sample.len().value_as::().unwrap() } + + pub fn passband_count(&self) -> usize { + self.passband_set.len() + } +} + +impl<'a, P, T> From> for MappedMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(mut flat: FlatMultiColorTimeSeries<'a, P, T>) -> Self { + Self::from_flat(&mut flat) + } +} + +impl<'a, P, T> From> for FlatMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(mut mapped: MappedMultiColorTimeSeries<'a, P, T>) -> Self { + Self::from_mapping(&mut mapped.0) + } +} + +impl<'a, P, T> From> for MultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(flat: FlatMultiColorTimeSeries<'a, P, T>) -> Self { + Self::Flat(flat) + } +} + +impl<'a, P, T> From> for MultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(mapped: MappedMultiColorTimeSeries<'a, P, T>) -> Self { + Self::Mapping(mapped) + } +} + +impl<'a, P, T> From> for FlatMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(mcts: MultiColorTimeSeries<'a, P, T>) -> Self { + match mcts { + MultiColorTimeSeries::Flat(flat) => flat, + MultiColorTimeSeries::Mapping(mapped) => mapped.into(), + MultiColorTimeSeries::MappingFlat { flat, .. } => flat, + } + } +} + +impl<'a, P, T> From> for MappedMultiColorTimeSeries<'a, P, T> +where + P: PassbandTrait, + T: Float, +{ + fn from(mcts: MultiColorTimeSeries<'a, P, T>) -> Self { + match mcts { + MultiColorTimeSeries::Flat(flat) => flat.into(), + MultiColorTimeSeries::Mapping(mapping) => mapping, + MultiColorTimeSeries::MappingFlat { mapping, .. } => mapping, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::MonochromePassband; + + use ndarray::Array1; + + #[test] + fn multi_color_ts_insert() { + let mut mcts = MultiColorTimeSeries::default(); + mcts.insert( + MonochromePassband::new(4700.0, "g"), + TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 11), Array1::zeros(11)), + ); + assert_eq!(mcts.passband_count(), 1); + assert_eq!(mcts.total_lenu(), 11); + mcts.insert( + MonochromePassband::new(6200.0, "r"), + TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 6), Array1::zeros(6)), + ); + assert_eq!(mcts.passband_count(), 2); + assert_eq!(mcts.total_lenu(), 17); + } + + fn compare_variants(mcts: MultiColorTimeSeries) { + let flat: FlatMultiColorTimeSeries<_, _> = mcts.clone().into(); + let mapped: MappedMultiColorTimeSeries<_, _> = mcts.clone().into(); + let mapped_from_flat: MappedMultiColorTimeSeries<_, _> = flat.clone().into(); + let flat_from_mapped: FlatMultiColorTimeSeries<_, _> = mapped.clone().into(); + assert_eq!(mapped, mapped_from_flat); + assert_eq!(flat, flat_from_mapped); + } + + #[test] + fn convert_between_variants() { + let mut mcts = MultiColorTimeSeries::default(); + compare_variants(mcts.clone()); + mcts.insert( + MonochromePassband::new(4700.0, "g"), + TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 11), Array1::zeros(11)), + ); + compare_variants(mcts.clone()); + mcts.insert( + MonochromePassband::new(6200.0, "r"), + TimeSeries::new_without_weight(Array1::linspace(0.0, 1.0, 6), Array1::zeros(6)), + ); + compare_variants(mcts.clone()); + } } diff --git a/src/error.rs b/src/error.rs index 6cd5234..9f11830 100644 --- a/src/error.rs +++ b/src/error.rs @@ -39,6 +39,9 @@ pub enum MultiColorEvaluatorError { #[error(r#"Underlying feature caused an error: "{0:?}""#)] UnderlyingEvaluatorError(#[from] EvaluatorError), + + #[error("All time-series are flat")] + AllTimeSeriesAreFlat, } impl MultiColorEvaluatorError { From 97ffcffc96c18a98f953b371c23d4d6b3da5f262 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Mon, 1 May 2023 10:24:02 -0500 Subject: [PATCH 10/11] Component norm for MC periodogram --- src/features/_periodogram_peaks.rs | 4 +- src/features/periodogram.rs | 53 ++++++----- src/multicolor/features/mod.rs | 2 +- .../features/multi_color_periodogram.rs | 90 +++++++++++++++---- 4 files changed, 106 insertions(+), 43 deletions(-) diff --git a/src/features/_periodogram_peaks.rs b/src/features/_periodogram_peaks.rs index 39c8829..aab3f6f 100644 --- a/src/features/_periodogram_peaks.rs +++ b/src/features/_periodogram_peaks.rs @@ -50,13 +50,13 @@ impl PeriodogramPeaks { .flat_map(|i| { vec![ format!( - "period of the {}{} highest peak of periodogram", + "period of the {}{} highest peak", i + 1, number_ending(i + 1), ), format!( "Spectral density to spectral density standard deviation ratio of \ - the {}{} highest peak of periodogram", + the {}{} highest peak", i + 1, number_ending(i + 1) ), diff --git a/src/features/periodogram.rs b/src/features/periodogram.rs index 27506cb..bc71b78 100644 --- a/src/features/periodogram.rs +++ b/src/features/periodogram.rs @@ -45,6 +45,10 @@ where max_freq_factor: f32, nyquist: NyquistFreq, pub(crate) feature_extractor: FeatureExtractor, + // In can be re-defined in MultiColorPeriodogram + pub(crate) name_prefix: String, + // In can be re-defined in MultiColorPeriodogram + pub(crate) description_suffix: String, periodogram_algorithm: PeriodogramPower, properties: Box, } @@ -100,13 +104,13 @@ where feature .get_names() .iter() - .map(|name| "periodogram_".to_owned() + name), + .map(|name| format!("{}_{}", self.name_prefix, name)), ); self.properties.descriptions.extend( feature .get_descriptions() .into_iter() - .map(|desc| format!("{} of periodogram", desc)), + .map(|desc| format!("{} {}", desc, self.description_suffix)), ); self.feature_extractor.add_feature(feature); self @@ -149,41 +153,44 @@ where { /// New [Periodogram] that finds given number of peaks pub fn new(peaks: usize) -> Self { - let peaks = PeriodogramPeaks::new(peaks); - let peak_names = peaks - .get_names() - .into_iter() - .map(ToOwned::to_owned) - .collect(); - let peak_descriptions = peaks - .get_descriptions() - .into_iter() - .map(ToOwned::to_owned) - .collect(); - let peaks_size_hint = peaks.size_hint(); - let peaks_min_ts_length = peaks.min_ts_length(); + Self::with_name_description( + peaks, + "periodogram", + "of periodogram (interpreting frequency as time, power as magnitude)", + ) + } + + pub(crate) fn with_name_description( + peaks: usize, + name_prefix: impl ToString, + description_suffix: impl ToString, + ) -> Self { let info = EvaluatorInfo { - size: peaks_size_hint, - min_ts_length: usize::max(peaks_min_ts_length, 2), + size: 0, + min_ts_length: 2, t_required: true, m_required: true, w_required: false, sorting_required: true, variability_required: false, }; - Self { + let mut slf = Self { properties: EvaluatorProperties { info, - names: peak_names, - descriptions: peak_descriptions, + names: vec![], + descriptions: vec![], } .into(), resolution: Self::default_resolution(), + name_prefix: name_prefix.to_string(), + description_suffix: description_suffix.to_string(), max_freq_factor: Self::default_max_freq_factor(), nyquist: AverageNyquistFreq.into(), - feature_extractor: FeatureExtractor::new(vec![peaks.into()]), + feature_extractor: FeatureExtractor::new(vec![]), periodogram_algorithm: PeriodogramPowerFft::new().into(), - } + }; + slf.add_feature(PeriodogramPeaks::new(peaks).into()); + slf } } @@ -286,7 +293,7 @@ where nyquist, feature_extractor, periodogram_algorithm, - properties: _, + .. } = f; let mut features = feature_extractor.into_vec(); diff --git a/src/multicolor/features/mod.rs b/src/multicolor/features/mod.rs index 4afb029..10c77d7 100644 --- a/src/multicolor/features/mod.rs +++ b/src/multicolor/features/mod.rs @@ -8,4 +8,4 @@ mod color_of_minimum; pub use color_of_minimum::ColorOfMinimum; mod multi_color_periodogram; -pub use multi_color_periodogram::MultiColorPeriodogram; +pub use multi_color_periodogram::{MultiColorPeriodogram, MultiColorPeriodogramNormalisation}; diff --git a/src/multicolor/features/multi_color_periodogram.rs b/src/multicolor/features/multi_color_periodogram.rs index 8f646d4..4829335 100644 --- a/src/multicolor/features/multi_color_periodogram.rs +++ b/src/multicolor/features/multi_color_periodogram.rs @@ -13,6 +13,20 @@ use crate::periodogram::{self, NyquistFreq, PeriodogramPower}; use ndarray::Array1; use std::fmt::Debug; +/// Normalisation of periodogram across passbands +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub enum MultiColorPeriodogramNormalisation { + /// Weight individual periodograms by the number of observations in each passband. + /// Useful if no weight is given to observations + Count, + /// Weight individual periodograms by $\chi^2 = \sum \left(\frac{m_i - \bar{m}}{\delta_i}\right)^2$ + /// + /// Be aware that if no weight are given to observations + /// (i.e. via [TimeSeries::new_without_weight]) unity weights are assumed and this is NOT + /// equivalent to [::Count], but weighting by magnitude variance. + Chi2, +} + /// Multi-passband periodogram #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] #[serde( @@ -25,14 +39,26 @@ where { // We use it to not reimplement some internals monochrome: Periodogram, - properties: Box, + normalization: MultiColorPeriodogramNormalisation, } impl MultiColorPeriodogram where T: Float, - F: FeatureEvaluator, + F: FeatureEvaluator + From, { + pub fn new(peaks: usize, normalization: MultiColorPeriodogramNormalisation) -> Self { + let monochrome = Periodogram::with_name_description( + peaks, + "multicolor_periodogram", + "of multi-color periodogram (interpreting frequency as time, power as magnitude)", + ); + Self { + monochrome, + normalization, + } + } + #[inline] pub fn default_peaks() -> usize { PeriodogramPeaks::default_peaks() @@ -103,19 +129,53 @@ where 'a: 'mcts, P: PassbandTrait, { - let unnormed_power = mcts - .mapping_mut() + let ts_weights = { + let mut a: Array1<_> = match self.normalization { + MultiColorPeriodogramNormalisation::Count => { + mcts.mapping_mut().values().map(|ts| ts.lenf()).collect() + } + MultiColorPeriodogramNormalisation::Chi2 => mcts + .mapping_mut() + .values_mut() + .map(|ts| ts.get_m_chi2()) + .collect(), + }; + let norm = a.sum(); + if norm.is_zero() { + match self.normalization { + MultiColorPeriodogramNormalisation::Count => { + return Err(MultiColorEvaluatorError::all_time_series_short( + mcts.mapping_mut(), + self.min_ts_length(), + )); + } + MultiColorPeriodogramNormalisation::Chi2 => { + return Err(MultiColorEvaluatorError::AllTimeSeriesAreFlat); + } + } + } + a /= norm; + a + }; + mcts.mapping_mut() .values_mut() - .filter(|ts| self.monochrome.check_ts_length(ts).is_ok()) - .map(|ts| p.power(ts) * ts.lenf()) - .reduce(|acc, x| acc + x) + .zip(ts_weights.iter()) + .filter(|(ts, _ts_weight)| self.monochrome.check_ts_length(ts).is_ok()) + .map(|(ts, &ts_weight)| { + let mut power = p.power(ts); + power *= ts_weight; + power + }) + .reduce(|mut acc, power| { + acc += &power; + acc + }) .ok_or_else(|| { MultiColorEvaluatorError::all_time_series_short( mcts.mapping_mut(), - self.monochrome.min_ts_length(), + self.min_ts_length(), ) - })?; - Ok(unnormed_power / mcts.total_lenf()) + }) } pub fn power<'slf, 'a, 'mcts, P>( @@ -174,7 +234,7 @@ where >::Error: Debug, { fn get_info(&self) -> &EvaluatorInfo { - &self.properties.info + self.monochrome.get_info() } } @@ -185,15 +245,11 @@ where >::Error: Debug, { fn get_names(&self) -> Vec<&str> { - self.properties.names.iter().map(String::as_str).collect() + self.monochrome.get_names() } fn get_descriptions(&self) -> Vec<&str> { - self.properties - .descriptions - .iter() - .map(String::as_str) - .collect() + self.monochrome.get_descriptions() } } From 21f0d66880c16f097ec6e4968ede81a5ec98ef6a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Sep 2023 18:34:44 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/time_series.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/time_series.rs b/src/time_series.rs index 8b13789..e69de29 100644 --- a/src/time_series.rs +++ b/src/time_series.rs @@ -1 +0,0 @@ -