diff --git a/TrainingExtensions/torch/src/python/aimet_torch/blockwise_quant_tensor_split.py b/TrainingExtensions/torch/src/python/aimet_torch/blockwise_quant_tensor_split.py
new file mode 100644
index 00000000000..813d18c8ad1
--- /dev/null
+++ b/TrainingExtensions/torch/src/python/aimet_torch/blockwise_quant_tensor_split.py
@@ -0,0 +1,117 @@
+# -*- mode: python -*-
+# =============================================================================
+#  @@-COPYRIGHT-START-@@
+#
+#  Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
+#
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions are met:
+#
+#  1. Redistributions of source code must retain the above copyright notice,
+#     this list of conditions and the following disclaimer.
+#
+#  2. Redistributions in binary form must reproduce the above copyright notice,
+#     this list of conditions and the following disclaimer in the documentation
+#     and/or other materials provided with the distribution.
+#
+#  3. Neither the name of the copyright holder nor the names of its contributors
+#     may be used to endorse or promote products derived from this software
+#     without specific prior written permission.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+#  POSSIBILITY OF SUCH DAMAGE.
+#
+#  SPDX-License-Identifier: BSD-3-Clause
+#
+#  @@-COPYRIGHT-END-@@
+# =============================================================================
+""" Utilities for implementing blockwise quantization using tensor splitting approach """
+
+import torch
+from aimet_torch import elementwise_ops
+from aimet_torch.quantsim import QuantizationSimModel
+
+
+class BlockwiseLinear(torch.nn.Module):
+    """
+    Blockwise Linear implementation.
+    This module replaces a single nn.Linear module, breaking it into a number of smaller linear modules depending on
+    block size. Each separate linear module can operate with per channel quantization independently, and the outputs
+    of each linear module are summed up.
+    """
+    def __init__(self, linear_module: torch.nn.Linear, block_size: int):
+        super(BlockwiseLinear, self).__init__()
+        self.block_size = block_size
+        self.linears = torch.nn.ModuleList()
+        self.elementwise_adds = torch.nn.ModuleList()
+        split_indices = list(range(block_size, linear_module.weight.shape[1], block_size))
+        self.split = elementwise_ops.Split()
+        split_weights = torch.tensor_split(linear_module.weight, split_indices, 1)
+        for idx, split_weight in enumerate(split_weights):
+            linear = torch.nn.Linear(split_weight.shape[1],
+                                     split_weight.shape[0],
+                                     bias=(linear_module.bias is not None and idx == 0))
+            linear.weight = torch.nn.Parameter(split_weight)
+            if linear.bias is not None:
+                linear.bias = linear_module.bias
+            self.linears.append(linear)
+            self.elementwise_adds.append(elementwise_ops.Add())
+        self.elementwise_adds = self.elementwise_adds[:-1]
+        if not self.elementwise_adds:
+            self.elementwise_adds = None
+
+    def forward(self, inp):
+        """ Forward pass """
+        if len(self.linears) == 1:
+            return self.linears[0](inp)
+
+        split_inputs = self.split(inp, self.block_size, -1)
+        out = None
+        for idx, split_input in enumerate(split_inputs):
+            linear_out = self.linears[idx](split_input)
+            if out is None:
+                out = linear_out
+            else:
+                out = self.elementwise_adds[idx-1](out, linear_out)
+        return out
+
+
+def replace_linears_for_blockwise_quant(model: torch.nn.Module, block_size: int):
+    """
+    Replace all instances of torch.nn.Linears in model with equivalent BlockwiseLinear modules.
+    The linear weights are split on the input channel dimension such that all constituent linear modules have weights
+    with input channel dimension <= block_size.
+
+    :param model: Model to replace nn.Linears for
+    :param block_size: Block size to use
+    """
+    for name, module in model.named_modules():
+        if isinstance(module, torch.nn.Linear):
+            setattr(model, name, BlockwiseLinear(module, block_size))
+
+
+def tie_blockwise_linear_quantizers(quantsim: QuantizationSimModel):
+    """
+    Tie all output quantizers within a BlockwiseLinear block together so they share the same encoding. In other words,
+    all output quantizers of constituent linear layers as well as output quantizers of elementwise add modules will
+    share the same quantization parameters.
+
+    :param quantsim: Quantsim model containing BlockwiseLinear modules to tie output quantizers for.
+    """
+    for module in quantsim.model.modules():
+        if isinstance(module, BlockwiseLinear):
+            output_quantizer = module.linears[0].output_quantizers[0]
+            for linear in module.linears:
+                linear.output_quantizers[0] = output_quantizer
+            if module.elementwise_adds is not None:
+                for add in module.elementwise_adds:
+                    add.output_quantizers[0] = output_quantizer
diff --git a/TrainingExtensions/torch/test/python/test_blockwise_quant_tensor_split.py b/TrainingExtensions/torch/test/python/test_blockwise_quant_tensor_split.py
new file mode 100644
index 00000000000..2b222c5dea3
--- /dev/null
+++ b/TrainingExtensions/torch/test/python/test_blockwise_quant_tensor_split.py
@@ -0,0 +1,188 @@
+# -*- mode: python -*-
+# =============================================================================
+#  @@-COPYRIGHT-START-@@
+#
+#  Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
+#
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions are met:
+#
+#  1. Redistributions of source code must retain the above copyright notice,
+#     this list of conditions and the following disclaimer.
+#
+#  2. Redistributions in binary form must reproduce the above copyright notice,
+#     this list of conditions and the following disclaimer in the documentation
+#     and/or other materials provided with the distribution.
+#
+#  3. Neither the name of the copyright holder nor the names of its contributors
+#     may be used to endorse or promote products derived from this software
+#     without specific prior written permission.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+#  POSSIBILITY OF SUCH DAMAGE.
+#
+#  SPDX-License-Identifier: BSD-3-Clause
+#
+#  @@-COPYRIGHT-END-@@
+# =============================================================================
+""" Tests for blockwise quant tensor split utility """
+
+import json
+import pytest
+import torch
+from aimet_torch.quantsim import QuantizationSimModel
+from aimet_torch.qc_quantize_op import QcQuantizeWrapper
+from aimet_torch.blockwise_quant_tensor_split import (BlockwiseLinear, replace_linears_for_blockwise_quant,
+                                                      tie_blockwise_linear_quantizers)
+
+class LinearModel(torch.nn.Module):
+    def __init__(self):
+        super(LinearModel, self).__init__()
+        self.linear1 = torch.nn.Linear(8, 3)
+        self.relu1 = torch.nn.ReLU()
+        self.linear2 = torch.nn.Linear(3, 2, bias=False)
+        self.softmax = torch.nn.Softmax()
+
+    def forward(self, inp):
+        x = self.linear1(inp)
+        x = self.relu1(x)
+        x = self.linear2(x)
+        x = self.softmax(x)
+        return x
+
+
+@pytest.mark.parametrize('model, dummy_input, block_size', [(torch.nn.Linear(8, 3), torch.randn(1, 8), 3),
+                                                             (torch.nn.Linear(8, 3, bias=False), torch.randn(1, 8), 3),
+                                                             (torch.nn.Linear(3, 2), torch.randn(1, 3), 3),
+                                                             (torch.nn.Linear(3, 2), torch.randn(1, 3), 4)])
+def test_blockwise_linears(model, dummy_input, block_size):
+    blockwise_linear = BlockwiseLinear(model, block_size=block_size)
+    orig_out = model(dummy_input)
+    new_out = blockwise_linear(dummy_input)
+    assert torch.allclose(orig_out, new_out, atol=1e-6)
+
+def test_replace_linears_for_blockwise_quant():
+    dummy_input = torch.randn(1, 8)
+    model = LinearModel()
+    linear1 = model.linear1
+    orig_out = model(dummy_input)
+    replace_linears_for_blockwise_quant(model, 3)
+
+    assert len(model.linear1.linears) == 3
+    assert torch.equal(model.linear1.linears[0].bias, linear1.bias)
+    for linear in model.linear1.linears[1:]:
+        assert linear.bias is None
+
+    assert len(model.linear1.elementwise_adds) == 2
+    assert len(model.linear2.linears) == 1
+    assert model.linear2.elementwise_adds is None
+    new_out = model(dummy_input)
+    assert torch.allclose(orig_out, new_out, atol=1e-6)
+
+def test_quantize_blockwise_linear():
+    quantsim_config = {
+        "defaults": {
+            "ops": {
+                "is_output_quantized": "True",
+                "is_symmetric": "False"
+            },
+            "params": {
+                "is_quantized": "False",
+                "is_symmetric": "True"
+            },
+            "per_channel_quantization": "True",
+        },
+        "params": {},
+        "op_type": {
+            'Split': {
+                'is_output_quantized': False
+            }
+        },
+        "supergroups": [],
+        "model_input": {},
+        "model_output": {}
+    }
+    with open('./data/quantsim_config.json', 'w') as f:
+        json.dump(quantsim_config, f)
+    dummy_input = torch.randn(1, 8)
+    model = BlockwiseLinear(torch.nn.Linear(8, 3), 3)
+    qsim = QuantizationSimModel(model, dummy_input=dummy_input)
+
+    assert isinstance(qsim.model.split, QcQuantizeWrapper)
+    assert isinstance(qsim.model.linears[0], QcQuantizeWrapper)
+    assert isinstance(qsim.model.linears[1], QcQuantizeWrapper)
+    assert isinstance(qsim.model.linears[2], QcQuantizeWrapper)
+    assert isinstance(qsim.model.elementwise_adds[0], QcQuantizeWrapper)
+    assert isinstance(qsim.model.elementwise_adds[1], QcQuantizeWrapper)
+
+    # Temporary hack to disable split op output quantizers while handling for CG split op is reworked
+    for output_quantizer in qsim.model.split.output_quantizers:
+        output_quantizer.enabled = False
+
+    # Temporary hack to enable model input split op input quantizer while handling for CG split op is reworked
+    qsim.model.split.input_quantizers[0].enabled = True
+
+    tie_blockwise_linear_quantizers(qsim)
+    qsim.compute_encodings(lambda m, _: m(dummy_input), None)
+    _ = qsim.model(dummy_input)
+
+    assert (qsim.model.linears[0].output_quantizers[0].encoding.max ==
+            qsim.model.linears[1].output_quantizers[0].encoding.max)
+    assert (qsim.model.linears[0].output_quantizers[0].encoding.max ==
+            qsim.model.linears[2].output_quantizers[0].encoding.max)
+    assert (qsim.model.linears[0].output_quantizers[0].encoding.max ==
+            qsim.model.elementwise_adds[0].output_quantizers[0].encoding.max)
+    assert (qsim.model.linears[0].output_quantizers[0].encoding.max ==
+            qsim.model.elementwise_adds[1].output_quantizers[0].encoding.max)
+
+def test_blockwise_quant_with_small_linear():
+    quantsim_config = {
+        "defaults": {
+            "ops": {
+                "is_output_quantized": "True",
+                "is_symmetric": "False"
+            },
+            "params": {
+                "is_quantized": "False",
+                "is_symmetric": "True"
+            },
+            "per_channel_quantization": "True",
+        },
+        "params": {},
+        "op_type": {
+            'Split': {
+                'is_output_quantized': False
+            }
+        },
+        "supergroups": [],
+        "model_input": {},
+        "model_output": {}
+    }
+    with open('./data/quantsim_config.json', 'w') as f:
+        json.dump(quantsim_config, f)
+    dummy_input = torch.randn(1, 3)
+    model = BlockwiseLinear(torch.nn.Linear(3, 2), 3)
+    qsim = QuantizationSimModel(model, dummy_input=dummy_input)
+    # Temporary hack to disable split op output quantizers while handling for CG split op is reworked
+    for output_quantizer in qsim.model.split.output_quantizers:
+        output_quantizer.enabled = False
+
+    # Temporary hack to enable model input split op input quantizer while handling for CG split op is reworked
+    qsim.model.split.input_quantizers[0].enabled = True
+
+    tie_blockwise_linear_quantizers(qsim)
+    qsim.compute_encodings(lambda m, _: m(dummy_input), None)
+    assert len(qsim.model.linears) == 1
+    assert qsim.model.elementwise_adds is None
+
+    _ = qsim.model(dummy_input)
+    assert len(qsim.connected_graph.get_all_ops()) == 1