-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Description
Research and implement adaptive stretch parameters that can be learned during training.
Motivation
The stretch parameters (gamma, zeta) control the hardness of the concrete distribution. Learning them could improve convergence and final sparsity quality.
Research Questions
- Should stretch be global, per-layer, or per-weight?
- How to constrain stretch to maintain valid distributions?
- Does adaptive stretch improve final accuracy/sparsity?
Proposed Implementation
class AdaptiveHardConcrete(nn.Module):
def __init__(self, *gate_size, init_stretch=0.1, adaptive=True):
super().__init__()
...
if adaptive:
# Learn stretch parameter
self.stretch = nn.Parameter(torch.tensor(init_stretch))
else:
self.register_buffer('stretch', torch.tensor(init_stretch))
@property
def gamma(self):
return -self.stretch.clamp(0.01, 0.5) # Constrain range
@property
def zeta(self):
return 1.0 + self.stretch.clamp(0.01, 0.5)Experiments
- Compare fixed vs adaptive stretch
- Ablation on stretch initialization
- Interaction with temperature scheduling
Metadata
Metadata
Assignees
Labels
No labels