Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vision trainer and dwain decomposer #14

Merged
merged 15 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ convolutions.
GPU hour (depending on model size and parameters). It can decompose linear
layers and 1x1 convolutions.

**dwain** method does not require pretraining. It can decompose linear layers.
**dwain** method does not require pretraining. It can decompose linear layers and
1x1 convolutions.

## Installation

Expand Down
2 changes: 2 additions & 0 deletions examples/trainer_llm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ PYTHON_FILES=$(PYTHON_LIB_FILES) $(PYTHON_RUN_FILE)

BUILD_DIR=dist

all: build

check:
-isort --check --profile=black $(PYTHON_FILES)
-black --check $(PYTHON_FILES)
Expand Down
8 changes: 5 additions & 3 deletions examples/trainer_llm/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@


def _log_linear_submodules(m: torch.nn.Module) -> None:
res = ["All Linear modules of the model:"]
msg_list = ["All linear modules of the model:"]

i = 1
for name, module in m.named_modules():
if isinstance(module, torch.nn.Linear):
res.append(f" - {name} # ({i}) {tuple(module.weight.shape)}")
bias = "+ bias" if module.bias is not None else "no bias"
msg = f" - {name} # ({i}) {bias} {tuple(module.weight.shape)}"
msg_list.append(msg)
i += 1
logger.info("\n".join(res))
logger.info("\n".join(msg_list))


