-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #190 from centre-for-humanities-computing/lumi
LUMI scripts - Mosaic/llm-foundry
- Loading branch information
Showing
11 changed files
with
524 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -158,3 +158,7 @@ models/* | |
|
||
# Hydra | ||
outputs/* | ||
|
||
# training artifacts | ||
logs/ | ||
separate-logs/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[submodule "llm-foundry"] | ||
path = llm-foundry | ||
url = https://github.com/rlrs/llm-foundry | ||
branch = lumi |
Submodule llm-foundry
added at
f89ce6
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# Modified by @rlrs for Danish Foundation Models | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Convert jsonl data to streaming MDS format, while tokenizing and concatenating.""" | ||
import os | ||
from argparse import ArgumentParser, Namespace | ||
from glob import glob | ||
from typing import Dict, Iterable, Optional, Union | ||
|
||
import datasets as hf_datasets | ||
from streaming import MDSWriter | ||
from torch.utils.data import IterableDataset | ||
from tqdm import tqdm | ||
from transformers import AutoTokenizer, PreTrainedTokenizerBase | ||
import numpy as np | ||
|
||
|
||
def parse_args() -> Namespace: | ||
"""Parse commandline arguments.""" | ||
parser = ArgumentParser( | ||
description= | ||
'Convert dataset into MDS format, tokenizing and concatenating.' | ||
) | ||
parser.add_argument('--path', type=str, required=True) | ||
parser.add_argument('--out_root', type=str, required=True) | ||
parser.add_argument('--compression', type=str, default='zstd') | ||
|
||
parser.add_argument( | ||
'--concat_tokens', | ||
type=int, | ||
help='Convert text to tokens and concatenate up to this many tokens', required=True) | ||
|
||
parser.add_argument('--tokenizer', type=str, required=False, default=None) | ||
parser.add_argument('--bos_text', type=str, required=False, default=None) | ||
parser.add_argument('--eos_text', type=str, required=False, default=None) | ||
parser.add_argument('--no_wrap', default=False, action='store_true') # why would you do this? | ||
|
||
parser.add_argument('--test_size', type=float, default=0.01) | ||
parser.add_argument('--seed', type=int, default=42) | ||
|
||
parsed = parser.parse_args() | ||
|
||
if parsed.bos_text is None: | ||
parsed.bos_text = '' | ||
if parsed.eos_text is None: | ||
parsed.eos_text = '' | ||
return parsed | ||
|
||
def generate_chunks(dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset], | ||
bos_tokens: list[int], eos_tokens: list[int], chunk_length: int) -> Iterable[Dict[str, bytes]]: | ||
buffer = np.empty(0, dtype=np.int64, order='C') | ||
for sample in dataset: | ||
iids = sample['input_ids'] | ||
buffer = np.append(buffer, [*bos_tokens, *iids, *eos_tokens]) | ||
while len(buffer) >= chunk_length: | ||
concat_sample = buffer[:chunk_length] | ||
buffer = buffer[chunk_length:] #if should_wrap else np.empty(0, dtype=np.int64, order='C') | ||
yield { | ||
# convert to bytes to store in MDS binary format | ||
'tokens': np.asarray(concat_sample, dtype=np.int64).tobytes() # unsure why the np.asarray is necessary, tbh, but it is | ||
} | ||
|
||
|
||
def build_hf_dataset( | ||
path: str, | ||
tokenizer: PreTrainedTokenizerBase, | ||
max_length: Optional[int] = None, | ||
bos_text: str = '', | ||
eos_text: str = '', | ||
) -> IterableDataset: | ||
"""Build an IterableDataset over the HF C4 or pile source data. | ||
Args: | ||
dataset_name (str): Dataset name | ||
max_length (int): The length of concatenated tokens | ||
bos_text (str): text to insert at the beginning of each sequence | ||
eos_text (str): text to insert at the end of each sequence | ||
no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries | ||
tokenizer (PreTrainedTokenizerBase): the tokenizer to use | ||
Returns: | ||
An IterableDataset. | ||
""" | ||
if os.path.isdir(path): | ||
data_files = glob(f'{path}/*') | ||
else: | ||
data_files = path | ||
|
||
hf_dataset = hf_datasets.load_dataset('json', | ||
keep_in_memory=False, | ||
data_files=data_files, | ||
split="train") | ||
|
||
if not isinstance(tokenizer, PreTrainedTokenizerBase): | ||
raise ValueError( | ||
f'{tokenizer=} must be of type PreTrainedTokenizerBase') | ||
if max_length is None: | ||
raise ValueError(f'max_length must be set.') | ||
if bos_text + eos_text == '': | ||
test_tokens = tokenizer('test') | ||
if test_tokens['input_ids'][ | ||
0] != tokenizer.bos_token_id and test_tokens['input_ids'][ | ||
-1] != tokenizer.eos_token_id: | ||
tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' | ||
tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' | ||
tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' | ||
tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' | ||
tok_error_msg += '--bos_text=<|endoftext|>.' | ||
raise ValueError(tok_error_msg) | ||
|
||
return hf_dataset | ||
|
||
def main(args: Namespace) -> None: | ||
"""Main: create C4/pile streaming dataset. | ||
Args: | ||
args (Namespace): Commandline arguments. | ||
""" | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | ||
# we will enforce length, so suppress warnings about sequences too long for the model | ||
tokenizer.model_max_length = int(1e30) | ||
columns = {'tokens': 'bytes'} | ||
|
||
# Get samples | ||
dataset = build_hf_dataset(path=args.path, | ||
max_length=args.concat_tokens, | ||
bos_text=args.bos_text, | ||
eos_text=args.eos_text, | ||
tokenizer=tokenizer) | ||
|
||
bos_tokens = tokenizer(args.bos_text, | ||
truncation=False, | ||
padding=False, | ||
add_special_tokens=False)['input_ids'] | ||
if len(bos_tokens) > 1: | ||
warnings.warn( | ||
f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token\ | ||
, instead we got {bos_tokens}. Quit if this was in error.') | ||
|
||
eos_tokens = tokenizer(args.eos_text, | ||
truncation=False, | ||
padding=False, | ||
add_special_tokens=False)['input_ids'] | ||
if len(eos_tokens) > 1: | ||
warnings.warn( | ||
f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token\ | ||
, instead we got {self.eos_tokens}. Quit if this was in error.') | ||
|
||
eos_text_provided = args.eos_text != '' | ||
bos_text_provided = args.bos_text != '' | ||
test_text = tokenizer('') | ||
if len(test_text['input_ids']) > 0 and (eos_text_provided or | ||
bos_text_provided): | ||
message = 'both eos and bos' if eos_text_provided and bos_text_provided else ( | ||
'eos_text' if eos_text_provided else 'bos_text') | ||
warnings.warn( | ||
f'The provided tokenizer adds special tokens, but you also specified {message}. This may result ' | ||
+ | ||
'in duplicated special tokens. Please be sure this is what you intend.' | ||
) | ||
|
||
def tokenize(batch): | ||
return tokenizer(batch['text'], | ||
truncation=False, | ||
padding=False, | ||
add_special_tokens=False) | ||
|
||
# We make a train/test split before chunking since it's way easier, although number | ||
# of samples will vary - splitting after chunking would fix this, but has other issues | ||
dataset = dataset.train_test_split(test_size=args.test_size, seed=args.seed) | ||
|
||
print("Tokenizing dataset") | ||
dataset = dataset.map(tokenize, batched=True, | ||
batch_size=24*10, | ||
remove_columns=['text']) | ||
|
||
# Write samples while chunking | ||
for split in dataset.keys(): | ||
print(f'Writing {split} split... (iterations are samples)') | ||
with MDSWriter(columns=columns, | ||
out=os.path.join(args.out_root, split), | ||
compression=args.compression) as out: | ||
chunks = generate_chunks(dataset[split], bos_tokens, eos_tokens, args.concat_tokens) | ||
for sample in tqdm(chunks): | ||
out.write(sample) | ||
|
||
if __name__ == '__main__': | ||
main(parse_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Model training on LUMI | ||
|
||
## Dataset preparation | ||
From a jsonl file (such as da-gigaword), something like `python scripts/data/convert_dataset_json.py --path /path/to/da-gigaword.jsonl.tar.gz --out_root ./da-gigaword-mds --concat_tokens 4096 --tokenizer mistralai/Mistral-7B-v0.1 --test_size 0.02` will generate the necessary Mosaic streaming dataset. Takes ~2 hours for da-gigaword, which is a bit slow. When done, copy this folder to LUMI scratch and configure data path in the training YAML, e.g. `scripts/lumi/yamls/continue-mistral-7b.yaml`. | ||
|
||
## LUMI setup and training | ||
1. SSH into LUMI | ||
3. Enter project: `cd /scratch/project_465000670/danish-foundation-models` | ||
2. Enter container: `singularity run --cleanenv --bind /scratch/project_465000670/ /project/project_465000670/pytorch_rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1.sif` | ||
5. Set up virtual environment: `./scripts/lumi/make_venv.sh` | ||
6. Exit container | ||
7. Run training: `./scripts/lumi/continue_mistral_mosaic.sh` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
#!/bin/bash | ||
|
||
##SBATCH --exclude=nid006865,nid005613,nid005988 | ||
#SBATCH --nodes=2 | ||
#SBATCH --ntasks-per-node=1 | ||
#SBATCH --cpus-per-task=56 | ||
#SBATCH --mem=0 | ||
#SBATCH --partition=standard-g | ||
#SBATCH --time=0-01:00:00 | ||
#SBATCH --gpus-per-node=mi250:8 | ||
#SBATCH --exclusive=user | ||
#SBATCH --hint=nomultithread | ||
#SBATCH --account=project_465000670 | ||
#SBATCH --output=logs/%j.out | ||
#SBATCH --error=logs/%j.err | ||
|
||
# if run without sbatch, invoke here | ||
if [ -z $SLURM_JOB_ID ]; then | ||
mkdir -p logs | ||
sbatch "$0" | ||
exit | ||
fi | ||
|
||
# LUMI setup | ||
# module load LUMI/22.08 partition/G singularity-bindings/system-cpeGNU-22.08-noglibc # singularity-bindings is BROKEN | ||
module load LUMI/22.08 partition/G | ||
|
||
# These replace the module load of singularity-bindings | ||
local_libfabric_version=1.15.2.0 | ||
local_craympich_version=8.1.27 | ||
export SINGULARITYENV_LD_LIBRARY_PATH="/lib64:/opt/cray/pe/mpich/$local_craympich_version/ofi/gnu/9.1/lib-abi-mpich:/opt/cray/pe/lib64:/opt/cray/pe:/opt/cray/libfabric/$local_libfabric_version/lib64:/usr/lib64:/opt/cray/pe/gcc-libs:${SINGULARITYENV_LD_LIBRARY_PATH}" | ||
export SINGULARITY_BIND="/opt/cray,/usr/lib64/libbrotlidec.so.1,/usr/lib64/libbrotlicommon.so.1,/usr/lib64/libnl-3.so.200,/usr/lib64/libnl-route-3.so.200,/usr/lib64/libcxi.so.1,/usr/lib64/libcurl.so.4,/usr/lib64/libnghttp2.so.14,/usr/lib64/libidn2.so.0,/usr/lib64/libssh.so.4,/usr/lib64/libpsl.so.5,/usr/lib64/libssl.so.1.1,/usr/lib64/libcrypto.so.1.1,/usr/lib64/libgssapi_krb5.so.2,/usr/lib64/libldap_r-2.4.so.2,/usr/lib64/liblber-2.4.so.2,/usr/lib64/libjson-c.so.3,/usr/lib64/libunistring.so.2,/usr/lib64/libkrb5.so.3,/usr/lib64/libk5crypto.so.3,/usr/lib64/libkrb5support.so.0,/usr/lib64/libsasl2.so.3,/usr/lib64/libkeyutils.so.1,/var/spool/slurmd/mpi_cray_shasta,/usr/lib64/libzstd.so.1,/lib64/libselinux.so.1,/usr/lib64/libpcre.so.1,${SINGULARITY_BIND}" | ||
|
||
# These are some more custom exports | ||
export SINGULARITY_BIND=/users/larsenra/aws-ofi-rccl/install:/opt/aws-ofi-rccl,/usr/lib64/libjitterentropy.so.3,${SINGULARITY_BIND} | ||
export SINGULARITYENV_LD_LIBRARY_PATH=/opt/ompi/lib:${EBROOTAWSMINOFIMINRCCL}/lib:/opt/cray/xpmem/2.5.2-2.4_3.47__gd0f7936.shasta/lib64:/opt/aws-ofi-rccl/lib:${SINGULARITYENV_LD_LIBRARY_PATH} | ||
export SINGULARITY_BIND=$(echo $SINGULARITY_BIND | sed 's|,/usr/lib64/libssh.so.4||g') # do not bind host libssh which is built against a wrong libssl for some reason | ||
export LC_ALL=C | ||
export HF_DATASETS_CACHE="/scratch/project_465000670/.cache/huggingface" | ||
export TRANSFORMERS_CACHE="/scratch/project_465000670/.cache/huggingface" | ||
|
||
# values for distributed setup | ||
GPUS_PER_NODE=$SLURM_GPUS_PER_NODE | ||
NNODES=$SLURM_NNODES | ||
export NODE_RANK=$SLURM_NODEID | ||
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) | ||
export MASTER_PORT=9999 | ||
export WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) | ||
|
||
# compilers in the container | ||
export CC=gcc-11 | ||
export CXX=g++-11 | ||
|
||
CONTAINER="/project/project_465000670/pytorch_rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1.sif" | ||
|
||
SING_BIND="/scratch/project_465000670" | ||
|
||
# hold separate logs for easier debugging | ||
rm -rf separate-logs | ||
mkdir -p separate-logs | ||
|
||
set -exuo pipefail | ||
|
||
# symlink logs/latest.out and logs/latest.err | ||
ln -f -s $SLURM_JOB_ID.out logs/latest.out | ||
ln -f -s $SLURM_JOB_ID.err logs/latest.err | ||
|
||
CHECKPOINT_PATH=checkpoints | ||
|
||
GIT_ROOT=$(git rev-parse --show-toplevel) | ||
PATH_TO_SCRIPTS="scripts/lumi" | ||
cd ${GIT_ROOT} # ensure that we are in the git root for remaining paths to work | ||
CMD=" \ | ||
llm-foundry/scripts/train/train.py \ | ||
${PATH_TO_SCRIPTS}/yamls/continue-mistral-7b.yaml | ||
" | ||
|
||
# Bind masks from Samuel (TODO: unused for now since composer handles process spawning, but might help performance to use this) | ||
c=fe | ||
|
||
# Bind mask for one thread per core | ||
BIND_MASK_1="0x${c}000000000000,0x${c}00000000000000,0x${c}0000,0x${c}000000,0x${c},0x${c}00,0x${c}00000000,0x${c}0000000000" | ||
|
||
# Bind mask for two threads per core | ||
BIND_MASK_2="0x${c}00000000000000${c}000000000000,0x${c}00000000000000${c}00000000000000,0x${c}00000000000000${c}0000,0x${c}00000000000000${c}000000,0x${c}00000000000000${c},0x${c}00000000000000${c}00,0x${c}00000000000000${c}00000000,0x${c}00000000000000${c}0000000000" | ||
|
||
BIND_MASK="$BIND_MASK_1" | ||
#echo "Using --cpu-bind=mask_cpu:$BIND_MASK" | ||
|
||
echo $CMD | ||
|
||
echo "START $SLURM_JOBID: $(date)" | ||
|
||
# --cpu-bind=mask_cpu:$BIND_MASK \ | ||
srun \ | ||
--label \ | ||
singularity exec -B "$SING_BIND" "$CONTAINER" \ | ||
/scratch/project_465000670/danish-foundation-models/scripts/lumi/mosaic_in_container.sh \ | ||
$CMD | ||
|
||
echo "END $SLURM_JOBID: $(date)" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#!/bin/bash | ||
# Important: should be run in the `rocm/pytorch`` container | ||
set -euxo pipefail | ||
export LC_ALL=C | ||
|
||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" | ||
GIT_ROOT=$(git rev-parse --show-toplevel) | ||
|
||
cd ${GIT_ROOT} | ||
python3 -m venv .venv | ||
source .venv/bin/activate | ||
|
||
pip install --upgrade pip | ||
pip install packaging cmake # build requirements | ||
|
||
cd ${GIT_ROOT}/llm-foundry | ||
pip install -e . | ||
|
||
pip install -r ${SCRIPT_DIR}/requirements.txt | ||
|
||
# Install flash attention | ||
TMP_DIR=$(mktemp -d) | ||
git clone --recurse-submodules https://github.com/ROCmSoftwarePlatform/flash-attention ${TMP_DIR} | ||
cd ${TMP_DIR} | ||
|
||
export GPU_ARCHS="gfx90a" | ||
|
||
# export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])') # this is for older versions of pytorch | ||
# patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch | ||
python3 setup.py install |
Oops, something went wrong.