@@ -6,17 +6,49 @@ import (
66 "github.com/born-ml/born/internal/tensor"
77)
88
9+ // LinearOption is a functional option for configuring a Linear layer.
10+ type LinearOption func (* linearConfig )
11+
12+ // linearConfig holds configuration for Linear layer creation.
13+ type linearConfig struct {
14+ useBias bool
15+ }
16+
17+ // defaultLinearConfig returns the default configuration.
18+ func defaultLinearConfig () linearConfig {
19+ return linearConfig {
20+ useBias : true , // Default: use bias (backwards compatible)
21+ }
22+ }
23+
24+ // WithBias sets whether the Linear layer should use bias.
25+ //
26+ // Default is true. Set to false for architectures like LLaMA that don't use bias.
27+ //
28+ // Example:
29+ //
30+ // // Linear layer without bias (LLaMA-style)
31+ // lm_head := nn.NewLinear(hidden_size, vocab_size, backend, nn.WithBias(false))
32+ //
33+ // // Linear layer with bias (default)
34+ // layer := nn.NewLinear(784, 128, backend) // same as WithBias(true)
35+ func WithBias (useBias bool ) LinearOption {
36+ return func (cfg * linearConfig ) {
37+ cfg .useBias = useBias
38+ }
39+ }
40+
941// Linear implements a fully connected (dense) layer.
1042//
1143// Performs the transformation: y = x @ W.T + b
1244// where:
1345// - x is the input tensor with shape [batch_size, in_features]
1446// - W is the weight matrix with shape [out_features, in_features]
15- // - b is the bias vector with shape [out_features]
47+ // - b is the bias vector with shape [out_features] (optional, see WithBias)
1648// - y is the output tensor with shape [batch_size, out_features]
1749//
1850// Weights are initialized using Xavier/Glorot initialization.
19- // Biases are initialized to zeros.
51+ // Biases are initialized to zeros (if enabled) .
2052//
2153// Example:
2254//
@@ -25,6 +57,9 @@ import (
2557//
2658// input := tensor.Randn[float32](tensor.Shape{32, 784}, backend) // batch_size=32
2759// output := layer.Forward(input) // shape: [32, 128]
60+ //
61+ // // Without bias (for LLaMA-style models)
62+ // lm_head := nn.NewLinear(512, vocab_size, backend, nn.WithBias(false))
2863type Linear [B tensor.Backend ] struct {
2964 inFeatures int
3065 outFeatures int
@@ -36,24 +71,42 @@ type Linear[B tensor.Backend] struct {
3671// NewLinear creates a new Linear layer.
3772//
3873// Weights are initialized using Xavier/Glorot uniform distribution.
39- // Biases are initialized to zeros.
74+ // Biases are initialized to zeros (if enabled) .
4075//
4176// Parameters:
4277// - inFeatures: Number of input features
4378// - outFeatures: Number of output features
4479// - backend: Backend to use for tensor operations
80+ // - opts: Optional configuration (see WithBias)
4581//
4682// Returns a new Linear layer.
47- func NewLinear [B tensor.Backend ](inFeatures , outFeatures int , backend B ) * Linear [B ] {
83+ //
84+ // Example:
85+ //
86+ // // With bias (default)
87+ // layer := nn.NewLinear(784, 128, backend)
88+ //
89+ // // Without bias (for LLaMA, attention projections, etc.)
90+ // lm_head := nn.NewLinear(hidden_size, vocab_size, backend, nn.WithBias(false))
91+ func NewLinear [B tensor.Backend ](inFeatures , outFeatures int , backend B , opts ... LinearOption ) * Linear [B ] {
92+ // Apply options
93+ cfg := defaultLinearConfig ()
94+ for _ , opt := range opts {
95+ opt (& cfg )
96+ }
97+
4898 // Weight: [out_features, in_features]
4999 weightShape := tensor.Shape {outFeatures , inFeatures }
50100 weightTensor := Xavier (inFeatures , outFeatures , weightShape , backend )
51101 weight := NewParameter ("weight" , weightTensor )
52102
53- // Bias: [out_features]
54- biasShape := tensor.Shape {outFeatures }
55- biasTensor := Zeros (biasShape , backend )
56- bias := NewParameter ("bias" , biasTensor )
103+ // Bias: [out_features] (optional)
104+ var bias * Parameter [B ]
105+ if cfg .useBias {
106+ biasShape := tensor.Shape {outFeatures }
107+ biasTensor := Zeros (biasShape , backend )
108+ bias = NewParameter ("bias" , biasTensor )
109+ }
57110
58111 return & Linear [B ]{
59112 inFeatures : inFeatures ,
@@ -137,6 +190,11 @@ func (l *Linear[B]) OutFeatures() int {
137190 return l .outFeatures
138191}
139192
193+ // HasBias returns true if this layer has a bias parameter.
194+ func (l * Linear [B ]) HasBias () bool {
195+ return l .bias != nil
196+ }
197+
140198// StateDict returns a map of parameter names to raw tensors.
141199func (l * Linear [B ]) StateDict () map [string ]* tensor.RawTensor {
142200 stateDict := make (map [string ]* tensor.RawTensor )
0 commit comments