Skip to content

Commit cbb6dfc

Browse files
authored
feat(nn): add WithBias option for Linear layer (#22)
Add functional options pattern to NewLinear for optional bias: - Add LinearOption type and WithBias(bool) function - Add HasBias() method for introspection - Update SwiGLUFFN to use public API - Export WithBias in public nn package Usage: nn.NewLinear(in, out, backend, nn.WithBias(false)) Closes feature request for Linear without bias support.
1 parent d0d4ea3 commit cbb6dfc

5 files changed

Lines changed: 188 additions & 39 deletions

File tree

internal/nn/glu_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ func TestSwiGLUFFN_DefaultFFNDim(t *testing.T) {
305305
func TestNewLinearNoBias(t *testing.T) {
306306
backend := autodiff.New(cpu.New())
307307

308-
linear := newLinearNoBias[Backend](128, 256, backend)
308+
linear := NewLinear[Backend](128, 256, backend, WithBias(false))
309309

310310
// Check parameters
311311
params := linear.Parameters()

internal/nn/linear.go

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,49 @@ import (
66
"github.com/born-ml/born/internal/tensor"
77
)
88

9+
// LinearOption is a functional option for configuring a Linear layer.
10+
type LinearOption func(*linearConfig)
11+
12+
// linearConfig holds configuration for Linear layer creation.
13+
type linearConfig struct {
14+
useBias bool
15+
}
16+
17+
// defaultLinearConfig returns the default configuration.
18+
func defaultLinearConfig() linearConfig {
19+
return linearConfig{
20+
useBias: true, // Default: use bias (backwards compatible)
21+
}
22+
}
23+
24+
// WithBias sets whether the Linear layer should use bias.
25+
//
26+
// Default is true. Set to false for architectures like LLaMA that don't use bias.
27+
//
28+
// Example:
29+
//
30+
// // Linear layer without bias (LLaMA-style)
31+
// lm_head := nn.NewLinear(hidden_size, vocab_size, backend, nn.WithBias(false))
32+
//
33+
// // Linear layer with bias (default)
34+
// layer := nn.NewLinear(784, 128, backend) // same as WithBias(true)
35+
func WithBias(useBias bool) LinearOption {
36+
return func(cfg *linearConfig) {
37+
cfg.useBias = useBias
38+
}
39+
}
40+
941
// Linear implements a fully connected (dense) layer.
1042
//
1143
// Performs the transformation: y = x @ W.T + b
1244
// where:
1345
// - x is the input tensor with shape [batch_size, in_features]
1446
// - W is the weight matrix with shape [out_features, in_features]
15-
// - b is the bias vector with shape [out_features]
47+
// - b is the bias vector with shape [out_features] (optional, see WithBias)
1648
// - y is the output tensor with shape [batch_size, out_features]
1749
//
1850
// Weights are initialized using Xavier/Glorot initialization.
19-
// Biases are initialized to zeros.
51+
// Biases are initialized to zeros (if enabled).
2052
//
2153
// Example:
2254
//
@@ -25,6 +57,9 @@ import (
2557
//
2658
// input := tensor.Randn[float32](tensor.Shape{32, 784}, backend) // batch_size=32
2759
// output := layer.Forward(input) // shape: [32, 128]
60+
//
61+
// // Without bias (for LLaMA-style models)
62+
// lm_head := nn.NewLinear(512, vocab_size, backend, nn.WithBias(false))
2863
type Linear[B tensor.Backend] struct {
2964
inFeatures int
3065
outFeatures int
@@ -36,24 +71,42 @@ type Linear[B tensor.Backend] struct {
3671
// NewLinear creates a new Linear layer.
3772
//
3873
// Weights are initialized using Xavier/Glorot uniform distribution.
39-
// Biases are initialized to zeros.
74+
// Biases are initialized to zeros (if enabled).
4075
//
4176
// Parameters:
4277
// - inFeatures: Number of input features
4378
// - outFeatures: Number of output features
4479
// - backend: Backend to use for tensor operations
80+
// - opts: Optional configuration (see WithBias)
4581
//
4682
// Returns a new Linear layer.
47-
func NewLinear[B tensor.Backend](inFeatures, outFeatures int, backend B) *Linear[B] {
83+
//
84+
// Example:
85+
//
86+
// // With bias (default)
87+
// layer := nn.NewLinear(784, 128, backend)
88+
//
89+
// // Without bias (for LLaMA, attention projections, etc.)
90+
// lm_head := nn.NewLinear(hidden_size, vocab_size, backend, nn.WithBias(false))
91+
func NewLinear[B tensor.Backend](inFeatures, outFeatures int, backend B, opts ...LinearOption) *Linear[B] {
92+
// Apply options
93+
cfg := defaultLinearConfig()
94+
for _, opt := range opts {
95+
opt(&cfg)
96+
}
97+
4898
// Weight: [out_features, in_features]
4999
weightShape := tensor.Shape{outFeatures, inFeatures}
50100
weightTensor := Xavier(inFeatures, outFeatures, weightShape, backend)
51101
weight := NewParameter("weight", weightTensor)
52102

53-
// Bias: [out_features]
54-
biasShape := tensor.Shape{outFeatures}
55-
biasTensor := Zeros(biasShape, backend)
56-
bias := NewParameter("bias", biasTensor)
103+
// Bias: [out_features] (optional)
104+
var bias *Parameter[B]
105+
if cfg.useBias {
106+
biasShape := tensor.Shape{outFeatures}
107+
biasTensor := Zeros(biasShape, backend)
108+
bias = NewParameter("bias", biasTensor)
109+
}
57110

58111
return &Linear[B]{
59112
inFeatures: inFeatures,
@@ -137,6 +190,11 @@ func (l *Linear[B]) OutFeatures() int {
137190
return l.outFeatures
138191
}
139192

193+
// HasBias returns true if this layer has a bias parameter.
194+
func (l *Linear[B]) HasBias() bool {
195+
return l.bias != nil
196+
}
197+
140198
// StateDict returns a map of parameter names to raw tensors.
141199
func (l *Linear[B]) StateDict() map[string]*tensor.RawTensor {
142200
stateDict := make(map[string]*tensor.RawTensor)

internal/nn/nn_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,99 @@ func TestLinear_ForwardBatch(t *testing.T) {
163163
}
164164
}
165165

166+
// TestLinear_WithBias tests Linear layer WithBias option.
167+
func TestLinear_WithBias(t *testing.T) {
168+
backend := autodiff.New(cpu.New())
169+
170+
// Test 1: Default (with bias)
171+
layerWithBias := nn.NewLinear(10, 5, backend)
172+
if !layerWithBias.HasBias() {
173+
t.Error("Default Linear should have bias")
174+
}
175+
if layerWithBias.Bias() == nil {
176+
t.Error("Bias() should not be nil for default Linear")
177+
}
178+
if len(layerWithBias.Parameters()) != 2 {
179+
t.Errorf("Default Linear should have 2 parameters, got %d", len(layerWithBias.Parameters()))
180+
}
181+
182+
// Test 2: Explicit WithBias(true)
183+
layerExplicitBias := nn.NewLinear(10, 5, backend, nn.WithBias(true))
184+
if !layerExplicitBias.HasBias() {
185+
t.Error("Linear with WithBias(true) should have bias")
186+
}
187+
if len(layerExplicitBias.Parameters()) != 2 {
188+
t.Errorf("Linear with bias should have 2 parameters, got %d", len(layerExplicitBias.Parameters()))
189+
}
190+
191+
// Test 3: WithBias(false)
192+
layerNoBias := nn.NewLinear(10, 5, backend, nn.WithBias(false))
193+
if layerNoBias.HasBias() {
194+
t.Error("Linear with WithBias(false) should not have bias")
195+
}
196+
if layerNoBias.Bias() != nil {
197+
t.Error("Bias() should be nil for Linear without bias")
198+
}
199+
if len(layerNoBias.Parameters()) != 1 {
200+
t.Errorf("Linear without bias should have 1 parameter, got %d", len(layerNoBias.Parameters()))
201+
}
202+
203+
// Verify weight is still properly initialized
204+
weight := layerNoBias.Weight().Tensor()
205+
expectedShape := tensor.Shape{5, 10}
206+
if !weight.Shape().Equal(expectedShape) {
207+
t.Errorf("Weight shape = %v, want %v", weight.Shape(), expectedShape)
208+
}
209+
}
210+
211+
// TestLinear_NoBias_Forward tests forward pass of Linear without bias.
212+
func TestLinear_NoBias_Forward(t *testing.T) {
213+
backend := autodiff.New(cpu.New())
214+
215+
// Create linear layer without bias
216+
layer := nn.NewLinear(2, 2, backend, nn.WithBias(false))
217+
218+
// Set known weights: [[1, 2], [3, 4]] (out=2, in=2)
219+
weightData := []float32{1, 2, 3, 4}
220+
copy(layer.Weight().Tensor().Raw().AsFloat32(), weightData)
221+
222+
// Input: [[1, 1]] (batch=1, in=2)
223+
input, _ := tensor.FromSlice([]float32{1, 1}, tensor.Shape{1, 2}, backend)
224+
225+
// Forward pass
226+
output := layer.Forward(input)
227+
228+
// Expected:
229+
// y = x @ W.T (no bias)
230+
// W.T = [[1, 3], [2, 4]]
231+
// x @ W.T = [1, 1] @ [[1, 3], [2, 4]] = [3, 7]
232+
expected := []float32{3.0, 7.0}
233+
actual := output.Raw().AsFloat32()
234+
235+
for i, exp := range expected {
236+
if !floatEqual(actual[i], exp, 1e-5) {
237+
t.Errorf("Output[%d] = %f, want %f", i, actual[i], exp)
238+
}
239+
}
240+
}
241+
242+
// TestLinear_NoBias_StateDict tests StateDict for Linear without bias.
243+
func TestLinear_NoBias_StateDict(t *testing.T) {
244+
backend := autodiff.New(cpu.New())
245+
246+
layer := nn.NewLinear(4, 3, backend, nn.WithBias(false))
247+
248+
stateDict := layer.StateDict()
249+
250+
// Should have weight but no bias
251+
if _, ok := stateDict["weight"]; !ok {
252+
t.Error("StateDict should contain 'weight'")
253+
}
254+
if _, ok := stateDict["bias"]; ok {
255+
t.Error("StateDict should not contain 'bias' for layer without bias")
256+
}
257+
}
258+
166259
// TestReLU_Forward tests ReLU activation.
167260
func TestReLU_Forward(t *testing.T) {
168261
backend := autodiff.New(cpu.New())

internal/nn/swiglu_ffn.go

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,12 @@ func NewSwiGLUFFN[B tensor.Backend](cfg SwiGLUFFNConfig, backend B) *SwiGLUFFN[B
9191
panic(fmt.Sprintf("SwiGLUFFN: unknown GLUVariant %q, expected swiglu/geglu/reglu/glu", cfg.GLUVariant))
9292
}
9393

94-
// Create projections
94+
// Create projections using WithBias option
9595
// Note: LLaMA doesn't use bias in FFN layers
96-
var gateProj, upProj, downProj *Linear[B]
97-
98-
if cfg.UseBias {
99-
gateProj = NewLinear[B](cfg.EmbedDim, cfg.FFNDim, backend)
100-
upProj = NewLinear[B](cfg.EmbedDim, cfg.FFNDim, backend)
101-
downProj = NewLinear[B](cfg.FFNDim, cfg.EmbedDim, backend)
102-
} else {
103-
gateProj = newLinearNoBias[B](cfg.EmbedDim, cfg.FFNDim, backend)
104-
upProj = newLinearNoBias[B](cfg.EmbedDim, cfg.FFNDim, backend)
105-
downProj = newLinearNoBias[B](cfg.FFNDim, cfg.EmbedDim, backend)
106-
}
96+
biasOpt := WithBias(cfg.UseBias)
97+
gateProj := NewLinear[B](cfg.EmbedDim, cfg.FFNDim, backend, biasOpt)
98+
upProj := NewLinear[B](cfg.EmbedDim, cfg.FFNDim, backend, biasOpt)
99+
downProj := NewLinear[B](cfg.FFNDim, cfg.EmbedDim, backend, biasOpt)
107100

108101
return &SwiGLUFFN[B]{
109102
gateProj: gateProj,
@@ -191,19 +184,3 @@ func (f *SwiGLUFFN[B]) UpProj() *Linear[B] {
191184
func (f *SwiGLUFFN[B]) DownProj() *Linear[B] {
192185
return f.downProj
193186
}
194-
195-
// newLinearNoBias creates a Linear layer without bias.
196-
func newLinearNoBias[B tensor.Backend](inFeatures, outFeatures int, backend B) *Linear[B] {
197-
// Initialize weight with Xavier/Glorot
198-
weightShape := tensor.Shape{outFeatures, inFeatures}
199-
weightTensor := Xavier(inFeatures, outFeatures, weightShape, backend)
200-
weight := NewParameter("weight", weightTensor)
201-
202-
return &Linear[B]{
203-
weight: weight,
204-
bias: nil, // No bias
205-
inFeatures: inFeatures,
206-
outFeatures: outFeatures,
207-
backend: backend,
208-
}
209-
}

nn/nn.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,35 @@ func NewParameter[B tensor.Backend](name string, t *tensor.Tensor[float32, B]) *
2525
// Linear represents a fully connected (dense) layer.
2626
type Linear[B tensor.Backend] = nn.Linear[B]
2727

28+
// LinearOption is a functional option for configuring a Linear layer.
29+
type LinearOption = nn.LinearOption
30+
31+
// WithBias sets whether the Linear layer should use bias.
32+
//
33+
// Default is true. Set to false for architectures like LLaMA that don't use bias.
34+
//
35+
// Example:
36+
//
37+
// // Linear layer without bias (LLaMA-style)
38+
// lm_head := nn.NewLinear(hidden_size, vocab_size, backend, nn.WithBias(false))
39+
//
40+
// // Linear layer with bias (default)
41+
// layer := nn.NewLinear(784, 128, backend) // same as WithBias(true)
42+
func WithBias(useBias bool) LinearOption {
43+
return nn.WithBias(useBias)
44+
}
45+
2846
// NewLinear creates a new linear layer with Xavier initialization.
2947
//
3048
// Example:
3149
//
3250
// backend := cpu.New()
3351
// layer := nn.NewLinear(784, 128, backend)
34-
func NewLinear[B tensor.Backend](inFeatures, outFeatures int, backend B) *Linear[B] {
35-
return nn.NewLinear(inFeatures, outFeatures, backend)
52+
//
53+
// // Without bias (for LLaMA, attention projections, etc.)
54+
// lm_head := nn.NewLinear(hidden_size, vocab_size, backend, nn.WithBias(false))
55+
func NewLinear[B tensor.Backend](inFeatures, outFeatures int, backend B, opts ...LinearOption) *Linear[B] {
56+
return nn.NewLinear(inFeatures, outFeatures, backend, opts...)
3657
}
3758

3859
// Conv2D represents a 2D convolutional layer.

0 commit comments

Comments
 (0)