Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 69 additions & 31 deletions src/myvllm/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@ def __init__(

# create weight parameter with custom weight loader
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
self.weight.loader = self.weight_loader

# create bias parameter
if bias:
self.bias = nn.Parameter(torch.zeros(output_size))
self.bias.weight_loader = self.weight_loader
self.bias.loader = self.bias_loader
else:
self.register_parameter('bias', None)

def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor):
raise NotImplementedError("Subclasses should implement this method.")

def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor):
raise NotImplementedError("Subclasses should implement this method.")

def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Subclasses should implement this method.")

Expand Down Expand Up @@ -72,6 +75,9 @@ def __init__(
def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor):
param.data.copy_(loaded_weights)

def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor):
param.data.copy_(loaded_bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.linear(x, self.weight, self.bias)

Expand All @@ -93,17 +99,27 @@ def __init__(

# param: parameter after tensor parallelism
# loaded_weights: the original full parameter to be loaded into param
def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor):
param_data = param.data
def _load_sharded_dim0(self, param: nn.Parameter, loaded_tensor: torch.Tensor):
param_data = param.data
# full_dim on the output column
full_data_output_size = loaded_weights.size(0)
full_size = loaded_tensor.size(0)
# dim size after sharding
shard_size = full_data_output_size // self.tp_size
shard_size = full_size // self.tp_size
assert shard_size == param_data.size(0), "Shard size does not match parameter size."
# starting index
start_index = self.tp_rank * shard_size
slided_weight = loaded_weights.narrow(0, start_index, shard_size)
param_data.copy_(slided_weight)
sliced_tensor = loaded_tensor.narrow(0, start_index, shard_size)
param_data.copy_(sliced_tensor)

def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor):
# weight: [out_features, in_features]
# Column parallel shards weight along output dimension.
self._load_sharded_dim0(param, loaded_weights)

def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor):
# bias: [out_features]
# Column parallel shards bias along output dimension as well.
self._load_sharded_dim0(param, loaded_bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.linear(x, self.weight, self.bias)
Expand All @@ -119,6 +135,19 @@ def __init__(
self.output_sizes = output_sizes
super().__init__(input_size, sum(output_sizes), bias)

def _load_merged_sharded_dim0(self, param: nn.Parameter, loaded_tensor: torch.Tensor, loaded_part_id: int):
param_data = param.data
# compute offset
offset = sum(self.output_sizes[:loaded_part_id]) // self.tp_size
# compute size
shard_size = self.output_sizes[loaded_part_id] // self.tp_size
# find the correct slice to be loaded in the sharded parameter
param_data = param_data.narrow(0, offset, shard_size)
# shard the original full weight
loaded_start_index = self.tp_rank * shard_size
shard_tensor = loaded_tensor.narrow(0, loaded_start_index, shard_size)
param_data.copy_(shard_tensor)

# param: parameter to be reloaded after tensor parallelism
# loaded_weights: the original full parameter to be loaded into param
# the index of merged matrices (e.g. it's 0 for Q, 1 for K, 2 for V assuming QKV are merged together)
Expand All @@ -135,17 +164,10 @@ def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor, loade
output_sizes=sum([4096, 4096, 4096]), # Q, K, V
) which is also sharded by tp_size
"""
param_data = param.data
# compute offset
offset = sum(self.output_sizes[:loaded_weight_id]) // self.tp_size
# compute size
shard_size = self.output_sizes[loaded_weight_id] // self.tp_size
# find the correct slice to be loaded in the sharded parameter
param_data = param_data.narrow(0, offset, shard_size)
# shard the original full weight
loaded_weights_start_index = self.tp_rank * shard_size
shard_weights = loaded_weights.narrow(0, loaded_weights_start_index, shard_size)
param_data.copy_(shard_weights)
self._load_merged_sharded_dim0(param, loaded_weights, loaded_weight_id)

def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor, loaded_bias_id: int):
self._load_merged_sharded_dim0(param, loaded_bias, loaded_bias_id)


class QKVColumnParallelLinear(ColumnParallelLinear):
Expand All @@ -163,37 +185,40 @@ def __init__(
self.num_heads = num_heads // self.tp_size
self.num_kv_heads = num_kv_heads // self.tp_size
# Calculate per-GPU output size
self.output_size = head_size * (self.num_heads + 2 * self.num_kv_heads)
# self.output_size = head_size * (self.num_heads + 2 * self.num_kv_heads)
# Pass TOTAL output size to parent (it will divide by tp_size)
total_output_size = head_size * (num_heads + 2 * num_kv_heads)
super().__init__(input_size, total_output_size, bias=bias)

# load_weight_id: q, k, v
def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor, load_weight_id: str):
def _load_qkv_sharded_dim0(self, param: nn.Parameter, loaded_tensor: torch.Tensor, load_weight_id: str):
# batch_size * num_heads * num_token * head_size
param_data = param.data
# loaded_weights: batch_size * num_token * (head_size*num_heads)
assert load_weight_id in ['q', 'k', 'v'], "load_weight_id must be one of 'q', 'k', 'v'"
# compute offset
# compute offset and shard size
if load_weight_id == 'q':
offset = 0
shard_size = self.head_size * self.num_heads
elif load_weight_id == 'k':
offset = self.head_size * self.num_heads
shard_size = self.head_size * self.num_kv_heads
elif load_weight_id == 'v':
elif load_weight_id == "v":
offset = self.head_size * self.num_heads + self.head_size * self.num_kv_heads
shard_size = self.head_size * self.num_kv_heads
else:
raise ValueError(f"Unknown load_weight_id: {load_weight_id}")

param_data = param_data.narrow(0, offset, shard_size)
# shard the original full weight
loaded_weights_start_index = self.tp_rank * shard_size
shard_weights = loaded_weights.narrow(0, loaded_weights_start_index, shard_size)
loaded_start_index = self.tp_rank * shard_size
shard_weights = loaded_tensor.narrow(0, loaded_start_index, shard_size)

param_data.copy_(shard_weights)

def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor, load_weight_id: str):
self._load_qkv_sharded_dim0(param, loaded_weights, load_weight_id)

def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor, load_bias_id: str):
self._load_qkv_sharded_dim0(param, loaded_bias, load_bias_id)


class RowParallelLinear(LinearBase):
def __init__(
Expand All @@ -207,7 +232,7 @@ def __init__(
super().__init__(input_size // tp_size, output_size, bias, tp_dim=1)

def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor):
param_data = param.data
param_data = param.data
# full_dim on the input row
full_data_input_size = loaded_weights.size(1)
# dim size after sharding
Expand All @@ -218,10 +243,23 @@ def weight_loader(self, param: nn.Parameter, loaded_weights: torch.Tensor):
slided_weight = loaded_weights.narrow(1, start_index, shard_size)
param_data.copy_(slided_weight)

def bias_loader(self, param: nn.Parameter, loaded_bias: torch.Tensor):
# bias: [out_features]
# Row parallel does not shard bias.
# Each rank keeps a full copy of the bias.
param.data.copy_(loaded_bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
result = nn.functional.linear(x, self.weight, self.bias)
# Do not add bias before all_reduce.
# Otherwise, each rank adds bias once and all_reduce will sum them.
result = nn.functional.linear(x, self.weight, None)

if self.tp_size > 1:
dist.all_reduce(result, op=dist.ReduceOp.SUM)

if self.bias is not None:
result = result + self.bias

return result


Expand Down