diff --git a/timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py b/timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py index 65153c01..20b9394d 100644 --- a/timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py +++ b/timesfm-forecasting/examples/covariates-forecasting/demo_covariates.py @@ -411,9 +411,18 @@ def demonstrate_api() -> None: pip install timesfm[xreg] import timesfm -hparams = timesfm.TimesFmHparams(backend="cpu", per_core_batch_size=32, horizon_len=12) -ckpt = timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-2.5-200m-pytorch") -model = timesfm.TimesFm(hparams=hparams, checkpoint=ckpt) + +model = timesfm.TimesFM_2p5_200M_torch.from_pretrained( + "google/timesfm-2.5-200m-pytorch", + torch_compile=False, +) +model.compile(timesfm.ForecastConfig( + max_context=512, + max_horizon=12, + normalize_inputs=True, + use_continuous_quantile_head=True, + fix_quantile_crossing=True, +)) point_fc, quant_fc = model.forecast_with_covariates( inputs=[sales_a, sales_b, sales_c], @@ -423,8 +432,8 @@ def demonstrate_api() -> None: xreg_mode="xreg + timesfm", normalize_xreg_target_per_input=True, ) -# point_fc: (num_series, horizon_len) -# quant_fc: (num_series, horizon_len, 10) +# point_fc: (num_series, horizon) +# quant_fc: (num_series, horizon, 10) # columns: [mean, q10, q20, ..., q90] """)