Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dwain method + vision trainer #9

Merged
merged 8 commits into from
Apr 16, 2024
4 changes: 2 additions & 2 deletions examples/trainer_llm/run_decompose_dwain.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def main(config_raw: dict[str, Any], output_path: pathlib.Path) -> None:
"mparams_initial": params_initial,
"mparams_final": params_final,
"mparams_frac": params_frac,
"gflops_orig": gflops_initial,
"gflops_initial": gflops_final,
"gflops_initial": gflops_initial,
"gflops_final": gflops_final,
"gflops_frac": gflops_frac,
"time_decomposition_and_perplex_eval": time_decomposition_and_perplex_eval,
"time_lm_eval": time_lm_eval,
Expand Down
2 changes: 1 addition & 1 deletion examples/trainer_llm/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This is parsed by setup.py, so we need to stick to str -> int parsing
__version__ = "0.0.53"
__version__ = "0.0.54"

_ver_major = int(__version__.split(".")[0])
_ver_minor = int(__version__.split(".")[1])
Expand Down
3 changes: 3 additions & 0 deletions examples/trainer_vision/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def log_model_stats(
kmapps = model_stats["kmapps"]
mparams = model_stats["mparams"]
msg = f"{log_prefix} gflops={gflops:.2f} kmapps={kmapps:.2f} Mparams={mparams:.2f}"
acc = model_stats.get("accuracy_val")
if acc is not None:
msg += f" {100*acc:.2f}"
stats_logger.info(msg)


Expand Down
4 changes: 3 additions & 1 deletion examples/trainer_vision/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ class DecomposeFALORConfig(_VersionConfig, _DataConfig):
nsr_final_threshold: float
num_data_steps: int
num_metric_steps: int

use_mean: bool
use_float64: bool
use_damping: bool
model_config = pydantic.ConfigDict(extra="forbid")


Expand Down
105 changes: 100 additions & 5 deletions examples/trainer_vision/run_decompose_falor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import json
import logging
import pathlib
import time
from typing import Any

import nvidia.dali.plugin.pytorch # type:ignore
import ptdeco.falor
import torch
import torchmetrics


import builder
import configurator
Expand All @@ -22,6 +25,46 @@ def make_image_iterator(
yield d["inputs"].permute(0, 3, 1, 2)


def calc_accuracy(*, model, valid_pipeline, device, n_batches=None):

model.eval()

model.to(device)

val_iter = datasets_dali.DaliGenericIteratorWrapper(
nvidia.dali.plugin.pytorch.DALIGenericIterator(
valid_pipeline, ["inputs", "targets"]
)
)
if n_batches is None:
n_batches = len(val_iter)

# pbar = tqdm(total=n_batches)
with torch.inference_mode():
metrics = torchmetrics.classification.MulticlassAccuracy(
num_classes=1000,
)

metrics.to(device)

for i, batch in enumerate(val_iter):
if i >= n_batches:
break
inputs, targets = batch["inputs"], batch["targets"]
inputs = inputs.permute(0, 3, 1, 2)
inputs = inputs.to(device)
outputs = model(inputs)
targets = torch.argmax(targets, dim=1)
outputs = torch.softmax(outputs, dim=1)
metrics.update(outputs, targets)
# pbar.update(1)
res = metrics.compute().item()

del valid_pipeline

return res


def main(config_raw: dict[str, Any], output_path: pathlib.Path) -> None:
config = configurator.DecomposeFALORConfig(**config_raw)
b_c_h_w = (1, 3, *config.input_h_w)
Expand All @@ -34,7 +77,6 @@ def main(config_raw: dict[str, Any], output_path: pathlib.Path) -> None:
normalization=config.normalization,
h_w=config.input_h_w,
)
del valid_pipeline

if torch.cuda.is_available():
device = torch.device("cuda")
Expand All @@ -50,11 +92,23 @@ def main(config_raw: dict[str, Any], output_path: pathlib.Path) -> None:
data_iterator = make_image_iterator(train_dataloader)

model = builder.make_model(config.decompose_model_name)
t_eval_start = time.perf_counter()
accuracy_val_initial = 100.0 * calc_accuracy(
model=model,
valid_pipeline=valid_pipeline,
device=device,
)
t_eval_intial = time.perf_counter() - t_eval_start
s = f"Initial accuracy {accuracy_val_initial:.2f}, eval took {t_eval_intial:.2f} s"
logger.info(s)

builder.validate_module_names(model, config.blacklisted_modules)

model.to(device)
model_orig_stats = builder.get_model_stats(model, b_c_h_w)
stats_initial = builder.get_model_stats(model, b_c_h_w)
stats_initial["accuracy_val"] = accuracy_val_initial

t_decomposition_start = time.perf_counter()
decompose_config = ptdeco.falor.decompose_in_place(
module=model,
device=device,
Expand All @@ -65,14 +119,55 @@ def main(config_raw: dict[str, Any], output_path: pathlib.Path) -> None:
num_data_steps=config.num_data_steps,
num_metric_steps=config.num_metric_steps,
blacklisted_module_names=config.blacklisted_modules,
use_float64=config.use_float64,
use_mean=config.use_mean,
use_damping=config.use_damping,
)
t_decomposition = time.perf_counter() - t_decomposition_start

stats_final = builder.get_model_stats(model, b_c_h_w)
t_eval_start = time.perf_counter()
accuracy_val_final = 100.0 * calc_accuracy(
model=model,
valid_pipeline=valid_pipeline,
device=device,
)
model_deco_stats = builder.get_model_stats(model, b_c_h_w)
t_eval_final = time.perf_counter() - t_eval_start
s = f"Final accuracy {accuracy_val_final:.2f}, eval took {t_eval_final:.2f} s"
logger.info(s)
stats_final["accuracy_val"] = accuracy_val_final

out_decompose_config_path = output_path / "decompose_config.json"
with open(out_decompose_config_path, "wt") as f:
json.dump(decompose_config, f)
out_decompose_state_dict_path = output_path / "decompose_state_dict.pt"
torch.save(model.state_dict(), out_decompose_state_dict_path)

builder.log_model_stats(logger, "Original model :", model_orig_stats)
builder.log_model_stats(logger, "Decomposed model:", model_deco_stats)
builder.log_model_stats(logger, "Original model :", stats_initial)
builder.log_model_stats(logger, "Decomposed model:", stats_final)

device_str = str(device)
if "cuda" in device_str:
device_str += " @ " + torch.cuda.get_device_name(device)

summary = {
"accuracy_val_initial": accuracy_val_initial,
"accuracy_val_final": accuracy_val_final,
"mparams_initial": stats_initial["mparams"],
"mparams_final": stats_final["mparams"],
"mparams_frac": stats_final["mparams"] / stats_initial["mparams"] * 100.0,
"gflops_initial": stats_initial["gflops"],
"gflops_final": stats_final["gflops"],
"gflops_frac": stats_final["gflops"] / stats_initial["gflops"] * 100.0,
"kmapps_initial": stats_initial["kmapps"],
"kmapps_finall": stats_final["kmapps"],
# Should be the same as "gflops_frac", but we log it for completeness
"kmapps_frac": stats_final["kmapps"] / stats_initial["kmapps"] * 100.0,
"time_eval_initial": t_eval_intial,
"time_decomposition": t_decomposition,
"time_eval_final": t_eval_final,
"device": device_str,
}

with open(output_path / "summary.json", "wt") as f:
json.dump(summary, f)
2 changes: 1 addition & 1 deletion examples/trainer_vision/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This is parsed by setup.py, so we need to stick to str -> int parsing
__version__ = "0.8.2"
__version__ = "0.8.6"

_ver_major = int(__version__.split(".")[0])
_ver_minor = int(__version__.split(".")[1])
Expand Down
2 changes: 1 addition & 1 deletion src/ptdeco/_version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This is parsed by setup.py, so we need to stick to str -> int parsing
__version__ = "0.4.31"
__version__ = "0.4.34"

_ver_major = int(__version__.split(".")[0])
_ver_minor = int(__version__.split(".")[1])
Expand Down
5 changes: 3 additions & 2 deletions src/ptdeco/dwain/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
"decompose_in_place",
]

logger = logging.getLogger(__name__)

EIGEN_DAMPEN_FACTOR = 0.01

logger = logging.getLogger(__name__)


class WrappedDWAINModule(torch.nn.Module):
def __init__(self) -> None:
Expand Down Expand Up @@ -153,6 +153,7 @@ def _update_Eyyt_in_place(Eyyt: torch.Tensor, y_reshaped: torch.Tensor) -> None:

def _get_eigenvectors(Eyyt: torch.Tensor, num_data_steps: int) -> torch.Tensor:
Eyyt = Eyyt / num_data_steps
# https://stats.stackexchange.com/questions/390532/adding-a-small-constant-to-the-diagonals-of-a-matrix-to-stabilize
damp = EIGEN_DAMPEN_FACTOR * torch.mean(torch.diag(Eyyt))
diag = torch.arange(Eyyt.shape[-1], device=Eyyt.device)
Eyyt[diag, diag] = Eyyt[diag, diag] + damp
Expand Down
42 changes: 37 additions & 5 deletions src/ptdeco/falor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from .. import utils

EIGEN_DAMPEN_FACTOR = 0.01

logger = logging.getLogger(__name__)

__all__ = ["decompose_in_place"]
Expand Down Expand Up @@ -164,22 +166,40 @@ def _compute_decompositon_of_covariance_matrix(
weight: torch.Tensor,
num_data_steps: int,
device: torch.device,
use_float64: bool,
use_mean: bool,
use_damping: bool,
) -> torch.Tensor:
root_module.eval()
decomposed_submodule = root_module.get_submodule(decomposed_submodule_name)
assert isinstance(decomposed_submodule, WrappedFALORModule)

