Skip to content

Commit

Permalink
Nick/matlab support (#360)
Browse files Browse the repository at this point in the history
* Internal: Try new link checker

* Internal: Add codespell and fix typos.

* Internal: See if codespell precommit finds config.

* Internal: Found config. Now enable reading it

* MATLAB: Add initial support for more matlab support.

Closes #350
  • Loading branch information
ntjohnson1 authored Jan 6, 2025
1 parent c3249ad commit 81e5217
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyttb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pyttb.import_data import import_data
from pyttb.khatrirao import khatrirao
from pyttb.ktensor import ktensor
from pyttb.matlab import matlab_support
from pyttb.sptenmat import sptenmat
from pyttb.sptensor import sptendiag, sptenrand, sptensor
from pyttb.sptensor3 import sptensor3
Expand Down Expand Up @@ -51,6 +52,7 @@ def ignore_warnings(ignore=True):
import_data.__name__,
khatrirao.__name__,
ktensor.__name__,
matlab_support.__name__,
sptenmat.__name__,
sptendiag.__name__,
sptenrand.__name__,
Expand Down
1 change: 1 addition & 0 deletions pyttb/matlab/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Partial support of MATLAB users in PYTTB."""
38 changes: 38 additions & 0 deletions pyttb/matlab/matlab_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""A limited number of utilities to support users coming from MATLAB."""

# Copyright 2024 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.

from typing import Optional, Union

import numpy as np

from pyttb.tensor import tensor

from .matlab_utilities import _matlab_array_str

PRINT_CLASSES = Union[tensor, np.ndarray]


def matlab_print(
data: Union[tensor, np.ndarray],
format: Optional[str] = None,
name: Optional[str] = None,
):
"""Print data in a format more similar to MATLAB.
Arguments
---------
data: Object to print
format: Numerical formatting
"""
if not isinstance(data, (tensor, np.ndarray)):
raise ValueError(
f"matlab_print only supports inputs of type {PRINT_CLASSES} but got"
f" {type(data)}."
)
if isinstance(data, np.ndarray):
print(_matlab_array_str(data, format, name))
return
print(data._matlab_str(format, name))
68 changes: 68 additions & 0 deletions pyttb/matlab/matlab_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Internal tools to aid in building MATLAB support.
Tensor classes can use these common tools, where matlab_support uses tensors.
matlab_support can depend on this, but tensors and this shouldn't depend on it.
Probably best for everything here to be private functions.
"""

# Copyright 2024 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.

import textwrap
from typing import Optional, Tuple, Union

import numpy as np


def _matlab_array_str(
array: np.ndarray,
format: Optional[str] = None,
name: Optional[str] = None,
skip_name: bool = False,
) -> str:
"""Convert numpy array to string more similar to MATLAB."""
if name is None:
name = type(array).__name__
header_str = ""
body_str = ""
if len(array.shape) > 2:
matlab_str = ""
# Iterate over all possible slices (in Fortran order)
for index in np.ndindex(
array.shape[2:][::-1]
): # Skip the first two dimensions and reverse the order
original_index = index[::-1] # Reverse the order back to the original
# Construct the slice indices
slice_indices: Tuple[Union[int, slice], ...] = (
slice(None),
slice(None),
*original_index,
)
slice_data = array[slice_indices]
matlab_str += f"{name}(:,:, {', '.join(map(str, original_index))}) ="
matlab_str += "\n"
array_str = _matlab_array_str(slice_data, format, name, skip_name=True)
matlab_str += textwrap.indent(array_str, "\t")
matlab_str += "\n"
return matlab_str[:-1] # Trim extra newline
elif len(array.shape) == 2:
header_str += f"{name}(:,:) ="
for row in array:
if format is None:
body_str += " ".join(f"{val}" for val in row)
else:
body_str += " ".join(f"{val:{format}}" for val in row)
body_str += "\n"
else:
header_str += f"{name}(:) ="
for val in array:
if format is None:
body_str += f"{val}"
else:
body_str += f"{val:{format}}"
body_str += "\n"

if skip_name:
return body_str
return header_str + "\n" + textwrap.indent(body_str[:-1], "\t")
19 changes: 19 additions & 0 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import logging
import textwrap
from collections.abc import Iterable
from inspect import signature
from itertools import combinations_with_replacement, permutations
Expand All @@ -30,6 +31,7 @@
from scipy import sparse

import pyttb as ttb
from pyttb.matlab.matlab_utilities import _matlab_array_str
from pyttb.pyttb_utils import (
IndexVariant,
OneDArray,
Expand Down Expand Up @@ -2723,6 +2725,23 @@ def __repr__(self):

__str__ = __repr__

def _matlab_str(
self, format: Optional[str] = None, name: Optional[str] = None
) -> str:
"""Non-standard representation to be more similar to MATLAB."""
header = name
if name is None:
name = "data"
if header is None:
header = "This"

matlab_str = f"{header} is a tensor of shape " + " x ".join(
map(str, self.shape)
)

array_str = _matlab_array_str(self.data, format, name)
return matlab_str + "\n" + textwrap.indent(array_str, "\t")


def tenones(shape: Shape, order: Union[Literal["F"], Literal["C"]] = "F") -> tensor:
"""Create a tensor of all ones.
Expand Down
44 changes: 44 additions & 0 deletions tests/matlab/test_matlab_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2024 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.

import numpy as np
import pytest

from pyttb import matlab_support, tensor


def test_matlab_printing_negative():
with pytest.raises(ValueError):
matlab_support.matlab_print("foo")


def test_np_printing():
"""These are just smoke tests since formatting needs manual style verification."""
# Check different dimensionality support
one_d_array = np.ones((1,))
matlab_support.matlab_print(one_d_array)
two_d_array = np.ones((1, 1))
matlab_support.matlab_print(two_d_array)
three_d_array = np.ones((1, 1, 1))
matlab_support.matlab_print(three_d_array)

# Check name and format
matlab_support.matlab_print(one_d_array, format="5.1f", name="X")
matlab_support.matlab_print(two_d_array, format="5.1f", name="X")
matlab_support.matlab_print(three_d_array, format="5.1f", name="X")


def test_dense_printing():
"""These are just smoke tests since formatting needs manual style verification."""
# Check different dimensionality support
example = tensor(np.arange(16), shape=(2, 2, 2, 2))
# 4D
matlab_support.matlab_print(example)
# 2D
matlab_support.matlab_print(example[:, :, 0, 0])
# 1D
matlab_support.matlab_print(example[:, 0, 0, 0])

# Check name and format
matlab_support.matlab_print(example, format="5.1f", name="X")

0 comments on commit 81e5217

Please sign in to comment.