Skip to content

Commit

Permalink
Added the 'load_target_platform_model' ahd export_target_platform_mod…
Browse files Browse the repository at this point in the history
…el functions to import and export TargetPlatformModel instances to JSON files with robust validation and error handling. It also includes comprehensive tests for both the 'export_target_platform_model' and 'load_target_platform_model' functions, covering valid use cases, edge cases, and error scenarios.
  • Loading branch information
liord committed Jan 5, 2025
1 parent afed6e3 commit 1d72b7b
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from pathlib import Path
from typing import Union

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
import json


def load_target_platform_model(tp_model_or_path: Union[TargetPlatformModel, str]) -> TargetPlatformModel:
"""
Parses the tp_model input, which can be either a TargetPlatformModel object
or a string path to a JSON file.
Parameters:
tp_model_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file.
Returns:
TargetPlatformModel: The parsed TargetPlatformModel.
Raises:
FileNotFoundError: If the JSON file does not exist.
ValueError: If the JSON content is invalid or cannot initialize the TargetPlatformModel.
TypeError: If the input is neither a TargetPlatformModel nor a valid JSON file path.
"""
if isinstance(tp_model_or_path, TargetPlatformModel):
return tp_model_or_path

if isinstance(tp_model_or_path, str):
path = Path(tp_model_or_path)

if not path.exists() or not path.is_file():
raise FileNotFoundError(f"The path '{tp_model_or_path}' is not a valid file.")
# Verify that the file has a .json extension
if path.suffix.lower() != '.json':
raise ValueError(f"The file '{path}' does not have a '.json' extension.")
try:
with path.open('r', encoding='utf-8') as file:
data = file.read()
except OSError as e:
raise ValueError(f"Error reading the file '{tp_model_or_path}': {e.strerror}.") from e

try:
return TargetPlatformModel.parse_raw(data)
except ValueError as e:
raise ValueError(f"Invalid JSON for loading TargetPlatformModel in '{tp_model_or_path}': {e}.") from e
except Exception as e:
raise ValueError(f"Unexpected error while initializing TargetPlatformModel: {e}.") from e

raise TypeError(
f"tp_model_or_path must be either a TargetPlatformModel instance or a string path to a JSON file, "
f"but received type '{type(tp_model_or_path).__name__}'."
)


def export_target_platform_model(model: TargetPlatformModel, export_path: Union[str, Path]) -> None:
"""
Exports a TargetPlatformModel instance to a JSON file.
Parameters:
model (TargetPlatformModel): The TargetPlatformModel instance to export.
export_path (Union[str, Path]): The file path to export the model to.
Raises:
ValueError: If the model is not an instance of TargetPlatformModel.
OSError: If there is an issue writing to the file.
"""
if not isinstance(model, TargetPlatformModel):
raise ValueError("The provided model is not a valid TargetPlatformModel instance.")

path = Path(export_path)
try:
# Ensure the parent directory exists
path.parent.mkdir(parents=True, exist_ok=True)

# Export the model to JSON and write to the file
with path.open('w', encoding='utf-8') as file:
file.write(model.json(indent=4))
except OSError as e:
raise OSError(f"Failed to write to file '{export_path}': {e.strerror}") from e
145 changes: 115 additions & 30 deletions tests/common_tests/test_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \
get_config_options_by_operators_set, is_opset_in_model
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_model, \
export_target_platform_model
from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, generate_test_op_qc

tp = mct.target_platform
Expand All @@ -30,42 +32,125 @@
TEST_QCO = schema.QuantizationConfigOptions(quantization_configurations=tuple([TEST_QC]))


class TargetPlatformModelingTest(unittest.TestCase):
def cleanup_file(self, file_path):
if os.path.exists(file_path):
os.remove(file_path)
print(f"Cleaned up: {file_path}")
class TPModelInputOutputTests(unittest.TestCase):

def test_dump_to_json(self):
def setUp(self):
# Setup reusable resources or configurations for tests
self.valid_export_path = "exported_model.json"
self.invalid_export_path = "/invalid/path/exported_model.json"
self.invalid_json_content = '{"field1": "value1", "field2": ' # Incomplete JSON
self.invalid_json_file = "invalid_model.json"
self.nonexistent_file = "nonexistent.json"
op1 = schema.OperatorsSet(name="opset1")
op2 = schema.OperatorsSet(name="opset2")
op3 = schema.OperatorsSet(name="opset3")
op12 = schema.OperatorSetConcat(operators_set=[op1, op2])
model = schema.TargetPlatformModel(default_qco=TEST_QCO,
operator_set=(op1, op2, op3),
fusing_patterns=(schema.Fusing(operator_groups=(op12, op3)),
schema.Fusing(operator_groups=(op1, op2))),
tpc_minor_version=1,
tpc_patch_version=0,
tpc_platform_type="dump_to_json",
add_metadata=False)
json_str = model.json()
# Define the output file path
file_path = "target_platform_model.json"
# Register cleanup to delete the file if it exists
self.addCleanup(self.cleanup_file, file_path)

