Skip to content

Commit

Permalink
code review modification 2
Browse files Browse the repository at this point in the history
  • Loading branch information
eladc-git committed Oct 30, 2023
1 parent 1b88e58 commit a85cb79
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,9 @@ def compute(self) -> np.ndarray:
# Compute the mean of the approximations
final_approx = torch.mean(torch.stack(approximation_per_iteration), dim=0)

# Make sure all final shape are tensors and not scalar
if self.hessian_request.granularity == HessianInfoGranularity.PER_TENSOR:
final_approx = final_approx.reshape(1)

return final_approx.detach().cpu().numpy()

Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_expected_shape(weights_shape, granularity):
if granularity==hessian_common.HessianInfoGranularity.PER_ELEMENT:
return weights_shape
elif granularity==hessian_common.HessianInfoGranularity.PER_TENSOR:
return ()
return (1,)
else:
return (weights_shape[0],)

Expand Down

0 comments on commit a85cb79

Please sign in to comment.