diff --git a/CMakeLists.txt b/CMakeLists.txt index 8b7e65e..9056cf5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,14 +14,23 @@ FetchContent_MakeAvailable(trieste) set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD_REQUIRED True) +include_directories(src) + add_executable(gitmem src/gitmem.cc src/reader.cc src/parser.cc + src/execution_state.cc src/passes/expressions.cc src/passes/statements.cc src/passes/check_refs.cc src/passes/branching.cc + src/linear/sync_protocol.cc + src/linear/version_store.cc + src/branching/base_sync_protocol.cc + src/branching/base_version_store.cc + src/branching/lazy/version_store.cc + src/branching/eager/version_store.cc src/interpreter.cc src/debugger.cc src/model_checker.cc diff --git a/README.md b/README.md index 651bee4..3022a86 100644 --- a/README.md +++ b/README.md @@ -1,48 +1,231 @@ # gitmem -Experimental interpreter for creating execution diagrams for a new -concurrency model. +An experimental interpreter and model checker for exploring concurrent programs with Git-inspired memory semantics. Gitmem allows you to write multi-threaded programs and automatically explore all possible interleavings to detect data races, deadlocks, and assertion failures. -## Building and Running +## Overview + +Gitmem is a research tool that models concurrent memory operations using version control semantics. It provides: + +- **Multiple sync protocols**: Linear and branching (with eager/lazy variants) semantics for thread synchronization +- **Automatic model checking**: Explores all possible execution paths to find concurrency bugs +- **Interactive debugging**: Step through different thread schedules interactively +- **Execution visualization**: Generates GraphViz diagrams showing execution traces and revision graphs + +## Language Features -To build you need CMake and Ninja. CMake will fetch any other dependencies. +The gitmem language supports: -The following commands should set you up: +- **Shared variables**: `x = value` (global shared state) +- **Thread-local registers**: `$r = value` (prefixed with `$`) +- **Thread operations**: `spawn { ... }`, `join $thread` +- **Synchronization**: `lock var`, `unlock var` +- **Control flow**: `if (condition) { ... } else { ... }` +- **Assertions**: `assert(condition)` +- **Operators**: `==`, `!=`, `+` +### Example Program + +```gitmem +x = 0; +$t1 = spawn { + lock l; + x = x + 1; + unlock l; +}; +$t2 = spawn { + lock l; + x = x + 1; + unlock l; +}; +join $t1; +join $t2; +assert(x == 2); ``` + +## Building and Running + +### Prerequisites + +- CMake (3.14+) +- Ninja build system +- C++23 compatible compiler (Clang recommended) +- Python 3 (for tests) + +### Build Instructions + +```bash mkdir build cd build cmake -G Ninja .. -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Debug ninja ``` -If you need, set the C standard with `-DCMAKE_CXX_STANDARD=20`. -If you are running a recent version of CMake, you may need -`-DCMAKE_POLICY_VERSION_MINIMUM=3.5`. +Optional CMake flags: +- `-DCMAKE_CXX_STANDARD=23` - Set C++ standard (if needed) +- `-DCMAKE_BUILD_TYPE=Release` - Build optimized version -You can test if the build was successful by running the following -command in the `build` directory: +### Quick Test -``` +Test your build with: + +```bash ./gitmem -e ../examples/race_condition.gm ``` -The build script creates two executables: +## Usage + +### Basic Execution + +Run a single execution trace: + +```bash +./gitmem examples/race_condition.gm +``` + +This generates a GraphViz `.dot` file showing the execution trace. + +### Model Checking Mode + +Explore all possible execution paths: + +```bash +./gitmem -e examples/race_condition.gm +``` + +Model checking will: +- Try all possible thread interleavings +- Report data races, deadlocks, and assertion failures +- Generate execution diagrams for failing traces +- Exit with status 0 if all paths succeed, non-zero if any fail + +### Interactive Mode + +Step through executions manually: + +```bash +./gitmem -i examples/singleton.gm +``` + +Commands in interactive mode: +- `?` - Show help +- `s` - Show current state +- `n` - Step one thread forward +- `r` - Run to next synchronization point -- `gitmem` parses source code and runs the interpreter in order to - create an execution diagram (work in progress). You can run the - interpreter interactively with the `-i` flag, and automatically - explore all possible traces with the `-e` flag (showing failing - runs). -- `gitmem_trieste` is the default - [Trieste](https://github.com/microsoft/Trieste) driver which can - be used to inspect the parsed source code and test the parser. - Running `gitmem_trieste build foo.gm` will create a file - `foo.trieste` with the parsed source code as an S-expression. +### Sync Protocols + +Gitmem supports different memory models: + +#### Linear (Default) +```bash +./gitmem --sync linear program.gm +``` +Traditional sequential consistency model. + +#### Branching (Eager) +```bash +./gitmem --sync branching --branching-mode eager program.gm +``` +Git-like branching semantics where threads create branches that merge at synchronization points. Conflicts are detected eagerly. + +#### Branching (Lazy) +```bash +./gitmem --sync branching --branching-mode lazy program.gm +``` +Lazy conflict detection variant that defers checking until synchronization. + +Additional flags: +- `--include-empty-commits` - Include empty commits in branching output +- `--raise-early-conflicts` - Raise conflicts before write suppression (lazy mode only) +- `-v, --verbose` - Enable verbose interpreter output +- `-o, --output ` - Specify output file path + +## Testing + +Run the test suite: + +```bash +# From build directory +ninja run_gitmem_tests + +# Or using CTest +ctest +``` + +The test suite includes: +- **Accept tests**: Programs that should execute successfully +- **Reject tests**: Programs with errors (deadlocks, races, assertion failures) +- Tests for both linear and branching semantics + +## Project Structure + +``` +src/ + ├── gitmem.cc - Main entry point + ├── lang.hh - Language token definitions + ├── parser.cc - Parser implementation + ├── interpreter.cc - Interpreter core + ├── model_checker.cc - Model checking engine + ├── debugger.cc - Interactive debugger + ├── execution_state.hh - Thread and memory state + ├── sync_protocol.hh - Sync protocol interface + ├── linear/ - Linear sync protocol + └── branching/ - Branching sync protocols + ├── base_sync_protocol.cc + ├── eager/ - Eager conflict detection + └── lazy/ - Lazy conflict detection + +examples/ + ├── accept/semantics/ - Valid programs + │ ├── linear/ - Linear semantics tests + │ └── branching/ - Branching semantics tests + └── reject/semantics/ - Programs with bugs + ├── linear/ - Deadlocks, races for linear + └── branching/ - Bugs for branching semantics +``` + +## Executables + +The build produces two binaries: + +### `gitmem` +The main interpreter and model checker. Executes programs and generates execution diagrams. + +### `gitmem_trieste` +Parser diagnostic tool built on [Trieste](https://github.com/microsoft/Trieste). Use it to inspect the AST: + +```bash +./gitmem_trieste build program.gm +# Creates program.trieste with S-expression AST +``` ## VSCode Extension -You should be able to use `Developer: Install Extension from -Location` in the VSCode command palette to install a rudimentary -extension in the `gitmem-extension` directory and get syntax -highlighting.. +A syntax highlighting extension is available in `gitmem-extension/`. + +Install via: +1. Open VSCode Command Palette (`Cmd+Shift+P` / `Ctrl+Shift+P`) +2. Run: `Developer: Install Extension from Location` +3. Select the `gitmem-extension` directory + +This provides syntax highlighting for `.gm` files. + +## Output + +Gitmem generates GraphViz `.dot` files visualizing: + +- **Execution traces**: Thread operations and memory states +- **Revision graphs**: For branching semantics, shows branch/merge structure +- **Conflict detection**: Highlights data races and conflicts + +View `.dot` files with GraphViz: + +```bash +dot -Tpng output.dot -o output.png +``` + +## Exit Codes + +- `0` - All execution paths succeeded +- `1` - Assertion failure, deadlock, data race, or error detected +- Other non-zero codes indicate internal errors \ No newline at end of file diff --git a/examples/passing/semantics/addition.gm b/examples/accept/semantics/branching/addition.gm similarity index 100% rename from examples/passing/semantics/addition.gm rename to examples/accept/semantics/branching/addition.gm diff --git a/examples/passing/semantics/conditional_non_race.gm b/examples/accept/semantics/branching/conditional_non_race.gm similarity index 72% rename from examples/passing/semantics/conditional_non_race.gm rename to examples/accept/semantics/branching/conditional_non_race.gm index 1a4e743..b8d4206 100644 --- a/examples/passing/semantics/conditional_non_race.gm +++ b/examples/accept/semantics/branching/conditional_non_race.gm @@ -12,6 +12,7 @@ $t2 = spawn { }; join $t1; join $t2; +// FIXME in branching these assertions are always true, but not in linear assert (x != 0); assert (y != 0); assert (x == y); diff --git a/examples/passing/semantics/globals.gm b/examples/accept/semantics/branching/globals.gm similarity index 100% rename from examples/passing/semantics/globals.gm rename to examples/accept/semantics/branching/globals.gm diff --git a/examples/passing/semantics/if.gm b/examples/accept/semantics/branching/if.gm similarity index 100% rename from examples/passing/semantics/if.gm rename to examples/accept/semantics/branching/if.gm diff --git a/examples/passing/semantics/join_fastforward.gm b/examples/accept/semantics/branching/join_fastforward.gm similarity index 100% rename from examples/passing/semantics/join_fastforward.gm rename to examples/accept/semantics/branching/join_fastforward.gm diff --git a/examples/passing/semantics/join_nop.gm b/examples/accept/semantics/branching/join_nop.gm similarity index 100% rename from examples/passing/semantics/join_nop.gm rename to examples/accept/semantics/branching/join_nop.gm diff --git a/examples/passing/semantics/join_pulled_variable.gm b/examples/accept/semantics/branching/join_pulled_variable.gm similarity index 91% rename from examples/passing/semantics/join_pulled_variable.gm rename to examples/accept/semantics/branching/join_pulled_variable.gm index ae481e8..7054065 100644 --- a/examples/passing/semantics/join_pulled_variable.gm +++ b/examples/accept/semantics/branching/join_pulled_variable.gm @@ -9,6 +9,7 @@ t = spawn { }; assert(x == 2); }; +assert (x == 0); join t; assert (x == 2); join t2; diff --git a/examples/passing/semantics/join_spawn.gm b/examples/accept/semantics/branching/join_spawn.gm similarity index 100% rename from examples/passing/semantics/join_spawn.gm rename to examples/accept/semantics/branching/join_spawn.gm diff --git a/examples/passing/semantics/local.gm b/examples/accept/semantics/branching/local.gm similarity index 100% rename from examples/passing/semantics/local.gm rename to examples/accept/semantics/branching/local.gm diff --git a/examples/passing/semantics/lock.gm b/examples/accept/semantics/branching/lock.gm similarity index 100% rename from examples/passing/semantics/lock.gm rename to examples/accept/semantics/branching/lock.gm diff --git a/examples/passing/semantics/lock_as_sync.gm b/examples/accept/semantics/branching/lock_as_sync.gm similarity index 100% rename from examples/passing/semantics/lock_as_sync.gm rename to examples/accept/semantics/branching/lock_as_sync.gm diff --git a/examples/accept/semantics/branching/lock_relock_as_sync.gm b/examples/accept/semantics/branching/lock_relock_as_sync.gm new file mode 100644 index 0000000..28fe69d --- /dev/null +++ b/examples/accept/semantics/branching/lock_relock_as_sync.gm @@ -0,0 +1,11 @@ +x = 0; +$t = spawn { + lock l1; + x = 42; + unlock l1; +}; +lock l1; +x = 2; +unlock l1; +lock l1; +unlock l1; \ No newline at end of file diff --git a/examples/accept/semantics/branching/singleton.gm b/examples/accept/semantics/branching/singleton.gm new file mode 100644 index 0000000..b39cecc --- /dev/null +++ b/examples/accept/semantics/branching/singleton.gm @@ -0,0 +1,41 @@ +// Construct the singleton pattern: +// A shared variable should be initialised only once by any thread +// Once initialised all threads should read the same value +// Perform the initialisation using double checked locking +// Each thread: +// 1. Checks the variable +// 1a. If it is uninitialised (here we use 0), then take the lock +// 1b. Taking the lock may pull in other thread updates, so check if the +// variable is still uninitialised +// 1bi. If the variable is still uninitialised then we know we need to +// initialise it. So do that and store the initialised value in a local +// var. +// 1bii. Otherwise the variable was initialised and we can read the value +// 2a. Otherwise the variable was initialised and we can read the value +// 3. Check that in all code paths we read the expected initialised value + +instance = 0; + +$t1 = spawn { + if (instance == 0) { + lock l; + if (instance == 0) { + instance = 100; + } + unlock l; + } + assert(instance == 100); +}; + + +if (instance == 0) { + lock l; + if (instance == 0) { + instance = 100; + } + unlock l; +} +assert(instance == 100); + +join $t1; +assert(instance == 100); \ No newline at end of file diff --git a/examples/passing/semantics/spawn.gm b/examples/accept/semantics/branching/spawn.gm similarity index 100% rename from examples/passing/semantics/spawn.gm rename to examples/accept/semantics/branching/spawn.gm diff --git a/examples/passing/semantics/spawn_read_global.gm b/examples/accept/semantics/branching/spawn_read_global.gm similarity index 100% rename from examples/passing/semantics/spawn_read_global.gm rename to examples/accept/semantics/branching/spawn_read_global.gm diff --git a/examples/accept/semantics/linear/addition.gm b/examples/accept/semantics/linear/addition.gm new file mode 100644 index 0000000..3ebb8a9 --- /dev/null +++ b/examples/accept/semantics/linear/addition.gm @@ -0,0 +1,4 @@ +$r1 = 1; +x = 2; +y = $r1 + x; +assert(y + 1 != x + 1); diff --git a/examples/accept/semantics/linear/conditional_non_race.gm b/examples/accept/semantics/linear/conditional_non_race.gm new file mode 100644 index 0000000..baf06f1 --- /dev/null +++ b/examples/accept/semantics/linear/conditional_non_race.gm @@ -0,0 +1,18 @@ +x = 0; +y = 0; +$t1 = spawn { + if (x == 0) { + y = 1; + } +}; +$t2 = spawn { + if (y == 0) { + x = 1; + } +}; +join $t1; +join $t2; + +// We don't know the interleaving of the two threads, but we know that they won't conflict, +// and that the reads are reading consistent values. +// We can get all of x = 0, y = 1; x = 1, y = 0; or x = 1, y = 1; but we won't get x = 0, y = 0. \ No newline at end of file diff --git a/examples/accept/semantics/linear/globals.gm b/examples/accept/semantics/linear/globals.gm new file mode 100644 index 0000000..283e6eb --- /dev/null +++ b/examples/accept/semantics/linear/globals.gm @@ -0,0 +1,5 @@ +nop; +x = 0; +x = 2; +y = 2; +assert(x == y); diff --git a/examples/accept/semantics/linear/if.gm b/examples/accept/semantics/linear/if.gm new file mode 100644 index 0000000..eb270f6 --- /dev/null +++ b/examples/accept/semantics/linear/if.gm @@ -0,0 +1,13 @@ +$r = 0; +x = 0; +if ($r == 0) { + x = x + 1; +} else { + x = 0; +} +if ($r == 1) { + x = 0; +} else { + x = x + 1; +} +assert (x == 2); diff --git a/examples/accept/semantics/linear/join_fastforward.gm b/examples/accept/semantics/linear/join_fastforward.gm new file mode 100644 index 0000000..d488fe4 --- /dev/null +++ b/examples/accept/semantics/linear/join_fastforward.gm @@ -0,0 +1,4 @@ +x = 2; +$t = spawn { x = 3; }; +join $t; +assert(x == 3); diff --git a/examples/accept/semantics/linear/join_nop.gm b/examples/accept/semantics/linear/join_nop.gm new file mode 100644 index 0000000..ed09147 --- /dev/null +++ b/examples/accept/semantics/linear/join_nop.gm @@ -0,0 +1,4 @@ +x = 2; +$t = spawn { nop; }; +join $t; +assert(x == 2); diff --git a/examples/accept/semantics/linear/join_pulled_variable.gm b/examples/accept/semantics/linear/join_pulled_variable.gm new file mode 100644 index 0000000..ef82df7 --- /dev/null +++ b/examples/accept/semantics/linear/join_pulled_variable.gm @@ -0,0 +1,15 @@ +x = 0; +t = spawn { + assert(x == 0); + x = 2; + t2 = spawn { + assert(x == 2); + x = 14; + assert(x == 14); + }; + assert(x == 2); +}; +join t; +// assert (x == 2 || x == 42); || not support but, in linear we only know this to be true +join t2; +assert (x == 14); \ No newline at end of file diff --git a/examples/accept/semantics/linear/join_spawn.gm b/examples/accept/semantics/linear/join_spawn.gm new file mode 100644 index 0000000..ba5acf0 --- /dev/null +++ b/examples/accept/semantics/linear/join_spawn.gm @@ -0,0 +1,7 @@ +x = 1; +join spawn { + assert(x == 1); + x = 2; + assert(x == 2); +}; +assert(x == 2); \ No newline at end of file diff --git a/examples/accept/semantics/linear/local.gm b/examples/accept/semantics/linear/local.gm new file mode 100644 index 0000000..266605b --- /dev/null +++ b/examples/accept/semantics/linear/local.gm @@ -0,0 +1,7 @@ +$t1 = 1; +$t2 = 2; +nop; +$t1 = $t2; +nop; +assert($t1 == $t2); +nop; diff --git a/examples/accept/semantics/linear/lock.gm b/examples/accept/semantics/linear/lock.gm new file mode 100644 index 0000000..bc95524 --- /dev/null +++ b/examples/accept/semantics/linear/lock.gm @@ -0,0 +1,2 @@ +lock l1; +unlock l1; diff --git a/examples/accept/semantics/linear/lock_as_sync.gm b/examples/accept/semantics/linear/lock_as_sync.gm new file mode 100644 index 0000000..63fdd4b --- /dev/null +++ b/examples/accept/semantics/linear/lock_as_sync.gm @@ -0,0 +1,10 @@ +x = 0; +lock l1; +$t = spawn { + assert (x == 0); + lock l1; + unlock l1; + assert (x == 2); +}; +x = 2; +unlock l1; diff --git a/examples/accept/semantics/linear/singleton.gm b/examples/accept/semantics/linear/singleton.gm new file mode 100644 index 0000000..b39cecc --- /dev/null +++ b/examples/accept/semantics/linear/singleton.gm @@ -0,0 +1,41 @@ +// Construct the singleton pattern: +// A shared variable should be initialised only once by any thread +// Once initialised all threads should read the same value +// Perform the initialisation using double checked locking +// Each thread: +// 1. Checks the variable +// 1a. If it is uninitialised (here we use 0), then take the lock +// 1b. Taking the lock may pull in other thread updates, so check if the +// variable is still uninitialised +// 1bi. If the variable is still uninitialised then we know we need to +// initialise it. So do that and store the initialised value in a local +// var. +// 1bii. Otherwise the variable was initialised and we can read the value +// 2a. Otherwise the variable was initialised and we can read the value +// 3. Check that in all code paths we read the expected initialised value + +instance = 0; + +$t1 = spawn { + if (instance == 0) { + lock l; + if (instance == 0) { + instance = 100; + } + unlock l; + } + assert(instance == 100); +}; + + +if (instance == 0) { + lock l; + if (instance == 0) { + instance = 100; + } + unlock l; +} +assert(instance == 100); + +join $t1; +assert(instance == 100); \ No newline at end of file diff --git a/examples/accept/semantics/linear/spawn.gm b/examples/accept/semantics/linear/spawn.gm new file mode 100644 index 0000000..0200097 --- /dev/null +++ b/examples/accept/semantics/linear/spawn.gm @@ -0,0 +1,4 @@ +x = 2; +y = 2; +$t = spawn { nop; }; +assert(x == y); diff --git a/examples/accept/semantics/linear/spawn_read_global.gm b/examples/accept/semantics/linear/spawn_read_global.gm new file mode 100644 index 0000000..83c2126 --- /dev/null +++ b/examples/accept/semantics/linear/spawn_read_global.gm @@ -0,0 +1,8 @@ +x = 2; +y = 2; +$t = spawn { + assert(x == y); + x = 42; + assert(x == 42); +}; +assert(x == y); diff --git a/examples/oracle/counters.gm b/examples/oracle/counters.gm deleted file mode 100644 index 90126bd..0000000 --- a/examples/oracle/counters.gm +++ /dev/null @@ -1,23 +0,0 @@ -// Two threads update the same 'variable' to the same value but there -// should still be a race - -x = 0; -$t = spawn { - // The thread inherits a copy of the snapshot of the heap state known by the - // spawning thread at the point when the thread was spawned. - // So we always know that the assert(x == 0) will succeed. - // This thread will then proceed to mutate its versioned copy of the heap. - assert(x == 0); - x = 1; -}; - -// The spawning thread continues but will mutate and read only from its version -// of the heap which is isolated from the thread it just spawned. So, we know -// that assert(x == 0) will always succeed. -assert(x == 0); -x = 1; - -// When we join the threads, the joining thread will attempt to pull in the -// changes made to the thread. Even thought the two values are the same, this -// represents a data race and so we have to crash. -join $t; // Data race \ No newline at end of file diff --git a/examples/reject/semantics/branching/conditional_race.gm b/examples/reject/semantics/branching/conditional_race.gm new file mode 100644 index 0000000..427be52 --- /dev/null +++ b/examples/reject/semantics/branching/conditional_race.gm @@ -0,0 +1,35 @@ +// The racy schedule 0 1 1 1 2 2 2 0 0 (thread 1 gets the lock first) +// thread 1 r and flag inside a lock, and then mutates x outside of a lock +// thread 2 sees the set flag inside of a lock, and also mutates x outside of a lock +x = 0; +y = 0; +flag = 0; +$t1 = spawn { + lock l1; + $r = 0; + if (flag == 0) { + flag = 1; + $r = 1; + } + unlock l1; + if ($r == 1) { + x = 1; + } +}; +$t2 = spawn { + lock l1; + $r = 0; + if (flag == 0) { + flag = 1; + $r = 1; + } + unlock l1; + if ($r == 1) { + y = 1; + } else { + x = 1; + } +}; +join $t1; +join $t2; +assert (x != y); \ No newline at end of file diff --git a/examples/failing/semantics/deadlock.gm b/examples/reject/semantics/branching/deadlock.gm similarity index 100% rename from examples/failing/semantics/deadlock.gm rename to examples/reject/semantics/branching/deadlock.gm diff --git a/examples/failing/semantics/error_and_race.gm b/examples/reject/semantics/branching/error_and_race.gm similarity index 100% rename from examples/failing/semantics/error_and_race.gm rename to examples/reject/semantics/branching/error_and_race.gm diff --git a/examples/failing/semantics/failed_assertion.gm b/examples/reject/semantics/branching/failed_assertion.gm similarity index 100% rename from examples/failing/semantics/failed_assertion.gm rename to examples/reject/semantics/branching/failed_assertion.gm diff --git a/examples/failing/semantics/failed_assertion_many_schedules.gm b/examples/reject/semantics/branching/failed_assertion_many_schedules.gm similarity index 100% rename from examples/failing/semantics/failed_assertion_many_schedules.gm rename to examples/reject/semantics/branching/failed_assertion_many_schedules.gm diff --git a/examples/failing/semantics/failed_assertion_neq.gm b/examples/reject/semantics/branching/failed_assertion_neq.gm similarity index 100% rename from examples/failing/semantics/failed_assertion_neq.gm rename to examples/reject/semantics/branching/failed_assertion_neq.gm diff --git a/examples/failing/semantics/join_blocked_thread.gm b/examples/reject/semantics/branching/join_blocked_thread.gm similarity index 100% rename from examples/failing/semantics/join_blocked_thread.gm rename to examples/reject/semantics/branching/join_blocked_thread.gm diff --git a/examples/reject/semantics/branching/join_datarace.gm b/examples/reject/semantics/branching/join_datarace.gm new file mode 100644 index 0000000..59ea95b --- /dev/null +++ b/examples/reject/semantics/branching/join_datarace.gm @@ -0,0 +1,11 @@ +x = 2; +$t = spawn { + assert (x == 2); + x = 3; + assert (x == 3); +}; +assert(x == 2); +x = 4; +assert(x == 4); +join $t; +assert(x == x); \ No newline at end of file diff --git a/examples/failing/semantics/join_deadlock.gm b/examples/reject/semantics/branching/join_deadlock.gm similarity index 100% rename from examples/failing/semantics/join_deadlock.gm rename to examples/reject/semantics/branching/join_deadlock.gm diff --git a/examples/failing/semantics/join_nonexisting.gm b/examples/reject/semantics/branching/join_nonexisting.gm similarity index 100% rename from examples/failing/semantics/join_nonexisting.gm rename to examples/reject/semantics/branching/join_nonexisting.gm diff --git a/examples/failing/semantics/read_unassigned.gm b/examples/reject/semantics/branching/read_unassigned.gm similarity index 100% rename from examples/failing/semantics/read_unassigned.gm rename to examples/reject/semantics/branching/read_unassigned.gm diff --git a/examples/failing/semantics/test.gm b/examples/reject/semantics/branching/test.gm similarity index 100% rename from examples/failing/semantics/test.gm rename to examples/reject/semantics/branching/test.gm diff --git a/examples/failing/semantics/unlock.gm b/examples/reject/semantics/branching/unlock.gm similarity index 100% rename from examples/failing/semantics/unlock.gm rename to examples/reject/semantics/branching/unlock.gm diff --git a/examples/failing/semantics/unlock_another_threads_lock.gm b/examples/reject/semantics/branching/unlock_another_threads_lock.gm similarity index 100% rename from examples/failing/semantics/unlock_another_threads_lock.gm rename to examples/reject/semantics/branching/unlock_another_threads_lock.gm diff --git a/examples/reject/semantics/branching/variable_under_two_locks.gm b/examples/reject/semantics/branching/variable_under_two_locks.gm new file mode 100644 index 0000000..79fb341 --- /dev/null +++ b/examples/reject/semantics/branching/variable_under_two_locks.gm @@ -0,0 +1,11 @@ +x = 0; +$t1 = spawn { + lock l1; + x = 1; + unlock l1; +}; +lock l2; +x = 2; +unlock l2; +join $t1; +assert(x == x); \ No newline at end of file diff --git a/examples/failing/semantics/conditional_race.gm b/examples/reject/semantics/linear/conditional_race.gm similarity index 65% rename from examples/failing/semantics/conditional_race.gm rename to examples/reject/semantics/linear/conditional_race.gm index 6dd0a36..6adbc05 100644 --- a/examples/failing/semantics/conditional_race.gm +++ b/examples/reject/semantics/linear/conditional_race.gm @@ -1,6 +1,9 @@ x = 0; y = 0; flag = 0; +// if t1 gets the lock l1 first then there is a race +// t1 will see the flag is not set, set the flag to 1, and set its x to 1 +// t2 will see the flag is set, and will also set its x at 1 $t1 = spawn { lock l1; $r = 0; @@ -29,4 +32,4 @@ $t2 = spawn { }; join $t1; join $t2; -assert (x != y); +assert (x != y); \ No newline at end of file diff --git a/examples/reject/semantics/linear/deadlock.gm b/examples/reject/semantics/linear/deadlock.gm new file mode 100644 index 0000000..ccf9fe9 --- /dev/null +++ b/examples/reject/semantics/linear/deadlock.gm @@ -0,0 +1,12 @@ +$t1 = spawn { + lock l1; + lock l2; + unlock l2; + unlock l1; +}; +$t2 = spawn { + lock l2; + lock l1; + unlock l1; + unlock l2; +}; diff --git a/examples/reject/semantics/linear/error_and_race.gm b/examples/reject/semantics/linear/error_and_race.gm new file mode 100644 index 0000000..636e9e7 --- /dev/null +++ b/examples/reject/semantics/linear/error_and_race.gm @@ -0,0 +1,13 @@ +x = 1; +$t1 = spawn { + lock l1; + x = 1; + unlock l1; +}; +$t2 = spawn { + x = 2; +}; +join $t2; +lock l1; +assert(x == 1); +unlock l2; diff --git a/examples/reject/semantics/linear/failed_assertion.gm b/examples/reject/semantics/linear/failed_assertion.gm new file mode 100644 index 0000000..2c4edd7 --- /dev/null +++ b/examples/reject/semantics/linear/failed_assertion.gm @@ -0,0 +1,6 @@ +x = 0; +$t = spawn { + x = 1; +}; +join $t; +assert(x == 0); diff --git a/examples/reject/semantics/linear/failed_assertion_many_schedules.gm b/examples/reject/semantics/linear/failed_assertion_many_schedules.gm new file mode 100644 index 0000000..9d6d1e9 --- /dev/null +++ b/examples/reject/semantics/linear/failed_assertion_many_schedules.gm @@ -0,0 +1,5 @@ +$t1 = spawn { assert(1 == 2); }; +$t2 = spawn { nop; }; +$t3 = spawn { nop; }; +$t4 = spawn { nop; }; +$t5 = spawn { nop; }; diff --git a/examples/reject/semantics/linear/failed_assertion_neq.gm b/examples/reject/semantics/linear/failed_assertion_neq.gm new file mode 100644 index 0000000..bd2db83 --- /dev/null +++ b/examples/reject/semantics/linear/failed_assertion_neq.gm @@ -0,0 +1,6 @@ +x = 0; +$t = spawn { + x = 1; +}; +join $t; +assert(x != 1); diff --git a/examples/reject/semantics/linear/join_blocked_thread.gm b/examples/reject/semantics/linear/join_blocked_thread.gm new file mode 100644 index 0000000..4aafbda --- /dev/null +++ b/examples/reject/semantics/linear/join_blocked_thread.gm @@ -0,0 +1,2 @@ +$t2 = spawn { join 0; }; +join $t2; \ No newline at end of file diff --git a/examples/failing/semantics/join_datarace.gm b/examples/reject/semantics/linear/join_datarace.gm similarity index 100% rename from examples/failing/semantics/join_datarace.gm rename to examples/reject/semantics/linear/join_datarace.gm diff --git a/examples/reject/semantics/linear/join_deadlock.gm b/examples/reject/semantics/linear/join_deadlock.gm new file mode 100644 index 0000000..e4feaee --- /dev/null +++ b/examples/reject/semantics/linear/join_deadlock.gm @@ -0,0 +1 @@ +join spawn { join 0; }; \ No newline at end of file diff --git a/examples/reject/semantics/linear/join_nonexisting.gm b/examples/reject/semantics/linear/join_nonexisting.gm new file mode 100644 index 0000000..3a85a1b --- /dev/null +++ b/examples/reject/semantics/linear/join_nonexisting.gm @@ -0,0 +1,2 @@ +$t = 42; +join $t; diff --git a/examples/reject/semantics/linear/read_unassigned.gm b/examples/reject/semantics/linear/read_unassigned.gm new file mode 100644 index 0000000..94120df --- /dev/null +++ b/examples/reject/semantics/linear/read_unassigned.gm @@ -0,0 +1 @@ +x = y; \ No newline at end of file diff --git a/examples/reject/semantics/linear/test.gm b/examples/reject/semantics/linear/test.gm new file mode 100644 index 0000000..6958578 --- /dev/null +++ b/examples/reject/semantics/linear/test.gm @@ -0,0 +1,11 @@ +nop; +x = 0; +$t = spawn { + lock l1; + $r = 1; + x = $r; + unlock l1; +}; +x = 2; // Data race! +join $t; +assert(x == 2); \ No newline at end of file diff --git a/examples/reject/semantics/linear/unlock.gm b/examples/reject/semantics/linear/unlock.gm new file mode 100644 index 0000000..bc083f6 --- /dev/null +++ b/examples/reject/semantics/linear/unlock.gm @@ -0,0 +1,2 @@ +unlock l1; +x=1; \ No newline at end of file diff --git a/examples/reject/semantics/linear/unlock_another_threads_lock.gm b/examples/reject/semantics/linear/unlock_another_threads_lock.gm new file mode 100644 index 0000000..41ef1ca --- /dev/null +++ b/examples/reject/semantics/linear/unlock_another_threads_lock.gm @@ -0,0 +1,5 @@ +lock l; +t = spawn { + unlock l; +}; +join t; diff --git a/examples/reject/semantics/linear/variable_under_two_locks.gm b/examples/reject/semantics/linear/variable_under_two_locks.gm new file mode 100644 index 0000000..79fb341 --- /dev/null +++ b/examples/reject/semantics/linear/variable_under_two_locks.gm @@ -0,0 +1,11 @@ +x = 0; +$t1 = spawn { + lock l1; + x = 1; + unlock l1; +}; +lock l2; +x = 2; +unlock l2; +join $t1; +assert(x == x); \ No newline at end of file diff --git a/examples/failing/syntax/bad_add.gm b/examples/reject/syntax/bad_add.gm similarity index 100% rename from examples/failing/syntax/bad_add.gm rename to examples/reject/syntax/bad_add.gm diff --git a/examples/failing/syntax/bad_add2.gm b/examples/reject/syntax/bad_add2.gm similarity index 100% rename from examples/failing/syntax/bad_add2.gm rename to examples/reject/syntax/bad_add2.gm diff --git a/examples/failing/syntax/bad_assert.gm b/examples/reject/syntax/bad_assert.gm similarity index 100% rename from examples/failing/syntax/bad_assert.gm rename to examples/reject/syntax/bad_assert.gm diff --git a/examples/failing/syntax/bad_condition.gm b/examples/reject/syntax/bad_condition.gm similarity index 100% rename from examples/failing/syntax/bad_condition.gm rename to examples/reject/syntax/bad_condition.gm diff --git a/examples/failing/syntax/bad_join.gm b/examples/reject/syntax/bad_join.gm similarity index 100% rename from examples/failing/syntax/bad_join.gm rename to examples/reject/syntax/bad_join.gm diff --git a/examples/failing/syntax/bad_lhs.gm b/examples/reject/syntax/bad_lhs.gm similarity index 100% rename from examples/failing/syntax/bad_lhs.gm rename to examples/reject/syntax/bad_lhs.gm diff --git a/examples/failing/syntax/bad_lock.gm b/examples/reject/syntax/bad_lock.gm similarity index 100% rename from examples/failing/syntax/bad_lock.gm rename to examples/reject/syntax/bad_lock.gm diff --git a/examples/failing/syntax/bad_rhs.gm b/examples/reject/syntax/bad_rhs.gm similarity index 100% rename from examples/failing/syntax/bad_rhs.gm rename to examples/reject/syntax/bad_rhs.gm diff --git a/examples/failing/syntax/empty.gm b/examples/reject/syntax/empty.gm similarity index 100% rename from examples/failing/syntax/empty.gm rename to examples/reject/syntax/empty.gm diff --git a/examples/failing/syntax/empty_assert.gm b/examples/reject/syntax/empty_assert.gm similarity index 100% rename from examples/failing/syntax/empty_assert.gm rename to examples/reject/syntax/empty_assert.gm diff --git a/examples/failing/syntax/empty_assign.gm b/examples/reject/syntax/empty_assign.gm similarity index 100% rename from examples/failing/syntax/empty_assign.gm rename to examples/reject/syntax/empty_assign.gm diff --git a/examples/failing/syntax/empty_brace.gm b/examples/reject/syntax/empty_brace.gm similarity index 100% rename from examples/failing/syntax/empty_brace.gm rename to examples/reject/syntax/empty_brace.gm diff --git a/examples/failing/syntax/if_no_brace.gm b/examples/reject/syntax/if_no_brace.gm similarity index 100% rename from examples/failing/syntax/if_no_brace.gm rename to examples/reject/syntax/if_no_brace.gm diff --git a/examples/failing/syntax/if_no_cond.gm b/examples/reject/syntax/if_no_cond.gm similarity index 100% rename from examples/failing/syntax/if_no_cond.gm rename to examples/reject/syntax/if_no_cond.gm diff --git a/examples/failing/syntax/no_semicolon.gm b/examples/reject/syntax/no_semicolon.gm similarity index 100% rename from examples/failing/syntax/no_semicolon.gm rename to examples/reject/syntax/no_semicolon.gm diff --git a/examples/failing/syntax/spurious_else.gm b/examples/reject/syntax/spurious_else.gm similarity index 100% rename from examples/failing/syntax/spurious_else.gm rename to examples/reject/syntax/spurious_else.gm diff --git a/examples/failing/syntax/top_eq.gm b/examples/reject/syntax/top_eq.gm similarity index 100% rename from examples/failing/syntax/top_eq.gm rename to examples/reject/syntax/top_eq.gm diff --git a/examples/failing/unassigned_reg.gm b/examples/reject/unassigned_reg.gm similarity index 100% rename from examples/failing/unassigned_reg.gm rename to examples/reject/unassigned_reg.gm diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc new file mode 100644 index 0000000..c15c82f --- /dev/null +++ b/src/branching/base_sync_protocol.cc @@ -0,0 +1,164 @@ +#include "base_sync_protocol.hh" +#include "overloaded.hh" +#include "branching/eager/sync_protocol.hh" +#include "branching/lazy/sync_protocol.hh" + +namespace gitmem { + +namespace branching { + +LocalVersionStore& get_store(ThreadContext& ctx) { + return static_cast(*ctx.sync); +} + +LockState& get_store(Lock& ctx) { + return static_cast(*ctx.sync); +} + +// -------------------- +// BranchingSyncProtocolBase +// -------------------- + +BranchingSyncProtocolBase::~BranchingSyncProtocolBase() = default; + +std::ostream &BranchingSyncProtocolBase::print(std::ostream &os) const { + os << _global_store; + return os; +} + +ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, + const std::string &var) { + auto& store = get_store(ctx); + + return std::visit(overloaded{ + [](std::monostate) -> ReadResult { return std::monostate{}; }, + [](const ValueWithSource& v) -> ReadResult { return v; }, + [&](const Conflict& c) -> ReadResult { + return std::make_shared( + c.obj, std::pair{c.timestamp_a, c.timestamp_b} + ); + }, + + }, store.read(var)); +} + +void BranchingSyncProtocolBase::write(ThreadContext &ctx, const std::string &var, + ValueWithSource value) { + auto& store = get_store(ctx); + store.stage(var, value); +} + +std::optional> +BranchingSyncProtocolBase::on_spawn(ThreadContext &parent, ThreadContext &child) { + auto& parent_store = get_store(parent); + parent_store.commit_staging(); + + auto& child_store = get_store(child); + child_store.adopt_history(parent_store); + + // a conflict cannot occur here + return std::nullopt; +} + +std::optional> +BranchingSyncProtocolBase::on_join(ThreadContext &joiner, ThreadContext &joinee) { + auto& joiner_store = get_store(joiner); + auto& joinee_store = get_store(joinee); + + joiner_store.commit_staging(); + assert(joinee_store.has_commited() && "joinee has staged changes"); + + std::optional conflict = joiner_store.merge_with_commit(joinee_store.get_head()); + if (conflict) { + return std::make_shared( + conflict->obj, + std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); + } + + return std::nullopt; +} + +std::optional> +BranchingSyncProtocolBase::on_start(ThreadContext &thread) { + // nothing to do, the thread will have inhereted the parent commit on spawn + return std::nullopt; +}; + +std::optional> +BranchingSyncProtocolBase::on_end(ThreadContext &thread) { + auto& store = get_store(thread); + store.commit_staging(); + + return std::nullopt; +}; + +std::optional> +BranchingSyncProtocolBase::on_lock(ThreadContext &thread, Lock &lock) { + auto& store = get_store(thread); + store.commit_staging(); + + LockState& lock_state = get_store(lock); + std::shared_ptr lock_commit = lock_state.commit; + + if (lock_commit != nullptr) { + std::optional conflict = store.merge_with_commit(lock_commit); + if (conflict) { + return std::make_shared( + conflict->obj, + std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); + } + + lock_state.commit = store.get_head(); + } + + return std::nullopt; +} + +std::optional> +BranchingSyncProtocolBase::on_unlock(ThreadContext &thread, Lock &lock) { + auto& store = get_store(thread); + store.commit_staging(); + + // we know that the last committer was this thread, so no need to merge + // this sort of mixes protocol logic and lock state, i am unsure if this is ideal + LockState& lock_state = get_store(lock); + lock_state.commit = store.get_head(); + + return std::nullopt; +} + +std::string BranchingSyncProtocolBase::build_revision_graph_dot( + const std::vector& thread_states) const { + + std::vector> heads; + + for (const ThreadSyncState* state_ptr : thread_states) { + const auto* local_store = dynamic_cast(state_ptr); + if (local_store && local_store->get_head()) { + heads.push_back(local_store->get_head()); + } + } + + return build_commit_graph_dot(heads); +} + +bool BranchingSyncProtocolBase::is_scheduling_point(SyncOperation op) const { + // For branching protocol, only operations that actually synchronize state + // (lock/unlock) or require waiting (join) are scheduling points + switch (op) { + case SyncOperation::Lock: + case SyncOperation::Unlock: + case SyncOperation::Join: + return true; + case SyncOperation::Spawn: + case SyncOperation::Start: + case SyncOperation::End: + // These just inherit/commit locally - no scheduling decision needed + return false; + } + return false; +} + +} // end branching + +} // end gitmem \ No newline at end of file diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh new file mode 100644 index 0000000..2e1091a --- /dev/null +++ b/src/branching/base_sync_protocol.hh @@ -0,0 +1,61 @@ +#pragma once + +#include "../sync_protocol.hh" +#include "base_version_store.hh" + +namespace gitmem { + +using BranchingConflict = Conflict; + +namespace branching { + +class BranchingSyncProtocolBase : public SyncProtocol { +protected: + GlobalVersionStore _global_store; + bool verbose_commits; + + explicit BranchingSyncProtocolBase(bool verbose_commits) + : verbose_commits(verbose_commits) {} + +public: + ~BranchingSyncProtocolBase() override; + + std::unique_ptr clone() const override = 0; + + ReadResult read(ThreadContext &ctx, const std::string &var) override; + + void write(ThreadContext &ctx, const std::string &var, + ValueWithSource value) override; + + std::optional> + on_spawn(ThreadContext &parent, ThreadContext &child) override; + + std::optional> + on_join(ThreadContext &joiner, ThreadContext &joinee) override; + + std::optional> + on_start(ThreadContext &thread) override; + + std::optional> + on_end(ThreadContext &thread) override; + + std::optional> + on_lock(ThreadContext &thread, Lock &lock) override; + + std::optional> + on_unlock(ThreadContext &thread, Lock &lock) override; + + std::ostream &print(std::ostream &os) const override; + + std::string build_revision_graph_dot(const std::vector& thread_states) const override; + + bool is_scheduling_point(SyncOperation op) const override; + + std::unique_ptr make_lock_state() const override { + return std::make_unique(); + } +}; + +} // end branching + +} // end gitmem \ No newline at end of file diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc new file mode 100644 index 0000000..e3d0360 --- /dev/null +++ b/src/branching/base_version_store.cc @@ -0,0 +1,259 @@ +#include "base_version_store.hh" +#include +#include +#include +#include +#include "debug.hh" +#include + +namespace gitmem { + +namespace branching { + +// Helper: recursive print with 2-space indentation and cycle protection +void print_commit_recursive(std::ostream& os, + const std::shared_ptr& commit, + std::unordered_set& visited, + int depth = 0) +{ + if (!commit) return; + if (!visited.insert(commit.get()).second) { + os << std::string(depth * 2, ' ') << "(already printed commit " << commit->id << ")\n"; + return; + } + + os << std::string(depth * 2, ' ') << "Commit " << commit->id << " {\n"; + + // Print changes + for (const auto& [obj, val] : commit->changes) { + os << std::string((depth + 1) * 2, ' ') << obj << " -> " << val.value << "\n"; + } + + // Print parents + if (!commit->parents.empty()) { + os << std::string((depth + 1) * 2, ' ') << "Parents: "; + for (size_t i = 0; i < commit->parents.size(); ++i) { + os << commit->parents[i]->id; + if (i + 1 < commit->parents.size()) os << ", "; + } + os << "\n"; + } + + os << std::string(depth * 2, ' ') << "}\n"; + + // Recursively print parents + for (auto& parent : commit->parents) { + print_commit_recursive(os, parent, visited, depth + 1); + } +} + +// operator<< for Commit +std::ostream& operator<<(std::ostream& os, const Commit& commit) { + std::unordered_set visited; + // Wrap the commit in a shared_ptr to reuse the recursive helper + print_commit_recursive(os, std::make_shared(commit), visited); + return os; +} + +void LocalVersionStore::stage(std::string obj, ValueWithSource value) { + staging[obj] = value; +} + +void LocalVersionStore::commit_staging() { + // No-op commit does nothing unless verbose mode + if (staging.empty() && !verbose) { + return; + } + + // Create the new commit with the staged changes + auto new_commit = std::make_shared(base_timestamp++, std::move(staging)); + + // Update last_writer for each staged variable + for (const auto& [obj, _] : new_commit->changes) { + last_writer[obj] = new_commit; + } + + // Clear staging + staging.clear(); + + // Set parent to previous head if it exists + if (head) + new_commit->parents.push_back(head); + + // Update head + head = new_commit; +} + +BranchingReadResult LocalVersionStore::read(std::string var) const { + auto it = staging.find(var); + if (it != staging.end()) + return it->second; + + return get_committed(var); +} + +void LocalVersionStore::adopt_history(const LocalVersionStore& other) { + // Inherit the DAG head + head = other.head; + + // Inherit the last_writer cache so the child sees all latest commits + last_writer = other.last_writer; +} + +std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { + os << "LocalVersionStore{" + << "base=" << store.base_timestamp + << ", head="; + + if (store.head) + os << store.head->id; + else + os << "null"; + + os << ", staged={"; + + bool first = true; + for (const auto& [obj, val] : store.staging) { + if (!first) os << ", "; + first = false; + os << obj << "->" << val.value; + } + + os << "}}"; + return os; +} + +bool LocalVersionStore::operator==(const LocalVersionStore& other) const { + return base_timestamp == other.base_timestamp && + head == other.head && + staging == other.staging; +} + +std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { + os << "GlobalVersionStore()"; + return os; +} + +bool can_reach(const std::shared_ptr& commit, const std::shared_ptr& other, std::unordered_map, bool>& memo) { + if (!commit) + return false; + + if (commit == other) + return true; + + auto it = memo.find(commit); + if (it != memo.end()) + return it->second; + + for (const auto& parent : commit->parents) { + if (can_reach(parent, other, memo)) { + memo[commit] = true; + return true; + } + } + + memo[commit] = false; + return false; +} + +std::string build_commit_graph_dot(const std::vector>& leaves) { + std::ostringstream dot; + dot << "digraph CommitGraph {\n"; + dot << " rankdir=BT;\n"; + dot << " node [shape=box];\n"; + + std::unordered_set visited; + std::unordered_map>> commits_by_thread; + std::stack> stack; + + // First pass: collect all commits and organize by thread + for (const auto& leaf : leaves) + if (leaf) stack.push(leaf); + + while (!stack.empty()) { + auto commit = stack.top(); + stack.pop(); + + if (!commit || !visited.insert(commit.get()).second) + continue; + + commits_by_thread[commit->id.thread].push_back(commit); + + for (const auto& parent : commit->parents) { + if (parent) stack.push(parent); + } + } + + // Create subgraph clusters for each thread + for (const auto& [thread_id, commits] : commits_by_thread) { + dot << " subgraph cluster_" << thread_id << " {\n"; + dot << " label=\"Thread " << thread_id << "\";\n"; + dot << " style=dashed;\n"; + + for (const auto& commit : commits) { + const std::string cid = to_string(commit->id); + + std::ostringstream label; + label << cid; + + // Mark merge commits + if (commit->parents.size() >= 2) { + label << " (merge"; + if (commit->conflicted) { + label << " - CONFLICT"; + } + label << ")"; + } + + if (!commit->changes.empty()) { + label << "\\n"; + bool first = true; + for (const auto& [obj, val] : commit->changes) { + if (!first) label << "\\n"; + first = false; + label << obj << "→" << val.value; + } + } + + // Style merge commits differently + if (commit->parents.size() >= 2) { + std::string fillcolor = commit->conflicted ? "pink" : "lightgray"; + dot << " \"" << cid << "\" [label=\"" << label.str() << "\", style=filled, fillcolor=" << fillcolor << "];\n"; + } else { + dot << " \"" << cid << "\" [label=\"" << label.str() << "\"];\n"; + } + } + + dot << " }\n"; + } + + // Draw edges (outside clusters so they can cross boundaries) + visited.clear(); + for (const auto& leaf : leaves) + if (leaf) stack.push(leaf); + + while (!stack.empty()) { + auto commit = stack.top(); + stack.pop(); + + if (!commit || !visited.insert(commit.get()).second) + continue; + + const std::string cid = to_string(commit->id); + + for (const auto& parent : commit->parents) { + if (!parent) continue; + + const std::string pid = to_string(parent->id); + dot << " \"" << cid << "\" -> \"" << pid << "\";\n"; + stack.push(parent); + } + } + + dot << "}\n"; + return dot.str(); +} + +} // branching + +} // gitmem \ No newline at end of file diff --git a/src/branching/base_version_store.hh b/src/branching/base_version_store.hh new file mode 100644 index 0000000..c62ff9f --- /dev/null +++ b/src/branching/base_version_store.hh @@ -0,0 +1,147 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "thread_id.hh" +#include "sync_state.hh" +#include "read_result.hh" + +namespace gitmem { + +namespace branching { + +struct Timestamp { + ThreadID thread; + size_t counter; + + auto operator<=>(const Timestamp &) const = default; + + // pre-increment + Timestamp& operator++() { + ++counter; + return *this; + } + + // post-increment + Timestamp operator++(int) { + Timestamp old = *this; + ++(*this); + return old; + } + + friend std::ostream &operator<<(std::ostream &os, + const Timestamp &ts) { + os << ts.thread << ":" << ts.counter; + return os; + } +}; + +inline std::string to_string(const Timestamp& ts) { + std::ostringstream ss; + ss << ts; + return ss.str(); +} + +struct Commit { + Timestamp id; + std::unordered_map changes; + std::vector> parents; + bool conflicted = false; +}; + +std::string build_commit_graph_dot(const std::vector>& leaves); + +bool can_reach(const std::shared_ptr& commit, const std::shared_ptr& lca, std::unordered_map, bool>& memo); + +std::ostream& operator<<(std::ostream& os, const Commit& commit); + +struct Conflict { + std::string obj; + Timestamp timestamp_a; + Timestamp timestamp_b; +}; + +using BranchingReadResult = std::variant; + +inline std::ostream& operator<<(std::ostream& os, const Conflict& c) { + return os << "Conflict{obj=" << c.obj + << ", timestamp_a=" << c.timestamp_a + << ", timestamp_b=" << c.timestamp_b << "}"; +} + +class LocalVersionStore : public ThreadSyncState { +protected: + Timestamp base_timestamp; + std::shared_ptr head; + std::unordered_map staging; + + std::unordered_map> last_writer; // cached + + bool verbose; + +public: + ~LocalVersionStore() = default; + + LocalVersionStore(ThreadID tid, bool verbose = false): base_timestamp(tid, 0), verbose(verbose) {} + + void stage(std::string obj, ValueWithSource value); + void commit_staging(); + + bool has_commited() { return staging.empty(); } + + std::shared_ptr get_head() const { return head; } + +private: + virtual BranchingReadResult get_committed(std::string var) const = 0; + +public: + BranchingReadResult read(std::string var) const; + + void adopt_history(const LocalVersionStore& other); + virtual std::optional merge_with_commit(const std::shared_ptr& other_head) = 0; + + friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store); + std::ostream &print(std::ostream &os) const override { + os << *dynamic_cast(this); + return os; + } + + bool operator==(const LocalVersionStore& other) const; + bool operator==(const ThreadSyncState& other) const override { + auto* o = dynamic_cast(&other); + if (!o) + return false; + return *this == *o; + } +}; + +class GlobalVersionStore { +public: + friend std::ostream& operator<<(std::ostream&, const GlobalVersionStore&); +}; + +class LockState : public LockSyncState { +public: + ~LockState() = default; + + std::shared_ptr commit; + + inline std::ostream &print(std::ostream &os) const override { + os << "LockState{commit="; + if (commit) + os << commit->id; + else + os << "empty"; + os << "}"; + return os; + } +}; + +} // namespace branching + +} // namespace gitmem \ No newline at end of file diff --git a/src/branching/eager/sync_protocol.hh b/src/branching/eager/sync_protocol.hh new file mode 100644 index 0000000..adcae17 --- /dev/null +++ b/src/branching/eager/sync_protocol.hh @@ -0,0 +1,28 @@ +#pragma once + +#include "branching/base_sync_protocol.hh" +#include "branching/eager/version_store.hh" + +namespace gitmem { + +namespace branching { + +class BranchingEagerSyncProtocol final : public BranchingSyncProtocolBase { +public: + explicit BranchingEagerSyncProtocol(bool verbose = false) + : BranchingSyncProtocolBase(verbose) {} + + ~BranchingEagerSyncProtocol() = default; + + std::unique_ptr clone() const override { + return std::make_unique(verbose_commits); + } + + std::unique_ptr make_thread_state(ThreadID tid) const override { + return std::make_unique(tid, verbose_commits); + } +}; + +} + +} \ No newline at end of file diff --git a/src/branching/eager/version_store.cc b/src/branching/eager/version_store.cc new file mode 100644 index 0000000..9da040a --- /dev/null +++ b/src/branching/eager/version_store.cc @@ -0,0 +1,183 @@ +#include "branching/eager/version_store.hh" +#include "debug.hh" + +#include + +namespace gitmem { + +namespace branching { + +bool traverse_until_lca( + const std::shared_ptr& commit, + const std::shared_ptr& lca, + std::unordered_map>& out_map, + std::unordered_set>& visited, + std::unordered_map, bool>& reach_memo) +{ + if (!commit || commit == lca || !visited.insert(commit).second) + return true; + + if (!can_reach(commit, lca, reach_memo)) + return true; + + for (const auto& [obj, _] : commit->changes) { + // first write seen dominates + if (out_map.find(obj) == out_map.end()) + out_map[obj] = commit; + } + + for (auto& parent : commit->parents) { + if (!traverse_until_lca(parent, lca, out_map, visited, reach_memo)) + return false; + } + + return true; +} + +std::shared_ptr +find_lowest_common_ancestor(std::shared_ptr a, + std::shared_ptr b) +{ + if (!a || !b) return nullptr; + if (a == b) return a; + + // Check if 'a' is an ancestor of 'b' + { + std::unordered_set> visited; + std::queue> q; + q.push(b); + while (!q.empty()) { + auto c = q.front(); q.pop(); + if (!c) continue; + if (c == a) return a; + if (!visited.insert(c).second) continue; + for (auto& p : c->parents) + q.push(p); + } + } + + // Check if 'b' is an ancestor of 'a' + { + std::unordered_set> visited; + std::queue> q; + q.push(a); + while (!q.empty()) { + auto c = q.front(); q.pop(); + if (!c) continue; + if (c == b) return b; + if (!visited.insert(c).second) continue; + for (auto& p : c->parents) + q.push(p); + } + } + + // Step 1: collect all ancestors of 'a' + std::unordered_set> ancestors_a; + std::queue> q; + q.push(a); + + while (!q.empty()) { + auto c = q.front(); q.pop(); + if (!c) continue; + if (!ancestors_a.insert(c).second) continue; // already visited + + for (auto& p : c->parents) + q.push(p); + } + + // Step 2: BFS from 'b' to find first common ancestor + std::unordered_set> visited_b; + q.push(b); + + while (!q.empty()) { + auto c = q.front(); q.pop(); + if (!c) continue; + if (!visited_b.insert(c).second) continue; + + if (ancestors_a.count(c)) + return c; // first common ancestor seen + + for (auto& p : c->parents) + q.push(p); + } + + return nullptr; // disjoint histories (shouldn’t happen) +} + +std::optional EagerLocalVersionStore::merge_with_commit(const std::shared_ptr& commit) { + assert(staging.empty()); + assert(commit != nullptr); + + // trivial case: same history + if (head == commit) + return std::nullopt; + + // Find lowest common ancestor of the two heads + std::shared_ptr lca = find_lowest_common_ancestor(head, commit); + verbose::out << "found lca of " << head->id << " and " << commit->id << " to be " << lca->id << std::endl; + + // Collect all writes after LCA for each branch + std::unordered_map> branch_a, branch_b; + std::unordered_set> visited; + + std::unordered_map, bool> reach_memo; + traverse_until_lca(head, lca, branch_a, visited, reach_memo); + visited.clear(); + traverse_until_lca(commit, lca, branch_b, visited, reach_memo); + + // 1. Eager conflict detection + std::optional conflict; + for (const auto& [obj, commit_a] : branch_a) { + auto it = branch_b.find(obj); + if (it != branch_b.end() && it->second != commit_a) { + conflict = Conflict{ + .obj = obj, + .timestamp_a = commit_a->id, + .timestamp_b = it->second->id + }; + break; + } + } + + // Create merge commit (even if conflicted, for visualization) + auto merge_commit = std::make_shared( + Commit{ + .id = base_timestamp++, + .changes = {}, // merge commit does not write anything + .parents = {head, commit}, + .conflicted = conflict.has_value() + } + ); + + // If there was a conflict, update head but return the conflict + if (conflict) { + head = merge_commit; + return conflict; + } + + // 2. Update thread-local last_writer incrementally + // Only overwrite variables that were touched along either branch after LCA + for (const auto& [obj, commit] : branch_a) + last_writer[obj] = commit; + + for (const auto& [obj, commit] : branch_b) + last_writer[obj] = commit; + + // 3. Variables not touched in either branch remain unchanged (from before LCA) + + // 4. Update head + head = merge_commit; + + return std::nullopt; +} + +BranchingReadResult EagerLocalVersionStore::get_committed(std::string var) const { + if (auto it = last_writer.find(var); it != last_writer.end()) + return it->second->changes.at(var); + + return std::monostate{}; +} + +} // end branching + +} // end gitmem \ No newline at end of file diff --git a/src/branching/eager/version_store.hh b/src/branching/eager/version_store.hh new file mode 100644 index 0000000..cc92be9 --- /dev/null +++ b/src/branching/eager/version_store.hh @@ -0,0 +1,22 @@ +#pragma once + +#include "branching/base_version_store.hh" + +namespace gitmem { + +namespace branching { + +class EagerLocalVersionStore : public LocalVersionStore { +public: + ~EagerLocalVersionStore() = default; + + EagerLocalVersionStore(ThreadID tid, bool verbose) : LocalVersionStore(tid, verbose) {} + + std::optional merge_with_commit(const std::shared_ptr&) override; + BranchingReadResult get_committed(std::string var) const override; + +}; + +} + +} \ No newline at end of file diff --git a/src/branching/lazy/sync_protocol.hh b/src/branching/lazy/sync_protocol.hh new file mode 100644 index 0000000..ac8a8bc --- /dev/null +++ b/src/branching/lazy/sync_protocol.hh @@ -0,0 +1,31 @@ +#pragma once + +#include "branching/base_sync_protocol.hh" +#include "branching/lazy/version_store.hh" + +namespace gitmem { + +namespace branching { + +class BranchingLazySyncProtocol final : public BranchingSyncProtocolBase { +private: + bool raise_early_conflicts; // currently not used + +public: + explicit BranchingLazySyncProtocol(bool verbose = false, bool raise_early_conflicts = false) + : BranchingSyncProtocolBase(verbose), raise_early_conflicts(raise_early_conflicts) {} + + ~BranchingLazySyncProtocol() = default; + + std::unique_ptr clone() const override { + return std::make_unique(verbose_commits, raise_early_conflicts); + } + + std::unique_ptr make_thread_state(ThreadID tid) const override { + return std::make_unique(tid, verbose_commits, raise_early_conflicts); + } +}; + +} + +} \ No newline at end of file diff --git a/src/branching/lazy/version_store.cc b/src/branching/lazy/version_store.cc new file mode 100644 index 0000000..5bdc6df --- /dev/null +++ b/src/branching/lazy/version_store.cc @@ -0,0 +1,96 @@ +#include "branching/lazy/version_store.hh" +#include "debug.hh" +#include +#include + +namespace gitmem { + +namespace branching { + +std::optional LazyLocalVersionStore::merge_with_commit(const std::shared_ptr& commit) { + assert(staging.empty()); + assert(commit != nullptr); + + // trivial case: same history + if (head == commit) + return std::nullopt; + + // Create merge commit (no changes itself) + auto merge_commit = std::make_shared( + Commit{ + .id = base_timestamp++, + .parents = {head, commit}, + .changes = {} // merge commit does not write anything + } + ); + + // don't check for conflicts, we do that when later read a variable + head = merge_commit; + + // whenever we merge, we loose all the information about the last writer + // last_writer.clear(); + + return std::nullopt; +} + +// Thought, if we merge two paths that conflict on a variable, but we never read it +// and just right to it, is that okay ? +// if (auto it = last_writer.find(number); it != last_writer.end()) { +// return it->second->changes.at(number); +// } + +BranchingReadResult LazyLocalVersionStore::get_committed(std::string var) const { + std::vector> writers; + + std::function)> dfs; + dfs = [&](std::shared_ptr c) { + if (!c) return; + + // Check if c is an ancestor of any existing writer + { + std::unordered_map, bool> reach_memo; + for (const auto& writer : writers) { + if (can_reach(writer, c, reach_memo)) { + // c is ancestor of existing writer, ignore this path + return; + } + } + } + + // Remove any existing writers that are ancestors of c + { + std::unordered_map, bool> reach_memo; + writers.erase( + std::remove_if(writers.begin(), writers.end(), + [&](const auto& writer) { return can_reach(c, writer, reach_memo); }), + writers.end() + ); + } + + if (c->changes.contains(var)) { + writers.push_back(c); + return; + } + + for (auto& p : c->parents) + dfs(p); + }; + + dfs(head); + + if (writers.empty()) return std::monostate{}; + if (writers.size() == 1) { + // last_writer[var] = writers[0]; + return writers[0]->changes.at(var); + } + + // conflict + auto a = writers[0]->id; + auto b = writers[1]->id; + return Conflict(var, a, b); +} + + +} + +} \ No newline at end of file diff --git a/src/branching/lazy/version_store.hh b/src/branching/lazy/version_store.hh new file mode 100644 index 0000000..261c154 --- /dev/null +++ b/src/branching/lazy/version_store.hh @@ -0,0 +1,25 @@ +#pragma once + +#include "branching/base_version_store.hh" + +namespace gitmem { + +namespace branching { + +class LazyLocalVersionStore : public LocalVersionStore { +private: + bool raise_early_conflicts; + +public: + ~LazyLocalVersionStore() = default; + + LazyLocalVersionStore(ThreadID tid, bool verbose, bool raise_early_conflicts) : LocalVersionStore(tid, verbose), raise_early_conflicts(raise_early_conflicts) {} + + std::optional merge_with_commit(const std::shared_ptr&) override; + BranchingReadResult get_committed(std::string var) const override; + +}; + +} + +} \ No newline at end of file diff --git a/src/conflict.hh b/src/conflict.hh new file mode 100644 index 0000000..9c1e49d --- /dev/null +++ b/src/conflict.hh @@ -0,0 +1,47 @@ +#pragma once + +#include + +namespace gitmem { + +struct ConflictBase { + virtual ~ConflictBase() = default; + virtual std::ostream &print(std::ostream &os) const = 0; + friend std::ostream &operator<<(std::ostream &os, + const ConflictBase &conflict) { + return conflict.print(os); + } + virtual bool operator==(const ConflictBase &other) const = 0; +}; + +template +struct Conflict : ConflictBase { + std::string var; + std::pair versions; + + Conflict(std::string var, std::pair versions) + : var(std::move(var)), versions(std::move(versions)) {} + + std::ostream &print(std::ostream &os) const override; + + bool operator==(const Conflict &other) const { + return var == other.var && versions == other.versions; + } + + bool operator==(const ConflictBase& other) const override { + auto* o = dynamic_cast(&other); + if (!o) + return false; + return *this == *o; + } +}; + +template +std::ostream &Conflict::print(std::ostream &os) const { + os << "conflict on " << var << " { " << versions.first << ", " + << versions.second << " }"; + return os; +} + + +} \ No newline at end of file diff --git a/src/debug.hh b/src/debug.hh new file mode 100644 index 0000000..19dc66e --- /dev/null +++ b/src/debug.hh @@ -0,0 +1,28 @@ +#pragma once + +#include + +namespace gitmem { + +namespace verbose { + +/* For debug printing */ +inline struct Verbose { + bool enabled = false; + + template const Verbose &operator<<(const T &msg) const { + if (enabled) + std::cout << msg; + return *this; + } + + const Verbose &operator<<(std::ostream &(*manip)(std::ostream &)) const { + if (enabled) + std::cout << manip; + return *this; + } +} out; + +} // namespace verbose + +} // namespace gitmem \ No newline at end of file diff --git a/src/debugger.cc b/src/debugger.cc index d0d3eef..5949e1e 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -1,380 +1,306 @@ #include +#include +#include "debug.hh" +#include "debugger.hh" #include "interpreter.hh" +#include "overloaded.hh" -namespace gitmem -{ - /** A command that can be parsed by the debugger. Some commands store a - * ThreadID argument. */ - struct Command - { - enum - { - Step, // Run a specified thread to next sync point - Finish, // Finish the rest of the program - Restart, // Start the program from the beginning - List, // List all threads - Print, // Print the execution graph - Graph, // Toggle automatically printing the execution graph - Quit, // Quit the interpreter - Info, // Show commands - Skip, // Do nothing, used for invalid commands - } cmd; - ThreadID argument = 0; - }; - - void show_global(const std::string &var, const Global &global) - { - std::cout << var << " = " << global.val - << " [" << (global.commit ? std::to_string(*global.commit) : "_") << "; "; - for (size_t i = 0; i < global.history.size(); ++i) - { - std::cout << global.history[i]; - if (i < global.history.size() - 1) - { - std::cout << ", "; - } - } - std::cout << "]" << std::endl; - } +namespace gitmem { +/** A command that can be parsed by the debugger. Some commands store a + * ThreadID argument. */ +struct Command { + enum { + Step, // Run a specified thread to next sync point + Finish, // Finish the rest of the program + Restart, // Start the program from the beginning + List, // List all threads + Print, // Print the execution graph + Graph, // Toggle automatically printing the execution graph + Quit, // Quit the interpreter + Info, // Show commands + Skip, // Do nothing, used for invalid commands + } cmd; + ThreadID argument = 0; +}; - /** Print the state of a thread, including its local and global variables, - * and the current position in the program. */ - void show_thread(const Thread &thread, size_t tid) - { - std::cout << "---- Thread " << tid << std::endl; - if (thread.ctx.locals.size() > 0) - { - for (auto &[reg, val] : thread.ctx.locals) - { - std::cout << reg << " = " << val << std::endl; - } - std::cout << "--" << std::endl; - } +/** Clear the terminal screen (platform-specific) */ +void clear_terminal() { +#ifdef _WIN32 + std::system("cls"); +#else + std::system("clear"); +#endif +} - if (thread.ctx.globals.size() > 0) - { - for (auto &[var, val] : thread.ctx.globals) - { - show_global(var, val); - } - std::cout << "--" << std::endl; - } +/** Print a visual separator line */ +void print_separator() { + std::cout << std::string(60, '=') << std::endl; +} - size_t idx = 0; - for (const auto &stmt : *thread.block) - { - if (idx == thread.pc) - { - std::cout << "-> "; - } - else - { - std::cout << " "; - } - // Fix indentation of nested blocks - auto s = std::string(stmt->location().view()); - s = std::regex_replace(s, std::regex("\n"), "\n "); - std::cout << s << ";" << std::endl; +/** Parse a command. See the help string for the 'Info' command for details. + */ +Command parse_command(std::string &input) { + auto command = std::string(input); + command.erase(0, command.find_first_not_of(" \t\n\r")); + command.erase(command.find_last_not_of(" \t\n\r") + 1); - idx++; - } - if (thread.pc == thread.block->size()) - { - std::cout << "-> " << std::endl; - } + if (command.find_first_not_of("0123456789") == std::string::npos) { + // Interpret numbers as stepping + return {Command::Step, std::stoul(command)}; + } else if (command == "s" || + (command.at(0) == 's' && !std::isalpha(command.at(1)))) { + auto arg = command.substr(1); + arg.erase(0, arg.find_first_not_of(" \t\n\r")); + if (arg.size() > 0 && + arg.find_first_not_of("0123456789") == std::string::npos) { + return {Command::Step, std::stoul(arg)}; + } else { + std::cout << "Expected thread id" << std::endl; + return {Command::Skip}; } + } else if (command == "q") { + return {Command::Quit}; + } else if (command == "r") { + return {Command::Restart}; + } else if (command == "f") { + return {Command::Finish}; + } else if (command == "l") { + return {Command::List}; + } else if (command == "g") { + return {Command::Graph}; + } else if (command == "p") { + return {Command::Print}; + } else if (command == "?") { + return {Command::Info}; + } else { + std::cout << "Unknown command: " << input << std::endl; + return {Command::Skip}; + } +} - void show_lock(const std::string &lock_name, const struct Lock &lock) - { - std::cout << lock_name << ": "; - if (lock.owner) - { - std::cout << "held by thread " << *lock.owner; - } - else - { - std::cout << ""; - } - std::cout << std::endl; - for (auto &[var, global] : lock.globals) - { - show_global(var, global); - } - } +enum class StepKind { + Progressed, // Thread made progress + Blocked, // Thread is blocked on sync + Terminated, // Thread terminated this step + Invalid // Invalid thread id, etc. +}; - /** Show the global context, including locks and non-completed threads. If - * show_all is true, show all threads, even those that have terminated - * normally. */ - void show_global_context(const GlobalContext &gctx, bool show_all = false) - { - auto &threads = gctx.threads; - bool showed_any = false; - for (size_t i = 0; i < threads.size(); i++) - { - auto thread = threads[i]; - if (show_all || !thread->terminated || *threads[i]->terminated != TerminationStatus::completed) - { - show_thread(*threads[i], i); - std::cout << std::endl; - showed_any = true; - } - } +struct StepUIResult { + StepKind kind; + std::optional termination; + std::optional message; - if (showed_any && gctx.locks.size() > 0) - { - std::cout << "---- Locks" << std::endl; + static StepUIResult progressed() { + return {StepKind::Progressed, std::nullopt, std::nullopt}; + } - for (const auto &[lock_name, lock] : gctx.locks) - { - show_lock(lock_name, lock); - } + static StepUIResult blocked(std::string msg) { + return {StepKind::Blocked, std::nullopt, std::move(msg)}; + } - if (gctx.locks.size() > 0) - std::cout << "--" << std::endl; - } - } + static StepUIResult terminated(TerminationStatus t) { + return {StepKind::Terminated, t, std::nullopt}; + } - /** Parse a command. See the help string for the 'Info' command for details. - */ - Command parse_command(std::string &input) - { - auto command = std::string(input); - command.erase(0, command.find_first_not_of(" \t\n\r")); - command.erase(command.find_last_not_of(" \t\n\r") + 1); - - if (command.find_first_not_of("0123456789") == std::string::npos) - { - // Interpret numbers as stepping - return {Command::Step, std::stoul(command)}; - } - else if (command == "s" || (command.at(0) == 's' && !std::isalpha(command.at(1)))) - { - auto arg = command.substr(1); - arg.erase(0, arg.find_first_not_of(" \t\n\r")); - if (arg.size() > 0 && arg.find_first_not_of("0123456789") == std::string::npos) - { - return {Command::Step, std::stoul(arg)}; - } - else - { - std::cout << "Expected thread id" << std::endl; - return {Command::Skip}; - } - } - else if (command == "q") - { - return {Command::Quit}; - } - else if (command == "r") - { - return {Command::Restart}; - } - else if (command == "f") - { - return {Command::Finish}; - } - else if (command == "l") - { - return {Command::List}; - } - else if (command == "g") - { - return {Command::Graph}; - } - else if (command == "p") - { - return {Command::Print}; - } - else if (command == "?") - { - return {Command::Info}; - } - else - { - std::cout << "Unknown command: " << input << std::endl; - return {Command::Skip}; - } + static StepUIResult invalid(std::string msg) { + return {StepKind::Invalid, std::nullopt, std::move(msg)}; + } + + bool has_message() { return message.has_value(); } + std::string& get_message() { return *message; } + + bool has_terminated() { return termination.has_value(); } + TerminationStatus& get_termination() { return *termination; } +}; + +StepUIResult step_thread(Interpreter& interp, ThreadID tid) { + GlobalContext& gctx = interp.context(); + + if (tid >= gctx.threads.size()) { + return StepUIResult::invalid( + "Invalid thread id: " + std::to_string(tid)); + } + + auto& thread = gctx.threads[tid]; + + if (thread.terminated) { + StepUIResult::terminated(*thread.terminated); + } + + auto prog_or_term = interp.progress_thread(gctx.threads[tid]); + + if (auto prog = std::get_if(&prog_or_term)) { + if (*prog == ProgressStatus::no_progress) { + auto stmt = thread.block->at(thread.pc); + return StepUIResult::blocked( + "Thread " + std::to_string(tid) + " is blocking on '" + + std::string(stmt->location().view()) + "'"); } + return StepUIResult::progressed(); + } - /** Perform the Step command on a given thread. Error messages are assigned - * to `msg`. The return value signals whether threads should be printed - * after stepping or not. */ - bool step_thread(ThreadID tid, GlobalContext &gctx, std::string &msg) - { - if (tid >= gctx.threads.size()) - { - msg = "Invalid thread id: " + std::to_string(tid); - return false; - } + auto term = std::get(prog_or_term); + return StepUIResult::terminated(term); +} - auto thread = gctx.threads[tid]; - if (auto term = thread->terminated) - { - if (*term == TerminationStatus::completed) - { - msg = "Thread " + std::to_string(tid) + " has terminated normally"; - } - else - { - msg = "Thread " + std::to_string(tid) + " has terminated with an error"; - } - return false; - } +/** Print the execution graph if requested */ +void maybe_print_graph(Interpreter& interp, + bool print_graphs, + const std::filesystem::path &output_file) { + if (print_graphs) { + interp.print_revision_graph(output_file); + interp.print_execution_graph(output_file); + verbose::out << "Execution graph written to " << output_file << std::endl; + } +} - auto prog_or_term = progress_thread(gctx, tid, thread); - if (ProgressStatus *prog = std::get_if(&prog_or_term)) - { - if (!*prog) - { - auto stmt = thread->block->at(thread->pc); - msg = "Thread " + std::to_string(tid) + " is blocking on '" + std::string(stmt->location().view()) + "'"; - return false; - } - } - else if (TerminationStatus *term = std::get_if(&prog_or_term)) - { - switch (*term) - { - case TerminationStatus::completed: - msg = "Thread " + std::to_string(tid) + " terminated normally"; - return true; - case TerminationStatus::datarace_exception: - // TODO: Say on which variable the datarace occurred. To - // do this, have pull return an optional variable that - // is in a race and have the data race exception - // remember that variable. - msg = "Thread " + std::to_string(tid) + " encountered a data race and was terminated"; - return false; - case TerminationStatus::assertion_failure_exception: - { - auto expr = thread->block->at(thread->pc) / Stmt / Expr; - msg = "Thread " + std::to_string(tid) + " failed assertion '" + std::string(expr->location().view()) + "' and was terminated"; - return false; - } - case TerminationStatus::unassigned_variable_read_exception: - throw std::runtime_error("Thread " + std::to_string(tid) + " read an uninitialised variable"); - case TerminationStatus::unlock_exception: - throw std::runtime_error("Thread " + std::to_string(tid) + " unlocked an unlocked lock"); - default: - throw std::runtime_error("Thread " + std::to_string(tid) + " has an unhandled termination state"); - } - } - return true; +/** Step a single thread and return the StepUIResult. Also prints the message. */ +StepUIResult do_step(Interpreter &interp, + ThreadID tid, + bool print_graphs, + const std::filesystem::path &output_file) { + StepUIResult result = step_thread(interp, tid); + if (result.has_message()) + std::cout << result.get_message() << std::endl; + if (result.has_terminated()) { + std::cout << "Thread " << tid << ": "; + std::visit( + overloaded{ + [&](const auto &t) { + // Any non-completed termination is exceptional + std::cout << t << std::endl; + } + }, + result.get_termination() + ); } - /** Interpret the AST in an interactive way, letting the user choose which - * thread to schedule next. */ - int interpret_interactive(const Node ast, const std::filesystem::path &output_file) - { - GlobalContext gctx(ast); - - size_t prev_no_threads = 1; - Command command = {Command::List}; - std::string msg = ""; - bool print_graphs = true; - gctx.print_execution_graph(output_file); - while (command.cmd != Command::Quit) - { - if (command.cmd != Command::Skip || prev_no_threads != gctx.threads.size()) - { - bool show_all = command.cmd == Command::List; - show_global_context(gctx, show_all); - } - prev_no_threads = gctx.threads.size(); + maybe_print_graph(interp, print_graphs, output_file); + return result; +} - if (!msg.empty()) - { - std::cout << msg << std::endl; - msg.clear(); - } +/** Reset the interpreter to a fresh state */ +void do_restart(Interpreter &interp, + const trieste::Node ast, + bool print_graphs, + const std::filesystem::path &output_file) { + interp = Interpreter(GlobalContext(ast, interp.context().protocol->clone())); + maybe_print_graph(interp, print_graphs, output_file); +} - std::cout << "> "; - std::string input; - std::getline(std::cin, input); - if (!input.empty() && input.find_first_not_of(" \t\n\r") != std::string::npos) - { - command = parse_command(input); - } +/** Print the list of threads and optionally all threads */ +void do_list(Interpreter &interp, bool show_all) { + // Uncomment the next line if you prefer clearing the screen + // clear_terminal(); + + print_separator(); + interp.print_state(std::cout, show_all); + print_separator(); +} - if (command.cmd == Command::Step) - { - auto tid = command.argument; - if (!step_thread(tid, gctx, msg)) command = {Command::Skip}; +void do_finish(Interpreter& interp, bool print_graphs, const std::filesystem::path &output_file) { + if (!interp.run()) { + std::cout << "Program finished successfully" << std::endl; + } else { + std::cout << "Program terminated with an error" << std::endl; + } - if (print_graphs) - { - gctx.print_execution_graph(output_file); - verbose << "Execution graph written to " << output_file << std::endl; - } - } - else if (command.cmd == Command::Finish) - { - // Finish the program - if (!run_threads(gctx)) - msg = "Program finished successfully"; - else - msg = "Program terminated with an error"; - - if (print_graphs) - { - gctx.print_execution_graph(output_file); - verbose << "Execution graph written to " << output_file << std::endl; - } + maybe_print_graph(interp, print_graphs, output_file); +} + +/** Print interactive command help */ +void print_help() { + std::cout << "Commands:\n"; + std::cout << "s [tid] - Step to next sync point in thread\n"; + std::cout << "[tid] - Step to next sync point in thread\n"; + std::cout << "f - Finish the program\n"; + std::cout << "r - Restart the program\n"; + std::cout << "l - List all threads\n"; + std::cout << "g - Toggle automatic execution graph printing\n"; + std::cout << "p - Print the execution graph immediately\n"; + std::cout << "q - Quit the interpreter\n"; + std::cout << "? - Display this help message\n"; +} + +/** Main interactive interpreter loop */ +int interpret_interactive(const trieste::Node ast, + const std::filesystem::path &output_file, + std::unique_ptr protocol) { + Interpreter interp(GlobalContext(ast, std::move(protocol))); + GlobalContext &gctx = interp.context(); + + size_t prev_no_threads = 1; + Command command = {Command::List}; + bool print_graphs = true; + + // clear the graph at the start + maybe_print_graph(interp, print_graphs, output_file); + + while (command.cmd != Command::Quit) { + // Print threads if new threads appeared or command is List + if (command.cmd != Command::Skip || prev_no_threads != gctx.threads.size()) { + do_list(interp, command.cmd == Command::List); + } + prev_no_threads = gctx.threads.size(); + + // Read user input + std::cout << "> "; + std::string input; + std::getline(std::cin, input); + if (!input.empty() && input.find_first_not_of(" \t\n\r") != std::string::npos) + command = parse_command(input); + + switch (command.cmd) { + case Command::Step: { + ThreadID tid = command.argument; + StepUIResult res = do_step(interp, tid, print_graphs, output_file); + if (res.kind != StepKind::Progressed && res.kind != StepKind::Terminated) + command = {Command::Skip}; + break; } - else if (command.cmd == Command::Restart) - { - // Start the program from the beginning - gctx = GlobalContext(ast); + + case Command::Finish: + do_finish(interp, print_graphs, output_file); + break; + + case Command::Restart: + do_restart(interp, ast, print_graphs, output_file); command = {Command::List}; - if (print_graphs) - { - gctx.print_execution_graph(output_file); - verbose << "Execution graph written to " << output_file << std::endl; - } - } - else if (command.cmd == Command::List) - { - // Listing is a no-op - } - else if (command.cmd == Command::Graph) - { - // Toggle printing execution graph automatically + break; + + case Command::List: + // Already handled before reading input, no-op here + break; + + case Command::Graph: print_graphs = !print_graphs; - std::cout << "graphs " << (print_graphs ? "will" : "won't") << " print automatically" << std::endl; + std::cout << "Graphs " << (print_graphs ? "will" : "won't") + << " print automatically" << std::endl; command = {Command::Skip}; - } - else if (command.cmd == Command::Print) - { - // Print the execution graph - gctx.print_execution_graph(output_file); - verbose << "Execution graph written to " << output_file << std::endl; + break; + + case Command::Print: + maybe_print_graph(interp, print_graphs, output_file); command = {Command::Skip}; - } - else if (command.cmd == Command::Skip) - { - // Skip is a no-op - } - else if (command.cmd == Command::Info) - { - std::cout << "Commands:" << std::endl; - std::cout << "s [tid] - Step to next sync point in thread" << std::endl; - std::cout << "[tid] - Step to next sync point in thread" << std::endl; - std::cout << "f - Finish the program" << std::endl; - std::cout << "r - Restart the program" << std::endl; - std::cout << "l - List all threads" << std::endl; - std::cout << "g - Toggle printing the execution graph at sync points" << std::endl; - std::cout << "p - Printing the execution graph at current sync point" << std::endl; - std::cout << "q - Quit the interpreter" << std::endl; - std::cout << "? - Display this help message" << std::endl; + break; + + case Command::Info: + print_help(); command = {Command::Skip}; - } - else if (command.cmd == Command::Quit) - { - // Quit is a no-op - } - } + break; + + case Command::Skip: + // No-op + break; - return 0; + case Command::Quit: + // No-op + break; + } } + + return 0; } + +} // namespace gitmem \ No newline at end of file diff --git a/src/debugger.hh b/src/debugger.hh new file mode 100644 index 0000000..90fbe28 --- /dev/null +++ b/src/debugger.hh @@ -0,0 +1,10 @@ +#pragma once + +#include +#include "sync_protocol.hh" + +namespace gitmem { + int interpret_interactive(const trieste::Node, + const std::filesystem::path &output_file, + std::unique_ptr protocol); +} \ No newline at end of file diff --git a/src/execution_state.cc b/src/execution_state.cc new file mode 100644 index 0000000..46d76fc --- /dev/null +++ b/src/execution_state.cc @@ -0,0 +1,198 @@ +#include + +#include "execution_state.hh" +#include "sync_protocol.hh" + +namespace gitmem { + +ThreadContext::ThreadContext(ThreadID tid, std::unique_ptr& protocol) { + sync = protocol->make_thread_state(tid); +} + +bool ThreadContext::operator==(const ThreadContext &other) const { + if (locals != other.locals) + return false; + + // ignore the graph node, we're not interested in that + + return *sync == *other.sync; +} + +bool Thread::operator==(const Thread &other) const { + return ctx == other.ctx && + block == other.block && + pc == other.pc && + terminated == other.terminated; +} + +GlobalContext::GlobalContext(const trieste::Node &ast, + std::unique_ptr protocol) + : protocol(std::move(protocol)) { + trieste::Node starting_block = ast / lang::File / lang::Block; + + ThreadID main_tid = 0; + + ThreadContext starting_ctx(main_tid, this->protocol); + + this->threads.emplace_back(main_tid, std::move(starting_ctx), starting_block); +} + +GlobalContext::~GlobalContext() = default; + +Lock& GlobalContext::get_lock(std::string lock) { + auto it = locks.find(lock); + if (it != locks.end()) + return it->second; + + auto [new_it, inserted] = locks.emplace( + lock, + Lock{ + .owner = std::nullopt, + .last_unlock_event = nullptr, + .sync = protocol->make_lock_state() + } + ); + + return new_it->second; +} + +// void GlobalContext::print_execution_graph( +// const std::filesystem::path &output_path) const { +// return; // FIXME +// // Loop over the threads and add pending nodes to running threads +// // to indicate a threads next step +// for (const auto &t : threads) { +// assert(t->ctx.tail); +// if (t->terminated || +// dynamic_pointer_cast(t->ctx.tail->next)) +// continue; + +// trieste::Node block = t->block; +// size_t &pc = t->pc; +// trieste::Node stmt = block->at(pc); +// thread_append_node(t->ctx, +// std::string(stmt->location().view())); +// } + +// graph::GraphvizPrinter gv(output_path); +// gv.visit(entry_node.get()); +// } + +bool GlobalContext::operator==(const GlobalContext &other) const { + if (threads.size() != other.threads.size() || + locks.size() != other.locks.size()) + return false; + + // Threads may have been spawned in a different order, so we + // find the thread with the same block in the other context + for (auto &thread : threads) { + auto it = std::find_if(other.threads.begin(), other.threads.end(), + [&thread](auto &t) + { return t.block == thread.block; }); + if (it == other.threads.end() || !(thread == *it)) + return false; + } + + for (auto &[name, lock] : locks) { + if (!other.locks.contains(name)) + return false; + auto &other_lock = other.locks.at(name); + if (lock.owner != other_lock.owner) + return false; + } + return true; +} + +/** Print the state of a thread, including its local and global variables, + * and the current position in the program. */ +std::ostream& operator<<(std::ostream& os, const Thread& thread) { + os << thread.ctx << std::endl; + + size_t idx = 0; + for (const auto &stmt : *(thread.block)) { + if (idx == thread.pc) { + os << "-> "; + } else { + os << " "; + } + + // This should be somewhere else + // Fix indentation of nested blocks + auto s = std::string(stmt->location().view()); + s = std::regex_replace(s, std::regex("\n"), "\n "); + os << s << ";" << std::endl; + + idx++; + } + if (thread.pc == thread.block->size()) { + os << "-> " << std::endl; + } + + return os; +} + +std::ostream& operator<<(std::ostream& os, const ThreadContext& ctx) { + os << "ThreadContext{locals={"; + + bool first = true; + for (const auto& [k, v] : ctx.locals) { + if (!first) os << ", "; + first = false; + os << k << "=" << v; + } + + os << "}, "; //, tail=" << ctx.tail; + + os << *(ctx.sync); + + os << "}"; + return os; +} + +void show_lock(const std::string &lock_name, const struct Lock &lock) { + std::cout << lock_name << ": "; + if (lock.owner) { + std::cout << "held by thread " << *lock.owner; + } else { + std::cout << ""; + } + if (lock.sync) { + std::cout << ", " << *(lock.sync); + } + std::cout << std::endl; +} + +void GlobalContext::print(std::ostream& os, bool show_all) const { + os << *protocol << std::endl; + + bool showed_any = false; + for (size_t i = 0; i < threads.size(); i++) { + auto& thread = threads[i]; + if (show_all || !thread.terminated || + !std::holds_alternative(*thread.terminated)) { + os << "---- Thread " << i << std::endl; + os << threads[i] << std::endl; + os << std::endl; + showed_any = true; + } + } + + if (showed_any && locks.size() > 0) { + os << "---- Locks" << std::endl; + + for (const auto &[lock_name, lock] : locks) { + show_lock(lock_name, lock); + } + + if (locks.size() > 0) + os << "--" << std::endl; + } +} + +std::ostream& operator<<(std::ostream& os, const GlobalContext& gctx) { + gctx.print(os); + return os; +} + + +} // namespace gitmem \ No newline at end of file diff --git a/src/execution_state.hh b/src/execution_state.hh new file mode 100644 index 0000000..bcae76d --- /dev/null +++ b/src/execution_state.hh @@ -0,0 +1,104 @@ +#pragma once + +#include +#include +#include +#include + +#include "lang.hh" +#include "sync_state.hh" +#include "graphviz.hh" +#include "termination_status.hh" +#include "thread_trace.hh" +#include "thread_id.hh" + +namespace gitmem { + +class SyncProtocol; + +struct ThreadContext { + std::unordered_map locals; + + std::unique_ptr sync; + + ThreadContext(const ThreadContext&) = delete; + ThreadContext& operator=(const ThreadContext&) = delete; + + ThreadContext(ThreadContext&&) = default; + ThreadContext& operator=(ThreadContext&&) = default; + + ThreadContext(ThreadID tid, std::unique_ptr&); + + bool operator==(const ThreadContext &other) const; + + friend std::ostream& operator<<(std::ostream&, const ThreadContext&); +}; + +using termination::TerminationStatus; + +struct Thread { + ThreadID tid; + ThreadContext ctx; + ThreadTrace trace; + trieste::Node block; + size_t pc = 0; + std::optional terminated = std::nullopt; + + Thread(ThreadID tid, ThreadContext ctx, trieste::Node block): + tid(tid), ctx(std::move(ctx)), trace(tid), block(block) {}; + + Thread(const Thread&) = delete; + Thread& operator=(const Thread&) = delete; + + Thread(Thread&&) = default; + Thread& operator=(Thread&&) = default; + + bool operator==(const Thread &other) const; + + friend std::ostream& operator<<(std::ostream&, const Thread&); +}; + +struct Lock { + std::optional owner = std::nullopt; + std::shared_ptr last_unlock_event = nullptr; + std::unique_ptr sync; +}; + +struct GlobalContext { + // Execution state + std::deque threads; +private: + std::unordered_map locks; +public: + Lock& get_lock(std::string); + + // AST evaluation cache + lang::NodeMap cache; + + // Graph root + // std::shared_ptr entry_node; + + // Synchronisation semantics (policy) + std::unique_ptr protocol; + + GlobalContext(const trieste::Node &ast, + std::unique_ptr protocol); + ~GlobalContext(); + + GlobalContext clone() const; + + GlobalContext(GlobalContext&&) = default; + GlobalContext& operator=(GlobalContext&&) = default; + + GlobalContext(const GlobalContext&) = delete; + GlobalContext& operator=(const GlobalContext&) = delete; + + bool operator==(const GlobalContext &other) const; + + void print(std::ostream& os, bool show_all = false) const; + friend std::ostream& operator<<(std::ostream&, const GlobalContext&); + + // void print_execution_graph(const std::filesystem::path &output_path) const; +}; + +} // namespace gitmem \ No newline at end of file diff --git a/src/gitmem.cc b/src/gitmem.cc index 5ea111f..b016c76 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -1,101 +1,144 @@ #include -#include "lang.hh" +#include "debug.hh" #include "interpreter.hh" +#include "model_checker.hh" +#include "debugger.hh" +#include "lang.hh" +#include "linear/sync_protocol.hh" +#include "branching/base_sync_protocol.hh" +#include "branching/eager/sync_protocol.hh" +#include "branching/lazy/sync_protocol.hh" + +int main(int argc, char **argv) { + using namespace trieste; + CLI::App app; + + std::filesystem::path input_path; + app.add_option("input", input_path, "Path to the input file ") + ->required() + ->check(CLI::ExistingFile); + + std::filesystem::path output_path = ""; + app.add_option("-o,--output", output_path, "Path to the output file."); + + bool verbose = false; + app.add_flag("-v,--verbose", verbose, + "Enable verbose output from the interpreter."); + + std::string sync_protocol = "linear"; + auto sync_opt = app.add_option("--sync", sync_protocol, "Select a sync protocol for execution (default: linear)") + ->check(CLI::IsMember({"linear", "branching"})) + ->type_name("KIND"); + + std::string branching_mode = "eager"; + auto branching_mode_opt = app.add_option("--branching-mode", branching_mode, "Select branching mode: eager or lazy (default: eager)") + ->check(CLI::IsMember({"eager", "lazy"})) + ->type_name("MODE"); + + bool include_empty_commits = false; + auto include_empty_opt = app.add_flag("--include-empty-commits", include_empty_commits, + "Include empty commits in branching protocol output (branching mode only)."); + + bool raise_early_conflicts = false; + auto raise_early_opt = app.add_flag("--raise-early-conflicts", raise_early_conflicts, + "Raise conflict errors before suppressing writes (lazy branching mode only)."); + + // Set up option dependencies + branching_mode_opt->needs(sync_opt); + include_empty_opt->needs(sync_opt); + raise_early_opt->needs(branching_mode_opt); + + bool interactive = false; + app.add_flag("-i,--interactive", interactive, + "Enable interactive scheduling mode (use command ? for help)."); + + bool model_check = false; + app.add_flag("-e,--explore", model_check, + "Explore all possible execution paths."); + + try { + app.parse(argc, argv); + + // Additional validation for logical consistency + if (sync_protocol == "linear") { + if (*branching_mode_opt) { + std::cerr << "Error: --branching-mode is only valid with --sync branching" << std::endl; + return 1; + } + if (include_empty_commits) { + std::cerr << "Error: --include-empty-commits is only valid with --sync branching" << std::endl; + return 1; + } + if (raise_early_conflicts) { + std::cerr << "Error: --raise-early-conflicts is only valid with --sync branching" << std::endl; + return 1; + } + } -int main(int argc, char **argv) -{ - using namespace trieste; - CLI::App app; - - std::filesystem::path input_path; - app.add_option("input", input_path, "Path to the input file ")->required()->check(CLI::ExistingFile); - - std::filesystem::path output_path = ""; - app.add_option( - "-o,--output", - output_path, - "Path to the output file." - ); - - bool verbose = false; - app.add_flag( - "-v,--verbose", - verbose, - "Enable verbose output from the interpreter." - ); - - // TODO: These should probably be subcommands - bool interactive = false; - app.add_flag( - "-i,--interactive", - interactive, - "Enable interactive scheduling mode (use command ? for help)."); - - bool model_check = false; - app.add_flag( - "-e,--explore", - model_check, - "Explore all possible execution paths."); - - try - { - app.parse(argc, argv); + if (sync_protocol == "branching" && branching_mode == "eager" && raise_early_conflicts) { + std::cerr << "Error: --raise-early-conflicts is only valid with --branching-mode lazy" << std::endl; + return 1; } - catch (const CLI::ParseError &e) - { - return app.exit(e); + } catch (const CLI::ParseError &e) { + return app.exit(e); + } + + try { + gitmem::verbose::out.enabled = verbose; + + gitmem::verbose::out << "Reading file " << input_path << std::endl; + if (!std::filesystem::exists(input_path)) { + std::cerr << "Input file does not exist: " << input_path << std::endl; + return 1; } - try - { - gitmem::verbose.enabled = verbose; - - gitmem::verbose << "Reading file " << input_path << std::endl; - if (!std::filesystem::exists(input_path)) - { - std::cerr << "Input file does not exist: " << input_path << std::endl; - return 1; - } - - auto reader = gitmem::reader().file(input_path); - auto result = reader.read(); - - if (!result.ok) - { - trieste::logging::Error err; - result.print_errors(err); - trieste::logging::Debug() << result.ast; - return 1; - } - - if (output_path.empty()) - output_path = input_path.stem().replace_extension(".dot"); - - gitmem::verbose << "Output will be written to " << output_path << std::endl; - - int exit_status; - wf::push_back(gitmem::wf); - if (model_check) - { - exit_status = gitmem::model_check(result.ast, output_path); - } - else if (interactive) - { - exit_status = gitmem::interpret_interactive(result.ast, output_path); - } - else - { - exit_status = gitmem::interpret(result.ast, output_path); - } - wf::pop_front(); - - gitmem::verbose << "Execution finished with exit status " << exit_status << std::endl; - return exit_status; + auto reader = gitmem::lang::reader().file(input_path); + auto result = reader.read(); + + if (!result.ok) { + trieste::logging::Error err; + result.print_errors(err); + trieste::logging::Debug() << result.ast; + return 1; } - catch (const std::exception &e) - { - std::cerr << "Exception caught: " << e.what() << std::endl; - return 1; + + if (output_path.empty()) + output_path = input_path.stem().replace_extension(".dot"); + + gitmem::verbose::out << "Output will be written to " << output_path << std::endl; + + // Build protocol based on command line options + std::unique_ptr protocol; + if (sync_protocol == "linear") { + protocol = std::make_unique(); + } else if (sync_protocol == "branching") { + if (branching_mode == "eager") { + protocol = std::make_unique(include_empty_commits); + } else { // lazy + protocol = std::make_unique( + include_empty_commits, + raise_early_conflicts + ); + } + } + + int exit_status; + wf::push_back(gitmem::lang::wf); + if (model_check) { + exit_status = gitmem::model_check(result.ast, output_path, std::move(protocol)); + } else if (interactive) { + exit_status = gitmem::interpret_interactive(result.ast, output_path, std::move(protocol)); + } else { + exit_status = gitmem::interpret(result.ast, output_path, std::move(protocol)); } + wf::pop_front(); + + gitmem::verbose::out << "Execution finished with exit status " << exit_status + << std::endl; + return exit_status; + } catch (const std::exception &e) { + std::cerr << "Exception caught: " << e.what() << std::endl; + return 1; + } } diff --git a/src/gitmem_trieste.cc b/src/gitmem_trieste.cc index 19cd4ef..a715915 100644 --- a/src/gitmem_trieste.cc +++ b/src/gitmem_trieste.cc @@ -1,7 +1,6 @@ -#include #include "lang.hh" +#include -int main(int argc, char** argv) -{ - return trieste::Driver(gitmem::reader()).run(argc, argv); +int main(int argc, char **argv) { + return trieste::Driver(gitmem::lang::reader()).run(argc, argv); } diff --git a/src/graph.hh b/src/graph.hh index 3a3c071..d261533 100644 --- a/src/graph.hh +++ b/src/graph.hh @@ -1,181 +1,184 @@ #pragma once +#include #include #include -#include -#include +#include namespace gitmem { - namespace graph { - - struct Visitor; - - struct Node - { - std::shared_ptr next = nullptr; - - virtual void accept(Visitor*) const = 0; - }; - - struct Start; - struct End; - struct Write; - struct Read; - struct Spawn; - struct Join; - struct Lock; - struct Unlock; - struct AssertionFailure; - struct Pending; - - struct Conflict - { - std::string var; - std::pair, std::shared_ptr> sources; - }; - - struct Visitor - { - virtual void visitStart(const Start*) = 0; - virtual void visitEnd(const End*) = 0; - virtual void visitWrite(const Write*) = 0; - virtual void visitRead(const Read*) = 0; - virtual void visitSpawn(const Spawn*) = 0; - virtual void visitJoin(const Join*) = 0; - virtual void visitLock(const Lock*) = 0; - virtual void visitUnlock(const Unlock*) = 0; - virtual void visitAssertionFailure(const AssertionFailure*) = 0; - virtual void visitPending(const Pending*) = 0; - virtual void visit(const Node* n) { n->accept(this); } - }; - - struct Start : Node - { - size_t id; - - Start(size_t id): id(id) {} - - void accept(Visitor* v) const override - { - v->visitStart(this); - } - }; - - struct End : Node - { - End() {} - - void accept(Visitor* v) const override - { - v->visitEnd(this); - } - }; - - struct Write : Node - { - const std::string var; - const size_t value; - const size_t id; - - Write(const std::string var, const size_t value, const size_t id): var(var), value(value), id(id) {} - - void accept(Visitor* v) const override - { - v->visitWrite(this); - } - }; - - struct Read : Node - { - const std::string var; - const size_t value; - const size_t id; - const std::shared_ptr sauce; - - - Read(const std::string var, const size_t value, const size_t id, const std::shared_ptr sauce): var(var), value(value), id(id), sauce(sauce) {} - - void accept(Visitor* v) const override - { - v->visitRead(this); - } - }; - - struct Spawn : Node - { - const size_t tid; - const std::shared_ptr spawned; - - Spawn(const size_t tid, const std::shared_ptr spawned): tid(tid), spawned(spawned) {} - - void accept(Visitor* v) const override - { - v->visitSpawn(this); - } - }; - - struct Join : Node - { - const size_t tid; - const std::shared_ptr joinee; - const std::optional conflict; - - Join(const size_t tid, const std::shared_ptr joinee, std::optional conflict = std::nullopt): tid(tid), joinee(joinee), conflict(conflict) {} - - void accept(Visitor* v) const override - { - v->visitJoin(this); - } - }; - - struct Lock : Node - { - const std::string var; - const std::shared_ptr ordered_after; - const std::optional conflict; - - Lock(const std::string var, const std::shared_ptr ordered_after, std::optional conflict = std::nullopt): var(var), ordered_after(ordered_after), conflict(conflict) {} - - void accept(Visitor* v) const override - { - v->visitLock(this); - } - }; - - struct Unlock : Node - { - const std::string var; - - Unlock(const std::string var): var(var) {} - - void accept(Visitor* v) const override - { - v->visitUnlock(this); - } - }; - - struct AssertionFailure : Node - { - const std::string cond; - - AssertionFailure(const std::string &cond): cond(cond) {} - - void accept(Visitor* v) const override - { - v->visitAssertionFailure(this); - } - }; - - struct Pending : Node - { - const std::string statement; - - Pending(const std::string statement): statement(statement) {} - void accept(Visitor* v) const override - { - v->visitPending(this); - } - }; +namespace graph { + +struct Visitor; + +struct Node { + std::shared_ptr next = nullptr; + + virtual void accept(Visitor *) const = 0; +}; + +struct Start; +struct End; +struct Write; +struct Read; +struct Spawn; +struct Join; +struct Lock; +struct Unlock; +struct Assertion; +struct Pending; + +struct Conflict { + std::string var; + // Optional: if we can determine the conflicting nodes, store them here + // Otherwise these can be nullptr and we just mark the node red + std::pair, std::shared_ptr> sources; + + // Constructor that allows creating conflicts without sources + Conflict(std::string v) : var(std::move(v)), sources{nullptr, nullptr} {} + Conflict(std::string v, std::pair, std::shared_ptr> s) + : var(std::move(v)), sources(std::move(s)) {} +}; + +struct Visitor { + virtual void visitStart(const Start *) = 0; + virtual void visitEnd(const End *) = 0; + virtual void visitWrite(const Write *) = 0; + virtual void visitRead(const Read *) = 0; + virtual void visitSpawn(const Spawn *) = 0; + virtual void visitJoin(const Join *) = 0; + virtual void visitLock(const Lock *) = 0; + virtual void visitUnlock(const Unlock *) = 0; + virtual void visitAssertion(const Assertion *) = 0; + virtual void visitPending(const Pending *) = 0; + virtual void visit(const Node *n) { n->accept(this); } +}; + +struct Start : Node { + size_t id; + + Start(size_t id) : id(id) {} + + void accept(Visitor *v) const override { v->visitStart(this); } +}; + +struct End : Node { + End() {} + + void accept(Visitor *v) const override { v->visitEnd(this); } +}; + +struct Write : Node { + const std::string var; + const size_t value; + const size_t id; + + Write(const std::string var, const size_t value, const size_t id) + : var(var), value(value), id(id) {} + + void accept(Visitor *v) const override { v->visitWrite(this); } +}; + +struct Read : Node { + const std::string var; + const size_t id; + + struct SuccessfulRead { + size_t value; + std::shared_ptr source; + }; + + const std::variant read_result; + + // Constructor for successful read + Read(const std::string var, const size_t value, const size_t id, + const std::shared_ptr source) + : var(var), id(id), read_result(SuccessfulRead{value, source}) {} + + // Constructor for conflicting read + Read(const std::string var, const size_t id, Conflict conflict) + : var(var), id(id), read_result(std::move(conflict)) {} + + void set_source(const std::shared_ptr source) { + if (std::holds_alternative(read_result)) { + auto &sr = std::get(read_result); + const_cast&>(sr.source) = source; + } } -} + + void accept(Visitor *v) const override { v->visitRead(this); } +}; + +struct Spawn : Node { + const size_t tid; + const std::shared_ptr spawned; + + Spawn(const size_t tid, const std::shared_ptr spawned) + : tid(tid), spawned(spawned) {} + + void accept(Visitor *v) const override { v->visitSpawn(this); } +}; + +struct Join : Node { + const size_t tid; + const std::shared_ptr joinee; + const std::optional conflict; + + Join(const size_t tid, const std::shared_ptr joinee, + std::optional conflict = std::nullopt) + : tid(tid), joinee(joinee), conflict(conflict) {} + + void accept(Visitor *v) const override { v->visitJoin(this); } +}; + +struct Lock : Node { + const std::string var; + const std::shared_ptr ordered_after; + const std::optional conflict; + + Lock(const std::string var, const std::shared_ptr ordered_after, + std::optional conflict = std::nullopt) + : var(var), ordered_after(ordered_after), conflict(conflict) {} + + void accept(Visitor *v) const override { v->visitLock(this); } +}; + +struct Unlock : Node { + const std::string var; + + Unlock(const std::string var) : var(var) {} + + void accept(Visitor *v) const override { v->visitUnlock(this); } +}; + +struct Assertion : Node { + const std::string cond; + const bool passed; + + Assertion(const std::string &cond, const bool passed) : cond(cond), passed(passed) {} + + void accept(Visitor *v) const override { v->visitAssertion(this); } +}; + +struct Pending : Node { + const std::string statement; + + Pending(const std::string statement) : statement(statement) {} + void accept(Visitor *v) const override { v->visitPending(this); } +}; + +struct ExecutionGraph { + std::shared_ptr entry; + std::vector> threads; + + ExecutionGraph(std::shared_ptr entry) : entry(entry) {} + + ExecutionGraph(const ExecutionGraph&) = delete; + ExecutionGraph& operator=(const ExecutionGraph&) = delete; + + ExecutionGraph(ExecutionGraph&&) = default; + ExecutionGraph& operator=(ExecutionGraph&&) = default; +}; + +} // namespace graph +} // namespace gitmem diff --git a/src/graphviz.cc b/src/graphviz.cc index 1198ba3..36370fa 100644 --- a/src/graphviz.cc +++ b/src/graphviz.cc @@ -1,152 +1,191 @@ #include "graphviz.hh" +#include "overloaded.hh" #include namespace gitmem { namespace graph { - using std::to_string; - - void GraphvizPrinter::emitNode(const Node* n, const std::string& label, const std::string& style) { - file << "\t" << (size_t)n << "[label=\"" << label << "\", shape=rectangle, style=\"rounded,filled\", "; - if (!style.empty()) file << style; - file << "]" << ";" << std::endl; - } - - void GraphvizPrinter::emitEdge(const Node* from, const Node* to, const std::string& label, const std::string& style) { - if (!from || !to) return; - - file << "\t" << (size_t)from << " -> " << (size_t)to; - if (!style.empty() || !label.empty()) { - file << "["; - if (!style.empty()) file << style; - if (!label.empty()) file << " label=\"" << label << "\""; - file << "]"; - } - file << ";" << std::endl; - } - - void GraphvizPrinter::emitProgramOrderEdge(const Node* from, const Node* to) { - emitEdge(from, to, ""); - } - - void GraphvizPrinter::emitReadFromEdge(const Node* from, const Node* to) { - emitEdge(from, to, "rf", "style=dashed, constraint=false"); - } - - void GraphvizPrinter::emitConflictEdge(const Node* from, const Node* to) { - emitEdge(from, to, "race", "style=dashed, color=red, constraint=false"); - } - - void GraphvizPrinter::emitSyncEdge(const Node* from, const Node* to) { - emitEdge(from, to, "sync", "style=bold, constraint=false"); - } - - void GraphvizPrinter::emitFillColor(const Node* n, const std::string& color) { - file << "\t" << (size_t)n << "[fillcolor = " << color << "];" << std::endl; - } - - void GraphvizPrinter::emitConflict(const Node* n, const Conflict& conflict) { - emitFillColor(n, "red"); - auto [s1, s2] = conflict.sources; +using std::to_string; + +void GraphvizPrinter::emitNode(const Node *n, const std::string &label, + const std::string &style) { + file << "\t" << (size_t)n << "[label=\"" << label + << "\", shape=rectangle, style=\"rounded,filled\", "; + if (!style.empty()) + file << style; + file << "]" << ";" << std::endl; +} + +void GraphvizPrinter::emitEdge(const Node *from, const Node *to, + const std::string &label, + const std::string &style) { + if (!from || !to) + return; + + file << "\t" << (size_t)from << " -> " << (size_t)to; + if (!style.empty() || !label.empty()) { + file << "["; + if (!style.empty()) + file << style; + if (!label.empty()) + file << " label=\"" << label << "\""; + file << "]"; + } + file << ";" << std::endl; +} + +void GraphvizPrinter::emitProgramOrderEdge(const Node *from, const Node *to) { + emitEdge(from, to, ""); +} + +void GraphvizPrinter::emitReadFromEdge(const Node *from, const Node *to) { + emitEdge(from, to, "rf", "style=dashed, constraint=false"); +} + +void GraphvizPrinter::emitConflictEdge(const Node *from, const Node *to) { + emitEdge(from, to, "race", "style=dashed, color=red, constraint=false"); +} + +void GraphvizPrinter::emitSyncEdge(const Node *from, const Node *to) { + emitEdge(from, to, "sync", "style=bold, constraint=false"); +} + +void GraphvizPrinter::emitFillColor(const Node *n, const std::string &color) { + file << "\t" << (size_t)n << "[fillcolor = " << color << "];" << std::endl; +} + +void GraphvizPrinter::emitShape(const Node *n, const std::string &shape) { + file << "\t" << (size_t)n << "[shape = " << shape << "];" << std::endl; +} + +void GraphvizPrinter::emitConflict(const Node *n, const Conflict &conflict) { + emitFillColor(n, "red"); + // emitShape(n, "doubleoctagon"); + + // Only draw conflict edges if we have actual source nodes + auto [s1, s2] = conflict.sources; + if (s1) { emitConflictEdge(n, s1.get()); + } + if (s2) { emitConflictEdge(n, s2.get()); } +} - GraphvizPrinter::GraphvizPrinter(std::string filename) noexcept { - file.open(filename); - } +GraphvizPrinter::GraphvizPrinter(std::string filename) noexcept { + file.open(filename); +} + +void GraphvizPrinter::visit(const Node *n) { + file << "digraph G {" << std::endl; + if (n) n->accept(this); + file << "}" << std::endl; +} - void GraphvizPrinter::visit(const Node* n) { - file << "digraph G {" << std::endl; +void GraphvizPrinter::visitProgramOrder(const Node *n) { + if (n) { n->accept(this); + } else { file << "}" << std::endl; } - - void GraphvizPrinter::visitProgramOrder(const Node* n) { - if(n) - { - n->accept(this); +} + +void GraphvizPrinter::visitStart(const Start *n) { + file << "subgraph cluster_Thread_" << n->id << "{" << std::endl; + file << "\tlabel = \"Thread #" << n->id << "\";" << std::endl; + file << "\tcolor=black;" << std::endl; + emitNode(n, "", "shape=circle width=.3 style=filled color=black"); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); +} + +void GraphvizPrinter::visitEnd(const End *n) { + assert(!n->next); + emitNode(n, "", "shape=doublecircle width=.2 style=empty"); + file << "}" << std::endl; +} + +void GraphvizPrinter::visitWrite(const Write *n) { + emitNode(n, "W" + n->var + " = " + to_string(n->value)); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); +} + +void GraphvizPrinter::visitRead(const Read *n) { + std::string label = "R" + n->var + " = "; + + std::visit(overloaded{ + [&](const Read::SuccessfulRead& success) { + label += to_string(success.value); + }, + [&](const Conflict& conflict) { + label += "conflict"; } - else - { - file << "}" << std::endl; + }, n->read_result); + + emitNode(n, label); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); + + if (auto* conflict = std::get_if(&n->read_result)) { + emitConflict(n, *conflict); + } else { + auto& success = std::get(n->read_result); + if (success.source) { + emitReadFromEdge(n, success.source.get()); } } - - void GraphvizPrinter::visitStart(const Start* n) { - file << "subgraph cluster_Thread_" << n->id << "{" << std::endl; - file << "\tlabel = \"Thread #" << n->id << "\";" << std::endl; - file << "\tcolor=black;" << std::endl; - emitNode(n, "", "shape=circle width=.3 style=filled color=black"); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - } - - void GraphvizPrinter::visitEnd(const End* n) { - assert(!n->next); - emitNode(n, "", "shape=doublecircle width=.2 style=empty"); - file << "}" << std::endl; - } - - void GraphvizPrinter::visitWrite(const Write* n) { - emitNode(n, "W" + n->var + " = " + to_string(n->value)); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - } - - void GraphvizPrinter::visitRead(const Read* n) { - emitNode(n, "R" + n->var + " = " + to_string(n->value)); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - - assert(n->sauce); - emitReadFromEdge(n, n->sauce.get()); - } - - void GraphvizPrinter::visitSpawn(const Spawn* n) { - emitNode(n, "Spawn " + std::to_string(n->tid)); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - if (n->spawned) { - emitSyncEdge(n, n->spawned.get()); - visitProgramOrder(n->spawned.get()); - } - } - - void GraphvizPrinter::visitJoin(const Join* n) { - emitNode(n, "Join " + std::to_string(n->tid)); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - if (n->joinee) emitSyncEdge(n->joinee.get(), n); - if (n->conflict) emitConflict(n, n->conflict.value()); - } - - void GraphvizPrinter::visitLock(const Lock* n) { - emitNode(n, "Lock " + n->var); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - if (n->ordered_after) emitSyncEdge(n->ordered_after.get(), n); - if (n->conflict) emitConflict(n, n->conflict.value()); - } - - void GraphvizPrinter::visitUnlock(const Unlock* n) { - emitNode(n, "Unlock " + n->var); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - } - - void GraphvizPrinter::visitAssertionFailure(const AssertionFailure* n) { - emitNode(n, "Assert " + n->cond); - emitFillColor(n, "red"); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - } - - void GraphvizPrinter::visitPending(const Pending* n) { - assert(!n->next); - emitNode(n, "" + n->statement + "", "style=dashed"); - file << "}" << std::endl; - } +} + +void GraphvizPrinter::visitSpawn(const Spawn *n) { + emitNode(n, "Spawn " + std::to_string(n->tid)); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); + if (n->spawned) { + emitSyncEdge(n, n->spawned.get()); + visitProgramOrder(n->spawned.get()); + } +} + +void GraphvizPrinter::visitJoin(const Join *n) { + emitNode(n, "Join " + std::to_string(n->tid)); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); + if (n->joinee) + emitSyncEdge(n->joinee.get(), n); + if (n->conflict) + emitConflict(n, n->conflict.value()); +} + +void GraphvizPrinter::visitLock(const Lock *n) { + emitNode(n, "Lock " + n->var); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); + if (n->ordered_after) + emitSyncEdge(n->ordered_after.get(), n); + if (n->conflict) + emitConflict(n, n->conflict.value()); +} + +void GraphvizPrinter::visitUnlock(const Unlock *n) { + emitNode(n, "Unlock " + n->var); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); +} + +void GraphvizPrinter::visitAssertion(const Assertion *n) { + emitNode(n, "Assert " + n->cond); + if (!n->passed) emitFillColor(n, "red"); + // emitShape(n, "doubleoctagon"); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); +} + +void GraphvizPrinter::visitPending(const Pending *n) { + assert(!n->next); + emitNode(n, "" + n->statement + "", "style=dashed"); + file << "}" << std::endl; +} } // namespace graph } // namespace gitmem diff --git a/src/graphviz.hh b/src/graphviz.hh index 263bad0..c3e294b 100644 --- a/src/graphviz.hh +++ b/src/graphviz.hh @@ -2,32 +2,36 @@ #include "graph.hh" namespace gitmem { - namespace graph { - struct GraphvizPrinter : Visitor { - void visitStart(const Start*) override; - void visitEnd(const End*) override; - void visitWrite(const Write*) override; - void visitRead(const Read*) override; - void visitSpawn(const Spawn*) override; - void visitJoin(const Join*) override; - void visitLock(const Lock*) override; - void visitUnlock(const Unlock*) override; - void visitAssertionFailure(const AssertionFailure*) override; - void visitPending(const Pending*) override; - void visit(const Node* n) override; +namespace graph { +struct GraphvizPrinter : Visitor { + void visitStart(const Start *) override; + void visitEnd(const End *) override; + void visitWrite(const Write *) override; + void visitRead(const Read *) override; + void visitSpawn(const Spawn *) override; + void visitJoin(const Join *) override; + void visitLock(const Lock *) override; + void visitUnlock(const Unlock *) override; + void visitAssertion(const Assertion *) override; + void visitPending(const Pending *) override; + void visit(const Node *n) override; - GraphvizPrinter(std::string filename) noexcept; - private: - std::ofstream file; - void emitNode(const Node* n, const std::string& label, const std::string& style = ""); - void emitEdge(const Node* from, const Node* to, const std::string& label, const std::string& style = ""); - void emitProgramOrderEdge(const Node* from, const Node* to); - void emitReadFromEdge(const Node* from, const Node* to); - void emitFillColor(const Node* n, const std::string& color); - void emitConflictEdge(const Node* from, const Node* to); - void emitSyncEdge(const Node* from, const Node* to); - void emitConflict(const Node* n, const Conflict& conflict); - void visitProgramOrder(const Node* n); - }; - } -} + GraphvizPrinter(std::string filename) noexcept; + +private: + std::ofstream file; + void emitNode(const Node *n, const std::string &label, + const std::string &style = ""); + void emitEdge(const Node *from, const Node *to, const std::string &label, + const std::string &style = ""); + void emitProgramOrderEdge(const Node *from, const Node *to); + void emitReadFromEdge(const Node *from, const Node *to); + void emitFillColor(const Node *n, const std::string &color); + void emitShape(const Node *n, const std::string &shape); + void emitConflictEdge(const Node *from, const Node *to); + void emitSyncEdge(const Node *from, const Node *to); + void emitConflict(const Node *n, const Conflict &conflict); + void visitProgramOrder(const Node *n); +}; +} // namespace graph +} // namespace gitmem diff --git a/src/internal.hh b/src/internal.hh index 4687161..c258ab7 100644 --- a/src/internal.hh +++ b/src/internal.hh @@ -1,23 +1,25 @@ #pragma once #include "lang.hh" -namespace gitmem -{ - using namespace trieste; +namespace gitmem { - Parse parser(); - PassDef expressions(); - PassDef statements(); - PassDef check_refs(); - PassDef branching(); +namespace lang { - inline const auto parse_token = - Reg | Var | Const | Nop | Brace | Paren | - Spawn | Join | Lock | Unlock | Assert | If | Else; +using namespace trieste; - inline const auto parse_op = Group | Assign | Eq | Neq | Add | Semi; +Parse parser(); +PassDef expressions(); +PassDef statements(); +PassDef check_refs(); +PassDef branching(); - // clang-format off +inline const auto parse_token = Reg | Var | Const | Nop | Brace | Paren | + Spawn | Join | Lock | Unlock | Assert | If | + Else; + +inline const auto parse_op = Group | Assign | Eq | Neq | Add | Semi; + +// clang-format off inline const wf::Wellformed parser_wf = (Top <<= File) | (File <<= ~parse_op) @@ -82,5 +84,8 @@ namespace gitmem | (Jump <<= Const) | (Cond <<= Expr * Const) ; - // clang-format on -} +// clang-format on + +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/interpreter.cc b/src/interpreter.cc index 37c63ad..30f25eb 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -1,612 +1,715 @@ +#include #include #include -#include +#include +#include "debug.hh" #include "interpreter.hh" -#include "graphviz.hh" - -namespace gitmem -{ - using namespace trieste; - - /* Interpreter for a gitmem program. Threads can read and write local - * variables as well as versioned global variables. Globals are not stored - * in a single memory location but instead in the state of 'synchronising - * objects' which include threads and locks. Synchronising actions between - * threads, and between threads and locks, synchronise the versioned memory - * and if both objects see updates to the same versioned global variable - * then a data race is detected. These synchronising actions include: - * - thread t1 joining a thread t2, which waits for t2 to complete before - * trying to 'pull' the new data into t1 - * - t locking a lock l, which waits for the lock l to be available before - * trying to 'pull' the new data into t - * - t unlocking a lock l, which updates l to have t's versioned memory - */ - - bool is_syncing(Node stmt) - { - auto s = stmt / Stmt; - return s == Join || s == Lock || s == Unlock; - } +#include "sync_protocol.hh" +#include "overloaded.hh" + +namespace gitmem { + +using namespace trieste; + +/* Interpreter for a gitmem program. Threads can read and write local + * variables as well as versioned global variables. Globals are not stored + * in a single memory location but instead in the state of 'synchronising + * objects' which include threads and locks. Synchronising actions between + * threads, and between threads and locks, synchronise the versioned memory + * and if both objects see updates to the same versioned global variable + * then a data race is detected. These synchronising actions include: + * - thread t1 joining a thread t2, which waits for t2 to complete before + * trying to 'pull' the new data into t1 + * - t locking a lock l, which waits for the lock l to be available before + * trying to 'pull' the new data into t + * - t unlocking a lock l, which updates l to have t's versioned memory + */ + +// Map AST node types to sync operations +static std::optional get_sync_operation(Node stmt) { + auto s = stmt / lang::Stmt; + if (s == lang::Join) return SyncOperation::Join; + if (s == lang::Lock) return SyncOperation::Lock; + if (s == lang::Unlock) return SyncOperation::Unlock; + + // Spawn is an expression, not a statement, but we check for assignment of spawn + if (s == lang::Assign) { + // A little gross but okay for now + auto rhs = s / lang::Expr / lang::Expr; + if (rhs == lang::Spawn) return SyncOperation::Spawn; + } + + return std::nullopt; +} + +// Check if a statement is a scheduling point according to the protocol +static bool is_syncing(const SyncProtocol& protocol, Node stmt) { + if (auto op = get_sync_operation(stmt)) { + return protocol.is_scheduling_point(*op); + } + return false; +} + +static bool is_syncing(const SyncProtocol& protocol, Thread &thread) { + // Can only be true if a thread hasn't terminated + // Either it has executed all statements but not yet terminated (and may sync) + // Or it is at a synchronisation node + // The lazy eval here is important + return !thread.terminated && + ((thread.pc >= thread.block->size()) || is_syncing(protocol, thread.block->at(thread.pc))); +} - bool is_syncing(Thread &thread) - { - return !thread.terminated && is_syncing(thread.block->at(thread.pc)); +/* Evaluating an expression either returns the result of the expression or + * a the exceptional termination status of the thread. + */ +std::variant +Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { + ThreadContext& ctx = thread.ctx; + + auto e = expr / lang::Expr; + if (e == lang::Reg) { + // It is invalid to read a previously unwritten value + auto var = std::string(expr->location().view()); + if (ctx.locals.contains(var)) { + return ctx.locals[var]; + } else { + return termination::UnassignedRead(var); + } + } else if (e == lang::Var) { + auto var = std::string(expr->location().view()); + + auto result = gctx.protocol->read(ctx, var); + + return std::visit(overloaded{ + [&](std::monostate) -> std::variant { + // invalid: reading a variable that hasn't been written + return termination::UnassignedRead(var); + }, + [&](ValueWithSource value_with_source) -> std::variant { + // normal read + thread.trace.on_read(var, value_with_source); + return value_with_source.value; + }, + [&](std::shared_ptr& conflict) -> std::variant { + verbose::out << (*conflict) << std::endl; + thread.trace.on_read(var, conflict); + return termination::DataRace(conflict); + } + }, result); + } else if (e == lang::Const) { + return size_t(std::stoi(std::string(e->location().view()))); + } else if (e == lang::Add) { + size_t sum = 0; + for (auto &child : *e) { + auto result = evaluate_expression(child, thread); + if (std::holds_alternative(result)) + return result; + sum += std::get(result); + } + return sum; + } else if (e == lang::Spawn) { + ThreadID child_tid = gctx.threads.size(); + ThreadContext child_ctx(child_tid, gctx.protocol); + + if (std::optional> conflict = + gctx.protocol->on_spawn(ctx, child_ctx)) { + throw std::logic_error("This code path should never be reached"); } - /* At a commit point, walk through all the versioned variables and see if - * they have a pending commit, if so commit the value by appending to - * the variables history. - */ - void commit(Globals &globals) { - for (auto& [var, global] : globals) { - if (global.commit) - { - global.history.push_back(*global.commit); - verbose << "Committed global '" << var << "' with id " << *global.commit << std::endl; - global.commit.reset(); - } - } + gctx.threads.emplace_back(child_tid, std::move(child_ctx), e / lang::Block); + thread.trace.on_spawn(child_tid); + return child_tid; + } else if (e == lang::Eq || e == lang::Neq) { + auto lhs = e / lang::Lhs; + auto rhs = e / lang::Rhs; + + auto lhsEval = evaluate_expression(lhs, thread); + if (std::holds_alternative(lhsEval)) + return lhsEval; + + auto rhsEval = evaluate_expression(rhs, thread); + if (std::holds_alternative(rhsEval)) + return rhsEval; + + return e == lang::Eq + ? (std::get(lhsEval)) == (std::get(rhsEval)) + : (std::get(lhsEval)) != (std::get(rhsEval)); + } else { + throw std::runtime_error("Unknown expression: " + + std::string(expr->type().str())); + } +} + +/* Evaluating a statement either returns the resulting change of the program + * counter (0 if waiting for some other thread) or the exceptional + * termination status of the thread. + */ +std::variant Interpreter::run_statement(Node stmt, Thread& thread) { + ThreadContext& ctx = thread.ctx; + + auto s = stmt / lang::Stmt; + if (s == lang::Nop) { + + verbose::out << "Nop" << std::endl; + + } else if (s == lang::Jump) { + + auto cnst = s / lang::Const; + auto delta = std::stoi(std::string(cnst->location().view())); + assert(delta > 0); + return delta; + + } else if (s == lang::Cond) { + + auto expr = s / lang::Expr; + auto cnst = s / lang::Const; + auto result = evaluate_expression(expr, thread); + + if (auto b = std::get_if(&result)) { + auto delta = std::stoi(std::string(cnst->location().view())); + assert(delta > 0); + return *b ? 1 : delta; + } else { + return std::get(result); } + } else if (s == lang::Assign) { - /* A versioned value can be fastforwarded to another version, if one - * version's history is a prefix of another version's history. - * A conflict between two commit histories exists if neither history is a - * prefix of the other. - */ - std::optional> has_conflict(CommitHistory& h1, CommitHistory& h2) - { - size_t length = std::min(h1.size(), h2.size()); + auto lhs = s / lang::LVal; + auto var = std::string(lhs->location().view()); + auto rhs = s / lang::Expr; + auto val_or_term = evaluate_expression(rhs, thread); - for (size_t i = 0; i < length; i++) - { - if (h1[i] != h2[i]) return std::pair{h1[i], h2[i]}; - } + if (size_t *val = std::get_if(&val_or_term)) { + if (lhs == lang::Reg) { + + // Local variables can be re-assigned whenever + verbose::out << "Set register '" << lhs->location().view() << "' to " << *val + << std::endl; + ctx.locals[var] = *val; + + } else if (lhs == lang::Var) { - return std::nullopt; + auto write_event = thread.trace.on_write(var, *val); + gctx.protocol->write(ctx, var + , ValueWithSource{*val, write_event}); + } else { + throw std::runtime_error("Bad left-hand side: " + + std::string(lhs->type().str())); + } + } else { + return std::get(val_or_term); + } + } else if (s == lang::Join) { + // A join must waiting for the terminating thread to continue, + // we don't want to re-evaluate the expression repeatedly as this + // may be effecting so store the result in the cache. + auto expr = s / lang::Expr; + + if (!gctx.cache.contains(expr)) { + auto val_or_term = evaluate_expression(expr, thread); + if (size_t *val = std::get_if(&val_or_term)) { + gctx.cache[expr] = *val; + } else { + return std::get(val_or_term); + } } - struct Conflict - { - std::string var; - std::pair commits; - }; - - /* Walk through all the global versions from source and update the versions - * in destination to be the most up-to-date version (this could come from - * either source or destination). This means destination will now also - * include variables it previously did not know about. - */ - std::optional pull(Globals &dst, Globals &src) { - for (auto& [var, global] : src) { - if (dst.contains(var)) - { - auto& src_var = src[var]; - auto& dst_var = dst[var]; - if (auto conflict = has_conflict(src_var.history, dst_var.history)) - { - auto [s1, s2] = *conflict; - verbose << "A data race on '" << var << "' was detected from commits " << s1 << " and " << s2 << std::endl; - return Conflict(var, *conflict); - } - else if (src_var.history.size() > dst_var.history.size()) - { - verbose << "Fast-forward '" << var << "' to id " << src_var.val << std::endl; - dst_var.val = src_var.val; - dst_var.history = src_var.history; - } - } - else - { - dst[var].val = src[var].val; - dst[var].history = src[var].history; - } - } - return std::nullopt; + auto result = gctx.cache[expr]; + // Check if the thread ID is valid + if (result >= gctx.threads.size()) { + verbose::out << "Join: invalid thread ID " << result + << ". gctx.threads.size()=" << gctx.threads.size() << std::endl; + return termination::UnassignedRead(std::to_string(result)); } - template - std::shared_ptr thread_append_node(ThreadContext& ctx, Args&&...args) - { - assert(ctx.tail); - auto node = std::make_shared(std::forward(args)...); - ctx.tail->next = node; - ctx.tail = node; - return node; + auto &joinee = gctx.threads[result]; + if (joinee.terminated && + std::holds_alternative(*joinee.terminated)) { + if (auto conflict = gctx.protocol->on_join(ctx, joinee.ctx)) { + verbose::out << (**conflict) << std::endl; + thread.trace.on_join(result, *conflict); + return termination::DataRace(*conflict); + } else { + thread.trace.on_join(result); + } + + } else { + verbose::out << "Waiting on thread " << result << std::endl; + return 0; + } + } else if (s == lang::Lock) { + // We can only lock unlocked locks, if a lock hasn't been used + // before it is implicitly created, we then commit the pending + // updates of this thread and pull the updates from the lock. + auto v = s / lang::Var; + auto var = std::string(v->location().view()); + + Lock& lock = gctx.get_lock(var); + if (lock.owner) { + verbose::out << "Waiting for lock " << var << " owned by " + << lock.owner.value() << std::endl; + return 0; } - template<> - std::shared_ptr thread_append_node(ThreadContext& ctx, std::string&& stmt) - { - // pending nodes don't update the tail position as we will destroy them - // once we execute the node - auto s = std::regex_replace(stmt, std::regex("\n"), "\\l "); - auto node = make_shared(std::move(s)); - ctx.tail->next = node; - return node; + lock.owner = thread.tid; + + if (auto conflict = gctx.protocol->on_lock(ctx, lock)) { + verbose::out << (**conflict) << std::endl; + thread.trace.on_lock(var, lock.last_unlock_event, *conflict); + return termination::DataRace(*conflict); } - /* Evaluating an expression either returns the result of the expression or - * a the exceptional termination status of the thread. - */ - std::variant evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) - { - auto e = expr / Expr; - if (e == Reg) - { - // It is invalid to read a previously unwritten value - auto var = std::string(expr->location().view()); - if (ctx.locals.contains(var)) - { - return ctx.locals[var]; - } - else - { - return TerminationStatus::unassigned_variable_read_exception; - } - } - else if (e == Var) - { - // It is invalid to read a previously unwritten value - auto var = std::string(expr->location().view()); - if (ctx.globals.contains(var)) - { - auto& global = ctx.globals[var]; - auto commit = global.commit.value_or(global.history.back()); - auto source_node = gctx.commit_map[commit]; - thread_append_node(ctx, var, global.val, commit, source_node); - return global.val; - } - else - { - return TerminationStatus::unassigned_variable_read_exception; - } - } - else if (e == Const) - { - return size_t(std::stoi(std::string(e->location().view()))); - } - else if (e == Add) - { - size_t sum = 0; - for (auto &child : *e) - { - auto result = evaluate_expression(child, gctx, ctx); - if (std::holds_alternative(result)) return result; - sum += std::get(result); - } - return sum; - } - else if (e == Spawn) - { - // Spawning is a sync point, commit local pending commits, and - // copy the global state to the spawned thread - commit(ctx.globals); - ThreadID tid = gctx.threads.size(); - auto node = std::make_shared(tid); + thread.trace.on_lock(var, lock.last_unlock_event); + verbose::out << "Locked " << var << std::endl; - ThreadContext new_ctx = { Locals(), ctx.globals, node }; - gctx.threads.push_back(std::make_shared(new_ctx, e / Block)); + } else if (s == lang::Unlock) { + // We can only unlock locks we previously locked. We commit any + // pending updates and then copy the threads versioned globals + // to the locks versioned globals (nobody could have changed + // them since we locked the lock). - thread_append_node(ctx, tid, node); + // commit(ctx.globals); + auto v = s / lang::Var; + auto var = std::string(v->location().view()); - return tid; - } - else if (e == Eq || e == Neq) - { - auto lhs = e / Lhs; - auto rhs = e / Rhs; + Lock& lock = gctx.get_lock(var); + if (!lock.owner || (lock.owner && *lock.owner != thread.tid)) { + return termination::UnlockError(var); + } - auto lhsEval = evaluate_expression(lhs, gctx, ctx); - if (std::holds_alternative(lhsEval)) return lhsEval; + if (auto conflict = gctx.protocol->on_unlock(ctx, lock)) { + verbose::out << (**conflict) << std::endl; + thread.trace.on_unlock(var, *conflict); + return termination::DataRace(*conflict); + } - auto rhsEval = evaluate_expression(rhs, gctx, ctx); - if (std::holds_alternative(rhsEval)) return rhsEval; + // lock.globals = ctx.globals; + lock.owner.reset(); - return e == Eq? (std::get(lhsEval)) == (std::get(rhsEval)) - : (std::get(lhsEval)) != (std::get(rhsEval)); - } - else - { - throw std::runtime_error("Unknown expression: " + std::string(expr->type().str())); - } + lock.last_unlock_event = thread.trace.on_unlock(var); + + verbose::out << "Unlocked " << var << std::endl; + + } else if (s == lang::Assert) { + + auto expr = s / lang::Expr; + auto result_or_term = evaluate_expression(expr, thread); + if (size_t *result = std::get_if(&result_or_term)) { + if (*result) { + verbose::out << "Assertion passed: " << expr->location().view() << std::endl; + thread.trace.on_assert_pass(std::string(expr->location().view())); + } else { + verbose::out << "Assertion failed: " << expr->location().view() << std::endl; + thread.trace.on_assert_fail(std::string(expr->location().view())); + return termination::AssertionFailure(std::string(expr->location().view())); + } + } else { + return std::get(result_or_term); } - /* Evaluating a statement either returns the resulting change of the program - * counter (0 if waiting for some other thread) or the exceptional - * termination status of the thread. - */ - std::variant run_statement(Node stmt, GlobalContext &gctx, ThreadContext &ctx, const ThreadID& tid) - { - auto s = stmt / Stmt; - if (s == Nop) - { - verbose << "Nop" << std::endl; - } - else if (s == Jump) - { - auto cnst = s / Const; - auto delta = std::stoi(std::string(cnst->location().view())); - assert(delta > 0); - return delta; - } - else if (s == Cond) - { - auto expr = s / Expr; - auto cnst = s / Const; - auto result = evaluate_expression(expr, gctx, ctx); - - if (auto b = std::get_if(&result)) - { - auto delta = std::stoi(std::string(cnst->location().view())); - assert(delta > 0); - return *b? 1 : delta; - } - else - { - return std::get(result); - } - } - else if (s == Assign) - { - auto lhs = s / LVal; - auto var = std::string(lhs->location().view()); - auto rhs = s / Expr; - auto val_or_term = evaluate_expression(rhs, gctx, ctx); - if(size_t* val = std::get_if(&val_or_term)) - { - if (lhs == Reg) - { - // Local variables can be re-assigned whenever - verbose << "Set register '" << lhs->location().view() << "' to " << *val << std::endl; - ctx.locals[var] = *val; - } - else if (lhs == Var) - { - // Global variable writes need to create a new commit id - // to track the history of updates - auto &global = ctx.globals[var]; - global.val = *val; - global.commit = gctx.uuid++; - verbose << "Set global '" << lhs->location().view() << "' to " << *val << " with id " << *(global.commit) << std::endl; - - auto node = thread_append_node(ctx, var, global.val, *global.commit); - gctx.commit_map[*(global.commit)] = node; - } - else - { - throw std::runtime_error("Bad left-hand side: " + std::string(lhs->type().str())); - } - } - else - { - return std::get(val_or_term); - } - } - else if (s == Join) - { - // A join must waiting for the terminating thread to continue, - // we don't want to re-evaluate the expression repeatedly as this - // may be effecting so store the result in the cache. - auto expr = s / Expr; - - if (!gctx.cache.contains(expr)) - { - auto val_or_term = evaluate_expression(expr, gctx, ctx); - if (size_t* val = std::get_if(&val_or_term)) - { - gctx.cache[expr] = *val; - } - else - { - return std::get(val_or_term); - } - } + } else { + throw std::runtime_error("Unknown statement: " + + std::string(stmt->type().str())); + } + return 1; +} - // when joining, we commit the updates of both threads (the joined - // thread will not necessarily have commited them), we then - // pull the updates into the joining thread. - auto result = gctx.cache[expr]; - auto& thread = gctx.threads[result]; - if (thread->terminated && (*thread->terminated == TerminationStatus::completed)) - { - commit(ctx.globals); - commit(thread->ctx.globals); - verbose << "Pulling from thread " << result << std::endl; - if(auto conflict = pull(ctx.globals, thread->ctx.globals)) - { - using graph::Node; - auto [s1, s2] = conflict->commits; - auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; - auto graph_conflict = graph::Conflict(conflict->var, sources); - thread_append_node(ctx, result, thread->ctx.tail, graph_conflict); - return TerminationStatus::datarace_exception; - } - - thread_append_node(ctx, result, thread->ctx.tail); - } - else - { - verbose << "Waiting on thread " << result << std::endl; - return 0; - } - } - else if (s == Lock) - { - // We can only lock unlocked locks, if a lock hasn't been used - // before it is implicitly created, we then commit the pending - // updates of this thread and pull the updates from the lock. - auto v = s / Var; - auto var = std::string(v->location().view()); - - auto& lock = gctx.locks[var]; - if (lock.owner) { - verbose << "Waiting for lock " << var << " owned by " << lock.owner.value() << std::endl; - return 0; - } +/* Run a particular thread until it reaches a synchronisation point or until + * it terminates. Report whether the thread was able to progress or not, or + * whether it terminated. + */ +std::variant +Interpreter::run_single_thread_to_sync(Thread& thread) { + if (thread.terminated) + return *thread.terminated; - lock.owner = tid; - commit(ctx.globals); - if(auto conflict = pull(ctx.globals, lock.globals)) - { - using graph::Node; - auto [s1, s2] = conflict->commits; - auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; - auto graph_conflict = graph::Conflict(conflict->var, sources); - thread_append_node(ctx, var, lock.last, graph_conflict); - return TerminationStatus::datarace_exception; - } + auto& ctx = thread.ctx; + auto& pc = thread.pc; + Node block = thread.block; - thread_append_node(ctx, var, lock.last); + // Initial sync when thread starts executing + if (pc == 0) { + gctx.protocol->on_start(ctx); + thread.trace.on_start(); + } - verbose << "Locked " << var << std::endl; + bool made_progress = false; - } - else if (s == Unlock) - { - // We can only unlock locks we previously locked. We commit any - // pending updates and then copy the threads versioned globals - // to the locks versioned globals (nobody could have changed - // them since we locked the lock). - commit(ctx.globals); - auto v = s / Var; - auto var = std::string(v->location().view()); - - auto& lock = gctx.locks[var]; - if (!lock.owner || (lock.owner && *lock.owner != tid)) - { - return TerminationStatus::unlock_exception; - } + while (pc < block->size()) { + Node stmt = block->at(pc); - lock.globals = ctx.globals; - lock.owner.reset(); + // Stop *before* executing a sync statement (except first) + if (made_progress && is_syncing(*gctx.protocol, stmt)) + return ProgressStatus::progress; - thread_append_node(ctx, var); - lock.last = ctx.tail; + auto result = run_statement(stmt, thread); - verbose << "Unlocked " << var << std::endl; - } - else if (s == Assert) - { - auto expr = s / Expr; - auto result_or_term = evaluate_expression(expr, gctx, ctx); - if (size_t* result = std::get_if(&result_or_term)) - { - if (*result) - { - verbose << "Assertion passed: " << expr->location().view() << std::endl; - } - else - { - verbose << "Assertion failed: " << expr->location().view() << std::endl; - thread_append_node(ctx, std::string(expr->location().view())); - return TerminationStatus::assertion_failure_exception; - } - } - else - { - return std::get(result_or_term); - } - } - else - { - throw std::runtime_error("Unknown statement: " + std::string(stmt->type().str())); - } - return 1; + if (auto term = std::get_if(&result)) { + thread.terminated = *term; + return *term; } - /* Run a particular thread until it reaches a synchronisation point or until - * it terminates. Report whether the thread was able to progress or not, or - * whether it terminated. - */ - std::variant run_single_thread_to_sync(GlobalContext& gctx, const ThreadID tid, std::shared_ptr thread) - { - if (thread->terminated) { - return *(thread->terminated); - } - Node block = thread->block; - size_t &pc = thread->pc; - ThreadContext &ctx = thread->ctx; - - bool first_statement = true; - while(pc < block->size()) - { - Node stmt = block->at(pc); - - if (!first_statement && is_syncing(stmt)) - { - return ProgressStatus::progress; - } + int delta = std::get(result); - auto delta_or_term = run_statement(stmt, gctx, ctx, tid); - if (auto term = std::get_if(&delta_or_term)) - { - thread->terminated = *term; - thread_append_node(ctx); - return *term; - } + // Blocked (e.g. waiting on lock/join) + if (delta == 0) + return made_progress ? ProgressStatus::progress + : ProgressStatus::no_progress; - auto delta = std::get(delta_or_term); + pc += delta; + made_progress = true; + } - if(delta == 0) - { - return first_statement ? ProgressStatus::no_progress : ProgressStatus::progress; - } + // If we ran *any* statements, finishing is a sync point for next iteration + if (made_progress && gctx.protocol->is_scheduling_point(SyncOperation::End)) + return ProgressStatus::progress; - pc += delta; - first_statement = false; - } + // Otherwise, we truly reached the end this iteration + if (auto conflict = gctx.protocol->on_end(ctx)) { + verbose::out << (**conflict) << std::endl; + TerminationStatus term = termination::DataRace(*conflict); + thread.terminated = term; + return term; + } + + thread.terminated = termination::Completed(); + thread.trace.on_end(); + return termination::Completed(); +} - thread->terminated = TerminationStatus::completed; - thread_append_node(ctx); - return TerminationStatus::completed; +/** + * Run a thread to the next sync point, including any threads spawned by that + * thread + */ +std::variant +Interpreter::progress_thread(Thread& thread) { + auto no_threads = gctx.threads.size(); + auto prog_or_term = run_single_thread_to_sync(thread); + + bool any_progress = + std::holds_alternative(prog_or_term) && + std::get(prog_or_term) == ProgressStatus::progress; + + for (size_t i = no_threads; i < gctx.threads.size(); ++i) { + // If there are new threads, we can run them to sync as well + any_progress = true; + auto& new_thread = gctx.threads[i]; + if (!is_syncing(*gctx.protocol, new_thread)) { + verbose::out << "==== Thread " << i << " (spawn) ====" << std::endl; + progress_thread(new_thread); } + } - /** - * Run a thread to the next sync point, including any threads spawned by that thread - */ - std::variant - progress_thread(GlobalContext &gctx, const ThreadID tid, std::shared_ptr thread) - { - auto no_threads = gctx.threads.size(); - auto prog_or_term = run_single_thread_to_sync(gctx, tid, thread); - - bool any_progress = std::holds_alternative(prog_or_term) && - std::get(prog_or_term) == ProgressStatus::progress; - for (size_t i = no_threads; i < gctx.threads.size(); ++i) - { - // If there are new threads, we can run them to sync as well - any_progress = true; - auto new_thread = gctx.threads[i]; - if (!is_syncing(*new_thread)) - { - verbose << "==== Thread " << i << " (spawn) ====" << std::endl; - progress_thread(gctx, i, new_thread); - } - } + if (std::holds_alternative(prog_or_term)) + return prog_or_term; - if (std::holds_alternative(prog_or_term)) - return prog_or_term; + return any_progress ? ProgressStatus::progress : ProgressStatus::no_progress; +} - return any_progress ? ProgressStatus::progress : ProgressStatus::no_progress; +/* Try to evaluate all threads until a sync point or termination point + */ +std::variant +Interpreter::run_threads_to_sync() { + verbose::out << "-----------------------" << std::endl; + bool all_completed = true; + ProgressStatus any_progress = ProgressStatus::no_progress; + for (size_t i = 0; i < gctx.threads.size(); ++i) { + verbose::out << "==== t" << i << " ====" << std::endl; + auto& thread = gctx.threads[i]; + if (!thread.terminated) { + auto prog_or_term = run_single_thread_to_sync(thread); + if (ProgressStatus *prog = std::get_if(&prog_or_term)) { + any_progress |= *prog; + } else { + // We could return termination status of any error here and stop + // at the first error + thread.terminated = std::get(prog_or_term); + any_progress |= ProgressStatus::progress; + } + + all_completed &= thread.terminated.has_value(); + // if a thread spawns a new thread, it will end up at the end so + // we will always include the new threads in the termination + // criteria } + } - /* Try to evaluate all threads until a sync point or termination point - */ - std::variant run_threads_to_sync(GlobalContext& gctx) - { - verbose << "-----------------------" << std::endl; - bool all_completed = true; - ProgressStatus any_progress = ProgressStatus::no_progress; - for (size_t i = 0; i < gctx.threads.size(); ++i) - { - verbose << "==== t" << i << " ====" << std::endl; - auto thread = gctx.threads[i]; - if (!thread->terminated) - { - auto prog_or_term = run_single_thread_to_sync(gctx, i, thread); - if (ProgressStatus* prog = std::get_if(&prog_or_term)) - { - any_progress |= *prog; - } - else - { - // We could return termination status of any error here and stop - // at the first error - thread->terminated = std::get(prog_or_term); - any_progress |= ProgressStatus::progress; - } - - all_completed &= thread->terminated.has_value(); - // if a thread spawns a new thread, it will end up at the end so - // we will always include the new threads in the termination - // criteria - } - } + if (all_completed) + return termination::Completed(); + + return any_progress; +} - if (all_completed) return TerminationStatus::completed; +static bool is_finished(const StepResult& r) { + // Either, the system is stuck and made no progress in which case there + // is a deadlock (or a thread is stuck waiting for a crashed thread?) + // Or, there was some termination criteria in which case we stop + return is_terminated(r) || + std::get(r) == ProgressStatus::no_progress; +} - return any_progress; +/* Try to evaluate all threads until they have all terminated in some way + * or we have reached a stuck configuration. + */ +int Interpreter::run() { + std::variant prog_or_term; + do { + prog_or_term = run_threads_to_sync(); + } while (!is_finished(prog_or_term)); + + verbose::out << "----------- execution complete -----------" << std::endl; + + bool exception_detected = false; + for (size_t i = 0; i < gctx.threads.size(); ++i) { + auto &thread = gctx.threads[i]; + + if (thread.terminated) { + verbose::out << "Thread " << i << ": "; + + std::visit( + overloaded{ + [&](const termination::Completed &t) { + verbose::out << t << std::endl; + }, + + [&](const auto &t) { + // Any non-completed termination is exceptional + verbose::out << t << std::endl; + exception_detected = true; + } + }, + *thread.terminated + ); + } else { + exception_detected = true; + thread.trace.on_end(); + verbose::out << "Thread " << i << " is stuck" << std::endl; } + } + + verbose::out << "------------------------------------------" << std::endl; + + print_thread_traces(); + + return exception_detected ? 1 : 0; +} + +void Interpreter::print_state(std::ostream& os, bool show_all) const { + gctx.print(os, show_all); +} - bool is_finished(std::variant& prog_or_term) - { - // Either, the system is stuck and made no progress in which case there - // is a deadlock (or a thread is stuck waiting for a crashed thread?) - if (ProgressStatus* prog = std::get_if(&prog_or_term)) - return (*prog) == ProgressStatus::no_progress; +void Interpreter::print_thread_traces() { + for (size_t tid = 0; tid < gctx.threads.size(); ++tid) { + const auto& thread = gctx.threads[tid]; + verbose::out << "=== Thread " << tid << " ===" << std::endl; + verbose::out << thread.trace; + verbose::out << "====================================\n"; + } +} - // Or, there was some termination criteria in which case we stop - return true; +void Interpreter::print_revision_graph(const std::filesystem::path& output_path) { + // Build revision graph - collect const raw pointers + std::vector thread_state_ptrs; + for (const auto& thread : gctx.threads) { + thread_state_ptrs.push_back(thread.ctx.sync.get()); + } + + std::string dot = gctx.protocol->build_revision_graph_dot(thread_state_ptrs); + if (!dot.empty()) { + // Write to file + auto dot_file = output_path.parent_path() / (output_path.stem().string() + "_revision_graph.dot"); + std::ofstream out(dot_file); + if (out) { + out << dot; + verbose::out << "Revision graph written to " << dot_file << std::endl; + } else { + verbose::out << "Failed to write revision graph to " << dot_file << std::endl; } + } +} - /* Try to evaluate all threads until they have all terminated in some way - * or we have reached a stuck configuration. - */ - int run_threads(GlobalContext &gctx) - { - std::variant prog_or_term; - do { - prog_or_term = run_threads_to_sync(gctx); - } while (!is_finished(prog_or_term)); - - verbose << "----------- execution complete -----------" << std::endl; - - bool exception_detected = false; - for (size_t i = 0; i < gctx.threads.size(); ++i) - { - const auto& thread = gctx.threads[i]; - if (thread->terminated) - { - switch (thread->terminated.value()) - { - case TerminationStatus::completed: - verbose << "Thread " << i << " terminated normally" << std::endl; - break; - - case TerminationStatus::unlock_exception: - verbose << "Thread " << i << " unlocked a lock it does not own" << std::endl; - exception_detected = true; - break; - - case TerminationStatus::datarace_exception: - verbose << "Thread " << i << " encountered a data-race" << std::endl; - exception_detected = true; - break; - - case TerminationStatus::assertion_failure_exception: - verbose << "Thread " << i << " failed an assertion" << std::endl; - exception_detected = true; - break; - - case TerminationStatus::unassigned_variable_read_exception: - verbose << "Thread " << i << " read an uninitialised value" << std::endl; - exception_detected = true; - break; - - default: - verbose << "Thread " << i << " has an unhandled termination state" << std::endl; - break; - } - } - else - { - exception_detected = true; - thread_append_node(thread->ctx); - verbose << "Thread " << i << " is stuck" << std::endl; +graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { + // Map each thread ID to its last graph node in the execution graph (per-thread program order) + std::unordered_map> thread_tails; + + // Track the last unlock event for each lock (for lock->unlock edges) + std::unordered_map> last_unlock_per_lock; + + // Track join nodes that need fixing up after all threads are processed + std::vector> joins_to_fix; + + // Track read nodes that need their source fixed up + std::vector, std::shared_ptr>> reads_to_fix; + + // Map from trace events to graph nodes + std::unordered_map, std::shared_ptr> event_to_node; + + // Create Start nodes for all threads + std::vector> thread_starts; + thread_starts.reserve(gctx.threads.size()); + + for (ThreadID tid = 0; tid < gctx.threads.size(); ++tid) { + auto node = std::make_shared(tid); + thread_starts.push_back(node); + thread_tails[tid] = node; + } + + // The entry point is thread 0's start + graph::ExecutionGraph g(thread_starts[0]); + g.threads = std::move(thread_starts); + + // Helper to link a node in program order for its thread + auto link_in_program_order = [&](ThreadID tid, std::shared_ptr node) { + if (thread_tails[tid]) { + thread_tails[tid]->next = node; + } + thread_tails[tid] = node; + }; + + // Process events from all threads + for (ThreadID tid = 0; tid < gctx.threads.size(); ++tid) { + auto& thread = gctx.threads[tid]; + + // Process all events from the trace + for (const auto& event : thread.trace) { + if (!event) { + // Safety check: skip null events (shouldn't happen) + continue; + } + + std::visit(overloaded{ + [&](const StartEvent&) { + // Skip: Start nodes already created above + }, + [&](const EndEvent&) { + auto node = std::make_shared(); + link_in_program_order(tid, node); + event_to_node[event] = node; + }, + [&](const WriteEvent& arg) { + auto node = std::make_shared(arg.var, arg.value, tid); + link_in_program_order(tid, node); + event_to_node[event] = node; + }, + [&](const ReadEvent& arg) { + // Link to the write that produced this value + std::shared_ptr node; + std::visit(overloaded{ + [&](const ReadValue& val) { + // Create the read node, but we might need to fix up the source later + node = std::make_shared(arg.var, val.value, tid, nullptr); + assert(val.source_event && "source missing"); + reads_to_fix.push_back({node, val.source_event}); + }, + [&](const std::shared_ptr&) { + node = std::make_shared(arg.var, tid, graph::Conflict(arg.var)); } + }, arg.value_or_conflict); + + link_in_program_order(tid, node); + event_to_node[event] = node; + }, + [&](const SpawnEvent& arg) { + // Link to the child thread's start node + auto node = std::make_shared(arg.child_tid, g.threads[arg.child_tid]); + link_in_program_order(tid, node); + event_to_node[event] = node; + }, + [&](const JoinEvent& arg) { + // Create join node - will fix up joinee pointer later + std::optional conflict; + if (arg.maybe_conflict) { + // Just mark as conflicting - version IDs don't map directly to nodes + conflict = graph::Conflict(""); // empty var name for joins + } + auto node = std::make_shared(arg.joinee_tid, nullptr, conflict); + joins_to_fix.push_back(node); + link_in_program_order(tid, node); + event_to_node[event] = node; + }, + [&](const LockEvent& arg) { + // Link to the last unlock event using the event-to-node mapping + std::shared_ptr ordered_after = nullptr; + if (arg.last_unlock_event && event_to_node.contains(arg.last_unlock_event)) { + ordered_after = event_to_node[arg.last_unlock_event]; + } + + std::optional conflict; + if (arg.maybe_conflict) { + // Mark as conflicting with the lock name + conflict = graph::Conflict(arg.lock_name); + } + auto node = std::make_shared(arg.lock_name, ordered_after, conflict); + link_in_program_order(tid, node); + event_to_node[event] = node; + }, + [&](const UnlockEvent& arg) { + auto node = std::make_shared(arg.lock_name); + last_unlock_per_lock[arg.lock_name] = node; + link_in_program_order(tid, node); + event_to_node[event] = node; + + // Mark conflict if present + if (arg.maybe_conflict) { + // TODO: Visualize unlock conflicts + } + }, + [&](const AssertEvent& arg) { + auto node = std::make_shared(arg.condition, arg.pass); + link_in_program_order(tid, node); + event_to_node[event] = node; } + }, event->data); + } - return exception_detected ? 1 : 0; + // Add pending node if thread hasn't terminated + if (!thread.terminated) { + if (thread.pc < thread.block->size()) { + // Thread is stuck waiting at a specific statement + trieste::Node stmt = thread.block->at(thread.pc); + auto pending = std::make_shared(std::string(stmt->location().view())); + link_in_program_order(tid, pending); + } else { + // Thread has finished all statements but hasn't terminated yet + auto pending = std::make_shared("..."); + link_in_program_order(tid, pending); + } } + } + + // Fix up join nodes to point to the actual end of the joined threads + for (auto& join_node : joins_to_fix) { + ThreadID joinee_tid = join_node->tid; + // thread_tails[joinee_tid] now points to the end (or pending) of that thread + const_cast&>(join_node->joinee) = thread_tails[joinee_tid]; + } + + // Fix up read nodes to point to their source write events + for (auto& [read_node, source_event] : reads_to_fix) { + assert(event_to_node.contains(source_event) && "source missing in event_to_node map"); + read_node->set_source(event_to_node[source_event]); + } + + return g; +} - int interpret(const Node ast, const std::filesystem::path &output_path) - { - GlobalContext gctx(ast); - auto result = run_threads(gctx); - gctx.print_execution_graph(output_path); +void Interpreter::print_execution_graph(const std::filesystem::path& output_path) { + auto exec_graph = build_execution_graph_from_traces(); + graph::GraphvizPrinter gv(output_path); + gv.visit(exec_graph.entry.get()); +} - return result; - } +int interpret(const Node ast, const std::filesystem::path &output_path, + std::unique_ptr protocol) { + Interpreter interp(GlobalContext(ast, std::move(protocol))); + int result = interp.run(); + + interp.print_execution_graph(output_path); + interp.print_revision_graph(output_path); + + return result; } + +} // namespace gitmem \ No newline at end of file diff --git a/src/interpreter.hh b/src/interpreter.hh index df2c7fc..93aea3c 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -1,213 +1,59 @@ #pragma once -#include -#include "lang.hh" +#include "execution_state.hh" #include "graph.hh" #include "graphviz.hh" +#include "lang.hh" +#include "progress_status.hh" +#include +#include "sync_protocol.hh" +#include "termination_status.hh" + +namespace gitmem { -namespace gitmem -{ - /* For debug printing */ - inline struct Verbose - { - bool enabled = false; - - template - const Verbose &operator<<(const T &msg) const - { - if (enabled) - { - std::cout << msg; - } - return *this; - } - - const Verbose &operator<<(std::ostream &(*manip)(std::ostream &)) const - { - if (enabled) - { - std::cout << manip; - } - return *this; - } - } verbose; - - /* A 'Global' is a structure to capture the current synchronising objects - * representation of a global variable. The structure is the current value, - * the current commit id for the variable, and the history of commited ids. - */ - - using Commit = size_t; - using CommitHistory = std::vector; - - struct Global - { - size_t val; - std::optional commit; - CommitHistory history; - }; - - using Globals = std::unordered_map; - - enum class TerminationStatus - { - completed, - datarace_exception, - unlock_exception, - assertion_failure_exception, - unassigned_variable_read_exception, - }; - - using Locals = std::unordered_map; - - struct ThreadContext - { - Locals locals; - Globals globals; - std::shared_ptr tail; - }; - - using ThreadStatus = std::optional; - - struct Thread - { - ThreadContext ctx; - Node block; - size_t pc = 0; - ThreadStatus terminated = std::nullopt; - - bool operator==(const Thread &other) const - { - // Globals have a history that we don't care about, so we only - // compare values - if (ctx.globals.size() != other.ctx.globals.size()) - return false; - for (const auto &[var, global] : ctx.globals) - { - if (!other.ctx.globals.contains(var) || - ctx.globals.at(var).val != other.ctx.globals.at(var).val) - { - return false; - } - } - return ctx.locals == other.ctx.locals && - block == other.block && - pc == other.pc && - terminated == other.terminated; - } - }; - - using ThreadID = size_t; - - struct Lock - { - Globals globals; - std::optional owner = std::nullopt; - std::shared_ptr last; - }; - - using Threads = std::vector>; - - using Locks = std::unordered_map; - - template - std::shared_ptr thread_append_node(ThreadContext& ctx, Args&&...args); - - template<> - std::shared_ptr thread_append_node(ThreadContext& ctx, std::string&& stmt); - - struct GlobalContext - { - Threads threads; - Locks locks; - NodeMap cache; - std::shared_ptr entry_node; - std::unordered_map> commit_map; - Commit uuid = 0; - - GlobalContext(const Node &ast) - { - Node starting_block = ast / File / Block; - entry_node = std::make_shared(0); - ThreadContext starting_ctx = {{}, {}, entry_node}; - auto main_thread = std::make_shared(starting_ctx, starting_block); - - this->threads = {main_thread}; - this->locks = {}; - this->cache = {}; - } - - bool operator==(const GlobalContext &other) const - { - if (threads.size() != other.threads.size() || locks.size() != other.locks.size()) - return false; - - // Threads may have been spawned in a different order, so we - // find the thread with the same block in the other context - for (auto &thread : threads) - { - auto it = std::find_if(other.threads.begin(), other.threads.end(), - [&thread](auto &t) - { return t->block == thread->block; }); - if (it == other.threads.end() || !(*thread == **it)) - return false; - } - - for (auto &[name, lock] : locks) - { - if (!other.locks.contains(name)) - return false; - auto &other_lock = other.locks.at(name); - if (lock.owner != other_lock.owner) - return false; - } - return true; - } - - void print_execution_graph(const std::filesystem::path &output_path) const - { - // Loop over the threads and add pending nodes to running threads - // to indicate a threads next step - for (const auto& t: threads) - { - assert(t->ctx.tail); - if (t->terminated || dynamic_pointer_cast(t->ctx.tail->next)) - continue; - - Node block = t->block; - size_t &pc = t->pc; - Node stmt = block->at(pc); - thread_append_node(t->ctx, std::string(stmt->location().view())); - } - - graph::GraphvizPrinter gv(output_path); - gv.visit(entry_node.get()); - } - }; - - enum class ProgressStatus - { - progress, - no_progress - }; - - inline bool operator!(ProgressStatus p) { return p == ProgressStatus::no_progress; } - - inline ProgressStatus operator||(const ProgressStatus &p1, const ProgressStatus &p2) - { - return (p1 == ProgressStatus::progress || p2 == ProgressStatus::progress) ? ProgressStatus::progress : ProgressStatus::no_progress; - } - - inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { p1 = (p1 || p2); } - - // Entry functions - int interpret(const Node, const std::filesystem::path &output_file); - int interpret_interactive(const Node, const std::filesystem::path &output_file); - int model_check(const Node, const std::filesystem::path &output_file); - - // Internal functions - int run_threads(GlobalContext &); - - std::variant - progress_thread(GlobalContext &, const ThreadID, std::shared_ptr); +using termination::TerminationStatus; + +template +using StepResult = std::variant; + +template +bool is_terminated(const StepResult& r) { + return std::holds_alternative(r); } + +template +T& value(StepResult& r) { + return std::get(r); +} + +class Interpreter { +private: + GlobalContext gctx; + + graph::ExecutionGraph build_execution_graph_from_traces(); + +public: + Interpreter(GlobalContext gctx): gctx(std::move(gctx)) {} + + GlobalContext& context() { return gctx; } + + int run(); + + StepResult evaluate_expression(trieste::Node, Thread&); + StepResult run_statement(trieste::Node, Thread&); + + StepResult progress_thread(Thread&); + StepResult run_single_thread_to_sync(Thread&); + StepResult run_threads_to_sync(); + + void print_state(std::ostream& os, bool show_all = false) const; + void print_thread_traces(); + void print_revision_graph(const std::filesystem::path& output_path); + void print_execution_graph(const std::filesystem::path& output_path); +}; + +// Entry function +int interpret(const trieste::Node, const std::filesystem::path &output_file, + std::unique_ptr protocol); + +} // namespace gitmem \ No newline at end of file diff --git a/src/lang.hh b/src/lang.hh index 4aa3ef0..e244b99 100644 --- a/src/lang.hh +++ b/src/lang.hh @@ -1,59 +1,61 @@ #pragma once #include -namespace gitmem -{ - using namespace trieste; - - Reader reader(); - - // Variables - inline const auto Reg = TokenDef("reg", flag::print); - inline const auto Var = TokenDef("var", flag::print); - - // Constants - inline const auto Const = TokenDef("const", flag::print); - - // Arithmetic - inline const auto Add = TokenDef("+"); - - // Comparison - inline const auto Eq = TokenDef("=="); - inline const auto Neq = TokenDef("!="); - - // Statements - inline const auto Semi = TokenDef(";"); - inline const auto Assign = TokenDef("=", flag::lookup); - inline const auto Spawn = TokenDef("spawn"); - inline const auto Join = TokenDef("join"); - inline const auto Lock = TokenDef("lock"); - inline const auto Unlock = TokenDef("unlock"); - inline const auto Nop = TokenDef("nop"); - inline const auto Assert = TokenDef("assert"); - inline const auto If = TokenDef("if"); - inline const auto Else = TokenDef("else"); - - // Branching - inline const auto Jump = TokenDef("jump"); - inline const auto Cond = TokenDef("cond"); - - // Grouping tokens - inline const auto Brace = TokenDef("brace"); - inline const auto Paren = TokenDef("paren"); - - inline const auto Stmt = TokenDef("stmt"); - inline const auto Expr = TokenDef("expr"); - inline const auto Block = TokenDef("block", flag::symtab | flag::defbeforeuse); - - // Convenience - inline const auto LVal = TokenDef("lval"); - inline const auto Lhs = TokenDef("lhs"); - inline const auto Rhs = TokenDef("rhs"); - inline const auto Op = TokenDef("op"); - inline const auto Then = TokenDef("then"); - - // Well-formedness - // clang-format off +namespace gitmem { + +namespace lang { + +using namespace trieste; + +Reader reader(); + +// Variables +inline const auto Reg = TokenDef("reg", flag::print); +inline const auto Var = TokenDef("var", flag::print); + +// Constants +inline const auto Const = TokenDef("const", flag::print); + +// Arithmetic +inline const auto Add = TokenDef("+"); + +// Comparison +inline const auto Eq = TokenDef("=="); +inline const auto Neq = TokenDef("!="); + +// Statements +inline const auto Semi = TokenDef(";"); +inline const auto Assign = TokenDef("=", flag::lookup); +inline const auto Spawn = TokenDef("spawn"); +inline const auto Join = TokenDef("join"); +inline const auto Lock = TokenDef("lock"); +inline const auto Unlock = TokenDef("unlock"); +inline const auto Nop = TokenDef("nop"); +inline const auto Assert = TokenDef("assert"); +inline const auto If = TokenDef("if"); +inline const auto Else = TokenDef("else"); + +// Branching +inline const auto Jump = TokenDef("jump"); +inline const auto Cond = TokenDef("cond"); + +// Grouping tokens +inline const auto Brace = TokenDef("brace"); +inline const auto Paren = TokenDef("paren"); + +inline const auto Stmt = TokenDef("stmt"); +inline const auto Expr = TokenDef("expr"); +inline const auto Block = TokenDef("block", flag::symtab | flag::defbeforeuse); + +// Convenience +inline const auto LVal = TokenDef("lval"); +inline const auto Lhs = TokenDef("lhs"); +inline const auto Rhs = TokenDef("rhs"); +inline const auto Op = TokenDef("op"); +inline const auto Then = TokenDef("then"); + +// Well-formedness +// clang-format off inline const wf::Wellformed wf = (Top <<= File) | (File <<= Block) @@ -72,6 +74,8 @@ namespace gitmem | (Jump <<= Const) | (Cond <<= Expr * Const) ; - // clang-format on +// clang-format on + +} // namespace lang -} +} // namespace gitmem diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc new file mode 100644 index 0000000..1a65c9c --- /dev/null +++ b/src/linear/sync_protocol.cc @@ -0,0 +1,211 @@ +#include "linear/sync_protocol.hh" +#include "debug.hh" +#include +#include +#include + +namespace gitmem { + +namespace linear { + +LocalVersionStore& get_store(ThreadContext& ctx) { + return static_cast(*ctx.sync); +} + +// -------------------- +// LinearSyncProtocol +// -------------------- + +std::ostream &LinearSyncProtocol::print(std::ostream &os) const { + os << _global_store << std::endl; + return os; +} + +std::string LinearSyncProtocol::build_revision_graph_dot( + const std::vector& thread_states) const { + + std::ostringstream dot; + dot << "digraph LinearHistory {\n"; + dot << " rankdir=BT;\n"; + dot << " node [shape=box];\n"; + + const auto& history = _global_store.get_history(); + + if (history.empty()) { + dot << "}\n"; + return dot.str(); + } + + // Create a subgraph for each variable showing its version history + for (const auto& [obj_name, versions] : history) { + dot << " subgraph cluster_" << obj_name << " {\n"; + dot << " label=\"" << obj_name << "\";\n"; + dot << " style=dashed;\n"; + + // Create nodes for each version + for (size_t i = 0; i < versions.size(); ++i) { + const auto& version = versions[i]; + std::ostringstream node_id; + node_id << obj_name << "_v" << i; + + std::ostringstream label; + label << version.timestamp() << "\\n" << obj_name << "=" << version.value().value; + + dot << " \"" << node_id.str() << "\" [label=\"" << label.str() << "\"];\n"; + } + + // Create edges between consecutive versions + for (size_t i = 1; i < versions.size(); ++i) { + std::ostringstream prev_id, curr_id; + prev_id << obj_name << "_v" << (i - 1); + curr_id << obj_name << "_v" << i; + dot << " \"" << curr_id.str() << "\" -> \"" << prev_id.str() << "\";\n"; + } + + dot << " }\n"; + } + + dot << "}\n"; + return dot.str(); +} + +std::optional +LinearSyncProtocol::push(LocalVersionStore &local) { + if (auto conflict = _global_store.check_conflicts(local.timestamp(), + local.staged_changes())) { + + return std::make_optional( + conflict->object, + std::make_pair(conflict->local_base, conflict->global_head)); + } + + uint64_t new_base = _global_store.apply_changes( + local.thread(), local.timestamp(), local.staged_changes()); + + local.clear_staging(); + local.advance_base(new_base); + return std::nullopt; +} + +std::optional +LinearSyncProtocol::pull(LocalVersionStore &local) { + if (auto conflict = _global_store.check_conflicts(local.timestamp(), + local.staged_changes())) { + + return std::make_optional( + conflict->object, + std::make_pair(conflict->local_base, conflict->global_head)); + } + + local.advance_base( _global_store.current_counter()); + return std::nullopt; +} + +LinearSyncProtocol::~LinearSyncProtocol() = default; + +ReadResult LinearSyncProtocol::read(ThreadContext &ctx, + const std::string &var) { + auto& store = get_store(ctx); + + if (auto result = store.get_staged(var)) { + return *result; + } + + std::optional value = _global_store.get_version_for_timestamp( + var, store.timestamp()); + if (value) { + return value.value(); + } + + return std::monostate{}; +} + +void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, + ValueWithSource value) { + // write into the staging area of the thread + auto& store = get_store(ctx); + store.stage(var, value); +} + +std::optional> +LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child) { + // push parent to global history + auto& store = get_store(parent); + if (auto conflict = push(store)) + return std::make_shared(std::move(*conflict)); + + // pull into the child + store = get_store(child); + if (auto conflict = pull(store)) { + throw std::logic_error("This code path should never be reached"); + } + + return std::nullopt; +} + +std::optional> +LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee) { + // we assume the joinee has already terminated and pushed + + // pull changes into parent + auto& store = get_store(joiner); + if (auto conflict = pull(store)) + return std::make_shared(std::move(*conflict)); + + return std::nullopt; +} + +std::optional> +LinearSyncProtocol::on_start(ThreadContext &thread) { + // pull state from global history + auto& store = get_store(thread); + auto conflict = pull(store); + assert(!conflict && "cannot conflict from starting state"); + + return std::nullopt; +}; + +std::optional> +LinearSyncProtocol::on_end(ThreadContext &thread) { + // push changes to global history + auto& store = get_store(thread); + if (auto conflict = push(store)) + return std::make_shared(std::move(*conflict)); + + return std::nullopt; +}; + +std::optional> +LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock) { + + auto& store = get_store(thread); + if (auto conflict = pull(store)) + return std::make_shared(std::move(*conflict)); + + return std::nullopt; +} + +std::optional> +LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &) { + // push changes to global history + auto& store = get_store(thread); + if (auto conflict = push(store)) + return std::make_shared(std::move(*conflict)); + + return std::nullopt; +} + +bool LinearSyncProtocol::is_scheduling_point(SyncOperation op) const { + switch (op) { + case SyncOperation::Lock: + case SyncOperation::Unlock: + case SyncOperation::Join: + case SyncOperation::Spawn: + case SyncOperation::Start: + case SyncOperation::End: return true; } + assert(false && "Unknown SyncOperation"); +} + +} // namespace linear + +} // namespace gitmem diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh new file mode 100644 index 0000000..37df6fd --- /dev/null +++ b/src/linear/sync_protocol.hh @@ -0,0 +1,73 @@ +#pragma once + +#include "../sync_protocol.hh" +#include "conflict.hh" +#include "execution_state.hh" +#include "linear/version_store.hh" + +namespace gitmem { + +using LinearConflict = Conflict; + +namespace linear { + +class LinearSyncProtocol final : public SyncProtocol { + GlobalVersionStore _global_store; + + std::optional push(LocalVersionStore &local); + std::optional pull(LocalVersionStore &local); + +public: + ~LinearSyncProtocol() override; + + std::unique_ptr clone() const override { + return std::make_unique(); + } + + ReadResult read(ThreadContext &ctx, const std::string &var) override; + + void write(ThreadContext &ctx, const std::string &var, ValueWithSource value) override; + + std::optional> + on_spawn(ThreadContext &parent, ThreadContext &child) override; + + std::optional> + on_join(ThreadContext &joiner, ThreadContext &joinee) override; + + std::optional> + on_start(ThreadContext &thread) override; + + std::optional> + on_end(ThreadContext &thread) override; + + std::optional> + on_lock(ThreadContext &thread, Lock &lock) override; + + std::optional> + on_unlock(ThreadContext &thread, Lock &lock) override; + + std::ostream &print(std::ostream &os) const override; + + std::string build_revision_graph_dot(const std::vector& thread_states) const override; + + bool is_scheduling_point(SyncOperation op) const override; + + std::unique_ptr make_thread_state(ThreadID tid) const override { + return std::make_unique(tid); + } + + std::unique_ptr make_lock_state() const override { + return nullptr; + } +}; + +class LinearSyncProtocolBuilder { +public: + std::unique_ptr build() const { + return std::make_unique(); + } +}; + +} // namespace linear + +} // namespace gitmem \ No newline at end of file diff --git a/src/linear/version_store.cc b/src/linear/version_store.cc new file mode 100644 index 0000000..a6dbf31 --- /dev/null +++ b/src/linear/version_store.cc @@ -0,0 +1,126 @@ +#include +#include + +#include "sync_protocol.hh" +#include "version_store.hh" +#include "thread_trace.hh" +#include "read_result.hh" + +namespace gitmem { + +namespace linear { + +// ----------------------------- +// LocalVersionStore +// ----------------------------- + +void LocalVersionStore::stage(std::string obj, ValueWithSource value) { + _staging[obj] = value; +} + +void LocalVersionStore::clear_staging() { + _staging.clear(); +} + +void LocalVersionStore::advance_base(uint64_t ts) { _timestamp = ts; } + +std::optional LocalVersionStore::get_staged(std::string obj) { + auto it = _staging.find(obj); + return it != _staging.end() ? std::make_optional(it->second) : std::nullopt; +} + +bool LocalVersionStore::operator==(const LocalVersionStore& other) const { + return _timestamp == other._timestamp && + _staging == other._staging; +} + +std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { + os << "LocalVersionStore{" + << "base=" << store._timestamp + << ", staged={"; + + bool first = true; + for (const auto& [obj, val] : store._staging) { + if (!first) os << ", "; + first = false; + os << obj << "->" << val.value << " (" << val.source_event << ")"; + } + + os << "}}"; + return os; +} + +// ----------------------------- +// GlobalVersionStore +// ----------------------------- + +std::optional +GlobalVersionStore::get_version_for_timestamp(std::string obj, + uint64_t ts) const { + const auto it = _history.find(obj); + + if (it == _history.end()) + return std::nullopt; + + const VersionHistory &history = it->second; + for (VersionHistory::const_reverse_iterator riter = history.rbegin(); + riter != history.rend(); ++riter) { + if (riter->timestamp().counter <= ts) + return riter->value(); + } + + return std::nullopt; +} + +std::optional GlobalVersionStore::check_conflicts( + uint64_t base, + const std::unordered_map &changes) const { + for (const auto &[obj, _] : changes) { + auto it = _history.find(obj); + if (it == _history.end()) { + continue; + } + + const Version &latest = it->second.back(); + if (latest.timestamp().counter > base) { + return Conflict{ + .object = obj, .local_base = base, .global_head = latest.timestamp()}; + } + } + return std::nullopt; +} + +uint64_t GlobalVersionStore::apply_changes( + ThreadID tid, uint64_t base, + const std::unordered_map &changes) { + if (auto conflict = check_conflicts(base, changes)) { + throw std::logic_error("apply_changes called with conflicts"); + } + + // Increment the global counter and create new timestamp with thread info from base + Timestamp new_ts{tid, ++_counter}; + + for (const auto &[obj, value] : changes) { + _history[obj].emplace_back(new_ts, value); + } + + return _counter; +} + +std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { + os << "GlobalVersionStore(counter=" << store._counter << ")\n"; + + for (const auto& [obj_name, history] : store._history) { + os << " Object " << obj_name << ":\n"; + + for (const auto& version : history) { + os << " [" << version.timestamp() << "] = " << version.value().value << "\n"; + } + } + + return os; +} + +} // namespace linear + +} // namespace gitmem \ No newline at end of file diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh new file mode 100644 index 0000000..fd719b8 --- /dev/null +++ b/src/linear/version_store.hh @@ -0,0 +1,136 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "sync_state.hh" +#include "thread_id.hh" +#include "read_result.hh" + +namespace gitmem { + +struct Event; // Forward declaration + +namespace linear { + +// ----------------------------- +// Timestamp +// ----------------------------- + +struct Timestamp { + size_t thread{0}; + uint64_t counter{0}; + + auto operator<=>(const Timestamp &) const = default; + + friend std::ostream &operator<<(std::ostream &os, const Timestamp &ts) { + os << "t" << ts.thread << ":" << ts.counter; + return os; + } +}; + +using Value = size_t; + +// ----------------------------- +// Version +// ----------------------------- + +class Version { + Timestamp _timestamp; + ValueWithSource _value; + +public: + Version(Timestamp ts, ValueWithSource value) + : _timestamp(ts), _value(value) {} + + Timestamp timestamp() const { return _timestamp; } + ValueWithSource value() const { return _value; } +}; + +using VersionHistory = std::vector; + +// ----------------------------- +// Conflict +// ----------------------------- + +struct Conflict { + std::string object; + Timestamp local_base; + Timestamp global_head; +}; + +// ----------------------------- +// LocalVersionStore +// ----------------------------- + +class LocalVersionStore : public ThreadSyncState { + ThreadID tid; + uint64_t _timestamp; + std::unordered_map _staging; + +public: + ~LocalVersionStore() = default; + + LocalVersionStore(ThreadID tid) : tid(tid), _timestamp(0) {} + + ThreadID thread() const { return tid; } + + uint64_t timestamp() const { return _timestamp; } + const auto &staged_changes() const { return _staging; } + + void stage(std::string obj, ValueWithSource value); + void clear_staging(); + void advance_base(uint64_t ts); + std::optional get_staged(std::string obj); + + bool operator==(const LocalVersionStore& other) const; + + bool operator==(const ThreadSyncState& other) const override { + auto* o = dynamic_cast(&other); + if (!o) + return false; + return *this == *o; + } + + friend std::ostream& operator<<(std::ostream&, const LocalVersionStore&); + + std::ostream &print(std::ostream &os) const override { + os << *dynamic_cast(this); + return os; + } +}; + +// ----------------------------- +// GlobalVersionStore +// ----------------------------- + +class GlobalVersionStore { + uint64_t _counter{0}; + std::unordered_map _history; + +public: + uint64_t current_counter() const { return _counter; } + + std::optional get_version_for_timestamp(std::string, uint64_t) const; + + std::optional + check_conflicts(uint64_t base, + const std::unordered_map &changes) const; + + uint64_t + apply_changes(ThreadID tid, uint64_t base, + const std::unordered_map &changes); + + friend std::ostream& operator<<(std::ostream&, const GlobalVersionStore&); + + std::unordered_map get_history() const { + return _history; + } +}; + +} // namespace linear + +} // namespace gitmem \ No newline at end of file diff --git a/src/main.cc b/src/main.cc index bb5a948..db19c05 100644 --- a/src/main.cc +++ b/src/main.cc @@ -1,8 +1,6 @@ -#include #include "reader.cc" +#include - -int main(int argc, char** argv) -{ +int main(int argc, char **argv) { return trieste::Driver(grunq::reader()).run(argc, argv); } diff --git a/src/model_checker.cc b/src/model_checker.cc index c7fff60..5126c8f 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -1,216 +1,217 @@ +#include "model_checker.hh" #include "interpreter.hh" +#include "debug.hh" +#include "sync_protocol.hh" + +namespace gitmem { +using namespace trieste; + +/** + * A TraceNode represents a point in the space of possible schedulings. A + * path in a tree of TraceNodes represents a scheduling, with the thread ID + * of each node being the thread that was scheduled at that point. When + * there are no more children to explore, or when one thread has crashed, + * the TraceNode is marked as complete so that the next run will not explore + * it again. + * + */ +struct TraceNode { + size_t tid_; + bool complete; + std::vector> children; + + TraceNode(const size_t tid) : tid_(tid), complete(false) {} + + std::shared_ptr extend(ThreadID tid) { + children.push_back(std::make_shared(tid)); + return children.back(); + } + + bool is_leaf() const { return children.empty(); } +}; + +/** + * Print the traces of the program, one trace per line. Each trace is a + * sequence of thread IDs that were scheduled in that order. + */ +template +void print_traces(S &stream, const std::vector> &traces) { + for (const auto &trace : traces) { + for (const auto &tid : trace) { + stream << tid << " "; + } + stream << std::endl; + } +} -namespace gitmem -{ - using namespace trieste; - - /** - * A TraceNode represents a point in the space of possible schedulings. A - * path in a tree of TraceNodes represents a scheduling, with the thread ID - * of each node being the thread that was scheduled at that point. When - * there are no more children to explore, or when one thread has crashed, - * the TraceNode is marked as complete so that the next run will not explore - * it again. - * - */ - struct TraceNode - { - size_t tid_; - bool complete; - std::vector> children; - - TraceNode(const size_t tid) : tid_(tid), complete(false) {} - - std::shared_ptr extend(ThreadID tid) - { - children.push_back(std::make_shared(tid)); - return children.back(); - } +/** Build an output path for the execution graph, appending an index to the + * filename to avoid overwriting previous graphs. */ +std::filesystem::path +build_output_path(const std::filesystem::path &output_path, const size_t idx) { + auto parent = output_path.parent_path(); + auto name = output_path.stem().string(); + auto ext = output_path.extension().string(); + return parent / (name + "_" + std::to_string(idx) + ext); +} - bool is_leaf() const - { - return children.empty(); - } - }; - - /** - * Print the traces of the program, one trace per line. Each trace is a - * sequence of thread IDs that were scheduled in that order. - */ - template - void print_traces(S &stream, const std::vector> &traces) - { - for (const auto &trace : traces) - { - for (const auto &tid : trace) - { - stream << tid << " "; - } - stream << std::endl; +/** + * Explore all possible execution paths of the program, printing one trace + * for each distinct final state that led to an error. + */ +int model_check(const Node ast, const std::filesystem::path &output_path, + std::unique_ptr protocol) { + auto final_contexts = std::vector>{}; + auto failing_contexts = std::vector>{}; + auto deadlocked_contexts = std::vector>{}; + + auto final_traces = std::vector>{}; + auto failing_traces = std::vector>{}; + auto deadlocked_traces = std::vector>{}; + + const auto root = std::make_shared(0); + auto cursor = root; + auto current_trace = std::vector{0}; // Start with the main thread + verbose::out << "==== Thread " << cursor->tid_ << " ====" << std::endl; + + Interpreter interp(GlobalContext(ast, std::move(protocol))); + + // Keep a pointer to the protocol for cloning later + const SyncProtocol* protocol_template = interp.context().protocol.get(); + + GlobalContext& gctx = interp.context(); + interp.progress_thread(gctx.threads[cursor->tid_]); + + while (!root->complete) { + while (!cursor->children.empty() && !cursor->children.back()->complete) { + // We have a child that is not complete, we can extend that trace + cursor = cursor->children.back(); + current_trace.push_back(cursor->tid_); + verbose::out << "==== Thread " << cursor->tid_ + << " (replay) ====" << std::endl; + interp.progress_thread(gctx.threads[cursor->tid_]); + } + + // Try to find a thread to schedule next + size_t start_idx = + cursor->children.empty() ? 0 : cursor->children.back()->tid_ + 1; + size_t no_threads = gctx.threads.size(); + bool made_progress = false; + for (size_t i = start_idx; i < no_threads && !made_progress; ++i) { + auto& thread = gctx.threads[i]; + if (!thread.terminated) { + // Run the thread to the next sync point + verbose::out << "==== Thread " << i << " ====" << std::endl; + auto prog_or_term = interp.progress_thread(thread); + if (is_terminated(prog_or_term)) { + // Thread terminated, we can extend the trace + made_progress = true; + cursor = cursor->extend(i); + current_trace.push_back(i); + if (!std::holds_alternative(std::get(prog_or_term))) { + // Thread terminated with an error, we can stop here + verbose::out << "Thread " << i << " terminated with an error" + << std::endl; + cursor->complete = true; + } + } else if (std::get(prog_or_term) == + ProgressStatus::progress) { + // Thread made progress, we can continue + made_progress = true; + cursor = cursor->extend(i); + current_trace.push_back(i); } + } } - /** Build an output path for the execution graph, appending an index to the - * filename to avoid overwriting previous graphs. */ - std::filesystem::path build_output_path(const std::filesystem::path &output_path, const size_t idx) - { - auto parent = output_path.parent_path(); - auto name = output_path.stem().string(); - auto ext = output_path.extension().string(); - return parent / (name + "_" + std::to_string(idx) + ext); + if (!made_progress) { + // No threads made progress, we can stop here + cursor->complete = true; } - /** - * Explore all possible execution paths of the program, printing one trace - * for each distinct final state that led to an error. - */ - int model_check(const Node ast, const std::filesystem::path &output_path) - { - GlobalContext gctx(ast); - - auto final_contexts = std::vector{}; - auto failing_contexts = std::vector{}; - auto deadlocked_contexts = std::vector{}; - - auto final_traces = std::vector>{}; - auto failing_traces = std::vector>{}; - auto deadlocked_traces = std::vector>{}; - - const auto root = std::make_shared(0); - auto cursor = root; - auto current_trace = std::vector{0}; // Start with the main thread - verbose << "==== Thread " << cursor->tid_ << " ====" << std::endl; - progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); - - while (!root->complete) - { - while (!cursor->children.empty() && !cursor->children.back()->complete) - { - // We have a child that is not complete, we can extend that trace - cursor = cursor->children.back(); - current_trace.push_back(cursor->tid_); - verbose << "==== Thread " << cursor->tid_ << " (replay) ====" << std::endl; - progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); - } - - // Try to find a thread to schedule next - size_t start_idx = cursor->children.empty() ? 0 : cursor->children.back()->tid_ + 1; - size_t no_threads = gctx.threads.size(); - bool made_progress = false; - for (size_t i = start_idx; i < no_threads && !made_progress; ++i) - { - auto thread = gctx.threads[i]; - if (!thread->terminated) - { - // Run the thread to the next sync point - verbose << "==== Thread " << i << " ====" << std::endl; - auto prog_or_term = progress_thread(gctx, i, thread); - if (std::holds_alternative(prog_or_term)) - { - // Thread terminated, we can extend the trace - made_progress = true; - cursor = cursor->extend(i); - current_trace.push_back(i); - if (std::get(prog_or_term) != TerminationStatus::completed) - { - // Thread terminated with an error, we can stop here - verbose << "Thread " << i << " terminated with an error" << std::endl; - cursor->complete = true; - } - } - else if (std::get(prog_or_term) == ProgressStatus::progress) - { - // Thread made progress, we can continue - made_progress = true; - cursor = cursor->extend(i); - current_trace.push_back(i); - } - } - } - - if (!made_progress) - { - // No threads made progress, we can stop here - cursor->complete = true; - } - - bool all_completed = std::all_of(gctx.threads.begin(), gctx.threads.end(), - [](const auto &thread) - { return thread->terminated && *thread->terminated == TerminationStatus::completed; }); - bool any_crashed = - std::any_of(gctx.threads.begin(), gctx.threads.end(), - [](const auto &thread) - { return thread->terminated && *thread->terminated != TerminationStatus::completed; }); - - bool is_deadlock = !all_completed && !made_progress && cursor->is_leaf(); - - if (all_completed || any_crashed || is_deadlock) - { - // Remember final state if it is new - if (!std::any_of(final_contexts.begin(), final_contexts.end(), - [&gctx](const GlobalContext &state) - { return state == gctx; })) - { - final_contexts.push_back(gctx); - final_traces.push_back(current_trace); - if (any_crashed) - { - failing_traces.push_back(current_trace); - failing_contexts.push_back(gctx); - } - else if (is_deadlock) - { - deadlocked_traces.push_back(current_trace); - deadlocked_contexts.push_back(gctx); - } - } - - cursor->complete = true; - } - - if (cursor->complete && !root->complete) - { - // Reset the cursor to the root and start a new trace - verbose << std::endl - << "Restarting trace..." << std::endl; - gctx = GlobalContext(ast); - - cursor = root; - current_trace.clear(); - current_trace.push_back(0); // Start with the main thread again - verbose << "==== Thread " << cursor->tid_ << " (replay) ====" << std::endl; - progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); - } + bool all_completed = std::all_of( + gctx.threads.begin(), gctx.threads.end(), [](const auto &thread) { + return thread.terminated && + std::holds_alternative(*thread.terminated); + }); + bool any_crashed = std::any_of( + gctx.threads.begin(), gctx.threads.end(), [](const auto &thread) { + return thread.terminated && + !std::holds_alternative(*thread.terminated); + }); + + bool is_deadlock = !all_completed && !made_progress && cursor->is_leaf(); + + if (all_completed || any_crashed || is_deadlock) { + // Remember final state if it is new + if (!std::any_of( + final_contexts.begin(), final_contexts.end(), + [&gctx](const std::shared_ptr &state) { return *state == gctx; })) { + // Here we take ownership of the global context + std::shared_ptr gctxp = std::make_shared(std::move(gctx)); + final_contexts.push_back(gctxp); + final_traces.push_back(current_trace); + if (any_crashed) { + failing_traces.push_back(current_trace); + failing_contexts.push_back(gctxp); + } else if (is_deadlock) { + deadlocked_traces.push_back(current_trace); + deadlocked_contexts.push_back(gctxp); } + } - verbose << "Found a total of " << final_traces.size() << " trace(s) with distinct final states:" << std::endl; - print_traces(verbose, final_traces); - - size_t idx = 0; - if (!failing_traces.empty()) - { - std::cout << "Found " << failing_traces.size() << " trace(s) with errors:" << std::endl; - print_traces(std::cout, failing_traces); - - for (const auto &ctx : failing_contexts) - { - auto path = build_output_path(output_path, idx++); - ctx.print_execution_graph(path); - } - } + cursor->complete = true; + } - if (!deadlocked_traces.empty()) - { - std::cout << "Found " << deadlocked_traces.size() << " trace(s) leading to deadlock:" << std::endl; - print_traces(std::cout, deadlocked_traces); + if (cursor->complete && !root->complete) { + // Reset the cursor to the root and start a new trace + verbose::out << std::endl << "Restarting trace..." << std::endl; + interp = Interpreter(GlobalContext(ast, protocol_template->clone())); + GlobalContext& gctx = interp.context(); + + cursor = root; + current_trace.clear(); + current_trace.push_back(0); // Start with the main thread again + verbose::out << "==== Thread " << cursor->tid_ + << " (replay) ====" << std::endl; + interp.progress_thread(gctx.threads[cursor->tid_]); + } + } + + std::cout << "Found a total of " << final_traces.size() + << " trace(s) with distinct final states" + << " (errors: " << failing_traces.size() + << ", no errors: " << final_traces.size() - failing_traces.size() << ")" + << std::endl; + print_traces(verbose::out, final_traces); + + size_t idx = 0; + if (!failing_traces.empty()) { + print_traces(std::cout, failing_traces); + + for (const auto &ctx : failing_contexts) { + auto path = build_output_path(output_path, idx++); + + for (size_t tid = 0; tid < ctx->threads.size(); ++tid) { + const auto& thread = ctx->threads[tid]; + verbose::out << "=== Thread " << tid << " ===" << std::endl; + verbose::out << thread.trace; + verbose::out << "====================================\n"; + } + // ctx->print_execution_graph(path); + } + } - for (const auto &ctx : deadlocked_contexts) - { - auto path = build_output_path(output_path, idx++); - ctx.print_execution_graph(path); - } - } + if (!deadlocked_traces.empty()) { + std::cout << "Found " << deadlocked_traces.size() + << " trace(s) leading to deadlock:" << std::endl; + print_traces(std::cout, deadlocked_traces); - return deadlocked_traces.empty() && failing_traces.empty() ? 0 : 1; + for (const auto &ctx : deadlocked_contexts) { + auto path = build_output_path(output_path, idx++); + // ctx->print_execution_graph(path); } + } + + return deadlocked_traces.empty() && failing_traces.empty() ? 0 : 1; } +} // namespace gitmem diff --git a/src/model_checker.hh b/src/model_checker.hh new file mode 100644 index 0000000..54d8494 --- /dev/null +++ b/src/model_checker.hh @@ -0,0 +1,10 @@ +#pragma once + +#include +#include "sync_protocol.hh" + +namespace gitmem { + using namespace trieste; + + int model_check(const Node ast, const std::filesystem::path &output_path, std::unique_ptr protocol); +} \ No newline at end of file diff --git a/src/overloaded.hh b/src/overloaded.hh new file mode 100644 index 0000000..689c013 --- /dev/null +++ b/src/overloaded.hh @@ -0,0 +1,8 @@ +#pragma once + +template +struct overloaded : Ts... { + using Ts::operator()...; +}; +template +overloaded(Ts...) -> overloaded; \ No newline at end of file diff --git a/src/parser.cc b/src/parser.cc index 4eb95ff..212e574 100644 --- a/src/parser.cc +++ b/src/parser.cc @@ -1,138 +1,141 @@ -#include "lang.hh" #include "internal.hh" +#include "lang.hh" -namespace gitmem -{ - using namespace trieste; - using namespace trieste::detail; - - Parse parser() - { - Parse p(depth::file, parser_wf); - auto infix = [](Make& m, Token t) { - // This precedence table maps infix operators to the operators that have - // *higher* precedence, and which should therefore be terminated when that - // operator is encountered. Note that operators with the same precedence - // terminate each other. (for reasons, it has to be defined inside the lambda) - const auto precedence_table = std::map> { - {Add, {}}, - {Eq, {Add}}, - {Neq, {Add}}, +namespace gitmem { + +namespace lang { + +using namespace trieste; +using namespace trieste::detail; + +Parse parser() { + Parse p(depth::file, parser_wf); + auto infix = [](Make &m, Token t) { + // This precedence table maps infix operators to the operators that have + // *higher* precedence, and which should therefore be terminated when that + // operator is encountered. Note that operators with the same precedence + // terminate each other. (for reasons, it has to be defined inside the + // lambda) + const auto precedence_table = std::map>{ + {Add, {}}, + {Eq, {Add}}, + {Neq, {Add}}, {Assign, {Add, Eq, Neq}}, - }; - - auto skip = precedence_table.at(t); - m.seq(t, skip); - // Push group to be able to check whether an operand follows - m.push(Group); }; -/* - auto pair_with = [pop_until](Make &m, Token preceding, Token following) { - pop_until(m, preceding, {Paren, Brace, File}); - m.term(); + auto skip = precedence_table.at(t); + m.seq(t, skip); + // Push group to be able to check whether an operand follows + m.push(Group); + }; - if (!m.in(preceding)) { - const std::string msg = (std::string) "Unexpected '" + following.str() + "'"; - m.error(msg); - return; - } + /* + auto pair_with = [pop_until](Make &m, Token preceding, Token following) { + pop_until(m, preceding, {Paren, Brace, File}); + m.term(); - m.pop(preceding); - m.push(following); - }; -*/ + if (!m.in(preceding)) { + const std::string msg = (std::string) "Unexpected '" + following.str() + + "'"; m.error(msg); return; + } - auto pop_until = [](Make &m, Token t, std::initializer_list stop = {File}) { - while (!m.in(t) && !m.group_in(t) - && !m.in(stop) && !m.group_in(stop)) { - m.term(); - m.pop(); - } + m.pop(preceding); + m.push(following); + }; + */ - return (m.in(t) || m.group_in(t)); - }; + auto pop_until = [](Make &m, Token t, + std::initializer_list stop = {File}) { + while (!m.in(t) && !m.group_in(t) && !m.in(stop) && !m.group_in(stop)) { + m.term(); + m.pop(); + } - p("start", + return (m.in(t) || m.group_in(t)); + }; + + p("start", { // Whitespace - "[[:space:]]+" >> [](auto&) { }, // no-op + "[[:space:]]+" >> [](auto &) {}, // no-op // Line comment - "//[^\n]*" >> [](auto&) { }, // no-op + "//[^\n]*" >> [](auto &) {}, // no-op // Constant - "[[:digit:]]+" >> [](auto& m) { m.add(Const); }, + "[[:digit:]]+" >> [](auto &m) { m.add(Const); }, // Addition - R"(\+)" >> [infix](auto& m) { infix(m, Add); }, + R"(\+)" >> [infix](auto &m) { infix(m, Add); }, // Comparison - "==" >> [infix](auto& m) { infix(m, Eq); }, - "!=" >> [infix](auto& m) { infix(m, Neq); }, + "==" >> [infix](auto &m) { infix(m, Eq); }, + "!=" >> [infix](auto &m) { infix(m, Neq); }, // Statements - ";" >> [](auto& m) { m.seq(Semi, {Assign, Spawn, Join, Lock, Unlock, Assert, If, Else, Eq, Neq, Add, Group}); }, - "=" >> [infix](auto& m) { infix(m, Assign); }, - "spawn" >> [](auto& m) { m.push(Spawn); }, - "join" >> [](auto& m) { m.push(Join); }, - "lock" >> [](auto& m) { m.push(Lock); }, - "unlock" >> [](auto& m) { m.push(Unlock); }, - "assert" >> [](auto& m) { m.push(Assert); }, - "nop" >> [](auto& m) { m.add(Nop); }, - - "if" >> [](auto& m) { m.push(If); }, - "else" >> [pop_until](auto &m) - { - pop_until(m, Semi, {Brace, Paren, File}); - m.push(Else); - }, + ";" >> + [](auto &m) { + m.seq(Semi, {Assign, Spawn, Join, Lock, Unlock, Assert, If, Else, + Eq, Neq, Add, Group}); + }, + "=" >> [infix](auto &m) { infix(m, Assign); }, + "spawn" >> [](auto &m) { m.push(Spawn); }, + "join" >> [](auto &m) { m.push(Join); }, + "lock" >> [](auto &m) { m.push(Lock); }, + "unlock" >> [](auto &m) { m.push(Unlock); }, + "assert" >> [](auto &m) { m.push(Assert); }, + "nop" >> [](auto &m) { m.add(Nop); }, + + "if" >> [](auto &m) { m.push(If); }, + "else" >> + [pop_until](auto &m) { + pop_until(m, Semi, {Brace, Paren, File}); + m.push(Else); + }, // Variables - R"(\$[_[:alpha:]][_[:alnum:]]*)" >> [](auto& m) { m.add(Reg); }, - R"([_[:alpha:]][_[:alnum:]]*)" >> [](auto& m) { m.add(Var); }, + R"(\$[_[:alpha:]][_[:alnum:]]*)" >> [](auto &m) { m.add(Reg); }, + R"([_[:alpha:]][_[:alnum:]]*)" >> [](auto &m) { m.add(Var); }, // Grouping - "\\{" >> [](auto& m) { m.push(Brace); }, - "\\}" >> [pop_until](auto& m) - { - pop_until(m, Brace, {Paren}); - m.term(); - m.pop(Brace); - m.extend(Brace); - if (m.group_in(If)) - { - m.term(); - m.pop(If); - } - else if (m.group_in(Else)) - { - m.term(); - m.pop(Else); - } - if (m.group_in({Semi, Brace, File})) - { - m.seq(Semi); - } - }, - - "\\(" >> [](auto& m) { m.push(Paren); }, - "\\)" >> [pop_until](auto& m) - { - pop_until(m, Paren, {Brace}); - m.term(); - m.pop(Paren); - m.extend(Paren); - }, - } - ); - - p.done([pop_until](auto& m) { - if (!m.in(Semi)) - m.error("Expected ';' at end of file"); - pop_until(m, File, {Brace, Paren}); + "\\{" >> [](auto &m) { m.push(Brace); }, + "\\}" >> + [pop_until](auto &m) { + pop_until(m, Brace, {Paren}); + m.term(); + m.pop(Brace); + m.extend(Brace); + if (m.group_in(If)) { + m.term(); + m.pop(If); + } else if (m.group_in(Else)) { + m.term(); + m.pop(Else); + } + if (m.group_in({Semi, Brace, File})) { + m.seq(Semi); + } + }, + + "\\(" >> [](auto &m) { m.push(Paren); }, + "\\)" >> + [pop_until](auto &m) { + pop_until(m, Paren, {Brace}); + m.term(); + m.pop(Paren); + m.extend(Paren); + }, }); - return p; - } + p.done([pop_until](auto &m) { + if (!m.in(Semi)) + m.error("Expected ';' at end of file"); + pop_until(m, File, {Brace, Paren}); + }); + + return p; } + +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/passes/branching.cc b/src/passes/branching.cc index 152cd10..cf589bd 100644 --- a/src/passes/branching.cc +++ b/src/passes/branching.cc @@ -1,31 +1,34 @@ #include "../internal.hh" -namespace gitmem -{ - using namespace trieste; +namespace gitmem { - PassDef branching() - { - return { - "branching", - branching_wf, - dir::bottomup | dir::once, - { - T(Stmt) << (T(If)[If] << (T(Expr)[Expr] * T(Block)[Then] * T(Block)[Else])) >> - [](Match &_) -> Node - { - auto then_length = std::to_string(_(Then)->size() + 1 + 1); // +1 for the jump - auto else_length = std::to_string(_(Else)->size() + 1); - auto cond_loc = Location("if (" + std::string(_(Expr)->location().view()) + ") jump " + then_length); - auto jump_loc = Location("jump " + else_length); - auto cond = (Stmt ^ cond_loc) << (Cond << _(Expr) << (Const ^ then_length)); - auto jump = (Stmt ^ jump_loc) << (Jump << (Const ^ else_length)); - return Seq << cond - << *_(Then) - << jump - << *_(Else); - }, - }}; - } +namespace lang { +using namespace trieste; +PassDef branching() { + return {"branching", + branching_wf, + dir::bottomup | dir::once, + { + T(Stmt) << (T(If)[If] << (T(Expr)[Expr] * T(Block)[Then] * + T(Block)[Else])) >> + [](Match &_) -> Node { + auto then_length = + std::to_string(_(Then)->size() + 1 + 1); // +1 for the jump + auto else_length = std::to_string(_(Else)->size() + 1); + auto cond_loc = + Location("if (" + std::string(_(Expr)->location().view()) + + ") jump " + then_length); + auto jump_loc = Location("jump " + else_length); + auto cond = (Stmt ^ cond_loc) + << (Cond << _(Expr) << (Const ^ then_length)); + auto jump = (Stmt ^ jump_loc) + << (Jump << (Const ^ else_length)); + return Seq << cond << *_(Then) << jump << *_(Else); + }, + }}; } + +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/passes/check_refs.cc b/src/passes/check_refs.cc index ca712b0..1cdd448 100644 --- a/src/passes/check_refs.cc +++ b/src/passes/check_refs.cc @@ -1,30 +1,29 @@ #include "../internal.hh" -namespace gitmem -{ - using namespace trieste; +namespace gitmem { - PassDef check_refs() - { - return { - "check_refs", - statements_wf, - dir::bottomup | dir::once, - { - In(Expr) * T(Reg)[Reg] >> - [](Match &_) -> Node - { - auto reg = _(Reg); - auto enclosing_block = reg->scope(); - auto bindings = reg->lookup(enclosing_block); - if (bindings.empty()) - { - return Error << (ErrorAst << _(Reg)) - << (ErrorMsg ^ "Register has not been assigned"); - } - return NoChange; - }, - }}; - } +namespace lang { +using namespace trieste; + +PassDef check_refs() { + return {"check_refs", + statements_wf, + dir::bottomup | dir::once, + { + In(Expr) * T(Reg)[Reg] >> [](Match &_) -> Node { + auto reg = _(Reg); + auto enclosing_block = reg->scope(); + auto bindings = reg->lookup(enclosing_block); + if (bindings.empty()) { + return Error << (ErrorAst << _(Reg)) + << (ErrorMsg ^ "Register has not been assigned"); + } + return NoChange; + }, + }}; } + +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/passes/expressions.cc b/src/passes/expressions.cc index 2c9d957..152a413 100644 --- a/src/passes/expressions.cc +++ b/src/passes/expressions.cc @@ -1,139 +1,102 @@ #include "../internal.hh" -namespace gitmem -{ - using namespace trieste; - - PassDef expressions() - { - auto Operand = T(Expr) << (T(Reg, Var, Const, Add)); - return { - "expressions", - expressions_wf, - dir::bottomup, - { - --In(Expr) * T(Const, Reg, Var)[Expr] >> - [](Match &_) -> Node - { - return Expr << _(Expr); - }, - - --In(Expr) * T(Spawn)[Spawn] << (T(Brace) * End) >> - [](Match &_) -> Node - { - return Expr << _(Spawn); - }, - - // Additions must have *at least* two operands - --In(Expr) * T(Add)[Add] << (Operand * Operand) >> - [](Match &_) -> Node - { - return Expr << _(Add); - }, - - --In(Expr) * T(Eq, Neq)[Eq] << (Operand * Operand * End) >> - [](Match &_) -> Node - { - return Expr << _(Eq); - }, - - T(Group) << (T(Brace)[Brace] * End) >> - [](Match &_) -> Node - { - return _(Brace); - }, - - T(Group) << (T(Paren)[Paren] * End) >> - [](Match &_) -> Node - { - return _(Paren); - }, - - T(Group) << (T(Expr)[Expr] * End) >> - [](Match &_) -> Node - { - return _(Expr); - }, - - T(Paren) << (T(Expr)[Expr] * End) >> - [](Match &_) -> Node - { - return _(Expr); - }, - - // Error rules - In(Group) * T(Expr) * (!T(Brace))[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Unexpected term (did you forget a brace or a semicolon?)"); - }, - - In(Group) * Any * T(Expr)[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Unexpected expression"); - }, - - T(Spawn)[Spawn] << End >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Spawn)) - << (ErrorMsg ^ "Expected body of spawn"); - }, - - --In(Expr) * T(Spawn) << Any[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid body of spawn"); - }, - - --In(Expr) * T(Add)[Add] << ((T(Group) << End) / (Any * (T(Group) << End))) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Add)) - << (ErrorMsg ^ "Expected operand"); - }, - - --In(Expr) * T(Add)[Add] << (Any) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Add)) - << (ErrorMsg ^ "Invalid operands for addition"); - }, - - - --In(Expr) * T(Eq, Neq)[Eq] << (Any * (T(Group) << End)) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Eq)) - << (ErrorMsg ^ "Expected right-hand side of equality"); - }, - - --In(Expr) * T(Eq, Neq)[Eq] << Any >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Eq)) - << (ErrorMsg ^ "Bad equality"); - }, - - Any * T(Paren)[Paren] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Paren)) - << (ErrorMsg ^ "Unexpected parenthesis"); - }, - - T(Paren) * Any[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Unexpected term (did you forget a brace or semicolon?)"); - }, - - }}; - } - +namespace gitmem { + +namespace lang { + +using namespace trieste; + +PassDef expressions() { + auto Operand = T(Expr) << (T(Reg, Var, Const, Add)); + return {"expressions", + expressions_wf, + dir::bottomup, + { + --In(Expr) * T(Const, Reg, Var)[Expr] >> + [](Match &_) -> Node { return Expr << _(Expr); }, + + --In(Expr) * T(Spawn)[Spawn] << (T(Brace) * End) >> + [](Match &_) -> Node { return Expr << _(Spawn); }, + + // Additions must have *at least* two operands + --In(Expr) * T(Add)[Add] << (Operand * Operand) >> + [](Match &_) -> Node { return Expr << _(Add); }, + + --In(Expr) * T(Eq, Neq)[Eq] << (Operand * Operand * End) >> + [](Match &_) -> Node { return Expr << _(Eq); }, + + T(Group) << (T(Brace)[Brace] * End) >> + [](Match &_) -> Node { return _(Brace); }, + + T(Group) << (T(Paren)[Paren] * End) >> + [](Match &_) -> Node { return _(Paren); }, + + T(Group) << (T(Expr)[Expr] * End) >> + [](Match &_) -> Node { return _(Expr); }, + + T(Paren) << (T(Expr)[Expr] * End) >> + [](Match &_) -> Node { return _(Expr); }, + + // Error rules + In(Group) * T(Expr) * (!T(Brace))[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Unexpected term (did you forget a " + "brace or a semicolon?)"); + }, + + In(Group) * Any *T(Expr)[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Unexpected expression"); + }, + + T(Spawn)[Spawn] << End >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Spawn)) + << (ErrorMsg ^ "Expected body of spawn"); + }, + + --In(Expr) * T(Spawn) << Any[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Invalid body of spawn"); + }, + + --In(Expr) * T(Add)[Add] + << ((T(Group) << End) / (Any * (T(Group) << End))) >> + [](Match &_) -> Node { + return Error << (ErrorAst << _(Add)) + << (ErrorMsg ^ "Expected operand"); + }, + + --In(Expr) * T(Add)[Add] << (Any) >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Add)) + << (ErrorMsg ^ "Invalid operands for addition"); + }, + + --In(Expr) * T(Eq, Neq)[Eq] << (Any * (T(Group) << End)) >> + [](Match &_) -> Node { + return Error + << (ErrorAst << _(Eq)) + << (ErrorMsg ^ "Expected right-hand side of equality"); + }, + + --In(Expr) * T(Eq, Neq)[Eq] << Any >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Eq)) + << (ErrorMsg ^ "Bad equality"); + }, + + Any *T(Paren)[Paren] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Paren)) + << (ErrorMsg ^ "Unexpected parenthesis"); + }, + + T(Paren) * Any[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Unexpected term (did you forget a " + "brace or semicolon?)"); + }, + + }}; } + +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/passes/statements.cc b/src/passes/statements.cc index 8b52cdb..91e4a2e 100644 --- a/src/passes/statements.cc +++ b/src/passes/statements.cc @@ -1,203 +1,158 @@ #include "../internal.hh" -namespace gitmem -{ - using namespace trieste; - - PassDef statements() - { - auto RVal = T(Expr) << (T(Reg, Var, Add, Const, Spawn)); - auto Condition = T(Expr) << (T(Eq, Neq)); - return { - "statements", - statements_wf, - dir::bottomup, - { - // Make Semi into Block - In(File) * T(Semi)[Semi] >> - [](Match &_) -> Node - { - return Block << *_(Semi); - }, - - T(Brace) << T(Semi)[Semi] >> - [](Match &_) -> Node - { - return Block << *_(Semi); - }, - - // Statements - --In(Stmt) * T(Nop)[Nop] >> - [](Match &_) -> Node - { - return Stmt << _(Nop); - }, - - --In(Stmt) * T(Join)[Join] << (RVal * End) >> - [](Match &_) -> Node - { - return Stmt << _(Join); - }, - - --In(Stmt) * T(Lock) << ((T(Expr) << T(Var)[Var]) * End) >> - [](Match &_) -> Node - { - return Stmt << (Lock << _(Var)); - }, - - --In(Stmt) * T(Unlock) << ((T(Expr) << T(Var)[Var]) * End) >> - [](Match &_) -> Node - { - return Stmt << (Unlock << _(Var)); - }, - - --In(Stmt) * T(Assign) << ((T(Expr) << (T(Reg, Var)[LVal] * End)) * RVal[Expr] * End) >> - [](Match &_) -> Node - { - return Stmt << (Assign << _(LVal) - << _(Expr)); - }, - - --In(Stmt) * T(Assert) << (Condition[Expr] * End) >> - [](Match &_) -> Node - { - return Stmt << (Assert << _(Expr)); - }, - - --In(Stmt) * (T(Group) << (T(If) << (T(Group) << (Condition[Expr] * T(Block)[Then])) * End)) - * (T(Group) << ((T(Else) << T(Block)[Else]) * End)) >> - [](Match &_) -> Node - { - return Stmt << (If << _(Expr) << _(Then) << _(Else)); - }, - - --In(Stmt) * (T(Group) << (T(If) << (T(Group) << (T(Expr)[Expr] * T(Block)[Then])) * End)) >> - [](Match &_) -> Node - { - return Stmt << (If << _(Expr) - << _(Then) - << (Block << ((Stmt ^ "nop") << Nop))); - }, - - T(Group) << (T(Stmt)[Stmt] * End) >> - [](Match &_) -> Node - { - return _(Stmt); - }, - - // Error rules - In(Group) * T(Stmt) * Any[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Unexpected term"); - }, - - T(Brace, File)[Brace] << End >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Brace)) - << (ErrorMsg ^ "Expected statement"); - }, - - T(Paren)[Paren] << End >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Paren)) - << (ErrorMsg ^ "Expected expression"); - }, - - --In(Spawn) * T(Brace)[Brace] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Brace)) - << (ErrorMsg ^ "Unexpected block"); - }, - - --In(Stmt) * T(Join) << End >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Join)) - << (ErrorMsg ^ "Expected thread identifier"); - }, - - --In(Stmt) * T(Join) << Any[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid thread identifier"); - }, - - --In(Stmt) * T(Lock, Unlock) << End >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Lock)) - << (ErrorMsg ^ "Expected lock identifier"); - }, - - --In(Stmt) * T(Lock, Unlock) << Any[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid lock identifier"); - }, - - --In(Stmt) * T(Assign) << (Any * End) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Assign)) - << (ErrorMsg ^ "Expected right-hand side to assignment"); - }, - - --In(Stmt) * T(Assign) << ((T(Expr) << T(Reg, Var)) * Any[Expr]) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid right-hand side to assignment"); - }, - - --In(Stmt) * T(Assign) << Any[LVal] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(LVal)) - << (ErrorMsg ^ "Invalid left-hand side to assignment"); - }, - - --In(Stmt) * T(Assert)[Assert] << (T(Group) << End) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Assert)) - << (ErrorMsg ^ "Expected condition"); - }, - - --In(Stmt) * T(Assert) << (Any[Expr] * End) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid assertion"); - }, - - In(If) * (Start * T(Block)[Expr]) / (T(Group) << (!Condition)[Expr]) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid condition"); - }, - - In(File, Brace) * T(Stmt)[Stmt] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Stmt)) - << (ErrorMsg ^ "Expected semicolon"); - }, - - In(Brace, File, Semi) * (!T(Stmt, Semi, Block))[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Expected statement"); - }, - }}; - } - +namespace gitmem { + +namespace lang { + +using namespace trieste; + +PassDef statements() { + auto RVal = T(Expr) << (T(Reg, Var, Add, Const, Spawn)); + auto Condition = T(Expr) << (T(Eq, Neq)); + return { + "statements", + statements_wf, + dir::bottomup, + { + // Make Semi into Block + In(File) * T(Semi)[Semi] >> + [](Match &_) -> Node { return Block << *_(Semi); }, + + T(Brace) << T(Semi)[Semi] >> + [](Match &_) -> Node { return Block << *_(Semi); }, + + // Statements + --In(Stmt) * T(Nop)[Nop] >> + [](Match &_) -> Node { return Stmt << _(Nop); }, + + --In(Stmt) * T(Join)[Join] << (RVal * End) >> + [](Match &_) -> Node { return Stmt << _(Join); }, + + --In(Stmt) * T(Lock) << ((T(Expr) << T(Var)[Var]) * End) >> + [](Match &_) -> Node { return Stmt << (Lock << _(Var)); }, + + --In(Stmt) * T(Unlock) << ((T(Expr) << T(Var)[Var]) * End) >> + [](Match &_) -> Node { return Stmt << (Unlock << _(Var)); }, + + --In(Stmt) * T(Assign) << ((T(Expr) << (T(Reg, Var)[LVal] * End)) * + RVal[Expr] * End) >> + [](Match &_) -> Node { + return Stmt << (Assign << _(LVal) << _(Expr)); + }, + + --In(Stmt) * T(Assert) << (Condition[Expr] * End) >> + [](Match &_) -> Node { return Stmt << (Assert << _(Expr)); }, + + --In(Stmt) * + (T(Group) << (T(If) << (T(Group) << (Condition[Expr] * + T(Block)[Then])) * + End)) * + (T(Group) << ((T(Else) << T(Block)[Else]) * End)) >> + [](Match &_) -> Node { + return Stmt << (If << _(Expr) << _(Then) << _(Else)); + }, + + --In(Stmt) * (T(Group) << (T(If) << (T(Group) << (T(Expr)[Expr] * + T(Block)[Then])) * + End)) >> + [](Match &_) -> Node { + return Stmt << (If << _(Expr) << _(Then) + << (Block << ((Stmt ^ "nop") << Nop))); + }, + + T(Group) << (T(Stmt)[Stmt] * End) >> + [](Match &_) -> Node { return _(Stmt); }, + + // Error rules + In(Group) * T(Stmt) * Any[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Unexpected term"); + }, + + T(Brace, File)[Brace] << End >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Brace)) + << (ErrorMsg ^ "Expected statement"); + }, + + T(Paren)[Paren] << End >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Paren)) + << (ErrorMsg ^ "Expected expression"); + }, + + --In(Spawn) * T(Brace)[Brace] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Brace)) + << (ErrorMsg ^ "Unexpected block"); + }, + + --In(Stmt) * T(Join) << End >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Join)) + << (ErrorMsg ^ "Expected thread identifier"); + }, + + --In(Stmt) * T(Join) << Any[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Invalid thread identifier"); + }, + + --In(Stmt) * T(Lock, Unlock) << End >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Lock)) + << (ErrorMsg ^ "Expected lock identifier"); + }, + + --In(Stmt) * T(Lock, Unlock) << Any[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Invalid lock identifier"); + }, + + --In(Stmt) * T(Assign) << (Any * End) >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Assign)) + << (ErrorMsg ^ + "Expected right-hand side to assignment"); + }, + + --In(Stmt) * T(Assign) << ((T(Expr) << T(Reg, Var)) * Any[Expr]) >> + [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ + "Invalid right-hand side to assignment"); + }, + + --In(Stmt) * T(Assign) << Any[LVal] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(LVal)) + << (ErrorMsg ^ "Invalid left-hand side to assignment"); + }, + + --In(Stmt) * T(Assert)[Assert] << (T(Group) << End) >> + [](Match &_) -> Node { + return Error << (ErrorAst << _(Assert)) + << (ErrorMsg ^ "Expected condition"); + }, + + --In(Stmt) * T(Assert) << (Any[Expr] * End) >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Invalid assertion"); + }, + + In(If) * (Start * T(Block)[Expr]) / + (T(Group) << (!Condition)[Expr]) >> + [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Invalid condition"); + }, + + In(File, Brace) * T(Stmt)[Stmt] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Stmt)) + << (ErrorMsg ^ "Expected semicolon"); + }, + + In(Brace, File, Semi) * (!T(Stmt, Semi, Block))[Expr] >> + [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Expected statement"); + }, + }}; } + +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/progress_status.hh b/src/progress_status.hh new file mode 100644 index 0000000..4da6cba --- /dev/null +++ b/src/progress_status.hh @@ -0,0 +1,19 @@ +#pragma once + +namespace gitmem { + +enum class ProgressStatus { progress, no_progress }; +inline bool operator!(ProgressStatus p) { + return p == ProgressStatus::no_progress; +} +inline ProgressStatus operator||(const ProgressStatus &p1, + const ProgressStatus &p2) { + return (p1 == ProgressStatus::progress || p2 == ProgressStatus::progress) + ? ProgressStatus::progress + : ProgressStatus::no_progress; +} +inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { + p1 = (p1 || p2); +} + +} \ No newline at end of file diff --git a/src/read_result.hh b/src/read_result.hh new file mode 100644 index 0000000..cb6a9b6 --- /dev/null +++ b/src/read_result.hh @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include "conflict.hh" + +namespace gitmem { + +struct Event; + +struct ValueWithSource { + size_t value; + std::shared_ptr source_event; + + auto operator<=>(const ValueWithSource&) const = default; +}; + +using ReadResult = std::variant>; + +} \ No newline at end of file diff --git a/src/reader.cc b/src/reader.cc index e3648be..37bf9fe 100644 --- a/src/reader.cc +++ b/src/reader.cc @@ -2,20 +2,23 @@ namespace gitmem { +namespace lang { + using namespace trieste; -Reader reader() - { - return { +Reader reader() { + return { "gitmem", { - expressions(), - statements(), - check_refs(), - branching(), + expressions(), + statements(), + check_refs(), + branching(), }, - gitmem::parser(), - }; - } - + gitmem::lang::parser(), + }; } + +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh new file mode 100644 index 0000000..d0cc8a8 --- /dev/null +++ b/src/sync_protocol.hh @@ -0,0 +1,75 @@ +#pragma once + +#include "conflict.hh" +#include "sync_state.hh" +#include "execution_state.hh" +#include "read_result.hh" +#include +#include + +namespace gitmem { + +struct Event; // Forward declaration + +// Forward declaration for builder +class SyncProtocolBuilder; + +// Types of synchronization operations that may be scheduling points +enum class SyncOperation { + Spawn, + Join, + Start, + End, + Lock, + Unlock +}; + +class SyncProtocol { +public: + virtual ~SyncProtocol() = default; + + // Create a fresh copy of this protocol with reset state + virtual std::unique_ptr clone() const = 0; + + virtual std::unique_ptr make_thread_state(ThreadID tid) const = 0; + virtual std::unique_ptr make_lock_state() const = 0; + + // Read a shared variable into the thread context + virtual ReadResult read(ThreadContext &ctx, const std::string &var) = 0; + + // Write a shared variable (staged, not committed) + virtual void write(ThreadContext &ctx, const std::string &var, + ValueWithSource value) = 0; + + virtual std::optional> + on_spawn(ThreadContext &parent, ThreadContext &child) = 0; + + virtual std::optional> + on_join(ThreadContext &joiner, ThreadContext &joinee) = 0; + + virtual std::optional> + on_start(ThreadContext &thread) = 0; + + virtual std::optional> + on_end(ThreadContext &thread) = 0; + + virtual std::optional> + on_lock(ThreadContext &thread, Lock &lock) = 0; + + virtual std::optional> + on_unlock(ThreadContext &thread, Lock &lock) = 0; + + // Returns true if the given sync operation is a scheduling point for this protocol + // (i.e., the scheduler should consider switching threads here) + virtual bool is_scheduling_point(SyncOperation op) const = 0; + + virtual std::string build_revision_graph_dot(const std::vector& thread_states) const = 0; + + virtual std::ostream &print(std::ostream &os) const = 0; + friend std::ostream &operator<<(std::ostream &os, + const SyncProtocol &protocol) { + return protocol.print(os); + } +}; + +} // namespace gitmem diff --git a/src/sync_state.hh b/src/sync_state.hh new file mode 100644 index 0000000..2d7d680 --- /dev/null +++ b/src/sync_state.hh @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace gitmem { + +class ThreadSyncState { +public: + virtual ~ThreadSyncState() = default; + + virtual bool operator==(const ThreadSyncState& other) const = 0; + + virtual std::ostream &print(std::ostream &os) const = 0; + friend std::ostream &operator<<(std::ostream &os, + const ThreadSyncState &state) { + return state.print(os); + } +}; + +class LockSyncState { +public: + virtual ~LockSyncState() = default; + + virtual std::ostream &print(std::ostream &os) const = 0; + friend std::ostream &operator<<(std::ostream &os, + const LockSyncState &state) { + return state.print(os); + } +}; + +} \ No newline at end of file diff --git a/src/termination_status.hh b/src/termination_status.hh new file mode 100644 index 0000000..a86323c --- /dev/null +++ b/src/termination_status.hh @@ -0,0 +1,91 @@ +#pragma once + +#include "thread_id.hh" +#include "conflict.hh" + +#include + +namespace gitmem { + +namespace termination { + +struct Completed { + friend std::ostream& operator<<(std::ostream& os, const Completed&) { + os << "Completed successfully"; + return os; + } +}; + +struct DataRace { + std::shared_ptr conflict; + + explicit DataRace(std::shared_ptr conflict): conflict(conflict) {} + + friend std::ostream& operator<<(std::ostream& os, const DataRace& r) { + assert(r.conflict != nullptr); + os << "Data race occurred: " << *r.conflict; + return os; + } +}; + +struct UnlockError { + std::string lock; + friend std::ostream& operator<<(std::ostream& os, const UnlockError& e) { + os << "Attempted to unlock '" << e.lock << "' without ownership"; + return os; + } +}; + +struct AssertionFailure { + std::string expression; + friend std::ostream& operator<<(std::ostream& os, const AssertionFailure& a) { + os << "Assertion failed: " << a.expression; + return os; + } +}; + +struct UnassignedRead { + std::string variable; + friend std::ostream& operator<<(std::ostream& os, const UnassignedRead& u) { + os << "Read of unassigned variable '" << u.variable; + return os; + } +}; + +inline bool operator==(const Completed&, const Completed&) { return true; } + +inline bool operator==(const DataRace& a, const DataRace& b) { + return *a.conflict == *b.conflict; +} + +inline bool operator==(const UnlockError& a, const UnlockError& b) { + return a.lock == b.lock; +} + +inline bool operator==(const AssertionFailure& a, const AssertionFailure& b) { + return a.expression == b.expression; +} + +inline bool operator==(const UnassignedRead& a, const UnassignedRead& b) { + return a.variable == b.variable; +} + +// Optional: != operators for convenience +inline bool operator!=(const Completed& a, const Completed& b) { return !(a == b); } +inline bool operator!=(const DataRace& a, const DataRace& b) { return !(a == b); } +inline bool operator!=(const UnlockError& a, const UnlockError& b) { return !(a == b); } +inline bool operator!=(const AssertionFailure& a, const AssertionFailure& b) { return !(a == b); } +inline bool operator!=(const UnassignedRead& a, const UnassignedRead& b) { return !(a == b); } + +using TerminationStatus = + std::variant< + Completed, + DataRace, + UnlockError, + AssertionFailure, + UnassignedRead + >; + +} // end termination + +} // end gitmem \ No newline at end of file diff --git a/src/thread_id.hh b/src/thread_id.hh new file mode 100644 index 0000000..637ef56 --- /dev/null +++ b/src/thread_id.hh @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace gitmem { + using ThreadID = std::size_t; +} \ No newline at end of file diff --git a/src/thread_trace.hh b/src/thread_trace.hh new file mode 100644 index 0000000..90dba56 --- /dev/null +++ b/src/thread_trace.hh @@ -0,0 +1,198 @@ +#pragma once + +#include "thread_id.hh" +#include "conflict.hh" +#include "overloaded.hh" +#include "read_result.hh" + +namespace gitmem { + +struct Event; + +struct StartEvent {}; +struct SpawnEvent { const ThreadID child_tid; }; +struct ReadValue { const size_t value; const std::shared_ptr source_event; }; +struct ReadEvent { const std::string var; std::variant> value_or_conflict; }; +struct WriteEvent { const std::string var; const size_t value; }; +struct LockEvent { std::string lock_name; std::shared_ptr maybe_conflict; std::shared_ptr last_unlock_event; }; +struct UnlockEvent { const std::string lock_name; std::shared_ptr maybe_conflict; }; +struct JoinEvent { const ThreadID joinee_tid; std::shared_ptr maybe_conflict; }; +struct AssertEvent { const std::string condition; bool pass; }; + +struct EndEvent {}; + +using EventID = size_t; + +struct Event { + ThreadID tid; + EventID eid; + std::variant< + StartEvent, + SpawnEvent, + ReadEvent, + WriteEvent, + LockEvent, + UnlockEvent, + JoinEvent, + AssertEvent, + EndEvent + > data; +}; + +inline std::string event_header(const Event& e) { + std::ostringstream oss; + oss << "[tid=" << e.tid << ", eid=" << e.eid << "]"; + return oss.str(); +} + +// --- operator<< overloads for individual event types --- +inline std::ostream& operator<<(std::ostream& os, const StartEvent&) { + return os << "StartEvent"; +} + +inline std::ostream& operator<<(std::ostream& os, const SpawnEvent& e) { + return os << "SpawnEvent(child_tid=" << e.child_tid << ")"; +} + +inline std::ostream& operator<<(std::ostream& os, const ReadEvent& e) { + os << "ReadEvent(var=\"" << e.var << "\", "; + std::visit(overloaded{ + [&os](const ReadValue& val) { os << "value=" << val.value << " (from " << val.source_event->eid << ")"; }, + [&os](const std::shared_ptr&) { os << "conflict"; } + }, e.value_or_conflict); + os << ")"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const WriteEvent& e) { + return os << "WriteEvent(var=\"" << e.var << "\", value=" << e.value << ")"; +} + +inline std::ostream& operator<<(std::ostream& os, const LockEvent& e) { + os << "LockEvent(lock_name=\"" << e.lock_name << "\""; + if (e.last_unlock_event) + os << ", last unlock " << event_header(*e.last_unlock_event); + if (e.maybe_conflict) + os << ", conflict)"; + else + os << ")"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const UnlockEvent& e) { + os << "UnlockEvent(lock_name=\"" << e.lock_name << "\""; + if (e.maybe_conflict) + os << ", conflict)"; + else + os << ")"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const JoinEvent& e) { + os << "JoinEvent(joinee_tid=" << e.joinee_tid; + if (e.maybe_conflict) + os << ", conflict)"; + else + os << ")"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const AssertEvent& e) { + return os << "AssertEvent(condition=\"" << e.condition << "\", " << (e.pass ? "pass" : "fail") << ")"; +} + +inline std::ostream& operator<<(std::ostream& os, const EndEvent&) { + return os << "EndEvent"; +} + +// --- operator<< for the wrapper Event --- +inline std::ostream& operator<<(std::ostream& os, const Event& e) { + os << event_header(e) << " "; + std::visit([&os](auto&& arg) { os << arg; }, e.data); + return os; +} + +static EventID next_eid = 0; + +struct ThreadTrace { + std::vector> trace; + ThreadID tid; + + auto begin() { return trace.begin(); } + auto end() { return trace.end(); } + + auto begin() const { return trace.begin(); } + auto end() const { return trace.end(); } + + explicit ThreadTrace(ThreadID tid) : tid(tid) {} + +private: + template + std::shared_ptr append(Args&&... args) { + auto event = std::make_shared(tid, next_eid++, T(std::forward(args)...)); + trace.push_back(event); + return event; + } + + public: + std::shared_ptr on_start() { + return append(); + } + + std::shared_ptr on_spawn(ThreadID child_tid) { + return append(child_tid); + } + + std::shared_ptr on_read(const std::string text, ValueWithSource value) { + return append(std::move(text), ReadValue{value.value, value.source_event}); + } + + std::shared_ptr on_read(const std::string text, std::shared_ptr conflict) { + return append(std::move(text), conflict); + } + + std::shared_ptr on_write(const std::string text, const size_t value) { + return append(std::move(text), value); + } + + std::shared_ptr on_lock(const std::string lock_name, + std::shared_ptr last_unlock_event, + std::shared_ptr conflict = nullptr) { + return append(std::move(lock_name), std::move(conflict), last_unlock_event); + } + + std::shared_ptr on_unlock(const std::string lock_name, std::shared_ptr conflict = nullptr) { + return append(std::move(lock_name), conflict); + } + + std::shared_ptr on_join(ThreadID tid, std::shared_ptr conflict = nullptr) { + return append(tid, conflict); + } + + std::shared_ptr on_assert(std::string expr, bool pass) { + return append(std::move(expr), pass); + } + + std::shared_ptr on_assert_pass(std::string expr) { + return append(std::move(expr), true); + } + + std::shared_ptr on_assert_fail(std::string expr) { + return append(std::move(expr), false); + } + + std::shared_ptr on_end() { + return append(); + } +}; + +// --- operator<< for ThreadTrace --- +inline std::ostream& operator<<(std::ostream& os, const ThreadTrace& tt) { + os << "ThreadTrace[" << tt.trace.size() << " events]:\n"; + for (size_t i = 0; i < tt.trace.size(); ++i) { + os << " " << i << ": " << *(tt.trace[i]) << "\n"; + } + return os; +} + +} // namespace gitmem \ No newline at end of file diff --git a/test_gitmem.py b/test_gitmem.py index 77fb36f..b03bce0 100644 --- a/test_gitmem.py +++ b/test_gitmem.py @@ -2,58 +2,181 @@ import subprocess import sys import argparse +from collections import defaultdict EXAMPLES_DIR = "examples" -def run_gitmem_test(gitmem_path, file_path, should_pass): - try: - result = subprocess.run([gitmem_path, file_path, "-e", "-o", "/dev/null"], capture_output=True, text=True) - passed = (result.returncode == 0) - except FileNotFoundError: - print(f"Error: '{gitmem_path}' executable not found.") - sys.exit(1) +SYNC_KINDS = { + "linear": {"sync": "linear"}, + "branching-eager": {"sync": "branching", "branching_mode": "eager"}, + "branching-lazy": {"sync": "branching", "branching_mode": "lazy"}, +} - if passed == should_pass: - status = "PASS" +def supports_color(): + return sys.stdout.isatty() and os.getenv("NO_COLOR") is None + +def color(text, code): + if not supports_color(): + return text + return f"\033[{code}m{text}\033[0m" + +def green(text): + return color(text, "32") + +def red(text): + return color(text, "31") + +def run_gitmem_test(gitmem_path, file_path, should_accept, sync_kind): + sync_config = SYNC_KINDS[sync_kind] + + cmd = [ + gitmem_path, + file_path, + "--sync", sync_config["sync"], + ] + + if "branching_mode" in sync_config: + cmd.extend(["--branching-mode", sync_config["branching_mode"]]) + + cmd.extend([ + "-e", + "-o", "/dev/null" + ]) + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True + ) + if should_accept: + accepted = (result.returncode == 0) else: - status = "FAIL" + accepted = (result.returncode == 1) + except FileNotFoundError: + print(f"Error: '{gitmem_path}' executable not found.") + sys.exit(1) - print(f"[{status}] {file_path} (exit code: {result.returncode})") - return status == "PASS" + status = green("PASS") if accepted else red("FAIL") + print(f"[{status}] {file_path} [{sync_kind}] (exit code: {result.returncode})") + return accepted def main(): - parser = argparse.ArgumentParser(description="Test runner for gitmem.") - parser.add_argument( - "--gitmem", "-g", - required=True, - help="Path to the gitmem executable" - ) - args = parser.parse_args() - gitmem_path = args.gitmem - - total_tests = 0 - failed_tests = 0 - - for outcome in ["passing", "failing"]: - should_pass = (outcome == "passing") - for category in ["syntax", "semantics"]: - test_dir = os.path.join(EXAMPLES_DIR, outcome, category) - if not os.path.isdir(test_dir): - continue - for root, _, files in os.walk(test_dir): - for file in files: - file_path = os.path.join(root, file) - total_tests += 1 - if not run_gitmem_test(gitmem_path, file_path, should_pass): - failed_tests += 1 - - print("\nSummary:") - print(f"Total tests run: {total_tests}") - print(f"Tests failed: {failed_tests}") - print(f"Tests passed: {total_tests - failed_tests}") - - if failed_tests > 0: - sys.exit(1) + parser = argparse.ArgumentParser(description="Test runner for gitmem.") + parser.add_argument( + "--gitmem", "-g", + required=True, + help="Path to the gitmem executable" + ) + parser.add_argument( + "--linear", + action="store_true", + help="Only run linear sync tests" + ) + parser.add_argument( + "--branching-eager", + action="store_true", + help="Only run branching-eager tests" + ) + parser.add_argument( + "--branching-lazy", + action="store_true", + help="Only run branching-lazy tests" + ) + + args = parser.parse_args() + gitmem_path = args.gitmem + + selected_syncs = [] + + if args.linear: + selected_syncs.append("linear") + if args.branching_eager: + selected_syncs.append("branching-eager") + if args.branching_lazy: + selected_syncs.append("branching-lazy") + + # If none specified, run all + if not selected_syncs: + selected_syncs = list(SYNC_KINDS.keys()) + + results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: { + "total": 0, + "failed": 0 + }))) + + total_tests = 0 + failed_tests = 0 + failing_tests = [] + + for expectation in ["accept", "reject"]: + should_accept = (expectation == "accept") + + for category in ["syntax", "semantics"]: + base_dir = os.path.join(EXAMPLES_DIR, expectation, category) + if not os.path.isdir(base_dir): + continue + + for sync_kind in selected_syncs: + # syntax tests are sync-agnostic → only run once + if category == "syntax" and sync_kind != "linear": + continue + + if category == "semantics": + if sync_kind == "linear": + test_dir = os.path.join(base_dir, "linear") + else: + test_dir = os.path.join(base_dir, "branching") + else: + test_dir = base_dir + + if not os.path.isdir(test_dir): + continue + + for root, _, files in os.walk(test_dir): + for file in files: + file_path = os.path.join(root, file) + + total_tests += 1 + results[expectation][category][sync_kind]["total"] += 1 + + passed = run_gitmem_test( + gitmem_path, + file_path, + should_accept, + sync_kind + ) + + if not passed: + failed_tests += 1 + results[expectation][category][sync_kind]["failed"] += 1 + failing_tests.append((file_path, sync_kind)) + + print("\nDetailed Summary:") + for expectation, categories in results.items(): + print(f"\n{expectation.upper()}:") + for category, syncs in categories.items(): + print(f" {category}:") + for sync_kind, stats in syncs.items(): + passed = stats["total"] - stats["failed"] + print( + f" {sync_kind}: " + f"{passed}/{stats['total']} passed " + f"({stats['failed']} failed)" + ) + + print("\nOverall Summary:") + print(f"Total tests run: {total_tests}") + print(f"Tests failed: {failed_tests}") + print(f"Tests passed: {total_tests - failed_tests}") + + if failing_tests: + print("\nFailing tests:") + for path, sync in failing_tests: + print(f" {red(path)} [{sync}]") + + if failed_tests > 0: + sys.exit(1) if __name__ == "__main__": - main() + main()