Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error: 'NoneType' object is not iterable #2652

Open
2 of 4 tasks
Whisht opened this issue Jan 3, 2025 · 3 comments
Labels
bug Something isn't working triaged Issue has been triaged by maintainers

Comments

@Whisht
Copy link

Whisht commented Jan 3, 2025

System Info

  • CPU architecture: X86
  • CPU/Host memory size: 500GB
  • GPU properties
  • GPU name: Nvidia V100
  • -GPU memory size: 32GB * 2
    -Libraries
    -TensorRT-LLM branch : tag 0.16.0
    -Versions of CUDA:12.1
    -Container used : nvcr.io/nvidia/tritonserver:24.12-trtllm-python-py3
    -NVIDlA driver version: 530.30.02
  • OS: Ubuntu 24.04.1 LTS

Who can help?

@Tracin @kaiyux @byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Download triton-server nvcr.io/nvidia/tritonserver:24.12-trtllm-python-py3.
  2. Download model weight Qwen2.5-32B-Instruct-GPTQ-Int4.
  3. run convert_checkpoint.py with python3 convert_checkpoint.py --model_dir /root/models/Qwen2.5-32B-Instruct-GPTQ-Int4/ --output_dir /root/checkpoint/qwen2.5 --dtype float16 --use_weight_only --weight_only_precision int4_gptq --per_group --tp_size 2

Expected behavior

conver success.

actual behavior

