diff --git a/src/myvllm/layers/linear.py b/src/myvllm/layers/linear.py index 07ea306..ee3a1aa 100644 --- a/src/myvllm/layers/linear.py +++ b/src/myvllm/layers/linear.py @@ -22,18 +22,21 @@ def __init__( # create weight parameter with custom weight loader self.weight = nn.Parameter(torch.empty(output_size, input_size)) - self.weight.weight_loader = self.weight_loader + self.weight.loader = self.weight_loader # create bias parameter if bias: self.bias = nn.Parameter(torch.zeros(output_size)) - self.bias.weight_loader = self.weight_loader + self.bias.loader = self.bias_loader else: self.register_parameter('bias', None) def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor): raise NotImplementedError("Subclasses should implement this method.") + def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor): + raise NotImplementedError("Subclasses should implement this method.") + def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError("Subclasses should implement this method.") @@ -72,6 +75,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor): param.data.copy_(loaded_weights) + def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor): + param.data.copy_(loaded_bias) + def forward(self, x: torch.Tensor) -> torch.Tensor: return nn.functional.linear(x, self.weight, self.bias) @@ -93,17 +99,27 @@ def __init__( # param: parameter after tensor parallelism # loaded_weights: the original full parameter to be loaded into param - def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor): - param_data = param.data + def _load_sharded_dim0(self, param: nn.Parameter, loaded_tensor: torch.Tensor): + param_data = param.data # full_dim on the output column - full_data_output_size = loaded_weights.size(0) + full_size = loaded_tensor.size(0) # dim size after sharding - shard_size = full_data_output_size // self.tp_size + shard_size = full_size // self.tp_size assert shard_size == param_data.size(0), "Shard size does not match parameter size." # starting index start_index = self.tp_rank * shard_size - slided_weight = loaded_weights.narrow(0, start_index, shard_size) - param_data.copy_(slided_weight) + sliced_tensor = loaded_tensor.narrow(0, start_index, shard_size) + param_data.copy_(sliced_tensor) + + def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor): + # weight: [out_features, in_features] + # Column parallel shards weight along output dimension. + self._load_sharded_dim0(param, loaded_weights) + + def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor): + # bias: [out_features] + # Column parallel shards bias along output dimension as well. + self._load_sharded_dim0(param, loaded_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return nn.functional.linear(x, self.weight, self.bias) @@ -119,6 +135,19 @@ def __init__( self.output_sizes = output_sizes super().__init__(input_size, sum(output_sizes), bias) + def _load_merged_sharded_dim0(self, param: nn.Parameter, loaded_tensor: torch.Tensor, loaded_part_id: int): + param_data = param.data + # compute offset + offset = sum(self.output_sizes[:loaded_part_id]) // self.tp_size + # compute size + shard_size = self.output_sizes[loaded_part_id] // self.tp_size + # find the correct slice to be loaded in the sharded parameter + param_data = param_data.narrow(0, offset, shard_size) + # shard the original full weight + loaded_start_index = self.tp_rank * shard_size + shard_tensor = loaded_tensor.narrow(0, loaded_start_index, shard_size) + param_data.copy_(shard_tensor) + # param: parameter to be reloaded after tensor parallelism # loaded_weights: the original full parameter to be loaded into param # the index of merged matrices (e.g. it's 0 for Q, 1 for K, 2 for V assuming QKV are merged together) @@ -135,17 +164,10 @@ def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor, loade output_sizes=sum([4096, 4096, 4096]), # Q, K, V ) which is also sharded by tp_size """ - param_data = param.data - # compute offset - offset = sum(self.output_sizes[:loaded_weight_id]) // self.tp_size - # compute size - shard_size = self.output_sizes[loaded_weight_id] // self.tp_size - # find the correct slice to be loaded in the sharded parameter - param_data = param_data.narrow(0, offset, shard_size) - # shard the original full weight - loaded_weights_start_index = self.tp_rank * shard_size - shard_weights = loaded_weights.narrow(0, loaded_weights_start_index, shard_size) - param_data.copy_(shard_weights) + self._load_merged_sharded_dim0(param, loaded_weights, loaded_weight_id) + + def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor, loaded_bias_id: int): + self._load_merged_sharded_dim0(param, loaded_bias, loaded_bias_id) class QKVColumnParallelLinear(ColumnParallelLinear): @@ -163,37 +185,40 @@ def __init__( self.num_heads = num_heads // self.tp_size self.num_kv_heads = num_kv_heads // self.tp_size # Calculate per-GPU output size - self.output_size = head_size * (self.num_heads + 2 * self.num_kv_heads) + # self.output_size = head_size * (self.num_heads + 2 * self.num_kv_heads) # Pass TOTAL output size to parent (it will divide by tp_size) total_output_size = head_size * (num_heads + 2 * num_kv_heads) super().__init__(input_size, total_output_size, bias=bias) - # load_weight_id: q, k, v - def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor, load_weight_id: str): + def _load_qkv_sharded_dim0(self, param: nn.Parameter, loaded_tensor: torch.Tensor, load_weight_id: str): # batch_size * num_heads * num_token * head_size param_data = param.data - # loaded_weights: batch_size * num_token * (head_size*num_heads) - assert load_weight_id in ['q', 'k', 'v'], "load_weight_id must be one of 'q', 'k', 'v'" - # compute offset + # compute offset and shard size if load_weight_id == 'q': offset = 0 shard_size = self.head_size * self.num_heads elif load_weight_id == 'k': offset = self.head_size * self.num_heads shard_size = self.head_size * self.num_kv_heads - elif load_weight_id == 'v': + elif load_weight_id == "v": offset = self.head_size * self.num_heads + self.head_size * self.num_kv_heads shard_size = self.head_size * self.num_kv_heads else: raise ValueError(f"Unknown load_weight_id: {load_weight_id}") - + param_data = param_data.narrow(0, offset, shard_size) # shard the original full weight - loaded_weights_start_index = self.tp_rank * shard_size - shard_weights = loaded_weights.narrow(0, loaded_weights_start_index, shard_size) + loaded_start_index = self.tp_rank * shard_size + shard_weights = loaded_tensor.narrow(0, loaded_start_index, shard_size) param_data.copy_(shard_weights) + def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor, load_weight_id: str): + self._load_qkv_sharded_dim0(param, loaded_weights, load_weight_id) + + def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor, load_bias_id: str): + self._load_qkv_sharded_dim0(param, loaded_bias, load_bias_id) + class RowParallelLinear(LinearBase): def __init__( @@ -207,7 +232,7 @@ def __init__( super().__init__(input_size // tp_size, output_size, bias, tp_dim=1) def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor): - param_data = param.data + param_data = param.data # full_dim on the input row full_data_input_size = loaded_weights.size(1) # dim size after sharding @@ -218,10 +243,23 @@ def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor): slided_weight = loaded_weights.narrow(1, start_index, shard_size) param_data.copy_(slided_weight) + def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor): + # bias: [out_features] + # Row parallel does not shard bias. + # Each rank keeps a full copy of the bias. + param.data.copy_(loaded_bias) + def forward(self, x: torch.Tensor) -> torch.Tensor: - result = nn.functional.linear(x, self.weight, self.bias) + # Do not add bias before all_reduce. + # Otherwise, each rank adds bias once and all_reduce will sum them. + result = nn.functional.linear(x, self.weight, None) + if self.tp_size > 1: dist.all_reduce(result, op=dist.ReduceOp.SUM) + + if self.bias is not None: + result = result + self.bias + return result