Ey = torch.zeros(weight.shape[0]).to(device)
Eyyt = torch.zeros((weight.shape[0], weight.shape[0])).to(device)

n_out = weight.shape[0]
if use_float64:
Ey = torch.zeros(n_out, dtype=torch.float64).to(device)
Eyyt = torch.zeros((n_out, n_out), dtype=torch.float64).to(device)
else:
Ey = torch.zeros(n_out, dtype=torch.float32).to(device)
Eyyt = torch.zeros((n_out, n_out), dtype=torch.float32).to(device)
for i in range(num_data_steps):
inputs = next(data_iterator).to(device)
_ = root_module(inputs)
x = decomposed_submodule.get_last_input()
Ey, Eyyt = _accumulate_Ey_and_Eyyt(Ey=Ey, Eyyt=Eyyt, weight=weight, x=x)
Ey /= num_data_steps
Eyyt /= num_data_steps
cov = Eyyt - torch.outer(Ey, Ey)
if use_mean:
logger.info("Using mean for covariance")
cov = Eyyt - torch.outer(Ey, Ey)
else:
logger.info("Not using mean for covariance")
cov = Eyyt
if use_damping:
# https://stats.stackexchange.com/questions/390532/adding-a-small-constant-to-the-diagonals-of-a-matrix-to-stabilize
logger.info("Using damping")
damp = EIGEN_DAMPEN_FACTOR * torch.mean(torch.diag(cov))
diag = torch.arange(cov.shape[-1], device=cov.device)
Eyyt[diag, diag] += damp
logger.info(f"{cov.dtype=}")
_, u = torch.linalg.eigh(cov)
return u

Expand Down Expand Up @@ -256,6 +276,9 @@ def _process_module(
num_data_steps: int,
num_metric_steps: int,
device: torch.device,
use_float64: bool,
use_mean: bool,
use_damping: bool,
) -> dict[str, Any]:
decomposed_submodule = root_module.get_submodule(decomposed_submodule_name)
decomposed_type = utils.get_type_name(decomposed_submodule)
Expand Down Expand Up @@ -289,6 +312,9 @@ def _process_module(
weight=orig_weight,
num_data_steps=num_data_steps,
device=device,
use_float64=use_float64,
use_mean=use_mean,
use_damping=use_damping,
)

U, V = torch.empty(0), torch.empty(0)
Expand Down Expand Up @@ -390,6 +416,9 @@ def decompose_in_place(
kl_final_threshold: float,
num_data_steps: int,
num_metric_steps: int,
use_float64: bool,
use_mean: bool,
use_damping: bool,
) -> dict[str, Any]:
start_time = time.perf_counter()

Expand Down Expand Up @@ -419,6 +448,9 @@ def decompose_in_place(
num_data_steps=num_data_steps,
num_metric_steps=num_metric_steps,
device=device,
use_float64=use_float64,
use_mean=use_mean,
use_damping=use_damping,
)
results_all[submodule_name] = result

Expand Down
Loading