[TensorRT-LLM] TensorRT-LLM version: 0.16.0
0.16.0
5it [00:01,  4.17it/s]
Traceback (most recent call last):
  File "/root/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 335, in <module>
    main()
  File "/root/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 327, in main
    convert_and_save_hf(args)
  File "/root/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 283, in convert_and_save_hf
    execute(args.workers, [convert_and_save_rank] * world_size, args)
  File "/root/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 290, in execute
    f(args, rank)
  File "/root/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 275, in convert_and_save_rank
    qwen = QWenForCausalLM.from_hugging_face(model_dir,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/qwen/model.py", line 438, in from_hugging_face
    loader.generate_tllm_weights(model, arg_dict)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/model_weights_loader.py", line 391, in generate_tllm_weights
    self.load(tllm_key,
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/model_weights_loader.py", line 305, in load
    v = sub_module.postprocess(tllm_key, v, **postprocess_kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/quantization/layers.py", line 1067, in postprocess
    return postprocess_weight_only_groupwise(tllm_key, weights, torch_dtype,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/quantization/functional.py", line 978, in postprocess_weight_only_groupwise
    qweight = torch.ops.trtllm.preprocess_weights_for_mixed_gemm(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1120, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: [TensorRT-LLM][ERROR] Assertion failed: Unsupported Arch (/workspace/tensorrt_llm/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp:138)
1       0x7fcb8d7687df tensorrt_llm::common::throwRuntimeError(char const*, int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 95
2       0x7fcd6862b771 tensorrt_llm::kernels::cutlass_kernels::getLayoutDetailsForTransform(tensorrt_llm::kernels::cutlass_kernels::QuantType, int) + 321
3       0x7fcd6862bc60 tensorrt_llm::kernels::cutlass_kernels::preprocess_weights_for_mixed_gemm(signed char*, signed char const*, std::vector<unsigned long, std::allocator<unsigned long> > const&, tensorrt_llm::kernels::cutlass_kernels::QuantType, bool) + 240
4       0x7fcd68622964 torch_ext::preprocess_weights_for_mixed_gemm(at::Tensor, c10::ScalarType, c10::ScalarType) + 644
5       0x7fcd686259d3 c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor, c10::ScalarType, c10::ScalarType), at::Tensor, c10::guts::typelist::typelist<at::Tensor, c10::ScalarType, c10::ScalarType> >, true>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) + 147
6       0x7fce3c790677 /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so(+0x5312677) [0x7fce3c790677]
7       0x7fce447fcb7b torch::jit::invokeOperatorFromPython(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, pybind11::args const&, pybind11::kwargs const&, std::optional<c10::DispatchKey>) + 251
8       0x7fce447fcfcd torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args const&, pybind11::kwargs const&, bool, std::optional<c10::DispatchKey>) + 557
9       0x7fce446e0847 /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so(+0x8da847) [0x7fce446e0847]
10      0x7fce442773d5 /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so(+0x4713d5) [0x7fce442773d5]
11            0x5820ff python3() [0x5820ff]
12            0x54b07c PyObject_Call + 156
13            0x5db68a _PyEval_EvalFrameDefault + 19514
14            0x54a712 _PyObject_Call_Prepend + 194
15            0x5a3698 python3() [0x5a3698]
16            0x548ec5 _PyObject_MakeTpCall + 117
17            0x5d74d9 _PyEval_EvalFrameDefault + 2697
18            0x54cae4 python3() [0x54cae4]
19            0x54b0f9 PyObject_Call + 281
20            0x5db68a _PyEval_EvalFrameDefault + 19514
21            0x54cae4 python3() [0x54cae4]
22            0x54b0f9 PyObject_Call + 281
23            0x5db68a _PyEval_EvalFrameDefault + 19514
24            0x5d59fb PyEval_EvalCode + 347
25            0x608b52 python3() [0x608b52]
26            0x6b4d83 python3() [0x6b4d83]
27            0x6b4aea _PyRun_SimpleFileObject + 426
28            0x6b491f _PyRun_AnyFileObject + 79
29            0x6bc985 Py_RunMain + 949
30            0x6bc46d Py_BytesMain + 45
31      0x7fce523c91ca /usr/lib/x86_64-linux-gnu/libc.so.6(+0x2a1ca) [0x7fce523c91ca]
32      0x7fce523c928b __libc_start_main + 139
33            0x657c15 _start + 37
Exception ignored in: <function PretrainedModel.__del__ at 0x7fcb39f23b00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/modeling_utils.py", line 607, in __del__
    self.release()
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/modeling_utils.py", line 604, in release
    release_gc()
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_utils.py", line 533, in release_gc
    torch.cuda.ipc_collect()
  File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 968, in ipc_collect
    _lazy_init()
  File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 338, in _lazy_init
    raise DeferredCudaCallError(msg) from e
torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error: 'NoneType' object is not iterable

CUDA call was originally invoked at:

  File "/root/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 7, in <module>
    from transformers import AutoConfig
  File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1331, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 935, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 995, in exec_module
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
  File "/usr/local/lib/python3.12/dist-packages/transformers/__init__.py", line 26, in <module>
    from . import dependency_versions_check
  File "<frozen importlib._bootstrap>", line 1415, in _handle_fromlist
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
  File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1331, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 935, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 995, in exec_module
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
  File "/usr/local/lib/python3.12/dist-packages/transformers/dependency_versions_check.py", line 16, in <module>
    from .utils.versions import require_version, require_version_core
  File "<frozen importlib._bootstrap>", line 1360, in _find_and_load

additional notes

No.

@Whisht Whisht added the bug Something isn't working label Jan 3, 2025
@nv-guomingz
Copy link
Collaborator

Volta is not supported since 0.14

@nv-guomingz nv-guomingz added the triaged Issue has been triaged by maintainers label Jan 7, 2025
@Whisht
Copy link
Author

Whisht commented Jan 8, 2025

Thanks for your reply. Are there other ways that I can run Qwen2.5 with trtlllm in V100 GPU?

@Justin-12138
Copy link

@nv-guomingz Hi,I got the same questions with this command when using RTX4090

python3 /tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py --model_dir /root/7b \
                              --output_dir /root/converted/7b/f16/1gpu-int4_gptq \
                              --dtype auto \
                              --use_weight_only \
                              --tp_size 1 \
                              --weight_only_precision int4_gptq
Traceback (most recent call last):
  File "/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 335, in <module>
    main()
  File "/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 327, in main
    convert_and_save_hf(args)
  File "/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 283, in convert_and_save_hf
    execute(args.workers, [convert_and_save_rank] * world_size, args)
  File "/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 290, in execute
    f(args, rank)
  File "/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 275, in convert_and_save_rank
    qwen = QWenForCausalLM.from_hugging_face(model_dir,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/qwen/model.py", line 438, in from_hugging_face
    loader.generate_tllm_weights(model, arg_dict)
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/model_weights_loader.py", line 391, in generate_tllm_weights
    self.load(tllm_key,
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/model_weights_loader.py", line 305, in load
    v = sub_module.postprocess(tllm_key, v, **postprocess_kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/quantization/layers.py", line 1067, in postprocess
    return postprocess_weight_only_groupwise(tllm_key, weights, torch_dtype,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/quantization/functional.py", line 937, in postprocess_weight_only_groupwise
    torch.cat(weights[i::len(weights) // 3], dim=1)
TypeError: expected Tensor as element 0 in argument 0, but got NoneType
Exception ignored in: <function PretrainedModel.__del__ at 0x7f24e388fd80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/modeling_utils.py", line 607, in __del__
    self.release()
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/models/modeling_utils.py", line 604, in release
    release_gc()
  File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_utils.py", line 533, in release_gc
    torch.cuda.ipc_collect()
  File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 968, in ipc_collect
    _lazy_init()
  File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 338, in _lazy_init
    raise DeferredCudaCallError(msg) from e
torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error: 'NoneType' object is not iterable

CUDA call was originally invoked at:

  File "/tensorrtllm_backend/tensorrt_llm/examples/qwen/convert_checkpoint.py", line 7, in <module>
    from transformers import AutoConfig
  File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1331, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 935, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 995, in exec_module
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants