diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..25aa79c --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +outputs/* +pyproject.toml +*/__pycache__/* +__pycache__/* +*.ipynb_checkpoints/ +poetry.lock +*.csv diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..fd7fe11 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,13 @@ +SOFTWARE DISCLAIMER / RELEASE + +This software was developed by employees of the National Telecommunications and Information Administration (NTIA), an agency of the Federal Government and is provided to you as a public service. Pursuant to Title 15 United States Code Section 105, works of NTIA employees are not subject to copyright protection within the United States. + +The software is provided by NTIA “AS IS.” NTIA MAKES NO WARRANTY OF ANY KIND, EXPRESS, IMPLIED OR STATUTORY, INCLUDING, WITHOUT LIMITATION, THE IMPLIED WARRANTY OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT AND DATA ACCURACY. NTIA does not warrant or make any representations regarding the use of the software or the results thereof, including but not limited to the correctness, accuracy, reliability or usefulness of the software. + +To the extent that NTIA holds rights in countries other than the United States, you are hereby granted the non-exclusive irrevocable and unconditional right to print, publish, prepare derivative works and distribute the NTIA software, in any medium, or authorize others to do so on your behalf, on a royalty-free basis throughout the World. + +You may improve, modify, and create derivative works of the software or any portion of the software, and you may copy and distribute such modifications or works. Modified works should carry a notice stating that you changed the software and should note the date and nature of any such change. + +You are solely responsible for determining the appropriateness of using and distributing the software and you assume all risks associated with its use, including but not limited to the risks and costs of program errors, compliance with applicable laws, damage to or loss of data, programs or equipment, and the unavailability or interruption of operation. This software is not intended to be used in any situation where a failure could cause risk of injury or damage to property. + +Please provide appropriate acknowledgments of NTIA’s creation of the software in any copies or derivative works of this software. diff --git a/README.md b/README.md index 35b8dd3..7feaec3 100644 --- a/README.md +++ b/README.md @@ -1 +1,168 @@ -# alignnet +# Dataset Alignment +This code corresponds to the paper "AlignNet: Learning dataset score alignment functions to enable better training of speech quality estimators," by Jaden Pieper, Stephen D. Voran, to appear in Proc. Interspeech 2024 and with [preprint available here](https://arxiv.org/abs/2406.10205). + +When training a no-reference (NR) speech quality estimator, multiple datasets provide more information and can thus lead to better training. But they often are inconsistent in the sense that they use different subjective testing scales, or the exact same scale is used differently by test subjects due to the corpus effect. +AlignNet improves the training of NR speech quality estimators with multiple, independent datasets. AlignNet uses an AudioNet to generate intermediate score estimates before using the Aligner to map intermediate estimates to the appropriate score range. +AlignNet is intentionally designed to be independent of the choice of AudioNet. + +This repository contains implementations of two different AudioNet choices: [MOSNet](https://arxiv.org/abs/1904.08352) and a simple example of a novel multi-scale convolution approach. + +MOSNet demonstrates a network that takes the STFT of an audio signal as its input, and the multi-scale convolution network is provided primarily as an example of a network that takes raw audio as an input. + +# Installation +## Dependencies +There are two included environment files. `environment.yml` has the dependencies required to train with alignnet but does not impose version requirements. It is thus susceptible to issues in the future if packages deprecate methods or have major backwards compatibility breaks. On the other hand, `environment-paper.yml` contains the exact versions of the packages that were used for all the results reported in our paper. + +Create and activate the `alignnet` environment. +``` +conda env create -f environment.yml +conda activate alignnet +``` + +## Installing alignnet package +``` +pip install . +``` + +# Preparing data for training +When training with multiple datasets, some work must first be done to format them in a consistent manner so they can all be loaded in the same way. +For each dataset, one must first make a csv that has subjective score in column called `MOS` and path to audio file in column called `audio_path`. + +If your `audio_net` model requires transformed data, you can transform it prior to training with `pretransform_data.py` (see `python pretransform_data.py --help` for more information) and store paths to those transformed representation files in a column called `transform_path`. For example, MOSNet uses the STFT of audio as an input. For more efficient training, pretransforming the audio into STFT representations, saving them, and including a column called `stft_path` in the csv is recommended. +More generally, the column name must match the value of `data.pathcol`. +For examples, see [MOSNet](alignnet/config/models/pretrain-MOSNet.yaml) or [MultiScaleConvolution](alignnet/config/models/pretrain-msc.yaml). + + +For each dataset, split the data into training, validation, and testing portions with +``` +python split_labeled_data.py /path/to/data/file.csv --output-dir /datasetX/splits/path +``` +This generates `train.csv`, `valid.csv`, and `test.csv` in `/datasetX/splits/path`. +Additional options for splitting can be seen via `python split_labeled_data.py --help`, including creating multiple independent splits and changing the amount of data placed into each split. + +# Training with AlignNet +Setting up training runs is configured via [Hydra](https://hydra.cc/docs/intro/). +Basic examples of configuration files can be found in [model/config](alignnet/config/models). + +Some basic training help can be found with + +``` +python train.py --help +``` + +To see an example config file and all the overrideable parameters for training MOSNet with AlignNet, run +``` +python train.py --config-dir alignnet/config/models --config-name=alignnet-MOSNet --cfg job +``` +Here the `--cfg job` shows the configuration for this job without running the code. + +If you are not training with a [clearML](https://clear.ml/) server, be sure to set `logging=none`. +To change the number of workers used for data loading, override the `data.num_workers` parameter, which defaults to 6. + +As an example, and to confirm you have appropriately overridden these parameters, you could run +``` +python train.py logging=none data.num_workers=4 --config-dir alignnet/config/models --config-name=alignnet-MOSNet --cfg job +``` + +### Pretraining MOSNet on a dataset +In order to pretrain on a dataset you run +``` +python path/to/alignnet/train.py \ +data.data_dirs=[/absolute/path/datasetX/splits/path] \ +--config-dir path/to/alignnet/alignnet/config/models/ --config-name pretrain-MOSNet.yaml +``` +Where `/absolute/path/datasetX/splits/path` contains `train.csv`, `valid.csv`, and `test.csv` for that dataset. + +### Training MOSNet with AlignNet +``` +python path/to/alignnet/train.py \ +data.data_dirs=[/absolute/path/dataset1/splits/path,/absolute/path/dataset2/splits/path] \ +--config-dir path/to/alignnet/alignnet/config/models/ --config-name alignnet-MOSNet.yaml +``` + +### Training MOSNet with AlignNet and MDF +``` +python path/to/alignnet/train.py \ +data.data_dirs=[/absolute/path/dataset1/splits/path,/absolute/path/dataset2/splits/path] \ +finetune.restore_file=/absolute/path/to/alignnet/pretrained/model \ +--config-dir path/to/alignnet/alignnet/config/models/ --config-name alignnet-MOSNet.yaml +``` + +### Training MOSNet in conventional way +Multiple datasets, no alignment. +``` +python path/to/alignnet/train.py \ +project.task=Conventional-MOSNet \ +data.data_dirs=[/absolute/path/dataset1/splits/path,/absolute/path/dataset2/splits/path] \ +--config-dir path/to/alignnet/alignnet/config/models/ --config-name pretrain-MOSNet.yaml +``` + +## Examples +## Training MOSNet with AlignNet and MDF starting with MOSNet that has been pretrained on Tencent dataset +``` +python path/to/alignnet/train.py \ +data.data_dirs=[/absolute/path/dataset1/splits/path,/absolute/path/dataset2/splits/path] \ +finetune.restore_file=/absolute/path/to/alignnet/trained_models/pretrained-MOSNet-tencent \ +--config-dir path/to/alignnet/alignnet/config/models/ --config-name alignnet-MOSNet.yaml +``` + +## MultiScaleConvolution example +Training NR speech estimators with AlignNet is intentionally designed to be agnostic to the choice of AudioNet. +To demonstrate this, we include code for a rudimentary network that takes in raw audio as an input and trains separate convolutional networks on multiple time scales that are then aggregated into a single network component. +This network is defined as `alignnet.MultiScaleConvolution` and can be trained via: +``` +python path/to/alignnet/train.py \ +data.data_dirs=[/absolute/path/dataset1/splits/path,/absolute/path/dataset2/splits/path] \ +--config-dir path/to/alignnet/alignnet/config/models/ --config-name alignnet-msc.yaml +``` + +# Using AlignNet models at inference +Trained AlignNet models can easily be used at inference via the CLI built into `inference.py`. +Some basic help can be seen via +``` +python inference.py --help +``` + +In general, three overrides must be set: +* `model.path` - path to a trained model +* `data.data_files` - list containing absolute paths to csv files that list audio files to perform inference on. +* `output.file` - path to file where inference output will be stored. + +After running inference, a csv will be created at `output.file` with the following columns: +* `file` - filenames where audio was loaded from +* `estimate` - estimate generated by the model +* `dataset` - index listing which file from `data.data_files` this file belongs to. +* `AlignNet dataset index` - index listing which dataset within the model the scores come from. This will be the same for every file in the csv. The default dataset will always be the reference dataset, but this can be overriden via `model.dataset_index`. + +For example, to run inference using the included AlignNet model trained on the smaller datasets, one would run +``` +python inference.py \ +data.data_files=[/absolute/path/to/inference/data1.csv,/absolute/path/to/inference/data2.csv] \ +model.path=trained_models/alignnet_mdf-MOSNet-small_data \ +output.file=estimations.csv +``` + + +# Gathering datasets used in 2024 Conference Paper +Here are links and references to help with locating the data we have used in the paper. + +* [Blizzard 2021](https://www.cstr.ed.ac.uk/projects/blizzard/data.html) + * Z.-H. Ling, X. Zhou, and S. King, "The Blizzard challenge 2021," in Proc. Blizzard Challenge Workshop, 2021. +* [Blizzard 2008](https://www.cstr.ed.ac.uk/projects/blizzard/data.html) + * V. Karaiskos, S. King, R. A. J. Clark, and C. Mayo, "The Blizzard challenge 2008," in Proc. Blizzard Challenge Workshop, 2008. +* [FFTNet](https://gfx.cs.princeton.edu/pubs/Jin_2018_FAR/clips/) + * Z. Jin, A. Finkelstein, G. J. Mysore, and J. Lu, "FFTNet: a real-time speaker-dependent neural vocoder," in Proc. IEEE International Conference on Acoustics, Speech and Signal Processing, 2018. +* [NOIZEUS](https://ecs.utdallas.edu/loizou/speech/noizeus/) + * Y. Hu and P. Loizou, "Subjective comparison of speech enhancement algorithms," in Proc. IEEE International Conference on Acoustics, Speech and Signal Processing, 2006. +* [VoiceMOS Challenge 2022](https://codalab.lisn.upsaclay.fr/competitions/695) + * W. C. Huang, E. Cooper, Y. Tsao, H.-M. Wang, T. Toda, and J. Yamagishi, "The VoiceMOS Challenge 2022," in Proc. Interspeech 2022, 2022, pp. 4536–4540. +* [Tencent](https://github.com/ConferencingSpeech/ConferencingSpeech2022) + * G. Yi, W. Xiao, Y. Xiao, B. Naderi, S. Moller, W. Wardah, G. Mittag, R. Cutler, Z. Zhang, D. S. Williamson, F. Chen, F. Yang, and S. Shang, "ConferencingSpeech 2022 Challenge: Non-intrusive objective speech quality assessment challenge for online conferencing applications," in Proc. Interspeech, 2022, pp. 3308–3312. +* [NISQA](https://github.com/gabrielmittag/NISQA/wiki/NISQA-Corpus) + * G. Mittag, B. Naderi, A. Chehadi, and S. Möller, "NISQA: A deep CNN-self-attention model for multidimensional speech quality prediction with crowdsourced datasets,” in Proc. Interspeech, 2021, pp. 2127–2131. +* [Voice Conversion Challenge 2018](https://datashare.ed.ac.uk/handle/10283/3257) + * J. Lorenzo-Trueba, J. Yamagishi, T. Toda, D. Saito, F. Villavicencio, T. Kinnunen, and Z. Ling, “The voice conversion challenge 2018: Promoting development of parallel and nonparallel methods,” in Proc. Speaker Odyssey, 2018. +* [Indiana U. MOS](https://github.com/ConferencingSpeech/ConferencingSpeech2022) + * X. Dong and D. S. Williamson, "A pyramid recurrent network for predicting crowdsourced speech-quality ratings of real-world signals," in Proc. Interspeech, 2020. +* [PSTN](https://github.com/ConferencingSpeech/ConferencingSpeech2022) + * G. Mittag, R. Cutler, Y. Hosseinkashi, M. Revow, S. Srinivasan, N. Chande, and R. Aichner, “DNN no-reference PSTN speech quality prediction,” in Proc. Interspeech, 2020. diff --git a/alignnet-GitHubRepoPublicReleaseApproval_sv_acm.pdf b/alignnet-GitHubRepoPublicReleaseApproval_sv_acm.pdf new file mode 100755 index 0000000..405615a Binary files /dev/null and b/alignnet-GitHubRepoPublicReleaseApproval_sv_acm.pdf differ diff --git a/alignnet/__init__.py b/alignnet/__init__.py new file mode 100644 index 0000000..a987c04 --- /dev/null +++ b/alignnet/__init__.py @@ -0,0 +1,4 @@ +from .model import * +from .data import * +from .transforms import * +from .optimizer import * diff --git a/alignnet/config/conf.yaml b/alignnet/config/conf.yaml new file mode 100644 index 0000000..07991cc --- /dev/null +++ b/alignnet/config/conf.yaml @@ -0,0 +1,4 @@ +defaults: + - logging: none + - loss: L2 + - override hydra/help: train_help \ No newline at end of file diff --git a/alignnet/config/hydra/help/train_help.yaml b/alignnet/config/hydra/help/train_help.yaml new file mode 100644 index 0000000..72978b0 --- /dev/null +++ b/alignnet/config/hydra/help/train_help.yaml @@ -0,0 +1,39 @@ +app_name: AlignNet + +header: == Training ${hydra.help.app_name} == + +footer: |- + Powered by Hydra (https://hydra.cc) + Use --hydra-help to view Hydra specific help. + +template: |- + ${hydra.help.header} + + This is the ${hydra.help.app_name} training program! + + == Configuration groups == + Compose your configuration from those groups (db=mysql) + + $APP_CONFIG_GROUPS + + == Config == + This is the config generated for this run. + You can override everything, for example to set the logger to none and loss to L1 run: + + ``` + python train.py logging=none loss=L1 --help + ``` + + For another example, to see the config file and all overrideable parameters for + training MOSNet with AlignNet run: + ``` + python train.py --config-dir alignnet/config/models --config-name=alignnet-MOSNet --cfg job + ``` + ------- + $CONFIG + ------- + + To see the config of an example command directly without running it add + `--cfg job` to your command. + + ${hydra.help.footer} \ No newline at end of file diff --git a/alignnet/config/logging/clearml.yaml b/alignnet/config/logging/clearml.yaml new file mode 100644 index 0000000..02e1880 --- /dev/null +++ b/alignnet/config/logging/clearml.yaml @@ -0,0 +1 @@ +logger: clearml \ No newline at end of file diff --git a/alignnet/config/logging/none.yaml b/alignnet/config/logging/none.yaml new file mode 100644 index 0000000..1ce6c82 --- /dev/null +++ b/alignnet/config/logging/none.yaml @@ -0,0 +1 @@ +logger: null \ No newline at end of file diff --git a/alignnet/config/loss/L1.yaml b/alignnet/config/loss/L1.yaml new file mode 100644 index 0000000..139167c --- /dev/null +++ b/alignnet/config/loss/L1.yaml @@ -0,0 +1 @@ +_target_: torch.nn.L1Loss \ No newline at end of file diff --git a/alignnet/config/loss/L2.yaml b/alignnet/config/loss/L2.yaml new file mode 100644 index 0000000..09ee678 --- /dev/null +++ b/alignnet/config/loss/L2.yaml @@ -0,0 +1 @@ +_target_: torch.nn.MSELoss \ No newline at end of file diff --git a/alignnet/config/models/alignnet-MOSNet.yaml b/alignnet/config/models/alignnet-MOSNet.yaml new file mode 100644 index 0000000..9d5eb08 --- /dev/null +++ b/alignnet/config/models/alignnet-MOSNet.yaml @@ -0,0 +1,80 @@ +hydra: + run: + dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + job: + chdir: True + +defaults: + - logging: clearml + - loss: L2 + - _self_ + +project: + name: Dataset Alignment + task: MOSNet-AlignNet + +finetune: + restore_file: null + +common: + seed: 1234 + auto_batch_size: false + lr: 0.0001 + +data: + _target_: alignnet.AudioDataModule + data_dirs: ??? + batch_size: 16 + num_workers: 6 + transform_time: get + cache: false + fs: null + flatten: false + pathcol: stft_path + +dataclass: + _target_: hydra.utils.get_class + path: alignnet.FeatureData + +optimization: + _target_: pytorch_lightning.Trainer + accelerator: gpu + devices: + - 0 + log_every_n_steps: 5 + max_epochs: 200 + precision: 16-mixed + +earlystop: + _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping + patience: 20 + +model: + _target_: alignnet.Model + loss_weights: 1 + +network: + _target_: alignnet.AlignNet + aligner_corr_threshold: null + audio_net: + _target_: alignnet.MOSNet + aligner: + _target_: alignnet.LinearSequenceAligner + reference_index: 0 + audio_net_freeze_epochs: 1 + +checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val_loss + mode: min + filename: '{epoch}-{val_loss:.4f}' + save_top_k: 5 + every_n_epochs: 1 + every_n_train_steps: null + +optimizer: + _target_: alignnet.OptimizerWrapper + class_name: torch.optim.Adam + +transform: + _target_: alignnet.NoneTransform \ No newline at end of file diff --git a/alignnet/config/models/alignnet-msc.yaml b/alignnet/config/models/alignnet-msc.yaml new file mode 100644 index 0000000..4afc0b3 --- /dev/null +++ b/alignnet/config/models/alignnet-msc.yaml @@ -0,0 +1,160 @@ +hydra: + run: + dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + job: + chdir: True + +defaults: + - logging: clearml + - loss: L2 + - _self_ + +project: + name: Dataset Alignment + task: MultiScaleConvolution-AlignNet + +finetune: + restore_file: null + +common: + seed: 1234 + auto_batch_size: false + lr: 0.0001 + +data: + _target_: alignnet.AudioDataModule + data_dirs: ??? + batch_size: 16 + num_workers: 6 + transform_time: get + cache: false + time_dim: 1 + fs: 16000 + pathcol: audio_path + +dataclass: + _target_: hydra.utils.get_class + path: alignnet.AudioData + +optimization: + _target_: pytorch_lightning.Trainer + accelerator: gpu + devices: + - 0 + log_every_n_steps: 5 + max_epochs: 200 + precision: 16-mixed + +earlystop: + _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping + patience: 20 + +model: + _target_: alignnet.Model + loss_weights: 1 + +network: + _target_: alignnet.AlignNet + aligner_corr_threshold: null + audio_net: + _target_: alignnet.MultiScaleConvolution + path1: + _target_: alignnet.ConvPath + kernels: + - 3 + - 3 + - 3 + - 3 + - 3 + strides: + - 1 + dilations: + - 1 + channels: + - 32 + paddings: + - 1 + pooling_kernels: + - 4 + - 4 + - 5 + - 5 + - 5 + path2: + _target_: alignnet.ConvPath + kernels: + - 11 + - 11 + - 11 + - 11 + - 11 + strides: + - 4 + dilations: + - 1 + paddings: + - 5 + channels: + - 32 + pooling_type: null + path3: + _target_: alignnet.ConvPath + rectify: True + kernels: + - 11 + - 11 + - 11 + - 11 + - 11 + strides: + - 4 + dilations: + - 1 + paddings: + - 5 + channels: + - 32 + pooling_type: null + path4: + _target_: alignnet.ConvPath + mu_law: True + kernels: + - 3 + - 3 + - 3 + - 3 + - 3 + strides: + - 1 + dilations: + - 1 + channels: + - 32 + paddings: + - 1 + pooling_kernels: + - 4 + - 4 + - 5 + - 5 + - 5 + aligner: + _target_: alignnet.LinearSequenceAligner + reference_index: 0 + audio_net_freeze_epochs: 1 + +checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val_loss + mode: min + filename: '{epoch}-{val_loss:.4f}' + save_top_k: 5 + every_n_epochs: 1 + every_n_train_steps: null + +optimizer: + _target_: alignnet.OptimizerWrapper + class_name: torch.optim.Adam + +transform: + _target_: alignnet.NoneTransform \ No newline at end of file diff --git a/alignnet/config/models/pretrain-MOSNet.yaml b/alignnet/config/models/pretrain-MOSNet.yaml new file mode 100644 index 0000000..797f0e7 --- /dev/null +++ b/alignnet/config/models/pretrain-MOSNet.yaml @@ -0,0 +1,77 @@ +hydra: + run: + dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + job: + chdir: True + +defaults: + - logging: clearml + - loss: L2 + - _self_ + +project: + name: Dataset Alignment + task: Pretrain-MOSNet + +finetune: + restore_file: null + +common: + seed: 1234 + auto_batch_size: false + lr: 0.0001 + +data: + _target_: alignnet.AudioDataModule + data_dirs: ??? + batch_size: 16 + num_workers: 6 + transform_time: get + cache: false + fs: null + flatten: false + pathcol: stft_path + +dataclass: + _target_: hydra.utils.get_class + path: alignnet.FeatureData + +optimization: + _target_: pytorch_lightning.Trainer + accelerator: gpu + devices: + - 0 + log_every_n_steps: 5 + max_epochs: 200 + precision: 16-mixed + +earlystop: + _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping + patience: 20 + +model: + _target_: alignnet.Model + +network: + _target_: alignnet.AlignNet + aligner_corr_threshold: null + audio_net: + _target_: alignnet.MOSNet + aligner: + _target_: alignnet.NoAligner + +checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val_loss + mode: min + filename: '{epoch}-{val_loss:.4f}' + save_top_k: 5 + every_n_epochs: 1 + every_n_train_steps: null + +optimizer: + _target_: alignnet.OptimizerWrapper + class_name: torch.optim.Adam + +transform: + _target_: alignnet.NoneTransform \ No newline at end of file diff --git a/alignnet/config/models/pretrain-msc.yaml b/alignnet/config/models/pretrain-msc.yaml new file mode 100644 index 0000000..49ca720 --- /dev/null +++ b/alignnet/config/models/pretrain-msc.yaml @@ -0,0 +1,157 @@ +hydra: + run: + dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + job: + chdir: True + +defaults: + - logging: clearml + - loss: L2 + - _self_ + +project: + name: Dataset Alignment + task: Pretrain-MultiScaleConvolution + +finetune: + restore_file: null + +common: + seed: 1234 + auto_batch_size: false + lr: 0.0001 + +data: + _target_: alignnet.AudioDataModule + data_dirs: ??? + batch_size: 16 + num_workers: 6 + transform_time: get + cache: false + time_dim: 1 + fs: 16000 + pathcol: audio_path + +dataclass: + _target_: hydra.utils.get_class + path: alignnet.AudioData + +optimization: + _target_: pytorch_lightning.Trainer + accelerator: gpu + devices: + - 0 + log_every_n_steps: 5 + max_epochs: 200 + precision: 16-mixed + +earlystop: + _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping + patience: 20 + +model: + _target_: alignnet.Model + +network: + _target_: alignnet.AlignNet + aligner_corr_threshold: null + audio_net: + _target_: alignnet.MultiScaleConvolution + path1: + _target_: alignnet.ConvPath + kernels: + - 3 + - 3 + - 3 + - 3 + - 3 + strides: + - 1 + dilations: + - 1 + channels: + - 32 + paddings: + - 1 + pooling_kernels: + - 4 + - 4 + - 5 + - 5 + - 5 + path2: + _target_: alignnet.ConvPath + kernels: + - 11 + - 11 + - 11 + - 11 + - 11 + strides: + - 4 + dilations: + - 1 + paddings: + - 5 + channels: + - 32 + pooling_type: null + path3: + _target_: alignnet.ConvPath + rectify: True + kernels: + - 11 + - 11 + - 11 + - 11 + - 11 + strides: + - 4 + dilations: + - 1 + paddings: + - 5 + channels: + - 32 + pooling_type: null + path4: + _target_: alignnet.ConvPath + mu_law: True + kernels: + - 3 + - 3 + - 3 + - 3 + - 3 + strides: + - 1 + dilations: + - 1 + channels: + - 32 + paddings: + - 1 + pooling_kernels: + - 4 + - 4 + - 5 + - 5 + - 5 + aligner: + _target_: alignnet.NoAligner + +checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val_loss + mode: min + filename: '{epoch}-{val_loss:.4f}' + save_top_k: 5 + every_n_epochs: 1 + every_n_train_steps: null + +optimizer: + _target_: alignnet.OptimizerWrapper + class_name: torch.optim.Adam + +transform: + _target_: alignnet.NoneTransform \ No newline at end of file diff --git a/alignnet/data.py b/alignnet/data.py new file mode 100644 index 0000000..7d080ba --- /dev/null +++ b/alignnet/data.py @@ -0,0 +1,488 @@ +import os +import pickle +import re +from pytorch_lightning.utilities.types import TRAIN_DATALOADERS +import torch +import torchaudio +import warnings + +import numpy as np +import pandas as pd +import pytorch_lightning as pl + +from tqdm import tqdm +from torch.utils.data import Dataset, DataLoader + + +class AudioData(Dataset): + def __init__( + self, + data_files, + transform=None, + transform_time="get", + cache=False, + target="MOS", + target_transform=None, + pathcol="full_path", + data_path="", + fs=16000, + percent=1, + time_dim=0, + ): + """ + Class for loading audio data with quality (or any target) scores. + + Parameters + ---------- + data_files : list + List of paths to data files containing audio paths and target scores. + transform : _type_, optional + Transform object for applying transforms to audio, by default None + transform_time : str, optional + When to perform transforms, by default "get" + cache : bool, optional + Store audio rather than loading it each time, by default False + target : str, optional + Target to train to, must be column name in data files, by default "MOS" + pathcol : str, optional + Name of column containing audio path in data files, by default "full_path" + data_path : str, optional + Parent path to data for paths stored in pathcol, by default "". + fs : int, optional + Sampling rate required for the network. All audio will be resampled + to this sample rate if necessary. If fs is set to None, the resampling + step will be skipped. + percent : float, optional + Percentage of data to use. Data will be randomly partitioned according + to percent. Primarily intended for debugging. + """ + self.fs = fs + self.percent = percent + self.time_dim = time_dim + + score_files = list() + # Load each data file + for k, file in enumerate(data_files): + score_file = pd.read_csv(file) + if self.percent < 1: + n_items = len(score_file) + keep_items = np.ceil(n_items * self.percent).astype(int) + rng = np.random.default_rng() + keep_ix = rng.choice(len(score_file), keep_items, replace=False) + score_file = score_file.loc[keep_ix] + score_file["Dataset_Indicator"] = k + score_files.append(score_file) + + self.score_file = pd.concat(score_files, ignore_index=True) + # Initialize dictionary to store audio in if cache is True + if cache: + self.wavs = dict() + + self.data_path = data_path + self.pathcol = pathcol + self.target = target + self.target_transform = target_transform + self.transform = transform + self.transform_time = transform_time + + def __len__(self): + return len(self.score_file) + + def __getitem__(self, idx): + dataset = self.score_file.loc[idx, "Dataset_Indicator"] + if self.target is not None: + mos = self.score_file.loc[idx, self.target] + else: + mos = None + if self.target_transform is not None: + mos = self.target_transform(mos, dataset) + + audio_path = os.path.join( + self.data_path, self.score_file.loc[idx, self.pathcol] + ) + if hasattr(self, "wavs") and audio_path in self.wavs: + # Already loaded and cached the audio + audio = self.wavs[audio_path] + else: + # Load audio + audio, sample_rate = torchaudio.load(audio_path) + + # Resample as needed + if self.fs is not None and sample_rate != self.fs: + resampler = torchaudio.transforms.Resample( + sample_rate, self.fs, dtype=audio.dtype + ) + audio = resampler(audio) + sample_rate = self.fs + + if self.transform is not None and self.transform_time == "get": + # Apply transform + + if self.fs is None: + # We are not resampling which means the transform should handle this + audio = self.transform.transform(audio, fs=sample_rate) + else: + # We have resampled, so the transform doesn't need to account for it + audio = self.transform.transform(audio) + + audio = audio.float() + if hasattr(self, "wavs"): + self.wavs[audio_path] = audio + return audio, mos, dataset + + def padding(self, batch): + """ + Pad inputs in a batch so that they have the same dimensions. + + Parameters + ---------- + batch : tuple + Tuple of items in the batch. + + Returns + ------- + tuple + Batch with padding applied so that all of audio is same dimension. + """ + # Unpack data within batch + audio_files, mos, dataset = zip(*batch) + + mos = torch.tensor(np.array(mos)) + dataset = torch.tensor(dataset) + + # Find maximum length in time dimension + max_len = np.max([audio.shape[self.time_dim] for audio in audio_files]) + + audio_out = [] + for ix, audio in enumerate(audio_files): + if audio.shape[self.time_dim] < max_len: + repeat_samples = max_len - audio.shape[self.time_dim] + # Initialize pad width for each dimension to no padding + pad_width = [(0, 0) for i in range(len(audio.shape))] + # Dimension always time - pad end with repeat_samples + pad_width[self.time_dim] = (0, repeat_samples) + # Convert to tuple for input to np.pad + pad_width = tuple(pad_width) + + audio = np.pad(audio, pad_width=pad_width, mode="constant") + audio = torch.from_numpy(audio) + audio_out.append(audio) + + # Concatenate into one tensor + audio_out = torch.stack(audio_out, dim=0) + # If a transform is defined and the transform time is at collate, now is the time to apply it + if self.transform is not None and self.transform_time == "collate": + audio_out = self.transform.transform(audio_out) + audio_out = torch.unsqueeze(audio_out, dim=1) + return audio_out, mos, dataset + + +class FeatureData(AudioData): + """ + For loading pre-computed features for audio files. Only the __getitem__ method is changed + """ + + def __init__( + self, + flatten=True, + dim_cutoff=None, + dim=0, + **kwargs, + ): + """ + Class for pre-computed features of audio files. + + Inherits from AudioData. + + Parameters + ---------- + flatten : bool, optional + Flatten representation into a single dimension, by default True + dim_cutoff : _type_, optional + Max number of dimensions to consider. By default None. + dim : int, optional + Dimension on which to perform cutoff using dim_cutoff, by default 0 + """ + super().__init__( + **kwargs, + ) + self.dim_cutoff = dim_cutoff + self.dim = dim + self.flatten = flatten + + def __getitem__(self, idx): + dataset = self.score_file.loc[idx, "Dataset_Indicator"] + + if self.target is not None: + mos = self.score_file.loc[idx, self.target] + else: + mos = None + + if self.target_transform is not None: + mos = self.target_transform(mos, int(dataset)) + + audio_path = os.path.join( + self.data_path, self.score_file.loc[idx, self.pathcol] + ) + if hasattr(self, "wavs") and audio_path in self.wavs: + # Already loaded and cached the audio + audio = self.wavs[audio_path] + else: + fname, ext = os.path.splitext(audio_path) + # If using same split csvs as audio, this may say wav and not pt + # (coming out of pretransform_data.py will save as pt) + if ext == ".wav": + audio_path = fname + ".pkl" + # Load audio + with open(audio_path, "rb") as feat_input: + audio = pickle.load(feat_input) + + if self.dim_cutoff is not None: + audio = torch.narrow( + audio, dim=self.dim, start=0, length=self.dim_cutoff + ) + + if self.flatten: + # Flatten by column + audio = audio.t().flatten() + + if self.transform is not None and self.transform_time == "get": + # Apply transform + audio = self.transform.transform(audio) + + audio = audio.float() + if hasattr(self, "wavs"): + self.wavs[audio_path] = audio + + return audio, mos, dataset + + +class AudioDataModule(pl.LightningDataModule): + def __init__( + self, + data_dirs, + batch_size=16, + num_workers=1, + persistent_workers=True, + DataClass=AudioData, + collate_type="padding", + data_percent=1, + **kwargs, + ): + """ + Primary audio data module that prepares data for training, testing, or predictions. + + Parameters + ---------- + data_dirs : list + List of paths to directories containing train.csv, valid.csv, and test.csv + for each dataset. + batch_size : int, optional + Number of items in each batch, by default 32 + num_workers : int, optional + Number of workers used during training, by default 1 + persistent_workers : bool, optional + Whether or not workers persist between epochs, by default True + DataClass : class, optional + Class that the data will be initialized with. Assumed to inherit from + torch.utils.data.Dataset, by default AudioData + collate_type : str, optional + String that determines what type of collate function is used, by default + "padding" + **kwargs : optional + Additional arguments are passed to the DataClass when instantiated in + AudioDataModule.setup() + + """ + super().__init__() + + # If this class sees batch_size=auto, it sets to default value and assumes a Tuner is being called in the main + # logic to update this later + if batch_size == "auto": + batch_size = 32 + self.batch_size = batch_size + self.collate_type = collate_type + self.data_dirs = data_dirs + self.DataClass = DataClass + self.num_workers = num_workers + self.persistent_workers = persistent_workers + + self.data_class_kwargs = kwargs + + def setup(self, stage: str): + """ + Load different datasubsets depending on stage. + + If stage == 'fit', then train, valid, and test data are loaded. + + If stage == 'test', then only test data is loaded. + + If stage == 'predict', then self.data_dirs should be full paths to the specific + csv files to run predictions on. + + Parameters + ---------- + stage : str + One of fit, test, or predict. + """ + if stage == "fit": + train_paths = self.find_datasubsets(self.data_dirs, "train") + self.train = self.DataClass( + data_files=train_paths, + **self.data_class_kwargs, + ) + + valid_paths = self.find_datasubsets(self.data_dirs, "valid") + self.valid = self.DataClass( + data_files=valid_paths, + **self.data_class_kwargs, + ) + + test_paths = self.find_datasubsets(self.data_dirs, "test") + self.test = self.DataClass( + data_files=test_paths, + **self.data_class_kwargs, + ) + elif stage == "test": + test_paths = self.find_datasubsets(self.data_dirs, "test") + self.test = self.DataClass( + data_files=test_paths, + **self.data_class_kwargs, + ) + elif stage == "predict": + self.predict = self.DataClass( + data_files=self.data_dirs, + **self.data_class_kwargs, + ) + + def find_datasubsets(self, data_paths, subset): + """ + Find subsets as determined by data_paths and subset. + + Primarily relies on find_datasubset. Loads in train.csv, valid.csv, and + test.csv as necessary. + + Parameters + ---------- + data_paths : list + List of paths for each dataset. + subset : str + Which type of csv file to read in. + + Returns + ------- + list + Paths to that datasubset. + """ + outs = [] + for data_path in data_paths: + out = self.find_datasubset(data_path, subset) + outs.append(out) + return outs + + def find_datasubset(self, data_path, subset): + """ + Helper function for setup to find the different data subsets (train/valid/test) + + Parameters + ---------- + data_path : str + Path to folder containg subset.csv + subset : str + String for representative datasubset file in data_path + + Returns + ------- + str + path to subset.csv + + Raises + ------ + ValueError + subset could not be found in data_path. + """ + _, ext = os.path.splitext(data_path) + if ext == ".csv": + return data_path + files = os.listdir(data_path) + out = [] + for file in files: + basename, ext = os.path.splitext(file) + if basename == subset: + out.append(file) + if len(out) == 0: + raise ValueError(f"Unable to find {subset} in {data_path}") + elif len(out) > 1: + warnings.warn( + f"Multiple matches for {subset} in {data_path}.\nOf {out} using {out[0]}" + ) + out = out[0] + else: + out = out[0] + out = os.path.join(data_path, out) + return out + + def train_dataloader(self): + """ + Prepare DataLoader for training. + """ + if self.collate_type == "padding": + collate_fn = self.train.padding + else: + collate_fn = None + return DataLoader( + self.train, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=self.persistent_workers, + shuffle=True, + collate_fn=collate_fn, + ) + + def val_dataloader(self): + """ + Prepare DataLoader for validation. + """ + if self.collate_type == "padding": + collate_fn = self.valid.padding + else: + collate_fn = None + return DataLoader( + self.valid, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=self.persistent_workers, + collate_fn=collate_fn, + ) + + def test_dataloader(self): + """ + Prepare DataLoader for testing. + """ + if self.collate_type == "padding": + collate_fn = self.test.padding + else: + collate_fn = None + return DataLoader( + self.test, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=self.persistent_workers, + collate_fn=collate_fn, + ) + + def predict_dataloader(self): + """ + Prepare DataLoader for prediction. + """ + if self.collate_type == "padding": + collate_fn = self.predict.padding + else: + collate_fn = None + return DataLoader( + self.predict, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=self.persistent_workers, + collate_fn=collate_fn, + ) diff --git a/alignnet/model.py b/alignnet/model.py new file mode 100644 index 0000000..d2db142 --- /dev/null +++ b/alignnet/model.py @@ -0,0 +1,926 @@ +from typing import Any +import hydra +import os +from pytorch_lightning.utilities.types import STEP_OUTPUT +import torch +import torchaudio +import yaml +import warnings + +import numpy as np +import pytorch_lightning as pl + +from omegaconf import DictConfig +from torch import nn + +from .optimizer import OptimizerWrapper +from .transforms import MelTransform + + +def load_model(trained_model_path): + """ + Load a model directory. + + Parameters + ---------- + trained_model_path : str + Path to trained_model directory containing a config.yaml and model.ckpt + required to load a pretrained model. + + Returns + ------- + Model + Pretrained alignnet.Model object. + """ + # Load config + cfg_path = os.path.join(trained_model_path, "config.yaml") + with open(cfg_path, "r") as f: + cfg_yaml = yaml.safe_load(f) + cfg = DictConfig(cfg_yaml) + # Initialize network + network = hydra.utils.instantiate(cfg.network) + + model_class = hydra.utils.get_class(cfg.model._target_) + + model_path = os.path.join(trained_model_path, "model.ckpt") + # Initialize model + model = model_class.load_from_checkpoint( + model_path, network=network, map_location=lambda storage, loc: storage + ) + + return model + + +def mean_pooling(frame_scores, dim=1): + """ + Time pooling method that averages frames. + + Parameters + ---------- + frame_scores : torch.tensor + Frame-wise estimates. + dim : int, optional + Dimension along which to average, by default 1 + + Returns + ------- + torch.tensor + Frame averaged estimates. + """ + mean_estimate = torch.mean(frame_scores, dim=dim) + mean_estimate = torch.squeeze(mean_estimate) + return mean_estimate + + +# ------------------------ +# Audio Processing Modules +# ------------------------ + + +class LinearSequence(nn.Module): + def __init__( + self, + in_features, + n_layers=2, + activation=nn.ReLU, + layer_dims=None, + last_activate=False, + ): + """ + Generate a sequence of n_layers Fully Connected (nn.linear) layers with activation. + + Parameters + ---------- + n_layers : int + Number of linear layers to include. + in_features : int + Number of features in input + activation : nn.Module + Activation to include between layers. There will always be n_layers - 1 activations in the sequence. + layer_dims : list + List of layer dimensions, not including input features (these are specified by in_features). + """ + super().__init__() + if layer_dims is not None and n_layers != len(layer_dims): + n_layers = len(layer_dims) + self.n_layers = n_layers + self.in_features = in_features + self.activation = activation + self.layer_dims = layer_dims + self.last_activate = last_activate + self.setup_layers(self.n_layers) + + def setup_layers(self, n_layers): + """ + Set up and store layers into `output_layers` attribute. + + If `self.layer_dims` is not None, linear layers are made that match the + dimension of that list. If it is None, layers are made such that the dimension + decreases by 1/2 for each layer. + + Parameters + ---------- + n_layers : int + Number of layers to make if `self.layer_dims` is not defined. + """ + n_features = self.in_features + layers = [] + + if self.layer_dims is None: + for k in range(n_layers - 1): + next_feat = int(n_features / 2) + layer = nn.Linear(n_features, next_feat) + layers.append(layer) + layers.append(self.activation()) + n_features = next_feat + # Final layer to map to MOS + layers.append(nn.Linear(n_features, 1)) + else: + for k, layer_dim in enumerate(self.layer_dims): + # Map previous number of features to layer dim + layer = nn.Linear(n_features, layer_dim) + layers.append(layer) + if k < len(self.layer_dims) - 1: + layers.append(self.activation()) + elif k == len(self.layer_dims) - 1 and self.last_activate: + layers.append(self.activation()) + + # Save current layer dim as previous number of features + n_features = layer_dim + + self.output_layers = nn.ModuleList(layers) + + def forward(self, frame_scores): + """ + Forward method for fully connected linear sequence. + + Parameters + ---------- + frame_scores : torch.tensor + Input tensor for linear sequence. + + Returns + ------- + torch.Tensor + Frame-based representation of audio (e.g., features x frames tensor for each audio file). + """ + for k, layer in enumerate(self.output_layers): + frame_scores = layer(frame_scores) + return frame_scores + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, dropout=0.3): + """ + Convolutional block used in MOSNet. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + dropout : float, optional + Dropout probability, by default 0.3 + """ + super().__init__() + self.block = nn.Sequential( + # Input shape: (B, T, in_channels, N) + # Output shape: (B, T, out_channels, ceil(N/3)) (stride 3 in freq of last convolutional layer causes decrease) + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + ), + nn.ReLU(), + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + ), + nn.ReLU(), + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 3), + padding=(1, 1), + ), + nn.Dropout(0.3), + nn.BatchNorm2d(num_features=out_channels), + nn.ReLU(), + ) + + def forward(self, x): + return self.block(x) + + +class MOSNet(nn.Module): + def __init__( + self, + ): + """ + Implement MOSNet architecture, mostly as described in "MOSNet: Deep Learning based + Objective Assessment for Voice Conversion" by Lo et al. (2019). + + Unlike the original, this implementation does not implement frame-level loss. + """ + super().__init__() + self.convolutions = nn.Sequential( + # Input shape: (B, 1, T, 257) + # Output shape: (B, 16, T, 86) (stride 3 in freq of last convolutional layer causes decrease) + ConvBlock(1, 16), + # Input shape: (B, 16, T, 86) + # Output shape: (B, 32, T, 29) + ConvBlock(16, 32), + # Input shape: (B, 32, T, 29) + # Output shape: (B, 64, T, 10) + ConvBlock(32, 64), + # Input shape: (B, 64, T, 10) + # Output shape: (B, 128, T, 4) + ConvBlock(64, 128), + ) + self.blstm = nn.LSTM( + # input_size - number of expected features in input x + input_size=512, # 4*128 + # hidden_size - number of features in hidden state h + hidden_size=128, + # num_layers - number of recurrent layers + num_layers=1, + # bias - bool if bias weight used (defaults to True) + # batch_first - if True, then input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature) + batch_first=True, + # dropout + # bidirectional + bidirectional=True, + ) # (B, T, 256=2*128), 2 b/c bidirectional==True, 128 b/c hidden_size=128 and proj_size=0 + self.fc = nn.Sequential( + nn.Linear( + in_features=256, + out_features=128, + ), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear( + in_features=128, + out_features=1, + ), + ) + self.time_pooling = torch.mean + + def forward(self, x): + # x dim: (B, C, T, F) = (B, T, 1, 257) + y = self.convolutions(x) + # y dim: (B, C, T, F) = (B, 128, T, 4) + # Swap dimensions to preserve frame-level time before flattening for BLSTM + y = torch.movedim(y, -2, -3).flatten(start_dim=-2, end_dim=-1) + # y dim: (B, T, F*C): (B, T, 512) + y, _ = self.blstm(y) + # y dim: (B, T, 2*H): (B, T, 256) -- H is hidden dimension, 2x b/c Bidirectional + y = self.fc(y) + # y dim: (B, T, 1) + y = self.time_pooling(y, dim=1) + # y dim: (B, 1) + # y = torch.squeeze(y) + # # y dim: B + return y + + +class AudioConvolutionalBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + padding=0, + pooling_type="average", + pooling_kernel=4, + batch_norm=True, + ): + """ + Audio convolutional block in the style of WAWENets. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int + Size of convolution kernel. + stride : int, optional + Convolution stride, by default 1 + dilation : int, optional + Convolution dilation, by default 1 + pooling_type : str, optional + Type of pooling to perform: "average", "blur", or "None", by default "average". + + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.padding = padding + self.pooling_type = pooling_type + self.pooling_kernel = pooling_kernel + self.batch_norm = batch_norm + + self.setup() + + def setup(self): + """ + Set up the modules. + + Sets up a convolution, batch norm (optional), ReLU, and pooling (optional). + """ + model_list = [] + conv = nn.Conv1d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=self.dilation, + padding=self.padding, + ) + model_list.append(conv) + if self.batch_norm: + bn = nn.BatchNorm1d(self.out_channels) + model_list.append(bn) + + relu = nn.ReLU() + model_list.append(relu) + + if self.pooling_type is not None: + if self.pooling_type == "average": + pool = nn.AvgPool1d(kernel_size=self.pooling_kernel) + elif self.pooling_type == "blur": + raise ValueError(f"blur pooling not implemented yet.") + else: + raise ValueError(f"Unrecognized pooling_type `{self.pooling_type}`") + # add pool to the nn.ModuleList + model_list.append(pool) + self.model_list = nn.ModuleList(model_list) + + def forward(self, x): + for module in self.model_list: + x = module(x) + return x + + +class MuLaw(nn.Module): + def __init__(self): + """ + Learnable Mu-Law compression module. + """ + super().__init__() + self.mu = torch.nn.Parameter(torch.tensor([255.0])) + + def forward(self, x): + sign = torch.sign(x) + out = sign * torch.log(1 + self.mu * torch.abs(x)) / torch.log(1 + self.mu) + return out + + +class ConvPath(nn.Module): + def __init__( + self, + kernels, + strides, + dilations, + channels, + paddings, + pooling_kernels=[None], + in_channel=1, + rectify=False, + mu_law=False, + **kwargs, + ): + """ + Convolutional paths for multi-scale convolution. + + Parameters + ---------- + kernels : list + List of kernel sizes within the path. The length of kernels determines + the number of elements in the convolutional path. + strides : list + List of strides within the path. Can be one element list and will be repeated + to be the same length as kernels. + dilations : list + List of dilations within the path. Can be one element list and will be repeated + to be the same length as kernels. + channels : list + List of channels within the path. Can be one element list and will be repeated + to be the same length as kernels. + paddings : list + List of paddings within the path. Can be one element list and will be repeated + to be the same length as kernels. + pooling_kernels : list, optional + List of poolings within the path. Can be one element list and will be repeated + to be same length as kernels, by default [None] + in_channel : int, optional + Number of channels in first AudioConvolutionalBlock, by default 1 + rectify : bool, optional + Rectify signal at beginning of the path, by default False + mu_law : bool, optional + Apply learnable Mu-Law compression at beginning of path, by default False + """ + super().__init__() + n_blocks = len(kernels) + + # make sure any single element lists have same length as kernels + if len(strides) == 1: + strides = n_blocks * list(strides) + if len(dilations) == 1: + dilations = n_blocks * list(dilations) + if len(channels) == 1: + channels = n_blocks * list(channels) + if len(paddings) == 1: + paddings = n_blocks * list(paddings) + if len(pooling_kernels) == 1: + pooling_kernels = n_blocks * list(pooling_kernels) + + self.rectify = rectify + self.mu_law = mu_law + conv_blocks = [] + if mu_law: + mu = MuLaw() + conv_blocks.append(mu) + + for ix, ( + kernel_size, + stride, + dilation, + channel, + padding, + pooling_kernel, + ) in enumerate( + zip(kernels, strides, dilations, channels, paddings, pooling_kernels) + ): + if ix > 0: + in_channel = channel + # Initialize conv_block + conv_block = AudioConvolutionalBlock( + in_channels=in_channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + pooling_kernel=pooling_kernel, + ) + + # Set any additional AudioConvolutionalBlock parameters that have been passed in through kwargs + for k, v in kwargs.items(): + if hasattr(conv_block, k): + setattr(conv_block, k, v) + # Reconfigure modules + conv_block.setup() + conv_blocks.append(conv_block) + self.conv_blocks = nn.ModuleList(conv_blocks) + + def forward(self, x): + if self.rectify: + x = torch.abs(x) + + for conv_block in self.conv_blocks: + x = conv_block(x) + return x + + +class IdentityBlock(nn.Module): + def __init__(self, **kwargs): + """ + Identity block for audio-style processing that returns the input in the forward method. + """ + super().__init__() + + def forward(x): + return x + + +class MultiScaleConvolution(nn.Module): + def __init__(self, path1, path2, path3, path4): + """ + Neural network that processes audio in up to four independent paths prior to + combining in a fully connected sequence. By means of simple statistical + aggregations, each path is compressed to the same size, regardless of audio length. + + Parameters + ---------- + path1 : nn.Module + First of the four independent paths. Will be ignored if set to IdentityBlock. + path2 : nn.Module + Second of the four independent paths. Will be ignored if set to IdentityBlock. + path3 : nn.Module + Third of the four independent paths. Will be ignored if set to IdentityBlock. + path4 : nn.Module + Fourth of the four independent paths. Will be ignored if set to IdentityBlock. + """ + super().__init__() + paths = [path1, path2, path3, path4] + + # Drop any identity paths. + paths = [path for path in paths if not isinstance(path, IdentityBlock)] + + self.conv_paths = nn.ModuleList(paths) + + # Track total dimension of convolutional outputs from all the paths + conv_out_dimension = 0 + for path in self.conv_paths: + out_dim = path.conv_blocks[-1].model_list[0].out_channels + conv_out_dimension += 2 * out_dim + + # Sequence of fully connected layers + self.decoder = LinearSequence( + in_features=conv_out_dimension, + layer_dims=[int(conv_out_dimension / 4), int(conv_out_dimension / 16), 1], + ) + + def forward(self, x): + if len(x.shape) > 3 and x.shape[1] == 1: + # This may not be the best place/way to do this, but should work on mono audio + x = torch.squeeze(x, dim=1) + path_outs = [] + for conv_path in self.conv_paths: + conv_out = conv_path(x) + + conv_means = torch.mean(conv_out, dim=-1) + conv_stds = torch.std(conv_out, dim=-1) + conv_stat = torch.cat((conv_means, conv_stds), dim=-1) + path_outs.append(conv_stat) + + stats = torch.cat(path_outs, -1) + + out = self.decoder(stats) + return out + + +# ----------------------- +# Aligner - Dataset Alignment Modules +# ----------------------- +class NoAligner(nn.Module): + def __init__(self, reference_index=0, num_datasets=0, **kwargs): + """ + NoAligner acts as a dummy module so that the other AlignNet module code + can be used even when there is no dataset alignment being performed. + + Parameters + ---------- + reference_index : int, optional + Unused, but exists to easily replace other Aligner setups, by default 0 + num_datasets : int, optional + Unused, but exists to easily replace other Aligner setups, by default 0 + """ + super().__init__() + self.reference_index = reference_index + self.num_datasets = num_datasets + + def forward(self, stilde, dataset_index): + return stilde + + +class LinearSequenceAligner(nn.Module): + def __init__( + self, + reference_index, + num_datasets, + embedding_dim=10, + layer_dims=[16, 16, 16, 16, 1], + ): + """ + Aligner network for dataset alignment. + + The LinearSequenceAligner implements the Aligner as defined in "AlignNet: + Learning dataset score alignment functions to enable better training of + speech quality estimators." + + Parameters + ---------- + reference_index : int + Dataset index that should be treated as the reference. The Aligner acts + as the identity function on the reference dataset. + num_datasets : int + Number of datasets in training. + embedding_dim : int, optional + Size of the dataset index embedding, by default 10 + layer_dims : list, optional + Dimensions of the Aligner's fully connected layers, by default [16, 16, 16, 16, 1] + """ + super().__init__() + self.reference_index = reference_index + self.num_datasets = num_datasets + + self.embedding_dim = embedding_dim + self.embedding = torch.nn.Embedding( + num_datasets, embedding_dim=self.embedding_dim + ) + self.weights = LinearSequence( + in_features=self.embedding_dim + 1, + layer_dims=layer_dims, + ) + + def forward(self, stilde, dataset_index): + embed = self.embedding(dataset_index) + # Concatenate score with embedding + x = torch.cat((stilde, embed), 1) + score = self.weights(x) + # Override for reference dataset -- this might be weird! + out = torch.where(dataset_index[:, None] == self.reference_index, stilde, score) + return out + + +# ------------------ +# Combination Module +# ------------------ +class AlignNet(nn.Module): + def __init__( + self, audio_net, aligner, aligner_corr_threshold=None, audio_net_freeze_epochs=0 + ): + """ + AlignNet module that uses an audio_net with an aligner to train on multiple datasets at once. + + Parameters + ---------- + audio_net : nn.Module + Network component that maps audio to quality on the reference dataset scale. + aligner : nn.Module + Network component that maps intermediate quality estimates and dataset + indicators to the appropriate dataset score. + aligner_corr_threshold : float, optional + Correlation threshold that determines when the aligner is activated. + If None, the aligner turns on immediately, by default None + audio_net_freeze_epochs : int, optional + Number of epochs to keep the audio_net frozen, by default 0 + """ + super().__init__() + self.audio_net = audio_net + self.aligner = aligner + self.reference_index = self.aligner.reference_index + + if aligner_corr_threshold is not None and aligner_corr_threshold > -1: + # We want to freeze aligner (and ideally ensure it is not changing + # estimations) until we see a validation correlation above + # aligner_corr_threshold. + self.use_aligner = False + self.aligner_corr_threshold = aligner_corr_threshold + + # Freeze aligner params + for p in self.aligner.parameters(): + p.requires_grad_(False) + + else: + self.use_aligner = True + self.aligner_corr_threshold = -1 + + self.audio_net_freeze_epochs = audio_net_freeze_epochs + + if audio_net_freeze_epochs > 0: + self.set_audio_net_update_status(False) + else: + self.update_audio_net = True + + def set_audio_net_update_status(self, status): + self.update_audio_net = status + for p in self.audio_net.parameters(): + p.requires_grad_(status) + + def forward(self, audio, dataset): + # Intermediate score representation + score = self.audio_net(audio) + if self.use_aligner: + # Aligned score estimate + score = self.aligner(score, dataset) + return score + + +# -------------- +# Primary Module +# -------------- +class Model(pl.LightningModule): + def __init__( + self, + network: nn.Module, + loss=torch.nn.MSELoss(), + optimizer=torch.optim.Adam, + loss_weights=None, + ): + """ + LightningModule to train AlignNet models. + + Module should be compatible with non-AlignNet architecture but includes + additional functionality specifically tailored to AlignNet. + + Parameters + ---------- + network : nn.Module + AlignNet model. + loss : func, optional + Loss function, by default torch.nn.MSELoss() + optimizer : OptimizerWrapper or torch.nn.optim class, optional + Optimizer class, by default torch.optim.Adam + loss_weights : list + List of weights to compute weighted average of loss over datasets. If None, then loss is computed without + respect to datasets. In the case where one dataset has significantly less data, a weighted average allows + more control to ensure it is properly learned. If loss_weights = 1, then all the datasets will receive equal weight. + """ + super().__init__() + # self.save_hyperparameters(ignore=["network", "loss"]) + self.network = network + + self.loss = loss + if loss_weights == 1: + n_datasets = self.network.aligner.num_datasets + loss_weights = n_datasets * [1 / n_datasets] + self.loss_weights = loss_weights + self.optimizer = optimizer + self.validation_step_info = { + "outputs": [], + "targets": [], + "datasets": [], + } + self.epoch = 0 + + def loss_calc(self, mean_estimate, mos, dataset): + """ + Perform loss calculation, taking into account loss weights. + + Parameters + ---------- + mean_estimate : torch.tensor + Network estimate. + mos : torch.tensor + Labeled truth value. + dataset : torch.tensor + Dataset indicators. + + Returns + ------- + torch.tensor + Loss. + """ + # If there are loss weights, use them + if self.loss_weights is not None: + loss = 0 + for dix in torch.unique(dataset): + dix = int(dix) + weight = self.loss_weights[dix] + sub_ests = mean_estimate[dataset == dix] + sub_mos = mos[dataset == dix] + loss += weight * self.loss(sub_ests, sub_mos) + + else: + loss = self.loss(mos, mean_estimate) + return loss + + def forward(self, audio, dataset): + mean_estimate = self.network(audio, dataset) + mean_estimate = torch.squeeze(mean_estimate, dim=1) + return mean_estimate + + def _forward(self, training_batch): + """ + Internal forward method with logic consistent across all training and test steps. + + Parameters + ---------- + training_batch : tuple + All data in a training batch. + """ + audio, mos, dataset = training_batch + mos = mos.float() + + mean_estimate = self.network(audio, dataset) + # If audio is 2-D (e.g., wav2vec representation), needs to be squeezed in diminsion 1 here + # If audio is raw wav, this won't do anything (dim 1 will be frames and != 1) + mean_estimate = torch.squeeze(mean_estimate, dim=1) + + loss = self.loss_calc(mean_estimate, mos, dataset) + + if mos.shape == torch.Size([1]): + warnings.warn(f"Batch only has one element, reporting correlation=0") + corrcoef = 0 + else: + corrcoef = self.pearsons_corr(mean_estimate, mos) + return loss, corrcoef + + def pearsons_corr(self, mean_estimate, mos): + """ + Simple wrapper for grabbing pearsons correlation coefficient + """ + mean_estimate = torch.unsqueeze(mean_estimate, dim=1) + mos = torch.unsqueeze(mos, dim=1) + + cat = torch.cat([mean_estimate, mos], dim=1) + cat = torch.transpose(cat, 0, 1) + + corrcoef = torch.corrcoef(cat)[0, 1] + + return corrcoef + + def training_step(self, training_batch, idx): + loss, corrcoef = self._forward(training_batch) + self.log("train_loss", loss) + self.log("train_pearsons", corrcoef) + return loss + + def validation_step(self, val_batch, idx): + """ + Validation step. Unlike the training and test steps, we need to store + per-dataset information here. + """ + audio, mos, dataset = val_batch + mos = mos.float() + + mean_estimate = self.network(audio, dataset) + mean_estimate = torch.squeeze(mean_estimate, dim=1) + + loss = self.loss_calc(mean_estimate, mos, dataset) + + # Store per dataset information to be used at epoch end. + self.validation_step_info["outputs"].append(mean_estimate) + self.validation_step_info["targets"].append(mos) + self.validation_step_info["datasets"].append(dataset) + + return loss + + def on_validation_epoch_end(self) -> None: + """ + At the end of validation epochs we calculate per-dataset statistics. + """ + # Concatenate stored epoch data into single tensor for each metric + estimates = torch.cat(self.validation_step_info["outputs"], dim=0) + targets = torch.cat(self.validation_step_info["targets"], dim=0) + datasets = torch.cat(self.validation_step_info["datasets"], dim=0) + + # Overall loss and correlation + loss = self.loss_calc(estimates, targets, datasets) + corrcoef = self.pearsons_corr(estimates, targets) + + # Check if network has a use_aligner flag + if hasattr(self.network, "use_aligner"): + # If aligner is off and we have passed the correlation threshold, do the updates + if ( + not self.network.use_aligner + and corrcoef > self.network.aligner_corr_threshold + ): + # Start using alignment network in forward + self.network.use_aligner = True + # Turn on gradients for aligner parameters + for p in self.network.aligner.parameters(): + p.requires_grad_(True) + print( + f"Correlation threshold of {self.network.aligner_corr_threshold} reached with {corrcoef:.4f}. Turning on aligner." + ) + + self.log("val_loss", loss) + self.log("val_pearsons", corrcoef) + + # Per dataset losses and correlations + for k, ds in enumerate(torch.unique(datasets)): + ds_ix = datasets == ds + + ds_est = estimates[ds_ix] + ds_tar = targets[ds_ix] + + ds_loss = self.loss(ds_est, ds_tar) + ds_corr = self.pearsons_corr(ds_est, ds_tar) + self.log(f"val_loss/dataset {k}", ds_loss) + self.log(f"val_pearsons/dataset {k}", ds_corr) + + # Clear epoch validation info + for k, v in self.validation_step_info.items(): + v.clear() + + # If we aren't updating audio-net and our epoch has passed the wait time, turn it on! + if ( + not self.network.update_audio_net + and self.epoch >= self.network.audio_net_freeze_epochs + ): + # Turn on audio_net, set + print(f"Turning audio_net on after {self.epoch} epochs.") + self.network.set_audio_net_update_status(True) + self.epoch += 1 + + return super().on_validation_epoch_end() + + def test_step(self, test_batch, idx): + loss, corrcoef = self._forward(test_batch) + self.log("test_loss", loss) + self.log("test_pearsons", corrcoef) + return loss + + def configure_optimizers(self): + if isinstance(self.optimizer, OptimizerWrapper): + optimizer = self.optimizer.optimizer(self.parameters()) + else: + optimizer = self.optimizer(self.parameters()) + return optimizer diff --git a/alignnet/optimizer.py b/alignnet/optimizer.py new file mode 100644 index 0000000..888fcd2 --- /dev/null +++ b/alignnet/optimizer.py @@ -0,0 +1,11 @@ +import torch + + +class OptimizerWrapper: + def __init__(self, class_name, lr, **kwargs): + self._optimizer = eval(class_name) + self.kwargs = kwargs + self.kwargs["lr"] = lr + + def optimizer(self, params): + return self._optimizer(params, **self.kwargs) diff --git a/alignnet/transforms.py b/alignnet/transforms.py new file mode 100644 index 0000000..c87500e --- /dev/null +++ b/alignnet/transforms.py @@ -0,0 +1,127 @@ +import contextlib +from typing import Any +import torch +import torchaudio + +import numpy as np + + +# ------------------ +# Audio Transforms +# ----------------- +class NoneTransform: + def __init__(self): + pass + + def transform(self, audio): + return audio + + +class MelTransform: + def __init__(self, fft_win_length, win_overlap, n_mels): + """ + Mel Spectrogram transform. + + Parameters + ---------- + sample_rate : int + Sample rate of audio. + fft_win_length : int + Window length, in samples. + win_overlap : int + Window overlap, in samples. + + """ + self.win_length = win_length + self.win_overlap = win_overlap + self.n_mels = n_mels + + def transform(self, audio, sample_rate, n_mels=None, device="cpu", **kwargs): + """ + Perform mel spectrogram transform + + Parameters + ---------- + audio : torch.tensor + Audio to transform. + sample_rate : int + Sample rate of audio. + n_mels : int, optional + Number of mel bands in transform, by default 32. + device : torch.device + Device to perform transform on, by default "cpu". + + Returns + ------- + torch.tensor + Transformed audio. + """ + + if n_mels is None: + n_mels = self.n_mels + + transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=self.win_length, + # win_length=self.win_length, + hop_length=self.win_length - self.win_overlap, + n_mels=n_mels, + center=False, + norm="slaney", + mel_scale="slaney", + ) + transform = transform.to(device) + + mel = transform(audio) + mel = torch.squeeze(mel) + + return mel + + +class STFTTransform: + def __init__(self, fft_win_length=512, win_overlap=256): + """ + Short-time fourier transform. + + Parameters + ---------- + fft_win_length : int, optional + Window length in samples, by default 512 + win_overlap : int, optional + Window overlap in samples, by default 256 + """ + self.fft_win_length = fft_win_length + self.win_overlap = win_overlap + + def transform(self, audio, **kwargs): + """ + Perform a STFT on audio. + + Parameters + ---------- + audio : torch.tensor + Audio to transform. + + Returns + ------- + torch.tensor + Transformed audio. + """ + hann_window = torch.hann_window( + window_length=self.fft_win_length, + periodic=True, + ) + + stft = torch.stft( + audio, + n_fft=self.fft_win_length, + hop_length=self.fft_win_length - self.win_overlap, + window=hann_window, + return_complex=True, + center=False, + ) + stft = torch.abs(stft) # (N_freq, N_frame) + stft = torch.movedim(stft, -1, -2) # (N_frame, N_freq) + stft = torch.squeeze(stft) # (N_frame, N_freq) + + return stft diff --git a/environment-paper.yaml b/environment-paper.yaml new file mode 100644 index 0000000..750b365 --- /dev/null +++ b/environment-paper.yaml @@ -0,0 +1,28 @@ +name: alignnet-paper +channels: + - pytorch + - nvidia + - conda-forge +dependencies: + - python=3.10 + - clearml=1.16.1 + - clearml-agent + - hydra-core=1.3.2 + - IPython + - matplotlib=3.8.4 + - numpy=1.26.4 + - pre-commit + - poetry + - pytorch=2.3.0 + - pytorch-cuda=11.8 + - conda-forge::pytorch-lightning=2.2.2 + - scipy + - torchvision + - torchaudio=2.3.0 + - pip + - pip: + - pandas=2.2.2 + - pylance + - sox + - soundfile + - tensorboard=2.16.2 diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..94d1ad0 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,28 @@ +name: alignnet +channels: + - pytorch + - nvidia + - conda-forge +dependencies: + - python=3.10 + - clearml + - clearml-agent + - hydra-core + - IPython + - matplotlib + - numpy + - pre-commit + - poetry + - pytorch + - pytorch-cuda=11.8 + - conda-forge::pytorch-lightning + - scipy + - torchvision + - torchaudio + - pip + - pip: + - pandas + - pylance + - sox + - soundfile + - tensorboard diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..33da4eb --- /dev/null +++ b/inference.py @@ -0,0 +1,84 @@ +import hydra +import os +import torch +import warnings + +import numpy as np +import pandas as pd + +from omegaconf import DictConfig, OmegaConf, open_dict +from tqdm import tqdm + +import alignnet + +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + + +@hydra.main( + config_path="./inference_configs", config_name="config.yaml", version_base=None +) +def main(cfg: DictConfig) -> None: + """ + Run inference on data with a trained model. + + See `python inference.py --help` for more details. + """ + + # Transform + transform = hydra.utils.instantiate(cfg.transform) + + print("Initializing data") + audio_data = hydra.utils.instantiate( + cfg.data, + transform=transform, + ) + print(f"Loading model from {cfg.model.path}") + model = alignnet.load_model(cfg.model.path) + + # Use GPU if available + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + if cfg.model.dataset_index == "reference": + dataset_index = torch.tensor([model.network.aligner.reference_index]) + else: + dataset_index = torch.tensor([cfg.model.dataset_index]) + dataset_index = dataset_index.to(device) + # Switch to eval mode + model.eval() + with torch.no_grad(): + output_dicts = [] + print(f"Generating estimations") + for ix, (audio, mos, dataset) in enumerate( + tqdm(audio_data, total=len(audio_data)) + ): + # Make audio look batched + audio = audio[None, None, :] + audio = audio.to(device) + + est = model(audio, dataset_index) + audio_path = audio_data.score_file.loc[ix, audio_data.pathcol] + output_dicts.append( + { + "file": audio_path, + "estimate": est.to("cpu").numpy()[0], + "dataset": dataset, + "AlignNet dataset index": dataset_index.to("cpu").numpy()[0] + } + ) + # Iterating over Datasets does not always stop appropriately so this ensures it does + if ix == len(audio_data) - 1: + break + output_df = pd.DataFrame(output_dicts) + print("First 5 results:") + print(output_df.head()) + output_dir = os.path.dirname(cfg.output.file) + + if not os.path.exists(output_dir) and output_dir != "": + os.makedirs(output_dir) + print(f"Saving results to {cfg.output.file}") + output_df.to_csv(cfg.output.file, index=False) + + +if __name__ == "__main__": + main() diff --git a/inference_configs/config.yaml b/inference_configs/config.yaml new file mode 100644 index 0000000..0cec210 --- /dev/null +++ b/inference_configs/config.yaml @@ -0,0 +1,28 @@ +hydra: + run: + dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + job: + chdir: False + +defaults: + - input_type: stft + - override hydra/help: inference_help + - _self_ +output: + file: ??? + +data: + data_files: ??? + cache: false + transform_time: get + target: null + +dataclass: + _target_: hydra.utils.get_class + +model: + path: ??? + dataset_index: reference + +transform: + _target_: alignnet.NoneTransform diff --git a/inference_configs/hydra/help/inference_help.yaml b/inference_configs/hydra/help/inference_help.yaml new file mode 100644 index 0000000..f15d4ed --- /dev/null +++ b/inference_configs/hydra/help/inference_help.yaml @@ -0,0 +1,48 @@ +app_name: AlignNet + +header: == Using ${hydra.help.app_name} at inference== + +footer: |- + Powered by Hydra (https://hydra.cc) + Use --hydra-help to view Hydra specific help. + +template: |- + ${hydra.help.header} + + This is the ${hydra.help.app_name} inference program! + + To use a model at inference, you must override three parameters: + * model.path : str pointing to the path containing a trained model + (must have a `model.ckpt` and `config.yaml` file in path.) + * data.data_files : list containing paths to csv files with filepaths to perform inference on. + The path name of the csv must correspond to `data.pathcol` which can be overriden. + * output.file : str to filepath where outputs will be saved. + + + == Configuration groups == + Compose your configuration from those groups (db=mysql) + + $APP_CONFIG_GROUPS + + == Config == + This is the config generated for this run. + You can override everything. For example, to switch to an audio input type and see all the options, run: + + ``` + python inference.py input_type=audio --help + ``` + + The appropriate input type is determined by what the trained model expects. + The default is stft features. + + The model.dataset_index override allows you to get estimates with different dataset alignment functions at inference. + It defaults to the reference dataset used at training but can be set to the integer corresponding to any other training dataset. + + ------- + $CONFIG + ------- + + To see the config of an example command directly without running it, add + `--cfg job` to your command. + + ${hydra.help.footer} \ No newline at end of file diff --git a/inference_configs/input_type/audio.yaml b/inference_configs/input_type/audio.yaml new file mode 100644 index 0000000..7d5f36d --- /dev/null +++ b/inference_configs/input_type/audio.yaml @@ -0,0 +1,9 @@ +# @package _global_ +data: + _target_: alignnet.AudioData + fs: 16000 + time_dim: 1 + pathcol: audio_path + +dataclass: + path: diff --git a/inference_configs/input_type/stft.yaml b/inference_configs/input_type/stft.yaml new file mode 100644 index 0000000..3b879fe --- /dev/null +++ b/inference_configs/input_type/stft.yaml @@ -0,0 +1,6 @@ +# @package _global_ +data: + _target_: alignnet.FeatureData + fs: null + flatten: false + pathcol: stft_path diff --git a/pretransform_data.py b/pretransform_data.py new file mode 100644 index 0000000..091ca43 --- /dev/null +++ b/pretransform_data.py @@ -0,0 +1,247 @@ +import datetime +import os +import pickle +import torchaudio + +import pandas as pd + +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from torch import save +from tqdm import tqdm + +from alignnet import transforms + + +def load_audio(fpath, target_fs=None): + """ + Load audio file and resample as necessary. + + Parameters + ---------- + fpath : str + Path to audio file. + target_fs : int, optional + Target sample rate if resampling required, by default None + + Returns + ------- + torch.tensor + Audio file + int + Sample rate + """ + audio, sample_rate = torchaudio.load(fpath) + if target_fs is not None and sample_rate != target_fs: + # We have a target fs and the current sample rate is not it => resample! + resampler = torchaudio.transforms.Resample( + orig_freq=sample_rate, + new_freq=target_fs, + dtype=audio.dtype, + ) + audio = resampler(audio) + sample_rate = target_fs + return audio, sample_rate + + +def transform_path_walk(datapath, outpath, transform, target_fs): + """ + Transform all audio files within directory via os.walk + + Parameters + ---------- + datapath : str + Path to walk + outpath : str + Path to save transformed files + transform : transform + Transform from alignnet.transforms. Must have transform.transform(audio, **kwargs) + as a method. + target_fs : int + Target sample rate. Audio will be resampled to this prior to transform if needed. + """ + failed_files = [] + for path, directories, files in os.walk(datapath): + print(f"Transforming audio in: {path}") + for file in tqdm(files): + bname, ext = os.path.splitext(file) + if ext == ".wav": + fpath = os.path.join(path, file) + try: + audio, sample_rate = load_audio(fpath, target_fs=target_fs) + except: + failed_files.append(fpath) + continue + audio = transform.transform(audio, sample_rate=sample_rate) + audio = audio.float() + + # Get relative path from datapath to path + subpath = os.path.relpath(path, datapath) + # Make a new path from outpath to path + newpath = os.path.join(outpath, subpath) + + # Make directories if necessary + if not os.path.exists(newpath): + os.makedirs(newpath) + + outfile = os.path.join(newpath, bname + ".pkl") + outfile = os.path.abspath(outfile) + + with open(outfile, "wb") as output: + pickle.dump(audio, output) + print(f"Unable to transform following files:\n{failed_files}") + + +def transform_csv( + datapath, outpath, csv_list, transform, target_fs, pathcol="filename" +): + """ + Transform all audio files listed in csv. + + Parameters + ---------- + datapath : str + Parent path to (potential) relative path within csv pathcol. + outpath : str + Path where transformed audio will be saved. + csv_list : str + Path to csv file containing audio names to transform in pathcol + transform : transform + Transform from alignnet.transforms. Must have transform.transform(audio, **kwargs) + as a method. + target_fs : int + Target sample rate. Audio will be resampled to this prior to transform if needed. + pathcol : str, optional + Column in csv that contains audio filenames, by default "filename" + """ + # Load csv int dataframe + df = pd.read_csv(csv_list) + for ix, row in tqdm(df.iterrows(), total=len(df)): + # Get filename + fname = row[pathcol] + # Create file path + fpath = os.path.join(datapath, fname) + + # Load audio and transform it + audio, sample_rate = load_audio(fpath, target_fs=target_fs) + audio = transform.transform(audio, sample_rate=sample_rate) + audio = audio.float() + + outfile = os.path.join(outpath, fname) + newpath = os.path.dirname(outfile) + + # Make directories if necessary + if not os.path.exists(newpath): + os.makedirs(newpath) + + outfile, _ = os.path.splitext(outfile) + outfile = outfile + ".pkl" + + with open(outfile, "wb") as output: + pickle.dump(audio, output) + + +def main(datapath, outpath, transform_name, csv_list, **kwargs): + log_file = os.path.join(outpath, "readme.log") + if transform_name == "Mel": + transform = transforms.MelTransform() + elif transform_name == "STFT": + transform = transforms.STFTTransform() + for k, v in kwargs.items(): + if hasattr(transform, k): + setattr(transform, k, v) + # We never want to flatten + if hasattr(transform, "flatten"): + setattr(transform, "flatten", False) + os.makedirs(outpath, exist_ok=True) + with open(log_file, "w") as outfile: + time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + time_str = f"Running {__file__}\nStarted: {time}\n" + outfile.write(time_str) + input_str = f"datapath={datapath}, outpath={outpath}, transform_name={transform_name}, csv_list={csv_list}\nkwargs: " + for k, v in kwargs.items(): + input_str += f"{k}={v}, " + outfile.write(f"Inputs: {input_str}") + transform_str = f"type: {type(transform)}\nAttributes: " + for v in dir(transform): + if v[0] != "_": + transform_str += f"{v}={getattr(transform, v)}, " + outfile.write(transform_str) + c = 0 + + _, ext = os.path.splitext(datapath) + if csv_list is not None: + print("Load csv and transform files in there") + transform_csv( + datapath=datapath, + outpath=outpath, + csv_list=csv_list, + transform=transform, + target_fs=kwargs["target_fs"], + ) + else: + print("Walk through the input directory") + transform_path_walk( + datapath=datapath, + outpath=outpath, + transform=transform, + target_fs=kwargs["target_fs"], + ) + + with open(log_file, "a") as outfile: + finish_str = ( + f'\nFinished: {datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}\n' + ) + outfile.write(finish_str) + + +if __name__ == "__main__": + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "transform_name", + choices=["Mel", "STFT"], + help="Transform to apply. Corresponds to a transform class in alignnet.transforms", + ) + parser.add_argument("datapath", type=str, help="Path to data to transform") + parser.add_argument( + "outpath", + type=str, + help=("Path where transformed version of data is stored."), + ) + + parser.add_argument( + "--target-fs", + default=None, + type=int, + help="Sample rate to resample audio to prior to transformation. If None, no resampling done.", + ) + + parser.add_argument( + "--fft-win-length", default=512, type=int, help="Window length for an STFT." + ) + + parser.add_argument( + "--win-overlap", default=256, type=int, help="Window overlap for an STFT." + ) + + parser.add_argument( + "--csv-list", + type=str, + default=None, + help=( + "CSV file with list of files to transform. Assumes that " + "os.path.join(datapath, x) is the full path to a file, where x is a " + "row of the csv under column 'filename'" + ), + ) + + parser.add_argument( + "--log", + action="store_true", + help="Take log10 of representations.", + ) + + + args = parser.parse_args() + + main(**vars(args)) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f6fa6b3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "alignnet" +version = "0.1.0" +description = "" +authors = ["Jaden Pieper "] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.10" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/split_labeled_data.py b/split_labeled_data.py new file mode 100644 index 0000000..7edcbbe --- /dev/null +++ b/split_labeled_data.py @@ -0,0 +1,216 @@ +import os + +import numpy as np +import pandas as pd + +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + +EXT = ".csv" + + +def get_split_numbers(n_audio, split_fraction): + """ + Number of audio files per split + + Parameters + ---------- + n_audio : int + Number of audio files + split_fraction : list + List of fractions for each split + + Returns + ------- + list + List of number of items in each split. + """ + # Number of audio files per split + split_numbers = [] + for split_frac in split_fraction[:-1]: + split_num = np.round(split_frac * n_audio) + split_numbers.append(int(split_num)) + split_numbers.append(n_audio - np.sum(split_numbers)) + return split_numbers + + +def split_df_by_column(df, split_col, split_names, split_fraction): + """ + Generate dictionary of indices for splitting up a DataFrame while maintaining + balance within splits for a single column. + + Split dataframe while maintaining balance of elements within a specific column. + + Note that if there are n conditions labelled within a certain column, this + ensures that the proper ratio of conditions is maintained within the train, validation, + and test datasets. For example, if 80% of the data is condition A, + 15% is condition B, and 5% is condition C, then those percentage ratios will + be preserved in each of the train, validation, and test datasets. + + + Parameters + ---------- + df : pd.DataFrame + _description_ + split_col : str + Categorical column name in df that will have balance of values preserved + in each output dataset. + split_names : list + List of names for output csvs, used as keys in output dict. + split_fraction : list + Fraction of df to go into each dictionary item (ordered according to split_names). + + Returns + ------- + dict + Dictionary with keys being split_names and values being array of indices + that has length of len(df) * split_fraction for each element. + """ + column_vals = np.unique(df[split_col]) + + # Initialize empty dictionary + split_ix = dict() + for name in split_names: + split_ix[name] = [] + + for col_val in column_vals: + df_filt = df[df[split_col] == col_val] + split_ix_val = split_df(df_filt, split_names, split_fraction) + for name, ix in split_ix_val.items(): + split_ix[name].extend(ix) + + # One final shuffle + rng = np.random.default_rng() + for name, ix in split_ix.items(): + split_ix[name] = rng.choice(ix, len(ix), replace=False) + return split_ix + + +def split_df(df, split_names, split_fraction): + """ + Generate dictionary of indices for splitting up a DataFrame. + + Dictionary keys are defined by split_names and the number of items in each key + is determined by split_fraction. + + Parameters + ---------- + df : pd.DataFrame + Dataframe to split. + split_names : list + List of names for output csvs, used as keys in output dict. + split_fraction : list + Fraction of df to go into each dictionary item (ordered according to split_names). + + Returns + ------- + dict + Dictionary with keys being split_names and values being array of indices + that has length of len(df) * split_fraction for each element. + """ + # Number of rows in df + n_audio = len(df) + + # Get list with number of audio files per split + split_numbers = get_split_numbers(n_audio, split_fraction) + + # Initialize random number generator + rng = np.random.default_rng() + + # Shuffle index + shuffled_ix = rng.choice(df.index, size=len(df.index), replace=False) + split_ix = dict() + seen = 0 + for n, name in zip(split_numbers, split_names): + start = seen + end = start + n + split_ix[name] = shuffled_ix[start:end] + seen = end + return split_ix + + +def main(args, n=None): + if len(args.split_fraction) != len(args.split_names): + raise ValueError( + f"Split fraction and split names must be same length, {len(args.split_fraction)} != {len(args.split_names)}" + ) + output_dir = args.output_dir + + if n is not None: + output_dir += f"/split{n:02}" + os.makedirs(output_dir, exist_ok=True) + # Read scores + score_df = pd.read_csv(args.label_file) + + if args.split_column is None: + # Split all the files + split_ix = split_df(score_df, args.split_names, args.split_fraction) + else: + split_ix = split_df_by_column( + # Split according to the split_column + score_df, + args.split_column, + args.split_names, + args.split_fraction, + ) + + ext = ".csv" + for name, ix in split_ix.items(): + audio = score_df.iloc[ix] + out_name = os.path.join(output_dir, name + ext) + audio.to_csv(out_name, index=False) + + +if __name__ == "__main__": + parser = ArgumentParser( + description="Split a label_file containing target and pathcol for audio file into train, valid, and test csvs.", + formatter_class=ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "label_file", + type=str, + help=("Path and filename to file with subjective scores and file paths."), + ) + + parser.add_argument( + "--output-dir", + type=str, + default="data-splits", + help="Path where data splits will be stored.", + ) + + parser.add_argument( + "--split-names", + nargs="+", + default=["train", "valid", "test"], + help="Labels for how data is split and saved.", + ) + + parser.add_argument( + "--split-fraction", + nargs="+", + type=float, + default=[0.8, 0.1, 0.1], + help="Amount of data to use for each split-name. Must sum to 1 and be same length as --split-names.", + ) + + parser.add_argument( + "--split-column", + type=str, + default=None, + help=( + "Column for which data should be split according to split-fraction (e.g., force distributions of values in " + "that column across each dataset.)" + ), + ) + + parser.add_argument( + "--n-splits", type=int, default=1, help="Number of independent splits to make." + ) + + parser.add_argument( + "--no-header", action="store_true", help="Flag for no header in csvs." + ) + + args = parser.parse_args() + for k in range(args.n_splits): + main(args, n=k) diff --git a/train.py b/train.py new file mode 100644 index 0000000..5fe9245 --- /dev/null +++ b/train.py @@ -0,0 +1,328 @@ +import hydra +import os +import shutil +import torch +import warnings +import yaml + +import matplotlib.cm as cm +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytorch_lightning as pl + +from omegaconf import DictConfig, OmegaConf, open_dict +from tqdm import tqdm +import alignnet + +# Load clearml Task only if clearml is imported +try: + from clearml import Task + from clearml.backend_api.session.defs import MissingConfigError +except ModuleNotFoundError as err: + + def Task(**kwargs): + return None + + +def post_train(model, audio_data, loggers, task=None): + audio_data.batch_size = 1 + data_loaders = { + "train": audio_data.train_dataloader(), + "val": audio_data.val_dataloader(), + "test": audio_data.test_dataloader(), + } + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + model.eval() + + # Save estimations + out_df = pd.DataFrame() + with torch.no_grad(): + for dataset_split, loader in data_loaders.items(): + print(f"Loading dataset: {dataset_split}") + results = [] + for audio, mos, dataset in tqdm(loader): + audio = audio.to(device) + dataset = dataset.to(device) + + est = model(audio, dataset) + est = est.numpy(force=True) + est = np.squeeze(est) + + audio_net_est = model.network.audio_net(audio) + audio_net_est = audio_net_est.numpy(force=True) + audio_net_est = np.squeeze(audio_net_est) + + mos = mos.numpy(force=True) + mos = np.squeeze(mos) + + results.append( + [float(mos), float(est), int(dataset), float(audio_net_est)] + ) + + results = np.array(results) + data_df = pd.DataFrame( + results, + columns=["MOS", "Estimation", "Dataset_Index", "AudioNet_Estimation"], + ) + + data_df["Dataset"] = dataset_split + + out_df = pd.concat([out_df, data_df]) + corr = np.corrcoef(results[:, 0], results[:, 1])[0, 1] + print(f"{dataset_split} corr coef: {corr:6f}") + + metric_name = f"test_pearsons/{dataset_split}" + metrics = dict() + metrics[metric_name] = corr + + rmse = np.sqrt(np.mean((results[:, 0] - results[:, 1]) ** 2)) + metrics[f"test_rmse/{dataset_split}"] = rmse + + if dataset_split == "test": + for dix in np.unique(data_df["Dataset_Index"]): + df_sub = data_df[data_df["Dataset_Index"] == dix] + corr = np.corrcoef(df_sub["MOS"], df_sub["Estimation"])[0, 1] + rmse = np.sqrt(np.mean((df_sub["MOS"] - df_sub["Estimation"]) ** 2)) + + corr_name = f"test_pearsons/dataset {dix}" + metrics[corr_name] = corr + + rmse_name = f"test_rmse/dataset {dix}" + metrics[rmse_name] = rmse + + [logger.log_metrics(metrics) for logger in loggers] + + estimations_name = "estimations.csv" + out_df.to_csv(estimations_name, index=False) + + # Store estimations to clearml + if task is not None: + task.upload_artifact(artifact_object=estimations_name, name="estimations csv") + task.upload_artifact(artifact_object=out_df, name="estimations df") + colormap = cm.rainbow + # Plot estimations vs MOS + for k, ds in enumerate(np.unique(out_df["Dataset"])): + df_sub = out_df[out_df["Dataset"] == ds] + mos = df_sub["MOS"] + est = df_sub["Estimation"] + dataset_index = df_sub["Dataset_Index"] + plt.plot([1, 5], [1, 5], color="black", linestyle="dashed") + plt.scatter(x=mos, y=est, c=dataset_index, alpha=0.1, cmap=colormap) + corrcoef = np.corrcoef(mos, est)[0, 1] + rmse = np.sqrt(np.mean((mos - est) ** 2)) + title_str = f"{ds} set, LCC={corrcoef:.4f}, RMSE={rmse:.4f}" + for dx in np.unique(dataset_index): + dx_ix = dataset_index == dx + mos_dx = mos[dx_ix] + est_dx = est[dx_ix] + corrcoef_dx = np.corrcoef(mos_dx, est_dx)[0, 1] + rmse_dx = np.sqrt(np.mean((mos_dx - est_dx) ** 2)) + subtitle_str = ( + f", (Dataset {dx}, LCC={corrcoef_dx:.4f}, RMSE={rmse_dx:.4f})" + ) + title_str += subtitle_str + plt.title(title_str) + plt.xlabel("MOS") + plt.ylabel("Estimation") + plt.show() + + # audio_net vs aligner estimations + for k, ds in enumerate(np.unique(out_df["Dataset"])): + df_sub = out_df[out_df["Dataset"] == ds] + mos = df_sub["AudioNet_Estimation"] + est = df_sub["Estimation"] + dataset_index = df_sub["Dataset_Index"] + plt.plot([1, 5], [1, 5], color="black", linestyle="dashed") + plt.scatter(x=mos, y=est, c=dataset_index, alpha=0.1, cmap=colormap) + plt.title(f"{ds} set audio_net vs estimation ") + plt.xlabel("audio_net estimation") + plt.ylabel("final estimation") + plt.show() + + for k, dx in enumerate(np.unique(out_df["Dataset_Index"])): + xv = np.arange(0.5, 5.5, step=0.01) + xv = torch.Tensor(xv) + xv = xv[:, None] + xv = xv.to(device) + + data_tensor = dx * torch.ones(xv.shape) + data_tensor = data_tensor.squeeze() + data_tensor = data_tensor.to(int) + data_tensor = data_tensor.to(device) + + yv = model.network.aligner(xv, data_tensor) + xv = xv.cpu().detach().numpy() + yv = yv.cpu().detach().numpy() + plt.plot([1, 5], [1, 5], color="black", linestyle="dashed") + plt.scatter(xv, yv) + plt.title(f"Alignment function for dataset {dx}") + plt.xlabel("Raw score") + plt.ylabel("Aligned") + plt.show() + + +@hydra.main(config_path="alignnet/config", config_name="conf.yaml", version_base=None) +def main(cfg: DictConfig) -> None: + if cfg.logging.logger == "clearml": + try: + task = Task.init( + project_name=cfg.project.name, + task_name=cfg.project.task, + ) + except MissingConfigError as E: + print(f"{E}") + print( + f"If you do not want to install clearML and want to avoid this error in the future, set `logging=none` override." + ) + task = None + else: + task = None + print("Working directory : {}".format(os.getcwd())) + + # Seed + seed = cfg.common.seed + if seed is None: + rng = np.random.default_rng() + seed = rng.choice(10000) + cfg.common.seed = seed + pl.seed_everything(seed) + + # Transform + transform = hydra.utils.instantiate(cfg.transform) + + data_class = hydra.utils.instantiate(cfg.dataclass) + + audio_data = hydra.utils.instantiate( + cfg.data, transform=transform, DataClass=data_class + ) + + num_datasets = len(cfg.data.data_dirs) + + # Lightning logs + # Initialize tensorboard logger, letting hydra control directory and versions + tb_logger = pl.loggers.TensorBoardLogger( + save_dir=".", + name="", + version="", + ) + loggers = [tb_logger] + [logger.log_hyperparams(dict(cfg)) for logger in loggers] + + checkpoint_callback = hydra.utils.instantiate(cfg.checkpoint) + + callbacks = [checkpoint_callback] + if "earlystop" in cfg: + # Earlystop needs monitor (e.g., val-loss) and mode (e.g., min). This can be added via CLI/cfg. Otherwise steal the checkpoint values. + stop_params = {"monitor": None, "mode": None} + for k, _ in stop_params.items(): + if k in cfg.earlystop: + stop_params[k] = cfg.earlystop[k] + else: + stop_params[k] = cfg.checkpoint[k] + early_stopping_callback = hydra.utils.instantiate(cfg.earlystop, **stop_params) + callbacks.append(early_stopping_callback) + # Trainer + trainer = hydra.utils.instantiate( + cfg.optimization, callbacks=callbacks, logger=loggers + ) + num_datasets = len(cfg.data.data_dirs) + # Initialize network + network = hydra.utils.instantiate( + cfg.network, aligner={"num_datasets": num_datasets} + ) + loss = hydra.utils.instantiate(cfg.loss) + + optimizer = hydra.utils.instantiate(cfg.optimizer, lr=cfg.common.lr) + + # initialize model + if cfg.finetune.restore_file is not None: + print(f"Loading model from checkpoint: {cfg.finetune.restore_file}") + # initialize model + model_class = hydra.utils.get_class(cfg.model._target_) + # Path to pretrained model checkpoint + model_path = os.path.join(cfg.finetune.restore_file, "model.ckpt") + restore_cfg_path = os.path.join(cfg.finetune.restore_file, "config.yaml") + + with open(restore_cfg_path, "r") as f: + restore_yaml = yaml.safe_load(f) + restore_cfg = DictConfig(restore_yaml) + + restore_network = hydra.utils.instantiate(restore_cfg.network) + # Turn restored audio_net gradients on or off depending on new network settings + old_freeze_name = "audio_net_freeze_steps" + if hasattr(network, old_freeze_name): + frozen_steps = getattr(network, old_freeze_name) + else: + frozen_steps = network.audio_net_freeze_epochs + if frozen_steps > 0: + restore_network.set_audio_net_update_status(False) + else: + restore_network.set_audio_net_update_status(True) + + # Initialize identical network to pretrained version (necessary to appropriately load in aligner) + # aligner is not transferable (different sizes based on number of datasets) + + restored_model = model_class.load_from_checkpoint( + model_path, network=restore_network, loss=loss, optimizer=optimizer + ) + # Grab audio_net from checkpoint + network.audio_net = restored_model.network.audio_net + + model = hydra.utils.instantiate( + cfg.model, network=network, loss=loss, optimizer=optimizer + ) + print(model) + + # Add working directory to config + with open_dict(cfg): + cfg.project.working_dir = os.getcwd() + # Save a version of the config + cfg_yaml = OmegaConf.to_yaml(cfg) + cfg_out = "input_config.yaml" + with open(cfg_out, "w") as file: + file.write(cfg_yaml) + + if cfg.common.auto_batch_size: + tuner = pl.tuner.Tuner(trainer) + tuner.scale_batch_size(model, datamodule=audio_data) + + # Fit Trainer + trainer.fit(model, audio_data) + + best_model_path = trainer.checkpoint_callback.best_model_path + trained_model_path = "trained_model" + os.makedirs(trained_model_path) + + # Save another copy of the top model + top_model_path = os.path.join(trained_model_path, "model.ckpt") + shutil.copy(best_model_path, top_model_path) + print(f'experiment_path = "{os.getcwd()}"') + print(f'model_ckpt = "{best_model_path}"') + + # Create output config + output_config = DictConfig({}) + output_config.model = cfg.model + output_config.network = cfg.network + # Store num datasets directly + output_config.network.aligner.num_datasets = num_datasets + + # Convert to yaml + output_config = OmegaConf.to_yaml(output_config) + + # Save output config + output_config_path = os.path.join(trained_model_path, "config.yaml") + with open(output_config_path, "w") as file: + file.write(output_config) + + # Get model class + model_class = hydra.utils.get_class(cfg.model._target_) + # Load best model + model = model_class.load_from_checkpoint(best_model_path, network=network) + post_train(model, audio_data, loggers, task=task) + + +if __name__ == "__main__": + main() diff --git a/trained_models/alignnet_mdf-MOSNet-large_data/config.yaml b/trained_models/alignnet_mdf-MOSNet-large_data/config.yaml new file mode 100755 index 0000000..f7a745f --- /dev/null +++ b/trained_models/alignnet_mdf-MOSNet-large_data/config.yaml @@ -0,0 +1,20 @@ +model: + _target_: alignnet.Model + loss_weights: 1 +network: + _target_: alignnet.AlignNet + aligner_corr_threshold: -1 + audio_net: + _target_: alignnet.MOSNet + aligner: + _target_: alignnet.LinearSequenceAligner + layer_dims: + - 16 + - 16 + - 16 + - 16 + - 1 + reference_index: 0 + embedding_dim: 10 + num_datasets: 4 + audio_net_freeze_epochs: 1 diff --git a/trained_models/alignnet_mdf-MOSNet-large_data/model.ckpt b/trained_models/alignnet_mdf-MOSNet-large_data/model.ckpt new file mode 100755 index 0000000..a74fefb Binary files /dev/null and b/trained_models/alignnet_mdf-MOSNet-large_data/model.ckpt differ diff --git a/trained_models/alignnet_mdf-MOSNet-small_data/config.yaml b/trained_models/alignnet_mdf-MOSNet-small_data/config.yaml new file mode 100755 index 0000000..47d0f81 --- /dev/null +++ b/trained_models/alignnet_mdf-MOSNet-small_data/config.yaml @@ -0,0 +1,20 @@ +model: + _target_: alignnet.Model + loss_weights: 1 +network: + _target_: alignnet.AlignNet + aligner_corr_threshold: -1 + audio_net: + _target_: alignnet.MOSNet + aligner: + _target_: alignnet.LinearSequenceAligner + layer_dims: + - 16 + - 16 + - 16 + - 16 + - 1 + reference_index: 0 + embedding_dim: 10 + num_datasets: 9 + audio_net_freeze_epochs: 1 diff --git a/trained_models/alignnet_mdf-MOSNet-small_data/model.ckpt b/trained_models/alignnet_mdf-MOSNet-small_data/model.ckpt new file mode 100755 index 0000000..12d1028 Binary files /dev/null and b/trained_models/alignnet_mdf-MOSNet-small_data/model.ckpt differ diff --git a/trained_models/pretrained-MOSNet-nisqa/config.yaml b/trained_models/pretrained-MOSNet-nisqa/config.yaml new file mode 100755 index 0000000..e24a31a --- /dev/null +++ b/trained_models/pretrained-MOSNet-nisqa/config.yaml @@ -0,0 +1,17 @@ +model: + _target_: alignnet.Model +network: + _target_: alignnet.AlignNet + aligner_corr_threshold: -1 + audio_net: + _target_: alignnet.MOSNet + aligner: + _target_: alignnet.LinearSequenceAligner + layer_dims: + - 32 + - 32 + - 32 + - 32 + - 1 + reference_index: 0 + num_datasets: 1 diff --git a/trained_models/pretrained-MOSNet-nisqa/model.ckpt b/trained_models/pretrained-MOSNet-nisqa/model.ckpt new file mode 100755 index 0000000..ce6260c Binary files /dev/null and b/trained_models/pretrained-MOSNet-nisqa/model.ckpt differ diff --git a/trained_models/pretrained-MOSNet-tencent/config.yaml b/trained_models/pretrained-MOSNet-tencent/config.yaml new file mode 100755 index 0000000..e24a31a --- /dev/null +++ b/trained_models/pretrained-MOSNet-tencent/config.yaml @@ -0,0 +1,17 @@ +model: + _target_: alignnet.Model +network: + _target_: alignnet.AlignNet + aligner_corr_threshold: -1 + audio_net: + _target_: alignnet.MOSNet + aligner: + _target_: alignnet.LinearSequenceAligner + layer_dims: + - 32 + - 32 + - 32 + - 32 + - 1 + reference_index: 0 + num_datasets: 1 diff --git a/trained_models/pretrained-MOSNet-tencent/model.ckpt b/trained_models/pretrained-MOSNet-tencent/model.ckpt new file mode 100755 index 0000000..2d70a76 Binary files /dev/null and b/trained_models/pretrained-MOSNet-tencent/model.ckpt differ