Support nested observation structures in MCTS agent #351
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #341
The
MCTSActor._forward()method previously hard-codedtf.expand_dims()directly on the observation, which only works for array-like observations (np.ndarray). This prevented using nested structures (dicts, tuples) as observations.Problem
As described in #341, when passing nested observation structures to the MCTS agent:
The
tf.expand_dimscall assumes the observation is a single array, but many environments use nested observation spaces (dictionaries, tuples, etc.).Solution
Modified
acme/agents/tf/mcts/acting.pyto use the existingtf2_utils.add_batch_dim()utility, which internally usestree.map_structure()to applytf.expand_dimsto each leaf of the observation structure:Also updated
acme/agents/tf/mcts/types.pyto change theObservationtype fromnp.ndarraytoAnyto properly reflect that nested structures are now supported.Changes
acme/agents/tf/mcts/acting.py: Usetf2_utils.add_batch_dim()instead of directtf.expand_dims()acme/agents/tf/mcts/types.py: UpdateObservationtype alias toAnyTesting
python3 -m py_compileacme/tf/utils.py(seeadd_batch_dimfunction, line 28-30)