Skip to content

Commit

Permalink
Merge pull request #190 from centre-for-humanities-computing/lumi
Browse files Browse the repository at this point in the history
LUMI scripts - Mosaic/llm-foundry
  • Loading branch information
rlrs authored Jan 15, 2024
2 parents 4dbcf60 + 1f25b55 commit 457f847
Show file tree
Hide file tree
Showing 11 changed files with 524 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,7 @@ models/*

# Hydra
outputs/*

# training artifacts
logs/
separate-logs/
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[submodule "llm-foundry"]
path = llm-foundry
url = https://github.com/rlrs/llm-foundry
branch = lumi
1 change: 1 addition & 0 deletions llm-foundry
Submodule llm-foundry added at f89ce6
190 changes: 190 additions & 0 deletions scripts/data/jsonl_to_mds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright 2022 MosaicML LLM Foundry authors
# Modified by @rlrs for Danish Foundation Models
# SPDX-License-Identifier: Apache-2.0

"""Convert jsonl data to streaming MDS format, while tokenizing and concatenating."""
import os
from argparse import ArgumentParser, Namespace
from glob import glob
from typing import Dict, Iterable, Optional, Union

import datasets as hf_datasets
from streaming import MDSWriter
from torch.utils.data import IterableDataset
from tqdm import tqdm
from transformers import AutoTokenizer, PreTrainedTokenizerBase
import numpy as np


def parse_args() -> Namespace:
"""Parse commandline arguments."""
parser = ArgumentParser(
description=
'Convert dataset into MDS format, tokenizing and concatenating.'
)
parser.add_argument('--path', type=str, required=True)
parser.add_argument('--out_root', type=str, required=True)
parser.add_argument('--compression', type=str, default='zstd')

parser.add_argument(
'--concat_tokens',
type=int,
help='Convert text to tokens and concatenate up to this many tokens', required=True)

parser.add_argument('--tokenizer', type=str, required=False, default=None)
parser.add_argument('--bos_text', type=str, required=False, default=None)
parser.add_argument('--eos_text', type=str, required=False, default=None)
parser.add_argument('--no_wrap', default=False, action='store_true') # why would you do this?

parser.add_argument('--test_size', type=float, default=0.01)
parser.add_argument('--seed', type=int, default=42)

parsed = parser.parse_args()

if parsed.bos_text is None:
parsed.bos_text = ''
if parsed.eos_text is None:
parsed.eos_text = ''
return parsed

def generate_chunks(dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
bos_tokens: list[int], eos_tokens: list[int], chunk_length: int) -> Iterable[Dict[str, bytes]]:
buffer = np.empty(0, dtype=np.int64, order='C')
for sample in dataset:
iids = sample['input_ids']
buffer = np.append(buffer, [*bos_tokens, *iids, *eos_tokens])
while len(buffer) >= chunk_length:
concat_sample = buffer[:chunk_length]
buffer = buffer[chunk_length:] #if should_wrap else np.empty(0, dtype=np.int64, order='C')
yield {
# convert to bytes to store in MDS binary format
'tokens': np.asarray(concat_sample, dtype=np.int64).tobytes() # unsure why the np.asarray is necessary, tbh, but it is
}


def build_hf_dataset(
path: str,
tokenizer: PreTrainedTokenizerBase,
max_length: Optional[int] = None,
bos_text: str = '',
eos_text: str = '',
) -> IterableDataset:
"""Build an IterableDataset over the HF C4 or pile source data.
Args:
dataset_name (str): Dataset name
max_length (int): The length of concatenated tokens
bos_text (str): text to insert at the beginning of each sequence
eos_text (str): text to insert at the end of each sequence
no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries
tokenizer (PreTrainedTokenizerBase): the tokenizer to use
Returns:
An IterableDataset.
"""
if os.path.isdir(path):
data_files = glob(f'{path}/*')
else:
data_files = path

hf_dataset = hf_datasets.load_dataset('json',
keep_in_memory=False,
data_files=data_files,
split="train")

if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError(
f'{tokenizer=} must be of type PreTrainedTokenizerBase')
if max_length is None:
raise ValueError(f'max_length must be set.')
if bos_text + eos_text == '':
test_tokens = tokenizer('test')
if test_tokens['input_ids'][
0] != tokenizer.bos_token_id and test_tokens['input_ids'][
-1] != tokenizer.eos_token_id:
tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. '
tok_error_msg += 'Concatenating with this tokenizer will result in sequences being '
tok_error_msg += 'attached without a separating token. Please use another tokenizer, '
tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. '
tok_error_msg += '--bos_text=<|endoftext|>.'
raise ValueError(tok_error_msg)

return hf_dataset

def main(args: Namespace) -> None:
"""Main: create C4/pile streaming dataset.
Args:
args (Namespace): Commandline arguments.
"""

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
# we will enforce length, so suppress warnings about sequences too long for the model
tokenizer.model_max_length = int(1e30)
columns = {'tokens': 'bytes'}

# Get samples
dataset = build_hf_dataset(path=args.path,
max_length=args.concat_tokens,
bos_text=args.bos_text,
eos_text=args.eos_text,
tokenizer=tokenizer)

bos_tokens = tokenizer(args.bos_text,
truncation=False,
padding=False,
add_special_tokens=False)['input_ids']
if len(bos_tokens) > 1:
warnings.warn(
f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token\
, instead we got {bos_tokens}. Quit if this was in error.')

eos_tokens = tokenizer(args.eos_text,
truncation=False,
padding=False,
add_special_tokens=False)['input_ids']
if len(eos_tokens) > 1:
warnings.warn(
f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token\
, instead we got {self.eos_tokens}. Quit if this was in error.')

eos_text_provided = args.eos_text != ''
bos_text_provided = args.bos_text != ''
test_text = tokenizer('')
if len(test_text['input_ids']) > 0 and (eos_text_provided or
bos_text_provided):
message = 'both eos and bos' if eos_text_provided and bos_text_provided else (
'eos_text' if eos_text_provided else 'bos_text')
warnings.warn(
f'The provided tokenizer adds special tokens, but you also specified {message}. This may result '
+
'in duplicated special tokens. Please be sure this is what you intend.'
)

def tokenize(batch):
return tokenizer(batch['text'],
truncation=False,
padding=False,
add_special_tokens=False)

# We make a train/test split before chunking since it's way easier, although number
# of samples will vary - splitting after chunking would fix this, but has other issues
dataset = dataset.train_test_split(test_size=args.test_size, seed=args.seed)

print("Tokenizing dataset")
dataset = dataset.map(tokenize, batched=True,
batch_size=24*10,
remove_columns=['text'])

# Write samples while chunking
for split in dataset.keys():
print(f'Writing {split} split... (iterations are samples)')
with MDSWriter(columns=columns,
out=os.path.join(args.out_root, split),
compression=args.compression) as out:
chunks = generate_chunks(dataset[split], bos_tokens, eos_tokens, args.concat_tokens)
for sample in tqdm(chunks):
out.write(sample)

if __name__ == '__main__':
main(parse_args())
12 changes: 12 additions & 0 deletions scripts/lumi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Model training on LUMI

## Dataset preparation
From a jsonl file (such as da-gigaword), something like `python scripts/data/convert_dataset_json.py --path /path/to/da-gigaword.jsonl.tar.gz --out_root ./da-gigaword-mds --concat_tokens 4096 --tokenizer mistralai/Mistral-7B-v0.1 --test_size 0.02` will generate the necessary Mosaic streaming dataset. Takes ~2 hours for da-gigaword, which is a bit slow. When done, copy this folder to LUMI scratch and configure data path in the training YAML, e.g. `scripts/lumi/yamls/continue-mistral-7b.yaml`.

## LUMI setup and training
1. SSH into LUMI
3. Enter project: `cd /scratch/project_465000670/danish-foundation-models`
2. Enter container: `singularity run --cleanenv --bind /scratch/project_465000670/ /project/project_465000670/pytorch_rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1.sif`
5. Set up virtual environment: `./scripts/lumi/make_venv.sh`
6. Exit container
7. Run training: `./scripts/lumi/continue_mistral_mosaic.sh`
101 changes: 101 additions & 0 deletions scripts/lumi/continue_mistral_mosaic.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#!/bin/bash

##SBATCH --exclude=nid006865,nid005613,nid005988
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=56
#SBATCH --mem=0
#SBATCH --partition=standard-g
#SBATCH --time=0-01:00:00
#SBATCH --gpus-per-node=mi250:8
#SBATCH --exclusive=user
#SBATCH --hint=nomultithread
#SBATCH --account=project_465000670
#SBATCH --output=logs/%j.out
#SBATCH --error=logs/%j.err

# if run without sbatch, invoke here
if [ -z $SLURM_JOB_ID ]; then
mkdir -p logs
sbatch "$0"
exit
fi

# LUMI setup
# module load LUMI/22.08 partition/G singularity-bindings/system-cpeGNU-22.08-noglibc # singularity-bindings is BROKEN
module load LUMI/22.08 partition/G

# These replace the module load of singularity-bindings
local_libfabric_version=1.15.2.0
local_craympich_version=8.1.27
export SINGULARITYENV_LD_LIBRARY_PATH="/lib64:/opt/cray/pe/mpich/$local_craympich_version/ofi/gnu/9.1/lib-abi-mpich:/opt/cray/pe/lib64:/opt/cray/pe:/opt/cray/libfabric/$local_libfabric_version/lib64:/usr/lib64:/opt/cray/pe/gcc-libs:${SINGULARITYENV_LD_LIBRARY_PATH}"
export SINGULARITY_BIND="/opt/cray,/usr/lib64/libbrotlidec.so.1,/usr/lib64/libbrotlicommon.so.1,/usr/lib64/libnl-3.so.200,/usr/lib64/libnl-route-3.so.200,/usr/lib64/libcxi.so.1,/usr/lib64/libcurl.so.4,/usr/lib64/libnghttp2.so.14,/usr/lib64/libidn2.so.0,/usr/lib64/libssh.so.4,/usr/lib64/libpsl.so.5,/usr/lib64/libssl.so.1.1,/usr/lib64/libcrypto.so.1.1,/usr/lib64/libgssapi_krb5.so.2,/usr/lib64/libldap_r-2.4.so.2,/usr/lib64/liblber-2.4.so.2,/usr/lib64/libjson-c.so.3,/usr/lib64/libunistring.so.2,/usr/lib64/libkrb5.so.3,/usr/lib64/libk5crypto.so.3,/usr/lib64/libkrb5support.so.0,/usr/lib64/libsasl2.so.3,/usr/lib64/libkeyutils.so.1,/var/spool/slurmd/mpi_cray_shasta,/usr/lib64/libzstd.so.1,/lib64/libselinux.so.1,/usr/lib64/libpcre.so.1,${SINGULARITY_BIND}"

# These are some more custom exports
export SINGULARITY_BIND=/users/larsenra/aws-ofi-rccl/install:/opt/aws-ofi-rccl,/usr/lib64/libjitterentropy.so.3,${SINGULARITY_BIND}
export SINGULARITYENV_LD_LIBRARY_PATH=/opt/ompi/lib:${EBROOTAWSMINOFIMINRCCL}/lib:/opt/cray/xpmem/2.5.2-2.4_3.47__gd0f7936.shasta/lib64:/opt/aws-ofi-rccl/lib:${SINGULARITYENV_LD_LIBRARY_PATH}
export SINGULARITY_BIND=$(echo $SINGULARITY_BIND | sed 's|,/usr/lib64/libssh.so.4||g') # do not bind host libssh which is built against a wrong libssl for some reason
export LC_ALL=C
export HF_DATASETS_CACHE="/scratch/project_465000670/.cache/huggingface"
export TRANSFORMERS_CACHE="/scratch/project_465000670/.cache/huggingface"

# values for distributed setup
GPUS_PER_NODE=$SLURM_GPUS_PER_NODE
NNODES=$SLURM_NNODES
export NODE_RANK=$SLURM_NODEID
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=9999
export WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

# compilers in the container
export CC=gcc-11
export CXX=g++-11

CONTAINER="/project/project_465000670/pytorch_rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1.sif"

SING_BIND="/scratch/project_465000670"

# hold separate logs for easier debugging
rm -rf separate-logs
mkdir -p separate-logs

set -exuo pipefail

# symlink logs/latest.out and logs/latest.err
ln -f -s $SLURM_JOB_ID.out logs/latest.out
ln -f -s $SLURM_JOB_ID.err logs/latest.err

CHECKPOINT_PATH=checkpoints

GIT_ROOT=$(git rev-parse --show-toplevel)
PATH_TO_SCRIPTS="scripts/lumi"
cd ${GIT_ROOT} # ensure that we are in the git root for remaining paths to work
CMD=" \
llm-foundry/scripts/train/train.py \
${PATH_TO_SCRIPTS}/yamls/continue-mistral-7b.yaml
"

# Bind masks from Samuel (TODO: unused for now since composer handles process spawning, but might help performance to use this)
c=fe

# Bind mask for one thread per core
BIND_MASK_1="0x${c}000000000000,0x${c}00000000000000,0x${c}0000,0x${c}000000,0x${c},0x${c}00,0x${c}00000000,0x${c}0000000000"

# Bind mask for two threads per core
BIND_MASK_2="0x${c}00000000000000${c}000000000000,0x${c}00000000000000${c}00000000000000,0x${c}00000000000000${c}0000,0x${c}00000000000000${c}000000,0x${c}00000000000000${c},0x${c}00000000000000${c}00,0x${c}00000000000000${c}00000000,0x${c}00000000000000${c}0000000000"

BIND_MASK="$BIND_MASK_1"
#echo "Using --cpu-bind=mask_cpu:$BIND_MASK"

echo $CMD

echo "START $SLURM_JOBID: $(date)"

# --cpu-bind=mask_cpu:$BIND_MASK \
srun \
--label \
singularity exec -B "$SING_BIND" "$CONTAINER" \
/scratch/project_465000670/danish-foundation-models/scripts/lumi/mosaic_in_container.sh \
$CMD

echo "END $SLURM_JOBID: $(date)"
30 changes: 30 additions & 0 deletions scripts/lumi/make_venv.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/bin/bash
# Important: should be run in the `rocm/pytorch`` container
set -euxo pipefail
export LC_ALL=C

SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
GIT_ROOT=$(git rev-parse --show-toplevel)

cd ${GIT_ROOT}
python3 -m venv .venv
source .venv/bin/activate

pip install --upgrade pip
pip install packaging cmake # build requirements

cd ${GIT_ROOT}/llm-foundry
pip install -e .

pip install -r ${SCRIPT_DIR}/requirements.txt

# Install flash attention
TMP_DIR=$(mktemp -d)
git clone --recurse-submodules https://github.com/ROCmSoftwarePlatform/flash-attention ${TMP_DIR}
cd ${TMP_DIR}

export GPU_ARCHS="gfx90a"

# export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])') # this is for older versions of pytorch
# patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
python3 setup.py install
Loading

0 comments on commit 457f847

Please sign in to comment.