def _add_pad_token(
Expand Down
2 changes: 1 addition & 1 deletion examples/trainer_llm/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This is parsed by setup.py, so we need to stick to str -> int parsing
__version__ = "0.0.61"
__version__ = "0.0.63"

_ver_major = int(__version__.split(".")[0])
_ver_minor = int(__version__.split(".")[1])
Expand Down
12 changes: 7 additions & 5 deletions examples/trainer_vision/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def log_model_stats(


def make_model(
model_name: str, log_linears_an_conv1x1: bool = False
model_name: str, log_linears_and_conv1x1: bool = False
) -> torch.nn.Module:
builder, model_name = model_name.split(".", maxsplit=1)

Expand All @@ -143,16 +143,18 @@ def make_model(
else:
raise ValueError(f"Unknown model builder {builder}")

if log_linears_an_conv1x1:
if log_linears_and_conv1x1:
n_linears = 0
n_conv1x1 = 0
i = 1
msg_list = ["All decomposeable modules of the model:"]

for name, module in model.named_modules():
i = n_linears + n_conv1x1 + 1
if isinstance(module, torch.nn.Linear):
bias = "+ bias" if module.bias is not None else "no bias"
msg = f" - {name} # ({i}) linear {bias} {tuple(module.weight.shape)}"
logger.info(msg)
msg_list.append(msg)
n_linears += 1
elif (
isinstance(module, torch.nn.Conv2d)
Expand All @@ -162,9 +164,9 @@ def make_model(
):
bias = "+ bias" if module.bias is not None else "no bias"
msg = f" - {name} # ({i}) conv1x1 {bias} {tuple(module.weight.shape)}"
logger.info(msg)
msg_list.append(msg)
n_conv1x1 += 1

logger.info("\n".join(msg_list))
logger.info(f"Decomposeable module statistics {n_linears=} {n_conv1x1=}")
return model

Expand Down
7 changes: 6 additions & 1 deletion examples/trainer_vision/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class _TrainConfig(pydantic.BaseModel):
lr: float
lr_t_warmup: str
max_duration: str
optimizer: Literal["Adam", "SGD"]
optimizer: Literal["SGD", "Adam", "AdamW"]
precision: Optional[Literal["fp32", "amp_fp16", "amp_bf16", "amp_fp8"]]
alg_channel_last: bool
alg_gradient_clipping_type: Optional[Literal["norm", "value", "adaptive"]]
Expand Down Expand Up @@ -79,6 +79,9 @@ class DecomposeDWAINConfig(_VersionConfig, _DataConfig):

finetuning_run: bool
finetuning_lr: float
finetuning_optimizer: Literal["SGD", "Adam", "AdamW"]
finetuning_reverting: bool
finetuning_batch_norms_in_eval: bool
finetuning_num_steps: int
finetuning_num_log_steps: int
finetuning_num_last_finetuned_modules: int
Expand Down Expand Up @@ -148,6 +151,8 @@ def get_optimizer(
logger.info(f"Using optimizer {optimizer_name}")
if optimizer_name == "Adam":
return torch.optim.Adam(params=params, lr=config.lr)
elif optimizer_name == "AdamW":
return torch.optim.AdamW(params=params, lr=config.lr)
elif optimizer_name == "SGD":
return torch.optim.SGD(params=params, lr=config.lr)
else:
Expand Down
69 changes: 58 additions & 11 deletions examples/trainer_vision/dwain_wrapper_module.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import collections
import json
import logging
import os
import pathlib
import time
from typing import Any
from typing import Any, Optional

import ptdeco
import timm # type:ignore
import torch

PREFIX = "raw_model."
Expand Down Expand Up @@ -59,6 +61,22 @@ def save_raw_model_decompose_config_and_state_dict(
torch.save(strip_prefix_dict(state_dict), out_decompose_state_dict_path)


_BATCH_NORM_TYPES = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
timm.layers.norm_act.BatchNormAct2d,
)


def _batch_norms_in_eval(m: torch.nn.Module) -> None:
for mod_name, mod in m.named_modules():
if isinstance(mod, _BATCH_NORM_TYPES):
mod.eval()
mod_type_name = ptdeco.utils.get_type_name(mod)
logger.info(f"Switching {mod_name} ({mod_type_name}) to eval mode")


def finetune_full(
*,
model: torch.nn.Module,
Expand All @@ -69,6 +87,9 @@ def finetune_full(
num_steps: int = 100,
num_log_steps: int = 10,
lr: float = 0.0001,
reverting_checkpoints_dir: Optional[pathlib.Path] = None,
optimizer_name: str,
batch_norms_in_eval: bool,
) -> torch.nn.Module:

if len(decomposed_modules) == 0:
Expand All @@ -83,22 +104,33 @@ def finetune_full(
logger.info(msg)
else:
param.requires_grad = False
if optimizer_name == "SGD":
optimizer: torch.optim.Optimizer = torch.optim.SGD(model.parameters(), lr=lr)
logger.info("Using SGD optimizer")
elif optimizer_name == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif optimizer_name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
logger.info("Using AdamW optimizer")
else:
raise ValueError(f"Unknown {optimizer_name=} only SGD and AdamW are allowed")

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# lr_scheduler = transformers.get_linear_schedule_with_warmup(
# optimizer=optimizer,
# num_warmup_steps=10,
# num_training_steps=num_steps,
# )
counter = 0
model.train()
if batch_norms_in_eval:
_batch_norms_in_eval(model)
total_loss = 0.0
for step in range(num_steps):
initial_loss = float("nan")
last_loss = float("nan")

if reverting_checkpoints_dir is not None:
pid = os.getpid()
sd_path = reverting_checkpoints_dir / f"tmp_reverting_state_dict_{pid}.pt"
torch.save(model.state_dict(), sd_path)

for step in range(1, num_steps + 1):
batch = ptdeco.utils.to_device(next(ft_iterator), device)
counter += 1
if step > num_steps:
break
optimizer.zero_grad()
outputs = model(batch)
loss = ce_loss(batch, outputs)
Expand All @@ -109,8 +141,23 @@ def finetune_full(

if step % num_log_steps == 0:
logger.info(f"Step: {step}/{num_steps}, loss: {total_loss / counter}")
# Thist cheks for NaN
if initial_loss != initial_loss:
initial_loss = total_loss / counter
last_loss = total_loss / counter
total_loss = 0.0
counter = 0

if reverting_checkpoints_dir is not None:
if initial_loss == initial_loss and initial_loss < last_loss:
loss_msg = f"{initial_loss=:.4f} < {last_loss=:.4f}"
logger.info(f"{loss_msg}: keeping the orig weights")
model.load_state_dict(torch.load(sd_path))
elif initial_loss == initial_loss and initial_loss >= last_loss:
loss_msg = f"{initial_loss=:.4f} >= {last_loss=:.4f}"
logger.info(f"{loss_msg}: using the fine-tuned weights")
sd_path.unlink(missing_ok=True)

model.eval()
stop = time.perf_counter()
logger.info(f"Full fine-tuning took {stop-start:.2f} seconds")
Expand Down
18 changes: 15 additions & 3 deletions examples/trainer_vision/run_decompose_dwain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any

import nvidia.dali.plugin.pytorch # type:ignore
import ptdeco.falor
import ptdeco
import torch

import builder
Expand Down Expand Up @@ -34,12 +34,19 @@ def no_finetune(

def make_finetune_fn(
config: configurator.DecomposeDWAINConfig,
output_path: pathlib.Path,
ft_iterator: collections.abc.Iterator[dict[str, torch.Tensor]],
) -> collections.abc.Callable[
[torch.nn.Module, torch.device, list[str]], torch.nn.Module
]:
if config.finetuning_run:
logger.info("Creating full finetuning function")
if config.finetuning_reverting:
reverting_checkpoints_dir = output_path
logger.info("Reverting finetuning is ON")
else:
reverting_checkpoints_dir = None
logger.info("Reverting finetuning is OFF")
return lambda m, device, decomposed_modules: dwain_wrapper_module.finetune_full(
model=m,
device=device,
Expand All @@ -49,6 +56,9 @@ def make_finetune_fn(
num_log_steps=config.finetuning_num_log_steps,
lr=config.finetuning_lr,
num_last_modules_to_finetune=config.finetuning_num_last_finetuned_modules,
reverting_checkpoints_dir=reverting_checkpoints_dir,
optimizer_name=config.finetuning_optimizer,
batch_norms_in_eval=config.finetuning_batch_norms_in_eval,
)
else:
logger.info("Creating empty finetuning function")
Expand Down Expand Up @@ -86,7 +96,9 @@ def main(config_raw: dict[str, Any], output_path: pathlib.Path) -> None:

decomposition_it = make_image_iterator(train_dataloader)

model = builder.make_model(config.decompose_model_name, log_linears_an_conv1x1=True)
model = builder.make_model(
config.decompose_model_name, log_linears_and_conv1x1=True
)
builder.validate_module_names(model, config.blacklisted_modules)
model.to(device)

Expand All @@ -111,7 +123,7 @@ def main(config_raw: dict[str, Any], output_path: pathlib.Path) -> None:

t_decomposition_start = time.perf_counter()

finetune_fn = make_finetune_fn(config, decomposition_it)
finetune_fn = make_finetune_fn(config, output_path, decomposition_it)

blacklisted_module_names_wrapped = dwain_wrapper_module.add_prefix(
config.blacklisted_modules
Expand Down
4 changes: 3 additions & 1 deletion examples/trainer_vision/run_decompose_falor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def main(config_raw: dict[str, Any], output_path: pathlib.Path) -> None:
)
data_iterator = make_image_iterator(train_dataloader)

model = builder.make_model(config.decompose_model_name, log_linears_an_conv1x1=True)
model = builder.make_model(
config.decompose_model_name, log_linears_and_conv1x1=True
)
builder.validate_module_names(model, config.blacklisted_modules)
model.to(device)

Expand Down
2 changes: 1 addition & 1 deletion examples/trainer_vision/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This is parsed by setup.py, so we need to stick to str -> int parsing
__version__ = "0.8.28"
__version__ = "0.8.37"

_ver_major = int(__version__.split(".")[0])
_ver_minor = int(__version__.split(".")[1])
Expand Down
2 changes: 1 addition & 1 deletion src/ptdeco/_version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This is parsed by setup.py, so we need to stick to str -> int parsing
__version__ = "0.4.50"
__version__ = "0.4.51"

_ver_major = int(__version__.split(".")[0])
_ver_minor = int(__version__.split(".")[1])
Expand Down
2 changes: 1 addition & 1 deletion src/ptdeco/dwain/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def _process_module(
out_features=dim_out,
)
if not decompose_decision:
msg = "{proportion=:.4f} leads to num param increase, not decomposing"
msg = f"{proportion=:.4f} leads to num param increase, not decomposing"
logger.info(f"{indent}{msg}")
else:
decompose_decision = False
Expand Down
Loading