-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmodels.py
More file actions
2611 lines (2384 loc) · 139 KB
/
models.py
File metadata and controls
2611 lines (2384 loc) · 139 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Main DLP model for single-image and dynamics.
"""
# imports
import numpy as np
# torch
import torch
import torch.nn.functional as F
import torch.nn as nn
from modules.modules import DLPEncoder, DLPDecoder, DLPContext
from modules.modules import DLPDynamics
# util functions
from utils.util_func import calc_model_size, generate_dlp_logo
from utils.loss_functions import calc_reconstruction_loss, calc_kl_beta_dist, calc_kl, LossLPIPS, calc_kl_categorical, \
ChamferLossKL
from modules.vision_modules import rgb_to_minusoneone, minusoneone_to_rgb
class DLP(nn.Module):
def __init__(self,
# Input configuration
cdim=3, # Number of input image channels
image_size=64, # Input image size (assumed square)
normalize_rgb=False, # If True, normalize RGB to [-1, 1], else keep [0, 1]
n_views=1, # number of input views (e.g., multiple cameras)
# Keypoint and patch configuration
n_kp_per_patch=1, # Number of proposal/prior keypoints to extract per patch
patch_size=16, # Size of patches for keypoint proposal network
anchor_s=0.25, # Glimpse size ratio relative to image size
n_kp_enc=20, # Number of posterior keypoints to learn
n_kp_prior=64, # Number of keypoints to filter from prior proposals
warmup_n_kp_ratio=1.0,
mask_bg_in_enc=True, # before encoding the bg, mask with the particles' obj_on
# Network configuration
pad_mode='zeros', # Padding mode for CNNs ('zeros' or 'replicate')
dropout=0.1, # Dropout rate for transformers
# Feature representation
features_dist='gauss', # Distribution type for features ('gauss' or 'categorical')
learned_feature_dim=16, # Dimension of learned visual features
learned_bg_feature_dim=None, # Background feature dimension (if None, equals learned_feature_dim)
n_fg_categories=8, # Number of foreground feature categories (if categorical)
n_fg_classes=4, # Number of foreground feature classes per category
n_bg_categories=4, # Number of background feature categories
n_bg_classes=4, # Number of background feature classes per category
# Prior distributions parameters
scale_std=0.3, # Prior standard deviation for scale
offset_std=0.2, # Prior standard deviation for offset
# z_t (transparency)
obj_on_alpha=0.01, # Alpha parameter for transparency Beta distribution
obj_on_beta=0.01, # Beta parameter for transparency Beta distribution
obj_on_min=1e-4, # Minimum concentration in Beta dist transparency value
obj_on_max=100, # Maximum concentration in Beta dist for transparency value
# Object decoder architecture
obj_res_from_fc=8, # Initial resolution for object encoder-decoder
obj_ch_mult_prior=(1, 2, 3), # Channel multipliers for prior patch encoder (kp proposal)
obj_ch_mult=(1, 2, 3), # Channel multipliers for object encoder-decoder
obj_base_ch=32, # Base channels for object encoder-decoder
obj_final_cnn_ch=32, # Final CNN channels for object encoder-decoder
# Background decoder architecture
bg_res_from_fc=8, # Initial resolution for background encoder-decoder
bg_ch_mult=(1, 2, 3), # Channel multipliers for background encoder-decoder
bg_base_ch=32, # Base channels for background decoder
bg_final_cnn_ch=32, # Final CNN channels for background encoder-decoder
# Network architecture options
use_resblock=True, # Use residual blocks in decoders
num_res_blocks=2, # Number of residual blocks per resolution
cnn_mid_blocks=False, # Use middle blocks in CNN
mlp_hidden_dim=256, # Hidden dimension for MLPs
attn_norm_type='rms', # Normalization type for attention ('rms' or 'ln')
# Particle interaction transformer (PINT) configuration
pint_enc_layers=1, # Number of PINT encoder layers
pint_enc_heads=1, # Number of PINT encoder attention heads
embed_init_std=0.02, # Standard deviation for embedding initialization
particle_positional_embed=True, # Use positional embeddings for particles
use_z_orig=True, # Include patch center coordinates in particle features
particle_score=False, # Use particle confidence score as feature
filtering_heuristic='none', # Method to filter prior keypoints ('none','distance','variance','random')
# Dynamics configuration
timestep_horizon=10, # Number of timesteps to predict ahead
n_static_frames=1, # Number of initial frames for static KL optimization
predict_delta=False, # Predict position deltas instead of absolute positions
context_dim=None, # Context latent dimension (if None, equals learned_feature_dim)
ctx_dist='gauss', # Context distribution type ('gauss' or 'categorical')
n_ctx_categories=8, # Number of context categories (if categorical)
n_ctx_classes=4, # Number of context classes per category
causal_ctx=True, # Use causal attention for context modeling
ctx_pool_mode='none', # Context pooling mode ('none' = per-particle context)
global_ctx_pool=False, # learn global latent context in addition to per-particle context
# global_ctx_pool: EXPERIMENTAL, NOT USED IN THE PAPER
# EXPERIMENTAL, NOT USED IN THE PAPER:
pool_ctx_dim=7, # pool dimension for the global ctx latent
n_pool_ctx_categories=8, # Number of global context categories (if categorical)
n_pool_ctx_classes=4, # Number of global context classes per category
global_local_fuse_mode='none', # concatenate/add global and local z_ctx to condition the dynamics
condition_local_on_global=True, # condition z_context on z_context_global
# END EXPERIMETNAL
# Context and dynamics transformer configuration
pint_dyn_layers=6, # Number of dynamics transformer layers
pint_dyn_heads=8, # Number of dynamics transformer heads
pint_dim=512, # Hidden dimension for PINT
pint_ctx_layers=4, # Number of context transformer layers
pint_ctx_heads=8, # Number of context transformer heads
# external conditioning
action_condition=False, # condition on actions
action_dim=0, # dimension of input actions
null_action_embed=False, # learn a "no-input-action" embedding, to learn on action-free videos as well
random_action_condition=False, # condition on random actions
random_action_dim=0, # dimension of sampled random actions
action_in_ctx_module=True, # use action to condition context generation
language_condition=False, # condition on language embedding
language_embed_dim=0, # embedding dimension for each token
language_max_len=64, # maximum tokens per prompt
img_goal_condition=False, # image as goal conditioning for dynamics
# initialization
init_zero_bias=True, # zero bias for conv and linear layers
init_ssm_last_layer=True, # spatial softmax initialization
init_conv_layers=True, # initialize conv layers with normal dist
init_conv_fg_std=0.02, # std for conv fg normal dist
init_conv_bg_std=0.005, # std for conv bg normal dist (<fg -> prioritize fg in learning)
):
super(DLP, self).__init__()
"""
Args:
cdim (int): Number of input image channels. Defaults to 3.
image_size (int): Size of input images (assumed square). Defaults to 64.
normalize_rgb (bool): Normalize RGB values to [-1, 1] instead of [0, 1]. Defaults to False.
n_kp_per_patch (int): Number of keypoints to extract per patch. Defaults to 1.
patch_size (int): Size of patches for keypoint proposal network. Defaults to 16.
anchor_s (float): Glimpse size as ratio of image_size (e.g., 0.25 for 32px glimpse on 128px image). Defaults to 0.25.
n_kp_enc (int): Number of posterior keypoints to learn. Defaults to 20.
n_kp_prior (int): Number of keypoints to filter from prior proposals. Defaults to 64.
pad_mode (str): Padding mode for CNNs ('zeros' or 'replicate'). Defaults to 'zeros'.
dropout (float): Dropout rate for transformers. Defaults to 0.1.
features_dist (str): Distribution type for features ('gauss' or 'categorical'). Defaults to 'gauss'.
learned_feature_dim (int): Dimension of learned visual features. Defaults to 16.
learned_bg_feature_dim (Optional[int]): Background feature dimension. If None, equals learned_feature_dim. Defaults to None.
n_fg_categories (int): Number of foreground feature categories if categorical. Defaults to 8.
n_fg_classes (int): Number of foreground feature classes per category. Defaults to 4.
n_bg_categories (int): Number of background feature categories. Defaults to 4.
n_bg_classes (int): Number of background feature classes per category. Defaults to 4.
scale_std (float): Prior standard deviation for scale. Defaults to 0.3.
offset_std (float): Prior standard deviation for offset. Defaults to 0.2.
obj_on_alpha (float): Alpha parameter for transparency Beta distribution. Defaults to 0.01.
obj_on_beta (float): Beta parameter for transparency Beta distribution. Defaults to 0.01.
obj_on_min (float): Minimum concentration value in Beta dist for transparency value. Defaults to 1e-4.
obj_on_max (float): Maximum concentration value in Beta dist transparency value. Defaults to 100.
obj_res_from_fc (int): Initial resolution for object encoder-decoder. Defaults to 8.
obj_ch_mult_prior (tuple): Channel multipliers for prior patch encoder (kp proposals). Defaults to (1, 2, 3).
obj_ch_mult (tuple): Channel multipliers for object encoder-decoder. Defaults to (1, 2, 3).
obj_base_ch (int): Base channels for object encoder-decoder. Defaults to 32.
obj_final_cnn_ch (int): Final CNN channels for object encoder-decoder. Defaults to 32.
bg_res_from_fc (int): Initial resolution for background encoder-decoder. Defaults to 8.
bg_ch_mult (tuple): Channel multipliers for background encoder-decoder. Defaults to (1, 2, 3).
bg_base_ch (int): Base channels for background encoder-decoder. Defaults to 32.
bg_final_cnn_ch (int): Final CNN channels for background encoder-decoder. Defaults to 32.
use_resblock (bool): Use residual blocks in encoders-decoders. Defaults to True.
num_res_blocks (int): Number of residual blocks per resolution. Defaults to 2.
cnn_mid_blocks (bool): Use middle blocks in CNN. Defaults to False.
mlp_hidden_dim (int): Hidden dimension for MLPs. Defaults to 256.
attn_norm_type (str): Normalization type for attention ('rms' or 'layer'). Defaults to 'rms'.
pint_enc_layers (int): Number of PINT encoder layers. Defaults to 1.
pint_enc_heads (int): Number of PINT encoder attention heads. Defaults to 1.
embed_init_std (float): Standard deviation for embedding initialization. Defaults to 0.2.
particle_positional_embed (bool): Use positional embeddings for particles. Defaults to True.
use_z_orig (bool): Include patch center coordinates in particle features. Defaults to True.
particle_score (bool): Use particle confidence score as feature. Defaults to False.
filtering_heuristic (str): Method to filter prior keypoints ('none','distance','variance','random'). Defaults to 'none'.
timestep_horizon (int): Number of timesteps to predict ahead. Defaults to 10.
n_static_frames (int): Number of initial frames for static KL optimization. Defaults to 1.
predict_delta (bool): Predict position deltas instead of absolute positions. Defaults to False.
context_dim (Optional[int]): Context latent dimension. If None, equals learned_feature_dim. Defaults to None.
ctx_dist (str): Context distribution type ('gauss' or 'categorical'). Defaults to 'gauss'.
n_ctx_categories (int): Number of context categories if categorical. Defaults to 8.
n_ctx_classes (int): Number of context classes per category. Defaults to 4.
causal_ctx (bool): Use causal attention for context modeling. Defaults to True.
ctx_pool_mode (str): Context pooling mode ('none' = per-particle context). Defaults to 'none'.
pint_dyn_layers (int): Number of dynamics transformer layers. Defaults to 6.
pint_dyn_heads (int): Number of dynamics transformer heads. Defaults to 8.
pint_dim (int): Hidden dimension for PINT. Defaults to 512.
pint_ctx_layers (int): Number of context transformer layers. Defaults to 4.
pint_ctx_heads (int): Number of context transformer heads. Defaults to 8.
Example: see in models.py after model definition
----
Deep Latent Particles (DLP) Model
DLP is an unsupervised/self-supervised object-centric model that decomposes input images into a set
of latent particles. Each particle represents a local region in the image and is characterized by:
- Position (x,y): 2D coordinate-keypoint (Gaussian distributed)
- Scale: 2D bounding box dimensions (Gaussian distributed)
- Depth: Local depth ordering parameter (Gaussian distributed)
- Transparency: Visibility parameter in [0,1] (Beta distributed)
- Features: Visual features within the bounding box (Gaussian or Categorical distributed)
The background is modeled as a single particle with its own feature dimension.
For dynamic scenes (LPWM), the model includes latent context variables that capture transitions
between particles in consecutive timesteps, similar to latent actions.
Pipeline:
1. Prior Network: Proposes n_kp_prior keypoints by:
- Processing image patches through a CNN
- Using spatial-softmax to locate highest activations
- Generating keypoint proposals per activation map
2. Posterior Network: Filters n_kp_enc keypoints by:
- Processing each proposal to extract particle attributes
- Optionally filtering based on keypoint variance (confidence)
- Modeling positions as offsets from prior keypoints
3. Background Processing:
- Creates background mask using particle transparency
- Masks out regions modeled by active particles
- Encodes masked background separately
4. Decoding:
- Decodes particles and background separately
- Stitches complete image using differentiable spatial transformer network (STN)
5. Dynamic Modeling (LPWM):
- Context encoder with shared causal transformer backbone:
* Posterior (inverse model): p(c_t|z_t+1, z_t)
* Prior (policy): p(c_t|z_t)
- Dynamics module: p(z_t+1|z_t, c_t)
- Uses AdaLN conditioning for particle transitions
- Optimizes KL divergence between posterior and prior
Note: Patch extraction and stitching use differentiable spatial transformer networks (STN).
"""
self.cdim = cdim # number of input image channels
self.image_size = image_size
self.normalize_rgb = normalize_rgb # normalize to [-1, 1] or keep [0, 1]
self.n_views = n_views # number of input views (e.g., multiple cameras)
self.dropout = dropout
self.num_patches = int((image_size // patch_size) ** 2)
self.filter_particles_in_decoder = (timestep_horizon > 1)
self.n_kp_per_patch = n_kp_per_patch
self.n_kp_total = self.n_kp_per_patch * self.num_patches
self.n_kp_prior = min(self.n_kp_total, n_kp_prior)
self.n_kp_enc = self.n_kp_prior if self.filter_particles_in_decoder else n_kp_enc
self.n_kp_dec = n_kp_enc
self.warmup_n_kp_ratio = warmup_n_kp_ratio
self.kp_range = (-1, 1)
self.kp_activation = 'tanh' # since keypoints are in [-1, 1], we use tanh activation for kp heads
self.anchor_s = anchor_s # posterior patch ratio, i.e., anchor size, glimpse-size = anchor_s * image_size
self.patch_size = patch_size # prior patch size, to propose prior keypoints
self.obj_patch_size = np.round(self.anchor_s * (image_size - 1)).astype(int)
self.mask_bg_in_enc = mask_bg_in_enc # before encoding the bg, mask with the particles' obj_on
self.features_dist = features_dist
self.n_fg_categories = n_fg_categories
self.n_fg_classes = n_fg_classes
self.n_bg_categories = n_bg_categories
self.n_bg_classes = n_bg_classes
if self.features_dist == 'categorical':
self.learned_feature_dim = int(self.n_fg_categories * self.n_fg_classes)
self.learned_bg_feature_dim = int(self.n_bg_categories * self.n_bg_classes)
else:
self.learned_feature_dim = learned_feature_dim
self.learned_bg_feature_dim = learned_feature_dim if learned_bg_feature_dim is None else learned_bg_feature_dim
assert learned_feature_dim > 0, "learned_feature_dim must be greater than 0"
assert self.learned_bg_feature_dim > 0, "bg_learned_feature_dim must be greater than 0"
self.obj_on_min = np.log(obj_on_min)
self.obj_on_max = np.log(obj_on_max)
assert filtering_heuristic in ['distance', 'variance',
'random', 'none'], f'unknown filtering heuristic: {filtering_heuristic}'
self.filtering_heuristic = filtering_heuristic
self.particle_score = particle_score
self.use_z_orig = use_z_orig if self.n_kp_enc == self.n_kp_prior else False
# attention hyper-parameters
self.attn_norm_type = attn_norm_type
self.pint_enc_layers = pint_enc_layers
self.pint_enc_heads = pint_enc_heads
self.particle_positional_embed = particle_positional_embed if self.n_kp_enc == self.n_kp_prior else False
self.embed_init_std = embed_init_std
# cnn hyper-parameters
self.use_resblock = use_resblock
self.num_res_blocks = num_res_blocks
self.cnn_mid_blocks = cnn_mid_blocks
self.mlp_hidden_dim = mlp_hidden_dim
self.pad_mode = pad_mode
self.obj_res_from_fc = obj_res_from_fc
self.obj_ch_mult_prior = obj_ch_mult_prior
self.obj_ch_mult = obj_ch_mult
self.obj_base_ch = obj_base_ch
self.obj_final_cnn_ch = obj_final_cnn_ch
self.bg_res_from_fc = bg_res_from_fc
self.bg_ch_mult = bg_ch_mult
self.bg_base_ch = bg_base_ch
self.bg_final_cnn_ch = bg_final_cnn_ch
# priors
self.register_buffer('logvar_kp', torch.log(torch.tensor(1.0 ** 2)))
self.register_buffer('mu_scale_prior',
torch.tensor(np.log(0.75 * self.anchor_s / (1 - 0.75 * self.anchor_s + 1e-5))))
self.register_buffer('logvar_scale_p', torch.log(torch.tensor(scale_std ** 2)))
self.register_buffer('logvar_offset_p', torch.log(torch.tensor(offset_std ** 2)))
self.register_buffer('obj_on_a_p', torch.tensor(obj_on_alpha))
self.register_buffer('obj_on_b_p', torch.tensor(obj_on_beta))
# dynamics
self.timestep_horizon = timestep_horizon
self.is_dynamics_model = (self.timestep_horizon > 1)
self.n_static_frames = n_static_frames
self.predict_delta = predict_delta
self.context_dist = ctx_dist
assert self.context_dist in ["gauss", "beta", "categorical"], f'ctx distribution {ctx_dist} unrecognized'
self.ctx_pool_mode = ctx_pool_mode
assert self.ctx_pool_mode in ["none", "token", "mlp", "mean", "last"], \
f'ctx pooling {ctx_pool_mode} unrecognized'
self.n_ctx_categories = n_ctx_categories
self.n_ctx_classes = n_ctx_classes
if self.is_dynamics_model:
if self.context_dist == 'categorical':
self.context_dim = int(self.n_ctx_categories * self.n_ctx_classes) if self.is_dynamics_model else None
else:
if context_dim is None:
self.context_dim = learned_feature_dim
else:
self.context_dim = context_dim
else:
self.context_dim = 0
self.causal_ctx = causal_ctx
# global latent context: EXPERIMENTAL, NOT USED IN THE PAPER
self.global_ctx_pool = global_ctx_pool
self.pool_ctx_dim = pool_ctx_dim
self.n_pool_ctx_categories = n_pool_ctx_categories
self.n_pool_ctx_classes = n_pool_ctx_classes
if self.is_dynamics_model and self.context_dist == 'categorical':
self.pool_ctx_dim = int(self.n_pool_ctx_categories * self.n_pool_ctx_classes)
self.global_local_fuse_mode = global_local_fuse_mode
self.condition_local_on_global = condition_local_on_global
self.pint_dyn_layers = pint_dyn_layers
self.pint_dyn_heads = pint_dyn_heads
self.pint_ctx_layers = pint_ctx_layers
self.pint_ctx_heads = pint_ctx_heads
self.pint_dim = pint_dim
pint_inner_dim = self.pint_dim
pte_dropout = dropout
max_particles = n_kp_enc + 2 # particle positional bias, +1 for the bg particle, +1 for context
# dynamics conditioning
# actions
self.action_condition = action_condition
self.action_dim = action_dim
self.random_action_condition = random_action_condition
self.random_action_dim = random_action_dim
self.learn_null_action_embed = null_action_embed
self.action_in_ctx_module = action_in_ctx_module
# language
self.language_condition = language_condition
self.language_embed_dim = language_embed_dim
self.language_max_len = language_max_len
# image
self.img_goal_condition = img_goal_condition
# initialization
self.init_zero_bias = init_zero_bias # zero bias for conv and linear layers
self.init_ssm_last_layer = init_ssm_last_layer # spatial softmax initialization
self.init_conv_layers = init_conv_layers # initialize conv layers with normal dist
self.init_conv_fg_std = init_conv_fg_std # std for conv fg normal dist
self.init_conv_bg_std = init_conv_bg_std # std for conv bg normal dist
# encoder
self.encoder_module = DLPEncoder(cdim=self.cdim,
image_size=self.image_size,
n_views=self.n_views,
patch_size=self.patch_size,
n_kp_per_patch=self.n_kp_per_patch,
n_kp_enc=self.n_kp_enc,
n_kp_prior=self.n_kp_prior,
n_kp_dec=self.n_kp_dec,
warmup_n_kp_ratio=self.warmup_n_kp_ratio,
kp_range=self.kp_range,
kp_activation=self.kp_activation,
anchor_s=self.anchor_s,
mask_bg_in_enc=self.mask_bg_in_enc,
features_dist=self.features_dist,
n_fg_categories=n_fg_categories,
n_fg_classes=n_fg_classes,
n_bg_categories=n_bg_categories,
n_bg_classes=n_bg_classes,
obj_on_min=self.obj_on_min,
obj_on_max=self.obj_on_max,
use_z_orig=self.use_z_orig,
learned_feature_dim=self.learned_feature_dim,
learned_bg_feature_dim=self.learned_bg_feature_dim,
pad_mode=self.pad_mode,
obj_ch_mult_prior=self.obj_ch_mult_prior,
obj_ch_mult=self.obj_ch_mult,
obj_base_ch=self.obj_base_ch,
obj_final_cnn_ch=self.obj_final_cnn_ch,
bg_ch_mult=self.bg_ch_mult,
bg_base_ch=self.bg_base_ch,
bg_final_cnn_ch=self.bg_final_cnn_ch,
use_resblock=self.use_resblock,
num_res_blocks=self.num_res_blocks,
cnn_mid_blocks=self.cnn_mid_blocks,
mlp_hidden_dim=self.mlp_hidden_dim,
particle_score=self.particle_score,
embed_init_std=self.embed_init_std,
attn_norm_type=self.attn_norm_type,
pte_layers=self.pint_enc_layers,
pte_heads=self.pint_enc_heads,
dropout=pte_dropout,
particle_positional_embed=self.particle_positional_embed,
projection_dim=self.mlp_hidden_dim,
interaction_obj_on=False, # use attention for transparency
interaction_depth=True, # use attention for depth
interaction_features=True, # use attention for visual features
timestep_horizon=self.timestep_horizon,
add_particle_temp_embed=False,
context_dim=self.context_dim,
init_zero_bias=init_zero_bias, # zero bias for conv and linear layers
init_ssm_last_layer=init_ssm_last_layer, # spatial softmax initialization
init_conv_layers=init_conv_layers, # initialize conv layers with normal dist
init_conv_fg_std=init_conv_fg_std, # std for conv fg normal dist
init_conv_bg_std=init_conv_bg_std # std for conv bg normal dist
)
# prior
self.prior_module = self.encoder_module.prior_encoder
particle_anchors = self.encoder_module.patch_centers[:, :-1] # [1, n_patches, 2], no need for (0,0)-the bg
particle_anchors = particle_anchors.unsqueeze(-2).repeat(1, 1, self.n_kp_per_patch, 1).view(1, -1, 2)
# [1, n_patches * n_kp_per_patch, 2]
# decoder
self.decoder_module = DLPDecoder(cdim=cdim, image_size=image_size,
learned_feature_dim=self.learned_feature_dim,
learned_bg_feature_dim=self.learned_bg_feature_dim,
anchor_s=anchor_s, n_kp_enc=self.n_kp_dec, pad_mode=pad_mode,
context_dim=self.context_dim,
obj_res_from_fc=obj_res_from_fc, obj_ch_mult=obj_ch_mult,
obj_base_ch=obj_base_ch, obj_final_cnn_ch=obj_final_cnn_ch,
bg_res_from_fc=bg_res_from_fc, bg_ch_mult=bg_ch_mult, bg_base_ch=bg_base_ch,
bg_final_cnn_ch=bg_final_cnn_ch,
num_res_blocks=num_res_blocks, decode_with_ctx=False,
timestep_horizon=timestep_horizon, use_resblock=use_resblock,
normalize_rgb=normalize_rgb, cnn_mid_blocks=cnn_mid_blocks,
mlp_hidden_dim=mlp_hidden_dim,
init_zero_bias=init_zero_bias, # zero bias for conv and linear layers
init_conv_layers=init_conv_layers, # initialize conv layers with normal dist
init_conv_fg_std=init_conv_fg_std, # std for conv fg normal dist
init_conv_bg_std=init_conv_bg_std # std for conv bg normal dist
)
# context (latent actions)
if self.context_dim > 0:
self.ctx_module = DLPContext(n_kp_enc=self.n_kp_enc, dropout=pte_dropout,
learned_feature_dim=self.learned_feature_dim,
learned_bg_feature_dim=self.learned_bg_feature_dim,
embed_init_std=embed_init_std, projection_dim=pint_inner_dim,
timestep_horizon=timestep_horizon, pte_layers=pint_ctx_layers,
pte_heads=pint_ctx_heads,
attn_norm_type=attn_norm_type,
context_dim=self.context_dim,
hidden_dim=pint_inner_dim,
ctx_pool_mode=self.ctx_pool_mode,
bg=True, n_views=self.n_views,
particle_positional_embed=particle_positional_embed,
particle_score=self.particle_score,
causal=self.causal_ctx, norm_layer=True,
shared_logvar=False, ctx_dist=ctx_dist,
n_ctx_categories=n_ctx_categories, n_ctx_classes=n_ctx_classes,
particle_anchors=particle_anchors, use_z_orig=self.use_z_orig,
global_ctx_pool=self.global_ctx_pool,
ctx_pool_dim=self.pool_ctx_dim,
n_pool_ctx_categories=self.n_pool_ctx_categories,
n_pool_ctx_classes=self.n_pool_ctx_classes,
global_local_fuse_mode=global_local_fuse_mode,
condition_local_on_global=condition_local_on_global,
# external conditioning
action_condition=self.action_condition,
# condition on actions
action_dim=self.action_dim, # dimension of input actions
random_action_condition=self.random_action_condition,
random_action_dim=self.random_action_dim,
null_action_embed=self.learn_null_action_embed,
# learn a "no-input-action" embedding, to learn on action-free videos as well
action_as_particle=self.action_condition and not self.action_in_ctx_module,
language_condition=self.language_condition, # condition on language embedding
language_embed_dim=self.language_embed_dim,
# embedding dimension for each token
language_max_len=self.language_max_len, # maximum tokens per prompt
img_goal_condition=self.img_goal_condition
)
self.encoder_module.ctx_enc = self.ctx_module
else:
self.ctx_module = None
# dynamics
if self.is_dynamics_model:
dyn_activ = self.kp_activation
ctx_cond_mode = 'adaln'
context_decoder_dyn = self.ctx_module
dyn_particle_anchors = particle_anchors if (self.n_kp_enc == self.n_kp_prior) else None
if self.global_ctx_pool and self.global_local_fuse_mode == 'concat':
dyn_ctx_dim = self.pool_ctx_dim + self.context_dim
else:
dyn_ctx_dim = self.context_dim
self.dyn_module = DLPDynamics(self.learned_feature_dim,
self.learned_bg_feature_dim,
pint_inner_dim,
pint_inner_dim,
n_head=pint_dyn_heads,
n_layer=pint_dyn_layers,
block_size=timestep_horizon,
dropout=dropout,
kp_activation=dyn_activ,
predict_delta=predict_delta,
max_delta=1.0,
positional_bias=False,
max_particles=max_particles,
context_dim=dyn_ctx_dim,
attn_norm_type=attn_norm_type,
n_fg_particles=self.n_kp_enc,
ctx_pool_mode=ctx_pool_mode,
particle_positional_embed=particle_positional_embed,
particle_anchors=dyn_particle_anchors,
particle_score=self.particle_score,
init_std=self.embed_init_std,
ctx_mode=ctx_cond_mode,
pint_ctx_layers=pint_ctx_layers,
pint_ctx_heads=pint_ctx_heads,
ctx_dist=ctx_dist,
n_ctx_categories=n_ctx_categories,
n_ctx_classes=n_ctx_classes,
context_decoder=context_decoder_dyn,
features_dist=self.features_dist,
n_fg_categories=n_fg_categories,
n_fg_classes=n_fg_classes,
n_bg_categories=n_bg_categories,
n_bg_classes=n_bg_classes,
scale_init=self.anchor_s,
obj_on_min=self.obj_on_min,
obj_on_max=self.obj_on_max,
use_z_orig=self.use_z_orig,
n_views=self.n_views,
# external conditioning
action_condition=(self.action_condition and not self.action_in_ctx_module),
# condition on actions
action_dim=self.action_dim, # dimension of input actions
random_action_condition=(
self.random_action_condition and not self.action_in_ctx_module),
random_action_dim=self.random_action_dim,
null_action_embed=(
self.learn_null_action_embed and not self.action_in_ctx_module),
# learn a "no-input-action" embedding, to learn on action-free videos as well
)
else:
self.dyn_module = nn.Identity()
self.init_weights()
def init_weights(self):
if self.init_zero_bias:
# all conv, linear layers are specific to modules
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
self.prior_module.init_weights()
self.encoder_module.init_weights()
self.decoder_module.init_weights()
if isinstance(self.ctx_module, DLPContext):
self.ctx_module.init_weights()
if isinstance(self.dyn_module, DLPDynamics):
self.dyn_module.init_weights()
def info(self):
# Create sections for different parts of the model
def create_section_header(title):
return f"\n{'=' * 80}\n{title}\n{'=' * 80}\n"
def format_row(label, value):
return f"{label:<45} | {value}"
sections = []
# DLP Logo
sections.append(generate_dlp_logo())
# [Previous sections remain the same until Latent Information]
basic_config = [
("Prior Keypoint Filtering", f"{self.n_kp_total} -> {self.n_kp_prior}"),
("Filtering Heuristic", self.filtering_heuristic),
("Prior Patch Size", self.patch_size),
("Posterior Particles (Encoder)", self.n_kp_enc),
("Posterior Particles (Decoder)", self.n_kp_dec),
("Filter Particles in Decoder", self.filter_particles_in_decoder),
("Include Origin Patch Center", self.use_z_orig),
("Posterior Object Patch Size", self.obj_patch_size),
("Attention Layer Normalization", self.attn_norm_type),
("Number of Input Views (Cameras)", self.n_views),
]
sections.append(create_section_header("Basic Configuration"))
sections.extend(format_row(label, value) for label, value in basic_config)
# Feature Distribution Information
sections.append(create_section_header("Feature Distribution"))
if self.features_dist == 'categorical':
feature_info = [
("Distribution Type", self.features_dist),
("Foreground Dimension", self.learned_feature_dim),
("Background Dimension", self.learned_bg_feature_dim),
("Foreground Categories/Classes", f"{self.n_fg_categories}/{self.n_fg_classes}"),
("Background Categories/Classes", f"{self.n_bg_categories}/{self.n_bg_classes}")
]
else:
feature_info = [
("Distribution Type", self.features_dist),
("Particle Visual Feature Dimension", self.learned_feature_dim),
("Background Visual Feature Dimension", self.learned_bg_feature_dim)
]
sections.extend(format_row(label, value) for label, value in feature_info)
# Context Distribution
sections.append(create_section_header("Context Information"))
if self.context_dist == 'categorical':
context_info = [
("Distribution Type", self.context_dist),
("Dimension", self.context_dim),
("Categories/Classes", f"{self.n_ctx_categories}/{self.n_ctx_classes}")
]
else:
context_info = [
("Distribution Type", self.context_dist),
("Dimension", self.context_dim)
]
if self.ctx_module is not None:
ctx_size_dict = calc_model_size(self.ctx_module)
ctx_n_params = ctx_size_dict['n_params']
context_info.append(("CTX Module Parameters", f"{ctx_n_params} ({ctx_n_params / 1e6:.4f}M)"))
sections.extend(format_row(label, value) for label, value in context_info)
# random actions
sections.append(create_section_header("Random Action Conditioning via AdaLN Information"))
if self.random_action_condition:
rand_action_info = [
("Random Action Conditioning", self.random_action_condition),
("Dimension", self.random_action_dim),
("Condition in CTX Module", self.action_in_ctx_module),
]
else:
rand_action_info = [
("Random Action Conditioning", self.random_action_condition),
]
sections.extend(format_row(label, value) for label, value in rand_action_info)
# actions
sections.append(create_section_header("Action Conditioning via AdaLN Information"))
if self.action_condition:
action_info = [
("Action Conditioning", self.action_condition),
("Dimension", self.action_dim),
("Condition in CTX Module", self.action_in_ctx_module),
("Learn Null Embedding for Actions", self.learn_null_action_embed),
]
else:
action_info = [
("Action Conditioning", self.action_condition),
]
sections.extend(format_row(label, value) for label, value in action_info)
# language
sections.append(create_section_header("Language Conditioning"))
if self.language_condition:
lang_info = [
("Language Conditioning", self.language_condition),
("Dimension", self.language_embed_dim),
("Maximum Language Tokens", self.language_max_len),
]
else:
lang_info = [
("Language Conditioning", self.language_condition),
]
sections.extend(format_row(label, value) for label, value in lang_info)
# image goal conditioning
sections.append(create_section_header("Image Goal Conditioning"))
lang_info = [
("Image Goal Conditioning", self.img_goal_condition),
]
sections.extend(format_row(label, value) for label, value in lang_info)
# CNN Architecture
sections.append(create_section_header("CNN Architecture"))
cnn_info = [
("Prior CNN Pre-pool Output Size", self.prior_module.enc.conv_output_size),
("Object CNN Output Shape", self.encoder_module.particle_enc.particle_features_enc.cnn_out_shape),
("Background CNN Output Shape", self.encoder_module.bg_encoder.cnn_out_shape),
("Decoder Background Upsamples", self.decoder_module.num_bg_upsample),
("Decoder Object Upsamples", self.decoder_module.num_obj_upsample)
]
sections.extend(format_row(label, value) for label, value in cnn_info)
# Latent Information
sections.append(create_section_header("Latent Space Information"))
context_coeff = 1 if self.ctx_pool_mode != 'none' else (self.n_kp_enc + 1)
latent_dim = ((6 + self.learned_feature_dim) * self.n_kp_enc
+ self.learned_bg_feature_dim
+ context_coeff * self.context_dim)
sections.append(
format_row("Encoder Particle Features", self.encoder_module.particle_enc.particle_features_enc.info))
sections.append(format_row("Background Encoder", self.encoder_module.bg_encoder.info))
if self.encoder_module.particle_inter_enc is not None:
sections.append(format_row("Particle Intermediate Encoder", self.encoder_module.particle_inter_enc.info))
sections.append(format_row("Particle Decoder", self.decoder_module.particle_dec.info))
sections.append(format_row("Background Decoder", self.decoder_module.bg_dec.info))
# Add latent dimension with formula
latent_formula = (f"(6 + {self.learned_feature_dim}) * {self.n_kp_enc} + "
f"{self.learned_bg_feature_dim} + "
f"{context_coeff} * {self.context_dim}")
sections.append(format_row("Latent Dimension Formula", latent_formula))
sections.append(format_row("Total Latent Dimension", f"{latent_formula} = {latent_dim}"))
# Dynamic Module Information (if applicable)
if self.is_dynamics_model:
sections.append(create_section_header("Dynamics Module Information"))
pint_size_dict = calc_model_size(self.dyn_module)
pint_n_params = pint_size_dict['n_params']
if self.context_dim > 0:
ctx_size_dict = calc_model_size(self.dyn_module.context_decoder)
pint_n_params = pint_n_params - ctx_size_dict['n_params']
dynamics_info = [
("Dropout (PINT)", self.dropout),
("Burn-in Frames", self.n_static_frames),
("Prior Predicts Delta", self.predict_delta),
("PINT Relative Positional Bias", self.dyn_module.particle_transformer.positional_bias),
("PINT Parameters", f"{pint_n_params} ({pint_n_params / 1e6:.4f}M)")
]
sections.extend(format_row(label, value) for label, value in dynamics_info)
# Model Size Information
sections.append(create_section_header("Model Size Information"))
prior_size_dict = calc_model_size(self.prior_module)
enc_size_dict = calc_model_size(self.encoder_module)
enc_n_params = enc_size_dict['n_params']
if self.ctx_module is not None:
ctx_size_dict = calc_model_size(self.encoder_module.ctx_enc)
enc_n_params = enc_n_params - ctx_size_dict['n_params']
dec_size_dict = calc_model_size(self.decoder_module)
size_dict = calc_model_size(self)
model_size_info = [
("Prior Parameters", f"{prior_size_dict['n_params']} ({prior_size_dict['n_params'] / 1e6:.4f}M)"),
("Encoder Parameters", f"{enc_n_params} ({enc_n_params / 1e6:.4f}M)"),
("Decoder Parameters", f"{dec_size_dict['n_params']} ({dec_size_dict['n_params'] / 1e6:.4f}M)"),
("Total Parameters", f"{size_dict['n_params']} ({size_dict['n_params'] / 1e6:.4f}M)"),
("Estimated Size on Disk", f"{size_dict['size_mb']:.3f}MB")
]
sections.extend(format_row(label, value) for label, value in model_size_info)
return "\n".join(sections)
def encode_prior(self, x):
return self.prior_module(x)
def encode_all(self, x, deterministic=False, warmup=False, actions=None, actions_mask=None, lang_embed=None,
x_goal=None):
"""
encode posterior particles
"""
# x: [bs, timestep_horizon, ch, h, w
# # make sure x is [bs, T, ch, h, w]
# x_goal: [bs, 1, ch, h, w]
if len(x.shape) == 4:
# that means x: [bs, ch, h, w]
x = x.unsqueeze(1) # -> [bs, T=1, ch, h, w]]
enc_dict = self.encoder_module(x, deterministic, warmup, actions=actions, actions_mask=actions_mask,
lang_embed=lang_embed, x_goal=x_goal)
cropped_objects = enc_dict['cropped_objects']
if self.normalize_rgb:
cropped_objects_rgb = minusoneone_to_rgb(cropped_objects)
else:
cropped_objects_rgb = cropped_objects
enc_dict['cropped_objects_rgb'] = cropped_objects_rgb
return enc_dict
def decode_all(self, z, z_scale, z_features, obj_on_sample, z_depth, z_bg_features, z_ctx,
warmup=False, filter_key=None):
if filter_key is not None:
orig_shape = z.shape
# filter_key: [batch_size, n_kp]
if len(filter_key.shape) == 3:
# [bs, T, n_kp]
filter_key = filter_key.view(-1, filter_key.shape[-1])
if len(orig_shape) == 4:
# [bs, T, n_kp, ...] -> [bs * T, n_kp, ...]
z = z.view(-1, *z.shape[2:])
z_scale = z_scale.view(-1, *z_scale.shape[2:])
z_depth = z_depth.view(-1, *z_depth.shape[2:])
z_features = z_features.view(-1, *z_features.shape[2:])
obj_on_sample = obj_on_sample.view(-1, *obj_on_sample.shape[2:])
# k = self.n_kp_dec
# discourage "lazy" particles that don't move by choking the model to use less particles for reconstruction
k = self.n_kp_dec if not warmup else min(self.n_kp_dec, int(self.warmup_n_kp_ratio * self.n_kp_enc))
_, embed_ind = torch.topk(filter_key, k=k, dim=-1, largest=False)
# make selection
batch_ind = torch.arange(z.shape[0], device=z.device)[:, None]
z = z[batch_ind, embed_ind] # [bs * T, n_kp_dec, 2]
z_scale = z_scale[batch_ind, embed_ind] # [bs * T, n_kp_dec, 2]
obj_on_sample = obj_on_sample[batch_ind, embed_ind] # [bs * T, n_kp_dec, 1]
z_depth = z_depth[batch_ind, embed_ind] # [bs * T, n_kp_dec, 1]
z_features = z_features[batch_ind, embed_ind] # [bs * T, n_kp_dec, features_dim]
if len(orig_shape) == 4:
# [bs * T, n_kp, ...] -> [bs, T, n_kp, ...]
z = z.reshape(orig_shape[0], orig_shape[1], *z.shape[1:])
z_scale = z_scale.reshape(orig_shape[0], orig_shape[1], *z_scale.shape[1:])
z_depth = z_depth.reshape(orig_shape[0], orig_shape[1], *z_depth.shape[1:])
z_features = z_features.reshape(orig_shape[0], orig_shape[1], *z_features.shape[1:])
obj_on_sample = obj_on_sample.reshape(orig_shape[0], orig_shape[1], *obj_on_sample.shape[1:])
dec_dict = self.decoder_module(z, z_scale, z_features, obj_on_sample, z_depth, z_bg_features, z_ctx, warmup)
dec_objects = dec_dict['dec_objects']
dec_objects_trans = dec_dict['dec_objects_trans']
rec = dec_dict['rec']
bg_rec = dec_dict['bg_rec']
if self.normalize_rgb:
rec_rgb = minusoneone_to_rgb(rec)
bg_rec_rgb = minusoneone_to_rgb(bg_rec)
dec_objects_trans = minusoneone_to_rgb(dec_objects_trans)
dec_objects_rgb = minusoneone_to_rgb(dec_objects)
else:
rec_rgb = rec
bg_rec_rgb = bg_rec
dec_objects_rgb = dec_objects
dec_dict['rec_rgb'] = rec_rgb
dec_dict['bg_rgb'] = bg_rec_rgb
dec_dict['dec_objects_trans'] = dec_objects_trans
dec_dict['dec_objects_original_rgb'] = dec_objects_rgb
return dec_dict
def sample_from_x(self, x, num_steps=10, deterministic=True, cond_steps=None, return_z=False, use_all_ctx=False,
actions=None, actions_mask=None, lang_embed=None, x_goal=None, decode=True, n_pred_eq_gt=True,
return_context_posterior=False):
"""
(Conditional) Sampling from LPWM: x is the conditional frames, encoded to latent particles
which are unrolled to the future with PINT, and the predicted particles are then decoded to a sequence
of RGB images.
"""
# use_all_ctx: if True, will encode context from the entire trajectory to condition the prediction
# this is meant to see if the model is able to follow conditions and reconstruct stochastic trajectories
# that involve latent stochastic actions
# x: [bs, T, ...]
assert self.is_dynamics_model, f"model timesteps: {self.timestep_horizon} -> non-dynamics model"
# encode-decode
batch_size, timestep_horizon_all = x.size(0), x.size(1)
timestep_horizon = self.timestep_horizon if cond_steps is None else cond_steps
if self.normalize_rgb:
x = rgb_to_minusoneone(x)
if self.action_condition and actions is not None:
actions_enc = actions[:, :timestep_horizon].contiguous()
else:
actions_enc = None
if self.action_condition and actions_mask is not None:
actions_mask_enc = actions_mask[:, :timestep_horizon].contiguous()
else:
actions_mask_enc = None
# encode particles
enc_dict = self.encode_all(x[:, :timestep_horizon].contiguous(), deterministic=True, actions=actions_enc,
actions_mask=actions_mask_enc, lang_embed=lang_embed, x_goal=x_goal)
x_in = x[:, :timestep_horizon].reshape(-1, *x.shape[2:]) # [bs * T, ...]
# encoder
z = enc_dict['z']
z_features = enc_dict['z_features']
z_bg_features = enc_dict['z_bg_features']
z_obj_on = enc_dict['obj_on']
z_depth = enc_dict['z_depth']
z_scale = enc_dict['z_scale']
z_context = enc_dict['z_context']
z_score = enc_dict['z_score']
z_goal_proj = enc_dict['z_goal_proj'] # [bs, 1, N, proj_dim] if img_goal_cond else None
filter_key = enc_dict['z_base_var'].sum(-1) if self.filter_particles_in_decoder else None
# "latent actions", [bs, T-1, ctx_dim], ctx models every pair of consecutive steps
if timestep_horizon_all == 1:
z_context = None
else:
z_context = z_context[:, 1:].contiguous() if z_context is not None else None
if timestep_horizon_all > 1 and use_all_ctx and z_context is not None:
# encode context from the entire trajectory
while z_context.shape[1] < timestep_horizon_all - 1:
causal = self.causal_ctx
if causal:
end_step = z_context.shape[1] + 1
start_step = max(end_step - self.timestep_horizon, 0)
if self.action_condition and actions is not None:
actions_enc = actions[:, start_step:end_step + 1].contiguous()
else:
actions_enc = None
if self.action_condition and actions_mask is not None:
actions_mask_enc = actions_mask[:, start_step:end_step + 1].contiguous()
else:
actions_mask_enc = None
enc_dict = self.encode_all(x[:, start_step:end_step + 1].contiguous(), deterministic,
actions=actions_enc, actions_mask=actions_mask_enc, x_goal=x_goal,
lang_embed=lang_embed)
z_context_t = enc_dict['z_context'][:, -1:].contiguous()
z_context = torch.cat([z_context, z_context_t], dim=1)
else:
start_step = z_context.shape[1]
end_step = start_step + self.timestep_horizon + 1
enc_dict = self.encode_all(x[:, start_step:end_step].contiguous(), deterministic)
z_context_t = enc_dict['z_context'][:, 1:].contiguous()
z_context = torch.cat([z_context, z_context_t], dim=1)
# decoder
if decode:
if z_context is not None:
z_ctx = z_context[:, :timestep_horizon - 1].contiguous()
else:
z_ctx = None
dec_dict = self.decode_all(z, z_scale, z_features, z_obj_on, z_depth, z_bg_features,
z_ctx=z_ctx, filter_key=filter_key)
rec = dec_dict['rec_rgb']
rec = rec.view(batch_size, -1, *rec.shape[1:])
else:
rec = None
# dynamics
if self.action_condition and actions is not None:
actions_dyn = actions[:, :timestep_horizon + num_steps].contiguous()
else:
actions_dyn = None
if self.action_condition and actions_mask is not None:
actions_mask_dyn = actions_mask[:, :timestep_horizon + num_steps].contiguous()
else:
actions_mask_dyn = None
dyn_out = self.dyn_module.sample(z, z_scale, z_obj_on, z_depth, z_features, z_bg_features, z_context,
z_score, steps=num_steps, deterministic=deterministic, actions=actions_dyn,
actions_mask=actions_mask_dyn, lang_embed=lang_embed, z_goal=z_goal_proj,
return_context_posterior=return_context_posterior)
z_dyn = dyn_out['z']
z_scale_dyn = dyn_out['z_scale']
z_obj_on_dyn = dyn_out['z_obj_on']
z_depth_dyn = dyn_out['z_depth']
z_features_dyn = dyn_out['z_features']
z_bg_features_dyn = dyn_out['z_bg_features']
z_context_dyn = dyn_out['z_context']
z_score_dyn = dyn_out['z_score']
z_context_posterior = dyn_out['z_context_posterior']
mu_context_posterior = dyn_out['mu_context_posterior']
if return_z:
z_ids = 1 + torch.arange(z_dyn.shape[2], device=z_dyn.device) # num_particles, ids start from 1
z_ids = z_ids[None, None, :].repeat(z_dyn.shape[0], z_dyn.shape[1], 1) # [bs, T, n_particles]
z_out = {'z_pos': z_dyn.detach(), 'z_scale': z_scale_dyn.detach(), 'z_obj_on': z_obj_on_dyn.detach(),
'z_depth': z_depth_dyn.detach(), 'z_features': z_features_dyn.detach(),
'z_context': z_context_dyn.detach(), 'z_bg_features': z_bg_features_dyn.detach(), 'z_ids': z_ids,
'z_score': z_score_dyn, 'z_goal_proj': z_goal_proj,
'z_context_posterior': z_context_posterior, 'mu_context_posterior': mu_context_posterior}
else:
z_out = None
z_dyn = z_dyn[:, -num_steps:].contiguous()
z_features_dyn = z_features_dyn[:, -num_steps:].contiguous()
z_bg_features_dyn = z_bg_features_dyn[:, -num_steps:].contiguous()
z_obj_on_dyn = z_obj_on_dyn[:, -num_steps:].contiguous()
z_depth_dyn = z_depth_dyn[:, -num_steps:].contiguous()
z_scale_dyn = z_scale_dyn[:, -num_steps:].contiguous()
z_context_dyn = z_context_dyn[:, -num_steps:].contiguous()
z_score_dyn = z_score_dyn[:, -num_steps:].contiguous()
if self.filter_particles_in_decoder: