Skip to content

Commit

Permalink
add support for model parallelism (#110)
Browse files Browse the repository at this point in the history
* Add support for Model Parallelism

* Add support for Model Parallelism

* Add support for Model Parallelism

* Add support for Model Parallelism

* Add support for Model Parallelism

* Update

* update

* update

* Update trainer.py

* Update trainer.py

* Update

* Update

* Add files via upload

* Update convert_llama_from_megatron_checkpoint_to_pytorch_checkpoint.py

* Update convert_llama_from_pytorch_checkpoint_to_megatron_checkpoint.py

* Update trainer.py

* Update convert_llama_from_megatron_checkpoint_to_pytorch_checkpoint.py

* Update convert_llama_from_pytorch_checkpoint_to_megatron_checkpoint.py

* update dataloader name

* update comment

---------

Co-authored-by: Cheng <435405393@qq.com>
Co-authored-by: “karots123” <“962”813115@qq.com>
Co-authored-by: kaeli <kaeli@tencent.com>
  • Loading branch information
4 people authored Nov 16, 2023
1 parent 669d46c commit f7f18c8
Show file tree
Hide file tree
Showing 29 changed files with 2,894 additions and 138 deletions.
3 changes: 2 additions & 1 deletion preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def main():
"and process them with `processes_num` processes.")
parser.add_argument("--data_processor",
choices=["bert", "lm", "mlm", "bilm", "albert", "mt", "t5", "cls", "prefixlm",
"gsg", "bart", "cls_mlm", "vit", "vilt", "clip", "s2t", "beit", "dalle", "alpaca"], default="bert",
"gsg", "bart", "cls_mlm", "vit", "vilt", "clip", "s2t", "beit", "dalle",
"llm_pretrain", "llm_sft"], default="bert",
help="The data processor of the pretraining model.")
parser.add_argument("--docs_buffer_size", type=int, default=100000,
help="The buffer size of documents in memory, specific to targets that require negative sampling.")
Expand Down
7 changes: 6 additions & 1 deletion pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,14 @@ def main():

# Model options.
model_opts(parser)

# Model parallelism options.
mp_opts(parser)

parser.add_argument("--data_processor",
choices=["bert", "lm", "mlm", "bilm", "albert", "mt", "t5", "cls",
"prefixlm", "gsg", "bart", "cls_mlm", "vit", "vilt", "clip", "s2t", "beit", "dalle", "alpaca"], default="bert",
"prefixlm", "gsg", "bart", "cls_mlm", "vit", "vilt", "clip", "s2t", "beit", "dalle",
"llm_pretrain", "llm_sft"], default="bert",
help="The data processor of the pretraining model.")
parser.add_argument("--deep_init", action="store_true",
help="Scaling initialization of projection layers by a "
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ six>=1.12.0
packaging
numpy
regex
sentencepiece
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import argparse
import os
import torch


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str, default="models/input_model",
help=".")
parser.add_argument("--output_model_path", type=str, default="models/output_model",
help=".")
parser.add_argument("--layers_num", type=int, default=32)
parser.add_argument("--tensor_model_parallel_size", type=int, default=4)


args = parser.parse_args()

if not os.path.exists(args.output_model_path):
os.mkdir(args.output_model_path)

output_model=torch.load(os.path.join(args.input_model_path,'mp_rank_00_model_states.pt'),map_location='cpu')["module"]

for n in range(1,args.tensor_model_parallel_size):
index=str(n) if len(str(n))==2 else '0'+str(n)
model_name=f"mp_rank_{index}_model_states.pt"
model_piece = torch.load(os.path.join(args.input_model_path,model_name),map_location="cpu")["module"]
output_model["embedding.word.embedding.weight"] = torch.cat((output_model["embedding.word.embedding.weight"],model_piece["embedding.word.embedding.weight"]),dim=-2)

for i in range(args.layers_num):
for j in range(3):
tensor_a=output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers."+ str(j) +".weight"]
tensor_b=model_piece["encoder.transformer." + str(i) + ".self_attn.linear_layers."+ str(j) +".weight"]
output_model["encoder.transformer." + str(i) + ".self_attn.linear_layers."+ str(j) +".weight"]=torch.cat((tensor_a,tensor_b),dim=-2)

tensor_a=output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"]
tensor_b=model_piece["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"]

output_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"]=torch.cat((tensor_a,tensor_b),dim=-1)

tensor_a=output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"]
tensor_b=model_piece["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"]
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"]=torch.cat((tensor_a,tensor_b),dim=-2)

tensor_a=output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"]
tensor_b=model_piece["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"]
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"]=torch.cat((tensor_a,tensor_b),dim=-2)

tensor_a=output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"]
tensor_b=model_piece["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"]
output_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"]=torch.cat((tensor_a,tensor_b),dim=-1)

tensor_a=output_model["target.lm.output_layer.weight"]
tensor_b=model_piece["target.lm.output_layer.weight"]
output_model["target.lm.output_layer.weight"]=torch.cat((tensor_a,tensor_b),dim=-2)

torch.save(output_model,os.path.join(args.output_model_path,'merge_model.bin'))

Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import argparse
import collections
import torch
import os


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str, default="models/input_model.bin",
help=".")
parser.add_argument("--output_model_path", type=str, default="models/output_model",
help=".")
parser.add_argument("--layers_num", type=int, default=32)
parser.add_argument("--tensor_model_parallel_size", type=int, default=4)
parser.add_argument("--hidden_size", type=int, default=4096)
parser.add_argument("--feedforward_size", type=int, default=11008)

args = parser.parse_args()

input_model = torch.load(args.input_model_path)

if not os.path.exists(args.output_model_path):
os.mkdir(args.output_model_path)

seg_feed_size=args.feedforward_size // args.tensor_model_parallel_size
seg_hidden_size = args.hidden_size // args.tensor_model_parallel_size
seg_word_size=input_model["embedding.word.embedding.weight"].size()[0] // args.tensor_model_parallel_size

for n in range(args.tensor_model_parallel_size):
model_piece=collections.OrderedDict()
seg_dim=input_model["embedding.word.embedding.weight"].size()[0]//args.tensor_model_parallel_size
model_piece["embedding.word.embedding.weight"] = input_model["embedding.word.embedding.weight"][n* seg_dim :(n+1) * seg_dim,:]

for i in range(args.layers_num):
for j in range(3):
model_piece["encoder.transformer." + str(i) + ".self_attn.linear_layers."+str(j)+".weight"] = input_model["encoder.transformer." + str(i) + ".self_attn.linear_layers."+str(j)+".weight"][n * seg_hidden_size:(n+1) * seg_hidden_size,:]

model_piece["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"] = \
input_model["encoder.transformer." + str(i) + ".self_attn.final_linear.weight"][:,n * seg_hidden_size:(n+1) * seg_hidden_size]

model_piece["encoder.transformer." + str(i) + ".layer_norm_1.weight"] = \
input_model["encoder.transformer." + str(i) + ".layer_norm_1.weight"]

model_piece["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"] = \
input_model["encoder.transformer." + str(i) + ".feed_forward.linear_1.weight"][n * seg_feed_size:(n+1) * seg_feed_size,:]

model_piece["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"]= \
input_model["encoder.transformer." + str(i) + ".feed_forward.linear_gate.weight"][n * seg_feed_size:(n+1) * seg_feed_size,:]

model_piece["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"] = \
input_model["encoder.transformer." + str(i) + ".feed_forward.linear_2.weight"][:,n * seg_feed_size:(n+1) * seg_feed_size]

model_piece["encoder.transformer." + str(i) + ".layer_norm_2.weight"] = \
input_model["encoder.transformer." + str(i) + ".layer_norm_2.weight"]

model_piece["encoder.layer_norm.weight"] = input_model["encoder.layer_norm.weight"]

model_piece["target.lm.output_layer.weight"]= input_model["target.lm.output_layer.weight"][n * seg_word_size:(n+1) * seg_word_size,:]

name=str(n) if len(str(n))==2 else '0'+str(n)
torch.save(model_piece, os.path.join(args.output_model_path,"mp_rank_"+str(name)+"_model_states.pt"))

2 changes: 1 addition & 1 deletion tencentpretrain/embeddings/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def forward(self, src, seg):
if i == 0:
emb = embedding(src, seg)
else:
emb += embedding(src, seg)
emb = embedding(src, seg) + emb.clone()

if not self.remove_embedding_layernorm:
emb = self.layer_norm(emb)
Expand Down
6 changes: 5 additions & 1 deletion tencentpretrain/embeddings/word_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import torch.nn as nn
from tencentpretrain import mpu


class WordEmbedding(nn.Module):
Expand All @@ -8,7 +9,10 @@ class WordEmbedding(nn.Module):

def __init__(self, args, vocab_size):
super(WordEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, args.emb_size)
if args.use_mp:
self.embedding = mpu.VocabParallelEmbedding(vocab_size, args.emb_size)
else:
self.embedding = nn.Embedding(vocab_size, args.emb_size)
self.emb_size = args.emb_size
self.sinusoidalpos = False
if "sinusoidalpos" in args.embedding:
Expand Down
22 changes: 17 additions & 5 deletions tencentpretrain/encoders/transformer_encoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import torch.nn as nn
from tencentpretrain.utils.rope import precompute_freqs_cis
from tencentpretrain.layers.transformer import TransformerLayer
from tencentpretrain.layers.transformer import TransformerLayer, ParallelTransformerLayer
from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding
from tencentpretrain.layers import *
from tencentpretrain import mpu

class TransformerEncoder(nn.Module):
"""
Expand All @@ -19,6 +20,7 @@ def __init__(self, args):
self.relative_position_embedding = args.relative_position_embedding
self.rotary_position_embedding = args.rotary_position_embedding
self.has_residual_attention = args.has_residual_attention
self.use_mp = args.use_mp
if "deepspeed_checkpoint_activations" in args:
self.deepspeed_checkpoint_activations = args.deepspeed_checkpoint_activations
self.deepspeed_checkpoint_layers_num = args.deepspeed_checkpoint_layers_num
Expand All @@ -31,11 +33,19 @@ def __init__(self, args):
self.linear = nn.Linear(args.emb_size, args.hidden_size)

if self.parameter_sharing:
self.transformer = TransformerLayer(args)
if self.use_mp:
self.transformer = ParallelTransformerLayer(args)
else:
self.transformer = TransformerLayer(args)
else:
self.transformer = nn.ModuleList(
[TransformerLayer(args) for _ in range(self.layers_num)]
)
if self.use_mp:
self.transformer = nn.ModuleList(
[ParallelTransformerLayer(args) for _ in range(self.layers_num)]
)
else:
self.transformer = nn.ModuleList(
[TransformerLayer(args) for _ in range(self.layers_num)]
)
if self.layernorm_positioning == "pre":
self.layer_norm = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)

Expand Down Expand Up @@ -122,6 +132,8 @@ def custom_forward(*inputs):
return x_, y_

return custom_forward
if self.use_mp:
mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.layers_num:
hidden, prev_attn = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num),
Expand Down
Loading

0 comments on commit f7f18c8

Please sign in to comment.