Skip to content

Implement adaptive stretch parameters #18

@MaxGhenis

Description

@MaxGhenis

Description

Research and implement adaptive stretch parameters that can be learned during training.

Motivation

The stretch parameters (gamma, zeta) control the hardness of the concrete distribution. Learning them could improve convergence and final sparsity quality.

Research Questions

  1. Should stretch be global, per-layer, or per-weight?
  2. How to constrain stretch to maintain valid distributions?
  3. Does adaptive stretch improve final accuracy/sparsity?

Proposed Implementation

class AdaptiveHardConcrete(nn.Module):
    def __init__(self, *gate_size, init_stretch=0.1, adaptive=True):
        super().__init__()
        ...
        if adaptive:
            # Learn stretch parameter
            self.stretch = nn.Parameter(torch.tensor(init_stretch))
        else:
            self.register_buffer('stretch', torch.tensor(init_stretch))
    
    @property
    def gamma(self):
        return -self.stretch.clamp(0.01, 0.5)  # Constrain range
    
    @property  
    def zeta(self):
        return 1.0 + self.stretch.clamp(0.01, 0.5)

Experiments

  • Compare fixed vs adaptive stretch
  • Ablation on stretch initialization
  • Interaction with temperature scheduling

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions