Skip to content

Conversation

@natinew77-creator
Copy link

Summary

Fixes #341

The MCTSActor._forward() method previously hard-coded tf.expand_dims() directly on the observation, which only works for array-like observations (np.ndarray). This prevented using nested structures (dicts, tuples) as observations.

Problem

As described in #341, when passing nested observation structures to the MCTS agent:

# This fails with nested observations
logits, value = self._network(tf.expand_dims(observation, axis=0))

The tf.expand_dims call assumes the observation is a single array, but many environments use nested observation spaces (dictionaries, tuples, etc.).

Solution

Modified acme/agents/tf/mcts/acting.py to use the existing tf2_utils.add_batch_dim() utility, which internally uses tree.map_structure() to apply tf.expand_dims to each leaf of the observation structure:

from acme.tf import utils as tf2_utils
# ...
batched_observation = tf2_utils.add_batch_dim(observation)
logits, value = self._network(batched_observation)

Also updated acme/agents/tf/mcts/types.py to change the Observation type from np.ndarray to Any to properly reflect that nested structures are now supported.

Changes

  • acme/agents/tf/mcts/acting.py: Use tf2_utils.add_batch_dim() instead of direct tf.expand_dims()
  • acme/agents/tf/mcts/types.py: Update Observation type alias to Any

Testing

  • Verified syntax is valid with python3 -m py_compile
  • The fix follows the existing pattern used in acme/tf/utils.py (see add_batch_dim function, line 28-30)

The MCTSActor._forward() method previously hard-coded tf.expand_dims()
directly on the observation, which only works for array-like observations
(np.ndarray). This prevented using nested structures (dicts, tuples) as
observations.

Changes:
- Modified acting.py to use tf2_utils.add_batch_dim() which internally
  uses tree.map_structure() to apply tf.expand_dims to each leaf of
  the observation structure
- Updated types.py Observation type from np.ndarray to Any to allow
  nested structures

This follows the pattern used elsewhere in the codebase (see tf/utils.py)
and allows MCTS to work with environments that have complex observation
spaces.

Fixes google-deepmind#341
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rigid assumption about Observations in acme.agents.tf.mcts

1 participant