# Write the JSON string to the file
with open(file_path, "w") as f:
f.write(json_str)

with open(file_path, "r") as f:
json_content = f.read()

loaded_target_model = schema.TargetPlatformModel.parse_raw(json_content)
self.assertEqual(model, loaded_target_model)

self.tp_model = schema.TargetPlatformModel(default_qco=TEST_QCO,
operator_set=(op1, op2, op3),
fusing_patterns=(schema.Fusing(operator_groups=(op12, op3)),
schema.Fusing(operator_groups=(op1, op2))),
tpc_minor_version=1,
tpc_patch_version=0,
tpc_platform_type="dump_to_json",
add_metadata=False)

# Create invalid JSON file
with open(self.invalid_json_file, "w") as file:
file.write(self.invalid_json_content)

def tearDown(self):
# Cleanup files created during tests
for file in [self.valid_export_path, self.invalid_json_file]:
if os.path.exists(file):
os.remove(file)

def test_valid_model_object(self):
"""Test that a valid TargetPlatformModel object is returned unchanged."""
result = load_target_platform_model(self.tp_model)
self.assertEqual(self.tp_model, result)

def test_invalid_json_parsing(self):
"""Test that invalid JSON content raises a ValueError."""
with self.assertRaises(ValueError) as context:
load_target_platform_model(self.invalid_json_file)
self.assertIn("Invalid JSON for loading TargetPlatformModel in", str(context.exception))

def test_nonexistent_file(self):
"""Test that a nonexistent file raises FileNotFoundError."""
with self.assertRaises(FileNotFoundError) as context:
load_target_platform_model(self.nonexistent_file)
self.assertIn("is not a valid file", str(context.exception))

def test_non_json_extension(self):
"""Test that a file with a non-JSON extension raises ValueError."""
non_json_file = "test_model.txt"
try:
with open(non_json_file, "w") as file:
file.write(self.invalid_json_content)
with self.assertRaises(ValueError) as context:
load_target_platform_model(non_json_file)
self.assertIn("does not have a '.json' extension", str(context.exception))
finally:
os.remove(non_json_file)

def test_invalid_input_type(self):
"""Test that an unsupported input type raises TypeError."""
invalid_input = 123 # Not a string or TargetPlatformModel
with self.assertRaises(TypeError) as context:
load_target_platform_model(invalid_input)
self.assertIn("must be either a TargetPlatformModel instance or a string path", str(context.exception))

def test_valid_export(self):
"""Test exporting a valid TargetPlatformModel instance to a file."""
export_target_platform_model(self.tp_model, self.valid_export_path)
# Verify the file exists
self.assertTrue(os.path.exists(self.valid_export_path))

# Verify the contents match the model's JSON representation
with open(self.valid_export_path, "r", encoding="utf-8") as file:
content = file.read()
self.assertEqual(content, self.tp_model.json(indent=4))

def test_export_with_invalid_model(self):
"""Test that exporting an invalid model raises a ValueError."""
with self.assertRaises(ValueError) as context:
export_target_platform_model("not_a_model", self.valid_export_path)
self.assertIn("not a valid TargetPlatformModel instance", str(context.exception))

def test_export_with_invalid_path(self):
"""Test that exporting to an invalid path raises an OSError."""
with self.assertRaises(OSError) as context:
export_target_platform_model(self.tp_model, self.invalid_export_path)
self.assertIn("Failed to write to file", str(context.exception))

def test_export_creates_parent_directories(self):
"""Test that exporting creates missing parent directories."""
nested_path = "nested/directory/exported_model.json"
try:
export_target_platform_model(self.tp_model, nested_path)
# Verify the file exists
self.assertTrue(os.path.exists(nested_path))

# Verify the contents match the model's JSON representation
with open(nested_path, "r", encoding="utf-8") as file:
content = file.read()
self.assertEqual(content, self.tp_model.json(indent=4))
finally:
# Cleanup created directories
if os.path.exists(nested_path):
os.remove(nested_path)
if os.path.exists("nested/directory"):
os.rmdir("nested/directory")
if os.path.exists("nested"):
os.rmdir("nested")

def test_export_then_import(self):
"""Test that a model exported and then imported is identical."""
export_target_platform_model(self.tp_model, self.valid_export_path)
imported_model = load_target_platform_model(self.valid_export_path)
self.assertEqual(self.tp_model, imported_model)

class TargetPlatformModelingTest(unittest.TestCase):
def test_immutable_tp(self):

with self.assertRaises(Exception) as e:
Expand Down

0 comments on commit 1d72b7b

Please sign in to comment.