Skip to content

Commit

Permalink
[trainer_llm] 0.0.53 Fix model loading for fine-tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-lopuszynski-tcl committed Apr 15, 2024
1 parent 8f7d7ed commit 5b264ee
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
7 changes: 6 additions & 1 deletion examples/trainer_llm/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,27 @@ def make_model_and_tokenizer(


def apply_decompose_config_and_state_dict_in_place(
*,
model: torch.nn.Module,
decompose_config_path: str,
state_dict_path: str,
device: torch.device,
dtype: torch.dtype,
log_linears: bool = False,
) -> None:

with open(decompose_config_path, "rt") as f:
decompose_config = json.load(f)

ptdeco.utils.apply_decompose_config_in_place(model, decompose_config)
model.to(device)
model.to(dtype)
ptdeco.utils.free_gpu_reserved_memory()
logger.info(f"Applied decompose config {decompose_config_path}")

sd = torch.load(state_dict_path, map_location=device)

model.load_state_dict(sd)

logger.info(f"Loaded state dict {state_dict_path}")
model.eval()

Expand Down
5 changes: 3 additions & 2 deletions examples/trainer_llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,16 @@ def main(config_raw: dict[str, Any], output_path: pathlib.Path) -> None:
dtype=dtype,
log_linears=False,
)

model.to(device)
params_orig = metrics.get_params(model) / 1.0e6
gflops_orig = metrics.get_giga_flops(model, tensor_size=(1, 512))
model.to(device)

builder.apply_decompose_config_and_state_dict_in_place(
model=model,
decompose_config_path=config.decompose_config,
state_dict_path=config.decompose_state_dict,
device=device,
dtype=dtype,
log_linears=True,
)

Expand Down
2 changes: 1 addition & 1 deletion examples/trainer_llm/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This is parsed by setup.py, so we need to stick to str -> int parsing
__version__ = "0.0.52"
__version__ = "0.0.53"

_ver_major = int(__version__.split(".")[0])
_ver_minor = int(__version__.split(".")[1])
Expand Down

0 comments on commit 5b264ee

Please sign in to comment.