diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 3e06dc1c..6374fe8d 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -645,9 +645,12 @@ def cli(): if hf_token is None: print("Warning, no huggingface token used, needs to be saved in environment variable, otherwise will throw error loading VAD model...") from pyannote.audio import Inference - vad_pipeline = Inference("pyannote/segmentation", - pre_aggregation_hook=lambda segmentation: segmentation, - use_auth_token=hf_token) + vad_pipeline = Inference( + "pyannote/segmentation", + pre_aggregation_hook=lambda segmentation: segmentation, + use_auth_token=hf_token, + device=torch.device(device), + ) diarize_pipeline = None if diarize: