We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
您好,感谢您的工作!我把您的clex layer部分插到我的模型中,实现方式如下:
class Encoder(nn.Module): def __init__(self, config): '''省略''' elif config.my_info_dict.get("algorithm",False)=="clex": from .clex_layer import CLEXScalingRotaryEmbedding rope_scaling={"factor": 1,"max_factor": 64,"param_factor": 1,"time_dt": 0.01,"type": "clex","act": "tanh"} self.clex_layer = CLEXScalingRotaryEmbedding(config.attention_key_size, self.config.my_info_dict["train_len"], rope_scaling) '''省略''' def forward( self, '''省略''' ): '''省略''' if self.config.my_info_dict.get("algorithm", False)=="clex": sinusoidal_pos = self.clex_layer(seqlen, do_train) '''省略'''
其中CLEXScalingRotaryEmbedding类仅进行了与核心操作无关的修改:
class ODELinear(nn.Module): def __init__( self, dim: int, factor, act, base=10000, **kwargs ): super().__init__() self.ode_up_proj = nn.Parameter(torch.empty(dim//2, factor*dim)) self.ode_down_proj = nn.Parameter(torch.empty(factor*dim, dim//2)) self.dim = dim self.base = base if act == "tanh": self.act = torch.nn.Tanh() elif act == "silu": self.act = torch.nn.SiLU() else: raise ValueError(f"act must be one of ['tanh', 'silu'], got {act}") self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.ode_up_proj, a=math.sqrt(5)) nn.init.zeros_(self.ode_down_proj) def get_time_embedding(self, t, base=10000, device='cuda', dtype=torch.float32): if t < 1: alpha = 1 else: alpha = 2*t-1 ntk_base = base * alpha ** (self.dim / (self.dim-2)) ntk_inv_freq = 1.0 / (ntk_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim)) index = torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) delta_ntk_freq = -2*index/(self.dim-2) * 1 / (base ** (index/self.dim) * (alpha ** (index/(self.dim-2) + 1))) return delta_ntk_freq.to(device, dtype=dtype), ntk_inv_freq.to(device, dtype=dtype) def forward(self, t, x: torch.Tensor): device = x.device delta_time, time = self.get_time_embedding(t.to(device), device=device, dtype=x.dtype) x = x + torch.log(time) time_embed = delta_time / time delta_inv_freq = self.act(x @ self.ode_up_proj.float()) @ self.ode_down_proj.float() delta_inv_freq = delta_inv_freq + time_embed return delta_inv_freq class CLEXScalingRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, rope_scaling=None, base=10000, device=None) -> None: super().__init__() self.max_t = rope_scaling["max_factor"] self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) self.proj_func = ODELinear(dim, rope_scaling["param_factor"], rope_scaling["act"], base) self.rope_cached = None self.max_t_cached = 0 self.freq_cached = None self.time_dt = rope_scaling["time_dt"] self.ode_args = { "method": "rk4", "options": {"step_size": self.time_dt}, } def sample_random_times(self, max_t, device): return torch.randint(1, max_t, (1,), dtype = torch.long, device=device) def get_random_position_ids(self, n=2048, max=8192): positions = torch.randperm(max)[:n].sort().values return positions def get_continuous_freq(self, time_grid, ex_positions, device): solution = odeint( self.proj_func, torch.log(self.inv_freq.to(device, dtype=torch.float32)), time_grid, **self.ode_args ) if time_grid.size(0) == 2: scale_inv_freq = torch.exp(solution[1]) freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq) else: scale_inv_freq = torch.exp(solution) return scale_inv_freq embed = torch.cat((freqs,freqs), dim=-1) return embed def forward(self, seq_len, do_train=False): device = self.proj_func.ode_up_proj.device scale_factor = seq_len // self.max_position_embeddings if do_train: t_val = self.sample_random_times(self.max_t+1, device)[0] if scale_factor < 1.0: scale_factor = 1 sampled_position_ids = self.get_random_position_ids(n=seq_len-2, max=seq_len*t_val-2).float() ex_positions = torch.cat([ torch.tensor([0]), (sampled_position_ids + 1) / scale_factor, torch.tensor([seq_len*t_val//scale_factor-1])] ).to(device, dtype=torch.float32) else: t_val = scale_factor if seq_len%self.max_position_embeddings == 0.0 else scale_factor + 1 t_val = t_val if t_val <= self.max_t else self.max_t ex_positions = torch.arange(0, self.max_position_embeddings * t_val, dtype=torch.float32).to(device) if t_val == 1.0: scale_inv_freq = self.inv_freq.to(device) freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq) embed = torch.cat((freqs,freqs), dim=-1) cos, sin = embed.cos(), embed.sin() elif do_train: time_grid = torch.tensor([1.0, t_val]).float().to(device) embed = self.get_continuous_freq(time_grid, ex_positions, device) cos, sin = embed.cos(), embed.sin() else: if self.freq_cached is None: time_grid = torch.arange(1.0, self.max_t+1.0, dtype=torch.float32).to(device) self.freq_cached = self.get_continuous_freq(time_grid, ex_positions, device) if t_val != self.max_t_cached: scale_inv_freq = self.freq_cached[int(t_val-1.0)] freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq) embed = torch.cat((freqs,freqs), dim=-1) self.rope_cached = torch.cat((embed.cos()[None, :, :], embed.sin()[None, :, :]), dim=0) self.max_t_cached = t_val cos, sin = self.rope_cached return cos[None, :seq_len], sin[None, :seq_len]
但是训练速度十分缓慢。按照您论文中的,在模型之上微调一个epoch,我在我的模型上加入如上实现,微调一个epoch需要两个小时,并且出现了loss和梯度为NaN的情况。而去掉CLEX,训练时长表现正常,约三分钟一个epoch。我的实验使用4个Tesla V100-SXM2-32GB GPU。
感谢您的回复!
The text was updated successfully, but these errors were encountered:
No branches or pull requests
您好,感谢您的工作!我把您的clex layer部分插到我的模型中,实现方式如下:
其中CLEXScalingRotaryEmbedding类仅进行了与核心操作无关的修改:
但是训练速度十分缓慢。按照您论文中的,在模型之上微调一个epoch,我在我的模型上加入如上实现,微调一个epoch需要两个小时,并且出现了loss和梯度为NaN的情况。而去掉CLEX,训练时长表现正常,约三分钟一个epoch。我的实验使用4个Tesla V100-SXM2-32GB GPU。
感谢您的回复!
The text was updated successfully, but these errors were encountered: