diff --git a/README.md b/README.md
index 26f8db0f..1c7d7758 100644
--- a/README.md
+++ b/README.md
@@ -27,7 +27,6 @@
More examples
-
@@ -55,8 +54,6 @@ This repository refines the timestamps of openAI's Whisper model via forced alig
- Character level timestamps (see `*.char.ass` file output)
- Diarization (still in beta, add `--diarize`)
-To enable VAD filtering and Diarization, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
-
Setup ⚙️
Install this package using
@@ -74,9 +71,13 @@ $ cd whisperX
$ pip install -e .
```
-
You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
+
+### Voice Activity Detection Filtering & Diarization
+To **enable VAD filtering and Diarization**, include your Hugging Face access token that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation) , [Voice Activity Detection (VAD)](https://huggingface.co/pyannote/voice-activity-detection) , and [Speaker Diarization](https://huggingface.co/pyannote/speaker-diarization)
+
+
Usage 💬 (command line)
### English
@@ -152,8 +153,9 @@ In addition to forced alignment, the following two modifications have been made
- Not thoroughly tested, especially for non-english, results may vary -- please post issue to let me know the results on your data
- Whisper normalises spoken numbers e.g. "fifty seven" to arabic numerals "57". Need to perform this normalization after alignment, so the phonemes can be aligned. Currently just ignores numbers.
-- Assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
-- Hacked this up quite quickly, there might be some errors, please raise an issue if you encounter any.
+- If not using VAD filter, whisperx assumes the initial whisper timestamps are accurate to some degree (within margin of 2 seconds, adjust if needed -- bigger margins more prone to alignment errors)
+- Overlapping speech is not handled particularly well by whisper nor whisperx
+- Diariazation is far from perfect.
Contribute 🧑🏫
@@ -176,29 +178,34 @@ The next major upgrade we are working on is whisper with speaker diarization, so
* [x] Incorporating speaker diarization
-* [ ] Improve diarization (word level)
+* [x] Inference speedup with batch processing
+
+* [ ] Improve diarization (word level). *Harder than first thought...*
-* [ ] Inference speedup with batch processing
-Contact maxbain[at]robots[dot]ox[dot]ac[dot]uk for business things.
+Contact maxbain[at]robots[dot]ox[dot]ac[dot]uk for queries
Acknowledgements 🙏
-Of course, this is mostly just a modification to [openAI's whisper](https://github.com/openai/whisper).
-As well as accreditation to this [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
+This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and University of Oxford.
+
+
+
+Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
+And borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
Citation
-If you use this in your research, just cite the repo,
+If you use this in your research, for now just cite the repo,
```bibtex
@misc{bain2022whisperx,
- author = {Bain, Max},
+ author = {Bain, Max and Han, Tengda},
title = {WhisperX},
year = {2022},
publisher = {GitHub},
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index 44df48ab..c272f18e 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -585,10 +585,9 @@ def cli():
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
# vad params
parser.add_argument("--vad_filter", action="store_true", help="Whether to first perform VAD filtering to target only transcribe within VAD. Produces more accurate alignment + timestamp, requires more GPU memory & compute.")
- parser.add_argument("--vad_input", default=None, type=str)
parser.add_argument("--parallel_bs", default=-1, type=int, help="Enable parallel transcribing if > 1")
# diarization params
- parser.add_argument("--diarize", action='store_true')
+ parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int)
parser.add_argument("--max_speakers", default=None, type=int)
# output save params
@@ -632,7 +631,6 @@ def cli():
hf_token: str = args.pop("hf_token")
vad_filter: bool = args.pop("vad_filter")
- vad_input: bool = args.pop("vad_input")
parallel_bs: int = args.pop("parallel_bs")
diarize: bool = args.pop("diarize")
@@ -640,9 +638,9 @@ def cli():
max_speakers: int = args.pop("max_speakers")
vad_pipeline = None
- if vad_input is not None:
- vad_input = pd.read_csv(vad_input, header=None, sep= " ")
- elif vad_filter:
+ if vad_filter:
+ if hf_token is None:
+ print("Warning, no huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model...")
from pyannote.audio import Inference
vad_pipeline = Inference("pyannote/segmentation",
pre_aggregation_hook=lambda segmentation: segmentation,
@@ -650,6 +648,8 @@ def cli():
diarize_pipeline = None
if diarize:
+ if hf_token is None:
+ print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
from pyannote.audio import Pipeline
diarize_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
use_auth_token=hf_token)
@@ -756,7 +756,7 @@ def cli():
# save word tsv
if output_type in ["vad"]:
exp_fp = os.path.join(output_dir, audio_basename + ".sad")
- wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])
+ wrd_segs = pd.concat([x["word-segments"] for x in result_aligned["segments"]])[['start','end']]
wrd_segs.to_csv(exp_fp, sep='\t', header=None, index=False)
if __name__ == "__main__":
cli()
diff --git a/whisperx/utils.py b/whisperx/utils.py
index 86d40633..805db26f 100644
--- a/whisperx/utils.py
+++ b/whisperx/utils.py
@@ -65,8 +65,8 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
def write_tsv(transcript: Iterator[dict], file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in transcript:
- print(round(1000 * segment['start']), file=file, end="\t")
- print(round(1000 * segment['end']), file=file, end="\t")
+ print(segment['start'], file=file, end="\t")
+ print(segment['end'], file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
diff --git a/whisperx/vad.py b/whisperx/vad.py
index eb8bd2ce..2932a137 100644
--- a/whisperx/vad.py
+++ b/whisperx/vad.py
@@ -137,8 +137,6 @@ def __call__(self, scores: SlidingWindowFeature) -> Annotation:
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
- # because of padding, some active regions might be overlapping: merge them.
- # also: fill same speaker gaps shorter than min_duration_off
active = Annotation()
for k, vad_t in enumerate(vad_arr):
@@ -161,16 +159,27 @@ def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_
if __name__ == "__main__":
- from pyannote.audio import Inference
- hook = lambda segmentation: segmentation
- inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook)
- audio = "/tmp/11962.wav"
- scores = inference(audio)
- binarize = Binarize(max_duration=15)
- anno = binarize(scores)
- res = []
- for ann in anno.get_timeline():
- res.append((ann.start, ann.end))
-
- res = pd.DataFrame(res)
- res[2] = res[1] - res[0]
\ No newline at end of file
+ # from pyannote.audio import Inference
+ # hook = lambda segmentation: segmentation
+ # inference = Inference("pyannote/segmentation", pre_aggregation_hook=hook)
+ # audio = "/tmp/11962.wav"
+ # scores = inference(audio)
+ # binarize = Binarize(max_duration=15)
+ # anno = binarize(scores)
+ # res = []
+ # for ann in anno.get_timeline():
+ # res.append((ann.start, ann.end))
+
+ # res = pd.DataFrame(res)
+ # res[2] = res[1] - res[0]
+ import pandas as pd
+ input_fp = "tt298650_sync.wav"
+ df = pd.read_csv(f"/work/maxbain/tmp/{input_fp}.sad", sep=" ", header=None)
+ print(len(df))
+ N = 0.15
+ g = df[0].sub(df[1].shift())
+ input_base = input_fp.split('.')[0]
+ df = df.groupby(g.gt(N).cumsum()).agg({0:'min', 1:'max'})
+ df.to_csv(f"/work/maxbain/tmp/{input_base}.lab", header=None, index=False, sep=" ")
+ print(df)
+ import pdb; pdb.set_trace()
\ No newline at end of file