Skip to content

Commit

Permalink
Refactor activation Hessian computation using model builder (#863)
Browse files Browse the repository at this point in the history
Removing the old dedicated mechanism that used inference on the model's graph to compute Hessian approximation with respect to intermediate activation tensors.
Instead, the new solution utilizes the model builder to build a model from the graph and then performs the Hessian approximation computation on the model.

---------

Co-authored-by: Ofir Gordon <Ofir.Gordon@altair-semi.com>
  • Loading branch information
ofirgo and Ofir Gordon authored Nov 22, 2023
1 parent d15d9fc commit e921060
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 341 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,19 @@
# limitations under the License.
# ==============================================================================

from typing import List, Tuple, Dict, Any
from typing import List

import tensorflow as tf
from tensorflow.python.keras.engine.base_layer import Layer
from tqdm import tqdm
import numpy as np

from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, EPS, \
from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, \
HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianInfoGranularity
from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
from model_compression_toolkit.core.keras.hessian.trace_hessian_calculator_keras import TraceHessianCalculatorKeras
from model_compression_toolkit.logger import Logger
from tensorflow.python.util.object_identity import Reference as TFReference


class ActivationTraceHessianCalculatorKeras(TraceHessianCalculatorKeras):
Expand Down Expand Up @@ -65,24 +61,52 @@ def compute(self) -> List[float]:
List[float]: Approximated trace of the Hessian for an interest point.
"""
if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR:
output_list = [n.node for n in self.graph.get_outputs()]
model_output_nodes = [ot.node for ot in self.graph.get_outputs()]

if self.hessian_request.target_node in model_output_nodes:
Logger.exception("Trying to compute activation Hessian approximation with respect to the model output. "
"This operation is not supported. "
"Remove the output node from the set of node targets in the Hessian request.")

grad_model_outputs = [self.hessian_request.target_node] + model_output_nodes

# Building a model to run Hessian approximation on
model, _ = FloatKerasModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()

# Record operations for automatic differentiation
with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
outputs, interest_points_tensors = self._get_model_outputs_for_single_image(output_list,
gradient_tape=g)
g.watch(self.input_images)

if len(self.input_images) > 1:
outputs = model(self.input_images)
else:
outputs = model(*self.input_images)

if len(outputs) != len(grad_model_outputs):
Logger.error(
f"Model for computing activation Hessian approximation expects {len(grad_model_outputs)} "
f"outputs, but got {len(outputs)} output tensors.")

# Extracting the intermediate activation tensors and the model real output
# TODO: we assume that the hessian request is for a single node.
# When we extend it to multiple nodes in the same request, then we should modify this part to take
# the first "num_target_nodes" outputs from the output list.
# We also assume that the target nodes are not part of the model output nodes, if this assumption changed,
# then the code should be modified accordingly.
target_activation_tensors = [outputs[0]]
output_tensors = outputs[1:]

# Unfold and concatenate all outputs to form a single tensor
output = self._concat_tensors(outputs)
output = self._concat_tensors(output_tensors)

# List to store the approximated trace of the Hessian for each interest point
trace_approx_by_node = []
# Loop through each interest point activation tensor
for ipt in tqdm(interest_points_tensors): # Per Interest point activation tensor
for ipt in tqdm(target_activation_tensors): # Per Interest point activation tensor
interest_point_scores = [] # List to store scores for each interest point
for j in range(self.num_iterations_for_approximation): # Approximation iterations
# Getting a random vector with normal distribution
v = tf.random.normal(shape=output.shape)
v = tf.random.normal(shape=output.shape, dtype=output.dtype)
f_v = tf.reduce_sum(v * output)

with g.stop_recording():
Expand All @@ -97,8 +121,7 @@ def compute(self) -> List[float]:
# Compute the approximation per node's output
score_approx_per_output = []
for grad in gradients:
grad = tf.reshape(grad, [grad.shape[0], -1])
score_approx_per_output.append(tf.reduce_mean(tf.reduce_sum(tf.pow(grad, 2.0))))
score_approx_per_output.append(tf.reduce_sum(tf.pow(grad, 2.0)))

# Free gradients
del grad
Expand Down Expand Up @@ -144,125 +167,3 @@ def compute(self) -> List[float]:

else:
Logger.error(f"{self.hessian_request.granularity} is not supported for Keras activation hessian's trace approx calculator")

def _get_model_outputs_for_single_image(self,
output_list: List[str],
gradient_tape: tf.GradientTape) -> Tuple[List[tf.Tensor], List[tf.Tensor]]:
"""
Computes the model's output according to the given graph representation on the given input,
while recording necessary intermediate tensors for gradients computation.
Args:
output_list: List of nodes that considered as model's output for the purpose of gradients computation.
gradient_tape: A GradientTape object for recording necessary info for computing gradients.
Returns: A list of output tensors and a list of activation tensors of all interest points.
"""
model_input_tensors = {inode: self.fw_impl.to_tensor(self.input_images[i]) for i, inode in
enumerate(self.graph.get_inputs())}

node_to_output_tensors_dict = dict()

# Build an OperationHandler to handle conversions from graph nodes to Keras operators.
oh = OperationHandler(self.graph)
input_nodes_to_input_tensors = {inode: tf.convert_to_tensor(model_input_tensors[inode]) for
inode in self.graph.get_inputs()} # Cast numpy array to tf.Tensor

interest_points_tensors = []
output_tensors = []
for n in oh.node_sort:
# Build a dictionary from node to its output tensors, by applying the layers sequentially.
op_func = oh.get_node_op_function(n) # Get node operation function

input_tensors = self._build_input_tensors_list(n,
self.graph,
node_to_output_tensors_dict) # Fetch Node inputs

out_tensors_of_n = self._run_operation(n, # Run node operation and fetch outputs
input_tensors,
op_func,
input_nodes_to_input_tensors)

# Gradients can be computed only on float32 tensors
if isinstance(out_tensors_of_n, list):
for i, t in enumerate(out_tensors_of_n):
out_tensors_of_n[i] = tf.dtypes.cast(t, tf.float32)
else:
out_tensors_of_n = tf.dtypes.cast(out_tensors_of_n, tf.float32)

if n.name==self.hessian_request.target_node.name:
# Recording the relevant feature maps onto the gradient tape
gradient_tape.watch(out_tensors_of_n)
interest_points_tensors.append(out_tensors_of_n)
if n in output_list:
output_tensors.append(out_tensors_of_n)

if isinstance(out_tensors_of_n, list):
node_to_output_tensors_dict.update({n: out_tensors_of_n})
else:
node_to_output_tensors_dict.update({n: [out_tensors_of_n]})

return output_tensors, interest_points_tensors

def _build_input_tensors_list(self,
node: BaseNode,
graph: Graph,
node_to_output_tensors_dict: Dict[BaseNode, List[TFReference]]) -> List[List[TFReference]]:
"""
Given a node, build a list of input tensors the node gets. The list is built
based on the node's incoming edges and previous nodes' output tensors.
Args:
node: Node to build its input tensors list.
graph: Graph the node is in.
node_to_output_tensors_dict: A dictionary from a node to its output tensors.
Returns:
A list of the node's input tensors.
"""

input_tensors = []
# Go over a sorted list of the node's incoming edges, and for each source node get its output tensors.
# Append them in a result list.
for ie in graph.incoming_edges(node, sort_by_attr=EDGE_SINK_INDEX):
_input_tensors = [node_to_output_tensors_dict[ie.source_node][ie.source_index]]
input_tensors.append(_input_tensors)
return input_tensors

def _run_operation(self,
n: BaseNode,
input_tensors: List[List[TFReference]],
op_func: Layer,
input_nodes_to_input_tensors: Dict[BaseNode, Any]) -> List[TFReference]:
"""
Applying the layer (op_func) to the input tensors (input_tensors).
Args:
n: The corresponding node of the layer it runs.
input_tensors: List of references to Keras tensors that are the layer's inputs.
op_func: Layer to apply to the input tensors.
input_nodes_to_input_tensors: A dictionary from a node to its input tensors.
Returns:
A list of references to Keras tensors. The layer's output tensors after applying the
layer to the input tensors.
"""

if len(input_tensors) == 0: # Placeholder handling
out_tensors_of_n = input_nodes_to_input_tensors[n]
else:
input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
# Build a functional node using its args
if isinstance(n, FunctionalNode):
if n.inputs_as_list: # If the first argument should be a list of tensors:
out_tensors_of_n = op_func(input_tensors, *n.op_call_args, **n.op_call_kwargs)
else: # If the input tensors should not be a list but iterated:
out_tensors_of_n = op_func(*input_tensors, *n.op_call_args, **n.op_call_kwargs)
else:
# If operator expects a single input tensor, it cannot be a list as it should have a dtype field.
if len(input_tensors) == 1:
input_tensors = input_tensors[0]
out_tensors_of_n = op_func(input_tensors)

return out_tensors_of_n
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianInfoGranularity
from model_compression_toolkit.core.pytorch.hessian.pytorch_model_gradients import PytorchModelGradients
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
from model_compression_toolkit.core.pytorch.hessian.trace_hessian_calculator_pytorch import \
TraceHessianCalculatorPytorch
from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy
Expand Down Expand Up @@ -61,24 +61,46 @@ def compute(self) -> List[float]:
List[float]: Approximated trace of the Hessian for an interest point.
"""
if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR:
# Set inputs to require_grad
for input_tensor in self.input_images:
input_tensor.requires_grad_()

model_grads_net = PytorchModelGradients(graph_float=self.graph,
trace_hessian_request=self.hessian_request
)
model_output_nodes = [ot.node for ot in self.graph.get_outputs()]

if self.hessian_request.target_node in model_output_nodes:
Logger.exception("Trying to compute activation Hessian approximation with respect to the model output. "
"This operation is not supported. "
"Remove the output node from the set of node targets in the Hessian request.")

grad_model_outputs = [self.hessian_request.target_node] + model_output_nodes
model, _ = FloatPyTorchModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
model.eval()

# Run model inference
output_tensors = model_grads_net(self.input_images)
# Set inputs to track gradients during inference
for input_tensor in self.input_images:
input_tensor.requires_grad_()
input_tensor.retain_grad()

outputs = model(*self.input_images)

if len(outputs) != len(grad_model_outputs):
Logger.error(f"Model for computing activation Hessian approximation expects {len(grad_model_outputs)} "
f"outputs, but got {len(outputs)} output tensors.")

# Extracting the intermediate activation tensors and the model real output
# TODO: we are assuming that the hessian request is for a single node.
# When we extend it to multiple nodes in the same request, then we should modify this part to take
# the first "num_target_nodes" outputs from the output list.
# We also assume that the target nodes are not part of the model output nodes, if this assumption changed,
# then the code should be modified accordingly.
target_activation_tensors = [outputs[0]]
output_tensors = outputs[1:]
device = output_tensors[0].device

# Concat outputs
# First, we need to unfold all outputs that are given as list, to extract the actual output tensors
output = self.concat_tensors(output_tensors)

ipts_hessian_trace_approx = []
for ipt in tqdm(model_grads_net.interest_points_tensors): # Per Interest point activation tensor
for ipt_tensor in tqdm(target_activation_tensors): # Per Interest point activation tensor
trace_hv = []
for j in range(self.num_iterations_for_approximation): # Approximation iterations
# Getting a random vector with normal distribution
Expand All @@ -87,7 +109,7 @@ def compute(self) -> List[float]:

# Computing the hessian trace approximation by getting the gradient of (output * v)
hess_v = autograd.grad(outputs=f_v,
inputs=ipt,
inputs=ipt_tensor,
retain_graph=True,
allow_unused=True)[0]
if hess_v is None:
Expand All @@ -98,8 +120,7 @@ def compute(self) -> List[float]:
requires_grad=True,
device=device))
break
hess_v = torch.reshape(hess_v, [hess_v.shape[0], -1])
hessian_trace_approx = torch.mean(torch.sum(torch.pow(hess_v, 2.0)))
hessian_trace_approx = torch.sum(torch.pow(hess_v, 2.0))

# If the change to the mean Hessian approximation is insignificant we stop the calculation
if j > MIN_HESSIAN_ITER:
Expand Down
Loading

0 comments on commit e921060

Please sign in to comment.