diff --git a/go.mod b/go.mod index 5940d46..b4632cf 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,10 @@ module github.com/harmony-one/vdf -go 1.12 +go 1.22 + +require github.com/stretchr/testify v1.3.0 require ( - github.com/stretchr/testify v1.3.0 - golang.org/x/tools v0.0.0-20190924052046-3ac2a5bbd98a // indirect + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect ) diff --git a/go.sum b/go.sum index f05ddd3..4347755 100644 --- a/go.sum +++ b/go.sum @@ -5,11 +5,3 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190924052046-3ac2a5bbd98a h1:DJzZ1GRmbjp7ihxzAN6UTVpVMi6k4CXZEr7A3wi2kRA= -golang.org/x/tools v0.0.0-20190924052046-3ac2a5bbd98a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/src/test/fast_vdf_test.go b/src/test/fast_vdf_test.go index d5ddd5a..dc63d3b 100644 --- a/src/test/fast_vdf_test.go +++ b/src/test/fast_vdf_test.go @@ -2,9 +2,9 @@ package main import ( "encoding/hex" + "github.com/harmony-one/vdf/src/vdf_go" "github.com/stretchr/testify/assert" "testing" - "github.com/harmony-one/vdf/src/vdf_go" ) func Test2047(t *testing.T) { @@ -18,6 +18,8 @@ func Test2047(t *testing.T) { } func Test1000(t *testing.T) { + skipIfShort(t, "skipping target=1000 generation in -short mode") + seed := []byte{0xde, 0xad, 0xbe, 0xef} ref_output := "0011c26e62c608dba629ce37953b2c7765f6c3c48f58ae5dc6ebc19206ca3135f8a240538a42f989d990488185f2d10a2504838f4f2e4dd933119088aa0e5b506bfd835d03147b03d5111e6ca135bf297435faf27a8ccbdb0c7598934fdccde6c9afbdc0488662618fc3934eaa9913f97559fb119ff959fc5f35a71da783c64af0000461c617ca4fd50ab15bf8c62963b043e1920b619402aa11a7fb82b793e9fca643bb8b8026e09493e6ed0f69ad7dafef7938f7c78d7067247d43ce2cf73174ffd78d2d4107a0421cf16a7fb118978b4903425bb84dcc4d0102267103494b798247cabc65caff373c368530fb7d869317d86a279eb55facf75a430109b5343875003fc63ce964ad0fc804687fb21b9322d672299cf0eeb53f5f426a4123e44db2ca593b50c026e54c079cd79634cd3969941aba18edae5fb51792776a2ed9076c79a456bb783ad87cdad013ce8a933c0c1a787a0232205dfa34b8ab65c1bd06f4004a3ffd5aec0c9cabedd081228c0b8c59e2bc2487f1fa2344288a8e9d7eefd169003fb7e55b707e9c5d76c84fa510647ec6c392f19f1a4ce98e71a601c1ee2479a93e5b9e4512b7c4cffc18b3498a36334e49db29aa56d487dec7b9dcc3128f722888903f10fd468a62ebb599eaced4114e36df647cc60b16c15f33b2f1c96bfe7c33d274ca57a456448ac7ac5539d10d71f72c0561d0460b9ab8f338898c8b8203" @@ -28,6 +30,8 @@ func Test1000(t *testing.T) { } func Test5000(t *testing.T) { + skipIfShort(t, "skipping target=5000 generation in -short mode") + seed := []byte{0xde, 0xad, 0xbe, 0xef} ref_output := "00624b507eae90e9ac51a8ee427a0195e843642992cc34fa2c42a43ab75ec098a4d85e64b45fa5e48f3241495761a8af35e6e17d1259ea6d2b173c58c925c47a64eedfadccdff41841e20dd8a50c8a8dd87f0a0160e4f68f548f2a8ec6994bc826de9fee7e3fcd9d526808667327bf5f4cefd905117932b1cb737d74853f89206cffe2d0bd37520dcced47fd18a0e894d200c7d583ee7d55d84be5332a9c21bb6726e6ae64f294726a6e88be77ccd10e4139ed3c6309cfa2763b958149ff948233db76fc05a7abdfb08411e3f5a6567ab9247ebdbe310b0bb453340d385bdf6a4d67d454b1cdc2fbd82b12ae474bfe3038fd8fa8552f692a62df48dc208c18ac638d00748ead55bfc86ac0bd40e3038ea2d6e13cea20d3df4abf8923ada9927bacc531f35fcbf74080fdf5bb25f22b4186e941b8585f5ecc533b588dcf46b18e0871781a6cfe3b1f7f8dbc47369334ff7b32294453162a20d6c2d5711558cce36a9db0507e665922a8b99fab12145098f930b1c11afaf6f1e78256b536de1513b64d90ff8ba3b7b91d5989bb1aaf557063938df474e73d0568a51b2a5ad89b058f3adfaea137c9c1925feeb80b68aac4cb4335224efc365c81db004826c69b9659cd0a3e46bd4af5ccfa55bd53f3b7ddbae2a0802ada1076397cb6c1d7c9b6a4239db430c66fd90a24c21602723d0b895708d49cb92e8aeefc9a3760997b5b6760378f0b" @@ -38,10 +42,11 @@ func Test5000(t *testing.T) { } func Test5000Verify(t *testing.T) { + skipIfShort(t, "skipping target=5000 verification in -short mode") + seed := []byte{0xde, 0xad, 0xbe, 0xef} ref_output := "00624b507eae90e9ac51a8ee427a0195e843642992cc34fa2c42a43ab75ec098a4d85e64b45fa5e48f3241495761a8af35e6e17d1259ea6d2b173c58c925c47a64eedfadccdff41841e20dd8a50c8a8dd87f0a0160e4f68f548f2a8ec6994bc826de9fee7e3fcd9d526808667327bf5f4cefd905117932b1cb737d74853f89206cffe2d0bd37520dcced47fd18a0e894d200c7d583ee7d55d84be5332a9c21bb6726e6ae64f294726a6e88be77ccd10e4139ed3c6309cfa2763b958149ff948233db76fc05a7abdfb08411e3f5a6567ab9247ebdbe310b0bb453340d385bdf6a4d67d454b1cdc2fbd82b12ae474bfe3038fd8fa8552f692a62df48dc208c18ac638d00748ead55bfc86ac0bd40e3038ea2d6e13cea20d3df4abf8923ada9927bacc531f35fcbf74080fdf5bb25f22b4186e941b8585f5ecc533b588dcf46b18e0871781a6cfe3b1f7f8dbc47369334ff7b32294453162a20d6c2d5711558cce36a9db0507e665922a8b99fab12145098f930b1c11afaf6f1e78256b536de1513b64d90ff8ba3b7b91d5989bb1aaf557063938df474e73d0568a51b2a5ad89b058f3adfaea137c9c1925feeb80b68aac4cb4335224efc365c81db004826c69b9659cd0a3e46bd4af5ccfa55bd53f3b7ddbae2a0802ada1076397cb6c1d7c9b6a4239db430c66fd90a24c21602723d0b895708d49cb92e8aeefc9a3760997b5b6760378f0b" buf, _ := hex.DecodeString(ref_output) assert.Equal(t, true, vdf_go.VerifyVDF(seed, buf, 5000, 2048), "must be true") } - diff --git a/src/test/proof_external_vectors_test.go b/src/test/proof_external_vectors_test.go new file mode 100644 index 0000000..5f14e2d --- /dev/null +++ b/src/test/proof_external_vectors_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "crypto/sha256" + "encoding/hex" + "testing" + + "github.com/harmony-one/vdf/src/vdf_go" + "github.com/stretchr/testify/assert" +) + +// Source vectors (Wesolowski): +// https://github.com/poanetwork/vdf/blob/master/wesolowski.csv +// (fetched on 2026-04-05, first three rows). +func TestRustReferenceWesolowskiCanaryVectors(t *testing.T) { + type vector struct { + seedHex string + iterations int + outputSHA256Hex string + } + + vectors := []vector{ + { + seedHex: "acc390feb1fbbe70d6a7ad2203a8b8c3c93e52d9782886124606686bc5716fed", + iterations: 66, + outputSHA256Hex: "3d72a0852b85b87e466c1268233ae02c31295497dd88b92b63561141b1b68b29", + }, + { + seedHex: "c86348990a772d094da4b5d4a3014a2d981c4b36fa1df672456d98d69bfb3f9e", + iterations: 68, + outputSHA256Hex: "4ea8c6df1a0a85e12f82320f631f4f276a5a373a0de4f65d7193ae1f1e66d132", + }, + { + seedHex: "7bd5c3dba392a0ef4b02052200748d5065e4d5d2f3bbf11ca5edeaccfc44f7aa", + iterations: 70, + outputSHA256Hex: "ccf96d2bd863cf7fa5ab992be2d6dd2f5f500da069ac07b34c06f6c367e67e75", + }, + } + + for _, tc := range vectors { + tc := tc + t.Run(tc.seedHex, func(t *testing.T) { + seed, err := hex.DecodeString(tc.seedHex) + assert.NoError(t, err) + + yBuf, proofBuf := vdf_go.GenerateVDF(seed, tc.iterations, 2048) + assert.NotNil(t, yBuf) + assert.NotNil(t, proofBuf) + + output := append(yBuf, proofBuf...) + + assert.True(t, vdf_go.VerifyVDF(seed, output, tc.iterations, 2048)) + + digest := sha256.Sum256(output) + assert.Equal(t, tc.outputSHA256Hex, hex.EncodeToString(digest[:])) + }) + } +} diff --git a/src/test/proof_negative_test.go b/src/test/proof_negative_test.go new file mode 100644 index 0000000..f59a1fc --- /dev/null +++ b/src/test/proof_negative_test.go @@ -0,0 +1,70 @@ +package main + +import ( + "testing" + + "github.com/harmony-one/vdf/src/vdf_go" + "github.com/stretchr/testify/assert" +) + +func buildValidProofFixture(t *testing.T) (seed []byte, proofBlob []byte, iterations int, intSizeBits int) { + seed = []byte{ + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + } + iterations = 20 + intSizeBits = 2048 + + yBuf, proofBuf := vdf_go.GenerateVDF(seed, iterations, intSizeBits) + assert.NotNil(t, yBuf) + assert.NotNil(t, proofBuf) + + proofBlob = append(append([]byte{}, yBuf...), proofBuf...) + assert.Equal(t, true, vdf_go.VerifyVDF(seed, proofBlob, iterations, intSizeBits), "fixture proof should be valid") + + return seed, proofBlob, iterations, intSizeBits +} + +func TestVerifyProofTamperedBlob(t *testing.T) { + seed, proofBlob, iterations, intSizeBits := buildValidProofFixture(t) + + tampered := append([]byte{}, proofBlob...) + tampered[len(tampered)-1] ^= 0x01 + + assert.False(t, vdf_go.VerifyVDF(seed, tampered, iterations, intSizeBits), "tampered proof must fail") +} + +func TestVerifyProofWrongTarget(t *testing.T) { + seed, proofBlob, iterations, intSizeBits := buildValidProofFixture(t) + + assert.False(t, vdf_go.VerifyVDF(seed, proofBlob, iterations+1, intSizeBits), "wrong iteration target must fail") +} + +func TestVerifyProofWrongSeed(t *testing.T) { + seed, proofBlob, iterations, intSizeBits := buildValidProofFixture(t) + + wrongSeed := append([]byte{}, seed...) + wrongSeed[0] ^= 0x01 + + assert.False(t, vdf_go.VerifyVDF(wrongSeed, proofBlob, iterations, intSizeBits), "wrong seed must fail") +} + +func TestVerifyProofWrongIntSizeBits(t *testing.T) { + seed, proofBlob, iterations, intSizeBits := buildValidProofFixture(t) + + assert.Panics(t, func() { + vdf_go.VerifyVDF(seed, proofBlob, iterations, intSizeBits-1) + }, "wrong int_size_bits currently panics with current API contract") +} + +func TestVerifyProofShortBlobPanics(t *testing.T) { + seed, _, iterations, intSizeBits := buildValidProofFixture(t) + + shortBlob := []byte{0x01, 0x02, 0x03} + + assert.Panics(t, func() { + vdf_go.VerifyVDF(seed, shortBlob, iterations, intSizeBits) + }, "too short proof blob should panic with current API contract") +} diff --git a/src/test/proof_wesolowski_test.go b/src/test/proof_wesolowski_test.go index c8062e6..b2bcc95 100644 --- a/src/test/proof_wesolowski_test.go +++ b/src/test/proof_wesolowski_test.go @@ -56,6 +56,8 @@ func TestGenerateAndVerifyIntSize(t *testing.T) { } func TestGenerateAndVerifyProof(t *testing.T) { + skipIfShort(t, "skipping exhaustive proof generation sweep in -short mode") + seed := []byte{0xde, 0xad, 0xbe, 0xef} for T := 5; T < 100; T++ { @@ -65,6 +67,8 @@ func TestGenerateAndVerifyProof(t *testing.T) { } func TestGenerateAndVerifyProof100(t *testing.T) { + skipIfShort(t, "skipping target 100..199 sweep in -short mode") + seed := []byte{0xde, 0xad, 0xbe, 0xef} for T := 101; T < 200; T++ { @@ -74,6 +78,8 @@ func TestGenerateAndVerifyProof100(t *testing.T) { } func TestGenerateAndVerifyProof200(t *testing.T) { + skipIfShort(t, "skipping target 200..299 sweep in -short mode") + seed := []byte{0xde, 0xad, 0xbe, 0xef} for T := 201; T < 300; T++ { @@ -83,6 +89,8 @@ func TestGenerateAndVerifyProof200(t *testing.T) { } func TestGenerateAndVerifyProof300(t *testing.T) { + skipIfShort(t, "skipping target 300..399 sweep in -short mode") + seed := []byte{0xde, 0xad, 0xbe, 0xef} for T := 301; T < 400; T++ { @@ -92,6 +100,8 @@ func TestGenerateAndVerifyProof300(t *testing.T) { } func TestGenerateAndVerifyProof400(t *testing.T) { + skipIfShort(t, "skipping target 400..499 sweep in -short mode") + seed := []byte{0xde, 0xad, 0xbe, 0xef} for T := 401; T < 500; T++ { @@ -101,6 +111,8 @@ func TestGenerateAndVerifyProof400(t *testing.T) { } func TestGenerateAndVerifyProof1000(t *testing.T) { + skipIfShort(t, "skipping target 1000 sweep in -short mode") + seed := []byte{0xde, 0xad, 0xbe, 0xef} for T := 1001; T < 1010; T++ { @@ -110,6 +122,7 @@ func TestGenerateAndVerifyProof1000(t *testing.T) { } func TestRandomInput(t *testing.T) { + skipIfShort(t, "skipping randomized long-running vectors in -short mode") for i := 0; i < 5; i++ { seed := make([]byte, 32) @@ -123,6 +136,8 @@ func TestRandomInput(t *testing.T) { } func TestInterruptibleGenerator(t *testing.T) { + skipIfShort(t, "skipping interruptible generator timing test in -short mode") + seed := []byte{0xde, 0xad, 0xbe, 0xef} stop := make(chan struct{}) diff --git a/src/test/slow_vdf_test.go b/src/test/slow_vdf_test.go index d0712f2..50bc032 100644 --- a/src/test/slow_vdf_test.go +++ b/src/test/slow_vdf_test.go @@ -5,16 +5,17 @@ import ( "encoding/csv" "encoding/hex" "fmt" + "github.com/harmony-one/vdf/src/vdf_go" "github.com/stretchr/testify/assert" "io" "log" "os" "strconv" "testing" - "github.com/harmony-one/vdf/src/vdf_go" ) func TestCreateProofCSV(t *testing.T) { + skipIfShort(t, "skipping CSV proof generation in -short mode") csvFile, _ := os.Open("wesolowski.csv") reader := csv.NewReader(bufio.NewReader(csvFile)) @@ -39,6 +40,7 @@ func TestCreateProofCSV(t *testing.T) { } func TestVerifyProofCSV(t *testing.T) { + skipIfShort(t, "skipping CSV proof verification in -short mode") csvFile, _ := os.Open("wesolowski.csv") reader := csv.NewReader(bufio.NewReader(csvFile)) diff --git a/src/test/square_test.go b/src/test/square_test.go index 51a29bd..470bc8b 100644 --- a/src/test/square_test.go +++ b/src/test/square_test.go @@ -36,6 +36,8 @@ func RepeatedSquareSlow(x *vdf_go.ClassGroup, k int) *vdf_go.ClassGroup { } func TestTwoSquarePerformance(t *testing.T) { + skipIfShort(t, "skipping square performance comparison in -short mode") + for k := 0; k < 10; k++ { seed := make([]byte, 32) rand.Read(seed) diff --git a/src/test/test_helpers_test.go b/src/test/test_helpers_test.go new file mode 100644 index 0000000..3599dba --- /dev/null +++ b/src/test/test_helpers_test.go @@ -0,0 +1,10 @@ +package main + +import "testing" + +func skipIfShort(t *testing.T, reason string) { + t.Helper() + if testing.Short() { + t.Skip(reason) + } +} diff --git a/src/test/vdf_benchmark_test.go b/src/test/vdf_benchmark_test.go new file mode 100644 index 0000000..13ee37a --- /dev/null +++ b/src/test/vdf_benchmark_test.go @@ -0,0 +1,78 @@ +package main + +import ( + "fmt" + "io" + "log" + "testing" + + "github.com/harmony-one/vdf/src/vdf_go" +) + +const benchmarkIntSizeBits = 2048 + +var benchmarkTargets = []int{500, 1000, 10000} + +var benchmarkSeed = []byte{ + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, +} + +func silenceVDFLogs(b *testing.B) { + b.Helper() + + originalWriter := log.Writer() + log.SetOutput(io.Discard) + b.Cleanup(func() { + log.SetOutput(originalWriter) + }) +} + +func BenchmarkGenerateVDFByTarget(b *testing.B) { + silenceVDFLogs(b) + + for _, target := range benchmarkTargets { + target := target + b.Run(fmt.Sprintf("target_%d", target), func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + yBuf, proofBuf := vdf_go.GenerateVDF(benchmarkSeed, target, benchmarkIntSizeBits) + if yBuf == nil || proofBuf == nil { + b.Fatalf("GenerateVDF returned nil output for target=%d", target) + } + } + }) + } +} + +func BenchmarkVerifyVDFByTarget(b *testing.B) { + silenceVDFLogs(b) + + for _, target := range benchmarkTargets { + target := target + b.Run(fmt.Sprintf("target_%d", target), func(b *testing.B) { + yBuf, proofBuf := vdf_go.GenerateVDF(benchmarkSeed, target, benchmarkIntSizeBits) + if yBuf == nil || proofBuf == nil { + b.Fatalf("GenerateVDF returned nil output for target=%d", target) + } + + proofBlob := append(yBuf, proofBuf...) + if !vdf_go.VerifyVDF(benchmarkSeed, proofBlob, target, benchmarkIntSizeBits) { + b.Fatalf("precomputed proof is invalid for target=%d", target) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if !vdf_go.VerifyVDF(benchmarkSeed, proofBlob, target, benchmarkIntSizeBits) { + b.Fatalf("VerifyVDF failed for target=%d", target) + } + } + }) + } +} diff --git a/src/test/vdf_module_test.go b/src/test/vdf_module_test.go index f9d50fc..249bf6a 100644 --- a/src/test/vdf_module_test.go +++ b/src/test/vdf_module_test.go @@ -4,16 +4,18 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "github.com/harmony-one/vdf/src/vdf_go" "github.com/stretchr/testify/assert" "log" "testing" "time" - "github.com/harmony-one/vdf/src/vdf_go" ) func TestGenerateVDFAndVerify(t *testing.T) { - input := [32]byte{0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, - 0xef,0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, 0xef,} + skipIfShort(t, "skipping end-to-end VDF generation in -short mode") + + input := [32]byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, + 0xef, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef} vdf := vdf_go.New(100, input) outputChannel := vdf.GetOutputChannel() @@ -23,7 +25,6 @@ func TestGenerateVDFAndVerify(t *testing.T) { duration := time.Now().Sub(start) - output := <-outputChannel log.Println(fmt.Sprintf("VDF computation finished, result is %s", hex.EncodeToString(output[:]))) @@ -32,9 +33,9 @@ func TestGenerateVDFAndVerify(t *testing.T) { } func TestVerifyVDF(t *testing.T) { - input := [32]byte{0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, - 0xef,0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, 0xef,0xde, 0xad, 0xbe, 0xef,} - inputVDF,_ := hex.DecodeString("0028f5de49d93dff7e2080a9bdadff1d63a2a4a143e6acedb814b78b49154ba6eb77d96d8c4ebefb2ae3f4b51af64219067c26693384eedffeca103767c2a4f4f0dd753a1e778aa372463f80a3fe01b2ca85a3be1707a8b82eeccffd0bc183a7f4c3c8854d3f46ec19bc797835e497b49db57b8a0fb0b87c3f3cfb3a631d12ee40ffe1bc410a72dd4804613e0bf6bf5968b75cbdc76ab45dae141b53645b9bfd5ffd667787b4941d1e1f306929844ced0fe90bf5e62632cb32e24f0f7dd276348dd3f769391da74456473513efd85b340f28504844b470187fdb5eccb9bf9e98897f1fba85f49f6fdbecaf6e18e12c34e4e525667f47de506cd5921ce818e026a06b000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001") + input := [32]byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, + 0xef, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef} + inputVDF, _ := hex.DecodeString("0028f5de49d93dff7e2080a9bdadff1d63a2a4a143e6acedb814b78b49154ba6eb77d96d8c4ebefb2ae3f4b51af64219067c26693384eedffeca103767c2a4f4f0dd753a1e778aa372463f80a3fe01b2ca85a3be1707a8b82eeccffd0bc183a7f4c3c8854d3f46ec19bc797835e497b49db57b8a0fb0b87c3f3cfb3a631d12ee40ffe1bc410a72dd4804613e0bf6bf5968b75cbdc76ab45dae141b53645b9bfd5ffd667787b4941d1e1f306929844ced0fe90bf5e62632cb32e24f0f7dd276348dd3f769391da74456473513efd85b340f28504844b470187fdb5eccb9bf9e98897f1fba85f49f6fdbecaf6e18e12c34e4e525667f47de506cd5921ce818e026a06b000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001") var vdfBytes [516]byte copy(vdfBytes[:], inputVDF) @@ -51,8 +52,8 @@ func TestVerifyVDF(t *testing.T) { assert.Equal(t, true, result, "failed verifying vdf proof") } - func TestVDFModuleRandomSeed(t *testing.T) { + skipIfShort(t, "skipping random-seed end-to-end VDF generation in -short mode") input := [32]byte{} rand.Read(input[:]) @@ -74,4 +75,3 @@ func TestVDFModuleRandomSeed(t *testing.T) { assert.Equal(t, true, vdf.Verify(output), "failed verifying proof") } - diff --git a/src/vdf_go/classgroup.go b/src/vdf_go/classgroup.go index 6e5ac2c..20fb07f 100644 --- a/src/vdf_go/classgroup.go +++ b/src/vdf_go/classgroup.go @@ -4,6 +4,11 @@ import ( "math/big" ) +var ( + bigTwoValue = big.NewInt(2) + bigFourValue = big.NewInt(4) +) + type ClassGroup struct { a *big.Int b *big.Int @@ -21,10 +26,11 @@ func NewClassGroup(a, b, c *big.Int) *ClassGroup { func NewClassGroupFromAbDiscriminant(a, b, discriminant *big.Int) *ClassGroup { //z = b*b-discriminant - z := new(big.Int).Sub(new(big.Int).Mul(b, b), discriminant) + z := new(big.Int).Mul(b, b) + z.Sub(z, discriminant) //z = z // 4a - c := floorDivision(z, new(big.Int).Mul(a, big.NewInt(4))) + c := floorDivision(z, new(big.Int).Lsh(new(big.Int).Set(a), 2)) return NewClassGroup(a, b, c) } @@ -47,27 +53,32 @@ func NewClassGroupFromBytesDiscriminant(buf []byte, discriminant *big.Int) (*Cla } func IdentityForDiscriminant(d *big.Int) *ClassGroup { - return NewClassGroupFromAbDiscriminant(big.NewInt(1), big.NewInt(1), d) + c := new(big.Int).Sub(big.NewInt(1), d) + c.Rsh(c, 2) + return &ClassGroup{ + a: big.NewInt(1), + b: big.NewInt(1), + c: c, + } } func (group *ClassGroup) Normalized() *ClassGroup { + if group.isNormalized() { + return group + } + a := new(big.Int).Set(group.a) b := new(big.Int).Set(group.b) c := new(big.Int).Set(group.c) - //if b > -a && b <= a: - if (b.Cmp(new(big.Int).Neg(a)) == 1) && (b.Cmp(a) < 1) { - return group - } - //r = (a - b) // (2 * a) r := new(big.Int).Sub(a, b) - r = floorDivision(r, new(big.Int).Mul(a, big.NewInt(2))) + r = floorDivision(r, new(big.Int).Lsh(new(big.Int).Set(a), 1)) //b, c = b + 2 * r * a, a * r * r + b * r + c - t := new(big.Int).Mul(big.NewInt(2), r) - t.Mul(t, a) oldB := new(big.Int).Set(b) + t := new(big.Int).Mul(r, a) + t.Lsh(t, 1) b.Add(b, t) x := new(big.Int).Mul(a, r) @@ -80,6 +91,10 @@ func (group *ClassGroup) Normalized() *ClassGroup { } func (group *ClassGroup) Reduced() *ClassGroup { + if group.isReduced() { + return group + } + g := group.Normalized() a := new(big.Int).Set(g.a) b := new(big.Int).Set(g.b) @@ -89,7 +104,7 @@ func (group *ClassGroup) Reduced() *ClassGroup { for (a.Cmp(c) == 1) || ((a.Cmp(c) == 0) && (b.Sign() == -1)) { //s = (c + b) // (c + c) s := new(big.Int).Add(c, b) - s = floorDivision(s, new(big.Int).Add(c, c)) + s = floorDivision(s, new(big.Int).Lsh(new(big.Int).Set(c), 1)) //a, b, c = c, -b + 2 * s * c, c * s * s - b * s + a oldA := new(big.Int).Set(a) @@ -97,8 +112,8 @@ func (group *ClassGroup) Reduced() *ClassGroup { a = new(big.Int).Set(c) b.Neg(b) - x := new(big.Int).Mul(big.NewInt(2), s) - x.Mul(x, c) + x := new(big.Int).Mul(s, c) + x.Lsh(x, 1) b.Add(b, x) c.Mul(c, s) @@ -112,7 +127,29 @@ func (group *ClassGroup) Reduced() *ClassGroup { } func (group *ClassGroup) identity() *ClassGroup { - return NewClassGroupFromAbDiscriminant(big.NewInt(1), big.NewInt(1), group.Discriminant()) + return IdentityForDiscriminant(group.Discriminant()) +} + +func (group *ClassGroup) isIdentityReduced() bool { + return group.a.Cmp(bigOne) == 0 && group.b.Cmp(bigOne) == 0 +} + +func (group *ClassGroup) isNormalized() bool { + if group.b.Cmp(group.a) > 0 { + return false + } + if group.b.Sign() >= 0 { + return true + } + return group.b.CmpAbs(group.a) < 0 +} + +func (group *ClassGroup) isReduced() bool { + if !group.isNormalized() { + return false + } + cmp := group.a.Cmp(group.c) + return cmp < 0 || (cmp == 0 && group.b.Sign() >= 0) } func (group *ClassGroup) Discriminant() *big.Int { @@ -121,7 +158,7 @@ func (group *ClassGroup) Discriminant() *big.Int { d.Mul(d, d) a := new(big.Int).Set(group.a) a.Mul(a, group.c) - a.Mul(a, big.NewInt(4)) + a.Mul(a, bigFourValue) d.Sub(d, a) group.d = d @@ -129,39 +166,40 @@ func (group *ClassGroup) Discriminant() *big.Int { return group.d } -func (group *ClassGroup) Multiply(other *ClassGroup) *ClassGroup { - //a1, b1, c1 = self.reduced() - x := group.Reduced() - - //a2, b2, c2 = other.reduced() - y := other.Reduced() +func (group *ClassGroup) multiplyReduced(other *ClassGroup) *ClassGroup { + if group.isIdentityReduced() { + return other + } + if other.isIdentityReduced() { + return group + } //g = (b2 + b1) // 2 - g := new(big.Int).Add(x.b, y.b) - g = floorDivision(g, big.NewInt(2)) + g := new(big.Int).Add(group.b, other.b) + g = floorDivision(g, bigTwoValue) //h = (b2 - b1) // 2 - h := new(big.Int).Sub(y.b, x.b) - h = floorDivision(h, big.NewInt(2)) + h := new(big.Int).Sub(other.b, group.b) + h = floorDivision(h, bigTwoValue) //w = mod.gcd(a1, a2, g) - w1 := allInputValueGCD(y.a, g) - w := allInputValueGCD(x.a, w1) + w1 := allInputValueGCD(other.a, g) + w := allInputValueGCD(group.a, w1) //j = w j := new(big.Int).Set(w) //r = 0 r := big.NewInt(0) //s = a1 // w - s := floorDivision(x.a, w) + s := floorDivision(group.a, w) //t = a2 // w - t := floorDivision(y.a, w) + t := floorDivision(other.a, w) //u = g // w u := floorDivision(g, w) //k_temp, constant_factor = mod.solve_mod(t * u, h * u + s * c1, s * t) b := new(big.Int).Mul(h, u) - sc := new(big.Int).Mul(s, x.c) + sc := new(big.Int).Mul(s, group.c) b.Add(b, sc) k_temp, constant_factor, solvable := SolveMod(new(big.Int).Mul(t, u), b, new(big.Int).Mul(s, t)) if !solvable { @@ -215,13 +253,27 @@ func (group *ClassGroup) Multiply(other *ClassGroup) *ClassGroup { return NewClassGroup(a3, b3, c3).Reduced() } -func (group *ClassGroup) Pow(n int64) *ClassGroup { - x := CloneClassGroup(group) +func (group *ClassGroup) Multiply(other *ClassGroup) *ClassGroup { + return group.Reduced().multiplyReduced(other.Reduced()) +} + +func (group *ClassGroup) powReduced(n int64) *ClassGroup { + if n == 0 { + return group.identity() + } + if n == 1 { + return group + } + if n == 2 { + return group.Square() + } + + x := group items_prod := group.identity() for n > 0 { if n&1 == 1 { - items_prod = items_prod.Multiply(x) + items_prod = items_prod.multiplyReduced(x) if items_prod == nil { return nil } @@ -235,14 +287,27 @@ func (group *ClassGroup) Pow(n int64) *ClassGroup { return items_prod } -func (group *ClassGroup) BigPow(n *big.Int) *ClassGroup { - x := CloneClassGroup(group) +func (group *ClassGroup) Pow(n int64) *ClassGroup { + return group.Reduced().powReduced(n) +} + +func (group *ClassGroup) bigPowReduced(n *big.Int) *ClassGroup { + switch n.Sign() { + case 0: + return group.identity() + case 1: + if n.Cmp(bigOne) == 0 { + return group + } + } + + x := group items_prod := group.identity() p := new(big.Int).Set(n) for p.Sign() > 0 { if p.Bit(0) == 1 { - items_prod = items_prod.Multiply(x) + items_prod = items_prod.multiplyReduced(x) if items_prod == nil { return nil } @@ -256,6 +321,10 @@ func (group *ClassGroup) BigPow(n *big.Int) *ClassGroup { return items_prod } +func (group *ClassGroup) BigPow(n *big.Int) *ClassGroup { + return group.Reduced().bigPowReduced(n) +} + func (group *ClassGroup) Square() *ClassGroup { u, _, solvable := SolveMod(group.b, group.c, group.a) if !solvable { @@ -267,14 +336,15 @@ func (group *ClassGroup) Square() *ClassGroup { //B = b − 2aµ, au := new(big.Int).Mul(group.a, u) - B := new(big.Int).Sub(group.b, new(big.Int).Mul(au, big.NewInt(2))) + au.Lsh(au, 1) + B := new(big.Int).Sub(group.b, au) //C = µ ^ 2 - (bµ−c)//a C := new(big.Int).Mul(u, u) m := new(big.Int).Mul(group.b, u) - m = new(big.Int).Sub(m, group.c) + m.Sub(m, group.c) m = floorDivision(m, group.a) - C = new(big.Int).Sub(C, m) + C.Sub(C, m) return NewClassGroup(A, B, C).Reduced() } @@ -367,8 +437,8 @@ func (group *ClassGroup) Serialize() []byte { int_size := (int_size_bits + 16) >> 4 buf := make([]byte, int_size*2) - copy(buf[:int_size], signBitFill(encodeTwosComplement(r.a), int_size)) - copy(buf[int_size:], signBitFill(encodeTwosComplement(r.b), int_size)) + fillTwosComplement(buf[:int_size], r.a) + fillTwosComplement(buf[int_size:], r.b) return buf } @@ -379,3 +449,7 @@ func (group *ClassGroup) Equal(other *ClassGroup) bool { return (g.a.Cmp(o.a) == 0 && g.b.Cmp(o.b) == 0 && g.c.Cmp(o.c) == 0) } + +func (group *ClassGroup) equalReduced(other *ClassGroup) bool { + return group.a.Cmp(other.a) == 0 && group.b.Cmp(other.b) == 0 && group.c.Cmp(other.c) == 0 +} diff --git a/src/vdf_go/discriminant.go b/src/vdf_go/discriminant.go index 0e6cae5..8be78ca 100644 --- a/src/vdf_go/discriminant.go +++ b/src/vdf_go/discriminant.go @@ -1,7 +1,6 @@ package vdf_go import ( - "bytes" "crypto/sha256" "encoding/binary" "math/big" @@ -14,50 +13,67 @@ type Pair struct { var m = 8 * 3 * 5 * 7 * 11 * 13 -func EntropyFromSeed(seed []byte, byte_count int) []byte { - buffer := bytes.Buffer{} - bufferSize := 0 +var ( + bigMValue = big.NewInt(int64(m)) + bigSieveStep = new(big.Int).Lsh(new(big.Int).Set(bigMValue), 16) + sieveInfoBigP = buildSieveInfoBigP() +) + +func buildSieveInfoBigP() []*big.Int { + values := make([]*big.Int, len(sieve_info)) + for i, pair := range sieve_info { + values[i] = big.NewInt(pair.p) + } + return values +} +func EntropyFromSeed(seed []byte, byte_count int) []byte { extra := uint16(0) - bytes := make([]byte, len(seed)+2) - copy(bytes, seed) - for bufferSize <= byte_count { - binary.BigEndian.PutUint16(bytes[len(seed):], extra) - more_entropy := sha256.Sum256(bytes) - buffer.Write(more_entropy[:]) - bufferSize += sha256.Size - extra += 1 + input := make([]byte, len(seed)+2) + copy(input, seed) + output := make([]byte, byte_count+sha256.Size) + + offset := 0 + for offset <= byte_count { + binary.BigEndian.PutUint16(input[len(seed):], extra) + sum := sha256.Sum256(input) + copy(output[offset:], sum[:]) + offset += sha256.Size + extra++ } - return buffer.Bytes()[:byte_count] + return output[:byte_count] } -//Return a discriminant of the given length using the given seed -//It is a random prime p between 13 - 2^2K -//return -p, where p % 8 == 7 +// Return a discriminant of the given length using the given seed +// It is a random prime p between 13 - 2^2K +// return -p, where p % 8 == 7 func CreateDiscriminant(seed []byte, length int) *big.Int { extra := uint8(length) & 7 byte_count := ((length + 7) >> 3) + 2 entropy := EntropyFromSeed(seed, byte_count) - n := new(big.Int) - n.SetBytes(entropy[:len(entropy)-2]) - n = new(big.Int).Rsh(n, uint(((8 - extra) & 7))) - n = new(big.Int).SetBit(n, length-1, 1) - n = new(big.Int).Sub(n, new(big.Int).Mod(n, big.NewInt(int64(m)))) - n = new(big.Int).Add(n, big.NewInt(int64(residues[int(binary.BigEndian.Uint16(entropy[len(entropy)-2:len(entropy)]))%len(residues)]))) + n := new(big.Int).SetBytes(entropy[:len(entropy)-2]) + n.Rsh(n, uint((8-extra)&7)) + n.SetBit(n, length-1, 1) + + mod := new(big.Int).Mod(n, bigMValue) + n.Sub(n, mod) + n.Add(n, big.NewInt(int64(residues[int(binary.BigEndian.Uint16(entropy[len(entropy)-2:]))%len(residues)]))) negN := new(big.Int).Neg(n) + modP := new(big.Int) + sieve := make([]bool, 1<<16) // Find the smallest prime >= n of the form n + m*x for { - sieve := make([]bool, (1 << 16)) + clear(sieve) - for _, v := range sieve_info { + for idx, v := range sieve_info { // q = m^-1 (mod p) // i = -n / m, so that m*i is -n (mod p) //i := ((-n % v.p) * v.q) % v.p - i := (new(big.Int).Mod(negN, big.NewInt(v.p)).Int64() * v.q) % v.p + i := (modP.Mod(negN, sieveInfoBigP[idx]).Int64() * v.q) % v.p for i < int64(len(sieve)) { sieve[i] = true @@ -65,16 +81,16 @@ func CreateDiscriminant(seed []byte, length int) *big.Int { } } - for i, v := range sieve { - t := new(big.Int).Add(n, big.NewInt(int64(m)*int64(i))) - if !v && t.ProbablyPrime(1) { + t := new(big.Int).Set(n) + for _, composite := range sieve { + if !composite && t.ProbablyPrime(1) { return new(big.Int).Neg(t) } + t.Add(t, bigMValue) } //n += m * (1 << 16) - bigM := big.NewInt(int64(m)) - n = new(big.Int).Add(n, bigM.Mul(bigM, big.NewInt(int64(1<<16)))) - + n.Add(n, bigSieveStep) + negN.Neg(n) } } diff --git a/src/vdf_go/division.go b/src/vdf_go/division.go index 628b401..e6a0344 100644 --- a/src/vdf_go/division.go +++ b/src/vdf_go/division.go @@ -2,16 +2,18 @@ package vdf_go import "math/big" -//Floor Division for big.Int -//Reference : Division and Modulus for Computer Scientists -//https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/divmodnote.pdf -//Golang only has Euclid division and T-division +var bigMinusOne = big.NewInt(-1) + +// Floor Division for big.Int +// Reference : Division and Modulus for Computer Scientists +// https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/divmodnote.pdf +// Golang only has Euclid division and T-division func floorDivision(x, y *big.Int) *big.Int { var r big.Int q, _ := new(big.Int).QuoRem(x, y, &r) if (r.Sign() == 1 && y.Sign() == -1) || (r.Sign() == -1 && y.Sign() == 1) { - q.Sub(q, big.NewInt(1)) + q.Add(q, bigMinusOne) } return q diff --git a/src/vdf_go/encode.go b/src/vdf_go/encode.go index db81f6c..a94f765 100644 --- a/src/vdf_go/encode.go +++ b/src/vdf_go/encode.go @@ -70,6 +70,20 @@ func signBitFill(bytes []byte, targetLen int) []byte { return buf } +func fillTwosComplement(dst []byte, n *big.Int) { + encoded := encodeTwosComplement(n) + fill := byte(0) + if len(encoded) > 0 && encoded[0]&0x80 != 0 { + fill = 0xff + } + + padding := len(dst) - len(encoded) + for i := 0; i < padding; i++ { + dst[i] = fill + } + copy(dst[padding:], encoded) +} + func EncodeBigIntBigEndian(a *big.Int) []byte { int_size_bits := a.BitLen() int_size := (int_size_bits + 16) >> 3 diff --git a/src/vdf_go/gcd.go b/src/vdf_go/gcd.go index c2207bf..ef9e6ce 100644 --- a/src/vdf_go/gcd.go +++ b/src/vdf_go/gcd.go @@ -4,59 +4,19 @@ import ( "math/big" ) -//Return r, s, t such that gcd(a, b) = r = a * s + b * t -func extendedGCD(a, b *big.Int) (r, s, t *big.Int) { - //r0, r1 = a, b - r0 := new(big.Int).Set(a) - r1 := new(big.Int).Set(b) - - //s0, s1, t0, t1 = 1, 0, 0, 1 - s0 := big.NewInt(1) - s1 := big.NewInt(0) - t0 := big.NewInt(0) - t1 := big.NewInt(1) - - //if r0 > r1: - //r0, r1, s0, s1, t0, t1 = r1, r0, t0, t1, s0, s1 - if r0.Cmp(r1) == 1 { - oldR0 := new(big.Int).Set(r0) - r0 = r1 - r1 = oldR0 - oldS0 := new(big.Int).Set(s0) - s0 = t0 - oldS1 := new(big.Int).Set(s1) - s1 = t1 - t0 = oldS0 - t1 = oldS1 +func absBigForGCD(x *big.Int) *big.Int { + if x.Sign() >= 0 { + return x } - - //while r1 > 0: - for r1.Sign() == 1 { - //q, r = divmod(r0, r1) - r := big.NewInt(1) - bb := new(big.Int).Set(b) - q, r := bb.DivMod(r0, r1, r) - - //r0, r1, s0, s1, t0, t1 = r1, r, s1, s0 - q * s1, t1, t0 - q * t1 - r0 = r1 - r1 = r - oldS0 := new(big.Int).Set(s0) - s0 = s1 - s1 = new(big.Int).Sub(oldS0, new(big.Int).Mul(q, s1)) - oldT0 := new(big.Int).Set(t0) - t0 = t1 - t1 = new(big.Int).Sub(oldT0, new(big.Int).Mul(q, t1)) - - } - return r0, s0, t0 + return new(big.Int).Neg(x) } -//wrapper around big.Int GCD to allow all input values for GCD -//as Golang big.Int GCD requires both a, b > 0 -//If a == b == 0, GCD sets r = 0. -//If a == 0 and b != 0, GCD sets r = |b| -//If a != 0 and b == 0, GCD sets r = |a| -//Otherwise r = GCD(|a|, |b|) +// wrapper around big.Int GCD to allow all input values for GCD +// as Golang big.Int GCD requires both a, b > 0 +// If a == b == 0, GCD sets r = 0. +// If a == 0 and b != 0, GCD sets r = |b| +// If a != 0 and b == 0, GCD sets r = |a| +// Otherwise r = GCD(|a|, |b|) func allInputValueGCD(a, b *big.Int) (r *big.Int) { if a.Sign() == 0 { return new(big.Int).Abs(b) @@ -66,27 +26,26 @@ func allInputValueGCD(a, b *big.Int) (r *big.Int) { return new(big.Int).Abs(a) } - return new(big.Int).GCD(nil, nil, new(big.Int).Abs(a), new(big.Int).Abs(b)) + return new(big.Int).GCD(nil, nil, absBigForGCD(a), absBigForGCD(b)) } -//Solve ax == b mod m for x. -//Return s, t where x = s + k * t for integer k yields all solutions. +// Solve ax == b mod m for x. +// Return s, t where x = s + k * t for integer k yields all solutions. func SolveMod(a, b, m *big.Int) (s, t *big.Int, solvable bool) { - //g, d, e = extended_gcd(a, m) - //TODO: golang 1.x big.int GCD requires both a > 0 and m > 0, so we can't use it :( - //d := big.NewInt(0) - //e := big.NewInt(0) - //g := new(big.Int).GCD(d, e, a, m) - g, d, _ := extendedGCD(a, m) + absA := absBigForGCD(a) + absM := absBigForGCD(m) + d := new(big.Int) + g := new(big.Int).GCD(d, nil, absA, absM) + if a.Sign() < 0 { + d.Neg(d) + } - //q, r = divmod(b, g) - r := big.NewInt(1) - bb := new(big.Int).Set(b) - q, r := bb.DivMod(b, g, r) + var r big.Int + q, rem := new(big.Int).QuoRem(b, g, &r) //TODO: replace with utils.GetLogInstance().Error(...) //if r != 0: - if r.Cmp(big.NewInt(0)) != 0 { + if rem.Sign() != 0 { //panic(fmt.Sprintf("no solution to %s x = %s mod %s", a.String(), b.String(), m.String())) return nil, nil, false } @@ -94,7 +53,7 @@ func SolveMod(a, b, m *big.Int) (s, t *big.Int, solvable bool) { //assert b == q * g //return (q * d) % m, m // g q.Mul(q, d) - s = q.Mod(q, m) - t = floorDivision(m, g) + s = q.Mod(q, absM) + t = new(big.Int).Quo(absM, g) return s, t, true } diff --git a/src/vdf_go/proof_wesolowski.go b/src/vdf_go/proof_wesolowski.go index 2620cd8..3cfd134 100644 --- a/src/vdf_go/proof_wesolowski.go +++ b/src/vdf_go/proof_wesolowski.go @@ -3,18 +3,23 @@ package vdf_go import ( "crypto/sha256" "encoding/binary" - "fmt" "log" "math" "math/big" "regexp" "runtime" - "sort" + "sync/atomic" "time" ) -//Creates L and k parameters from papers, based on how many iterations need to be -//performed, and how much memory should be used. +var ( + timingEnabled atomic.Bool + runtimeFuncNameRE = regexp.MustCompile(`^.*\.(.*)$`) + primePrefix = []byte("prime") +) + +// Creates L and k parameters from papers, based on how many iterations need to be +// performed, and how much memory should be used. func approximateParameters(T int) (int, int, int) { //log_memory = math.log(10000000, 2) log_memory := math.Log(10000000) / math.Log(2) @@ -38,32 +43,59 @@ func approximateParameters(T int) (int, int, int) { return L, k, w } -func iterateSquarings(x *ClassGroup, powers_to_calculate []int, stop <-chan struct{}) map[int]*ClassGroup { - powers_calculated := make(map[int]*ClassGroup) - - previous_power := 0 - currX := CloneClassGroup(x) - sort.Ints(powers_to_calculate) - for _, current_power := range powers_to_calculate { - - for i := 0; i < current_power-previous_power; i++ { - currX = currX.Pow(2) - if currX == nil { +func repeatedSquare(x *ClassGroup, count int, stop <-chan struct{}) *ClassGroup { + if stop == nil { + for i := 0; i < count; i++ { + x = x.Square() + if x == nil { return nil } } - previous_power = current_power - powers_calculated[current_power] = currX + return x + } + for i := 0; i < count; i++ { select { case <-stop: return nil default: } + + x = x.Square() + if x == nil { + return nil + } + } + + return x +} + +func iterateSquarings(x *ClassGroup, step, iterations int, stop <-chan struct{}) ([]*ClassGroup, *ClassGroup) { + if iterations <= 0 { + return []*ClassGroup{x}, x + } + + loopCount := int(math.Ceil(float64(iterations) / float64(step))) + powersCalculated := make([]*ClassGroup, loopCount) + powersCalculated[0] = x + + currX := x + for i := 1; i < loopCount; i++ { + currX = repeatedSquare(currX, step, stop) + if currX == nil { + return nil, nil + } + powersCalculated[i] = currX } - return powers_calculated + remaining := iterations - (loopCount-1)*step + y := repeatedSquare(currX, remaining, stop) + if y == nil { + return nil, nil + } + + return powersCalculated, y } func GenerateVDF(seed []byte, iterations, int_size_bits int) ([]byte, []byte) { @@ -71,7 +103,9 @@ func GenerateVDF(seed []byte, iterations, int_size_bits int) ([]byte, []byte) { } func GenerateVDFWithStopChan(seed []byte, iterations, int_size_bits int, stop <-chan struct{}) ([]byte, []byte) { - defer timeTrack(time.Now()) + if timingEnabled.Load() { + defer timeTrack(time.Now()) + } D := CreateDiscriminant(seed, int_size_bits) x := NewClassGroupFromAbDiscriminant(big.NewInt(2), big.NewInt(1), D) @@ -86,31 +120,24 @@ func GenerateVDFWithStopChan(seed []byte, iterations, int_size_bits int, stop <- } func VerifyVDF(seed, proof_blob []byte, iterations, int_size_bits int) bool { - defer timeTrack(time.Now()) - - int_size := (int_size_bits + 16) >> 4 - - D := CreateDiscriminant(seed, int_size_bits) - x := NewClassGroupFromAbDiscriminant(big.NewInt(2), big.NewInt(1), D) - y, _ := NewClassGroupFromBytesDiscriminant(proof_blob[:(2*int_size)], D) - proof, _ := NewClassGroupFromBytesDiscriminant(proof_blob[2*int_size:], D) + if timingEnabled.Load() { + defer timeTrack(time.Now()) + } - return verifyProof(x, y, proof, iterations) + return verifyProofBlob(newVerifierState(seed, int_size_bits), proof_blob, iterations) } // Creates a random prime based on input x, y func hashPrime(x, y []byte) *big.Int { - var j uint64 = 0 - - jBuf := make([]byte, 8) + var j uint64 + buf := make([]byte, len(primePrefix)+8+len(x)+len(y)) + copy(buf, primePrefix) + copy(buf[len(primePrefix)+8:], x) + copy(buf[len(primePrefix)+8+len(x):], y) z := new(big.Int) for { - binary.BigEndian.PutUint64(jBuf, j) - s := append([]byte("prime"), jBuf...) - s = append(s, x...) - s = append(s, y...) - - checkSum := sha256.Sum256(s[:]) + binary.BigEndian.PutUint64(buf[len(primePrefix):], j) + checkSum := sha256.Sum256(buf) z.SetBytes(checkSum[:16]) if z.ProbablyPrime(1) { @@ -123,64 +150,60 @@ func hashPrime(x, y []byte) *big.Int { // Get's the ith block of 2^T // B // such that sum(get_block(i) * 2^ki) = t^T // B -func getBlock(i, k, T int, B *big.Int) *big.Int { - //(pow(2, k) * pow(2, T - k * (i + 1), B)) // B - p1 := big.NewInt(int64(math.Pow(2, float64(k)))) - p2 := new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(T-k*(i+1))), B) - return floorDivision(new(big.Int).Mul(p1, p2), B) +func getBlock(i, k, T int, B, twoPowK *big.Int) *big.Int { + p2 := new(big.Int).Exp(bigTwoValue, big.NewInt(int64(T-k*(i+1))), B) + p2.Mul(p2, twoPowK) + return p2.Quo(p2, B) } -//Optimized evalutation of h ^ (2^T // B) -func evalOptimized(identity, h *ClassGroup, B *big.Int, T, k, l int, C map[int]*ClassGroup) *ClassGroup { +// Optimized evalutation of h ^ (2^T // B) +func evalOptimized(identity, h *ClassGroup, B *big.Int, T, k, l int, C []*ClassGroup) *ClassGroup { //k1 = k//2 var k1 int = k / 2 k0 := k - k1 + outerIterations := int(math.Ceil(float64(T) / float64(k*l))) + bLimit := 1 << uint(k) + b1Limit := 1 << uint(k1) + b0Limit := 1 << uint(k0) + b0Stride := int64(b0Limit) + twoPowK := new(big.Int).Lsh(new(big.Int).Set(bigOne), uint(k)) //x = identity - x := CloneClassGroup(identity) + x := identity + ys := make([]*ClassGroup, bLimit) for j := l - 1; j > -1; j-- { - //x = pow(x, pow(2, k)) - b_limit := int64(math.Pow(2, float64(k))) - x = x.Pow(b_limit) + x = repeatedSquare(x, k, nil) if x == nil { return nil } - //ys = {} - ys := make([]*ClassGroup, b_limit) - for b := int64(0); b < b_limit; b++ { + for b := range ys { ys[b] = identity } - //for i in range(0, math.ceil((T)/(k*l))): - for i := 0; i < int(math.Ceil(float64(T)/float64(k*l))); i++ { + for i := 0; i < outerIterations; i++ { if T-k*(i*l+j+1) < 0 { continue } - ///TODO: carefully check big.Int to int64 value conversion...might cause serious issues later - b := getBlock(i*l+j, k, T, B).Int64() - ys[b] = ys[b].Multiply(C[i*k*l]) + b := int(getBlock(i*l+j, k, T, B, twoPowK).Int64()) + ys[b] = ys[b].Multiply(C[i]) if ys[b] == nil { return nil } } - //for b1 in range(0, pow(2, k1)): - for b1 := 0; b1 < int(math.Pow(float64(2), float64(k1))); b1++ { + for b1 := 0; b1 < b1Limit; b1++ { z := identity - //for b0 in range(0, pow(2, k0)): - for b0 := 0; b0 < int(math.Pow(float64(2), float64((k0)))); b0++ { - //z *= ys[b1 * pow(2, k0) + b0] - z = z.Multiply(ys[int64(b1)*int64(math.Pow(float64(2), float64(k0)))+int64(b0)]) + for b0 := 0; b0 < b0Limit; b0++ { + z = z.Multiply(ys[b1*b0Limit+b0]) if z == nil { return nil } } - //x *= pow(z, b1 * pow(2, k0)) - c := z.Pow(int64(b1) * int64(math.Pow(float64(2), float64(k0)))) + c := z.Pow(int64(b1) * b0Stride) if c == nil { return nil } @@ -190,18 +213,14 @@ func evalOptimized(identity, h *ClassGroup, B *big.Int, T, k, l int, C map[int]* } } - //for b0 in range(0, pow(2, k0)): - for b0 := 0; b0 < int(math.Pow(float64(2), float64(k0))); b0++ { + for b0 := 0; b0 < b0Limit; b0++ { z := identity - //for b1 in range(0, pow(2, k1)): - for b1 := 0; b1 < int(math.Pow(float64(2), float64(k1))); b1++ { - //z *= ys[b1 * pow(2, k0) + b0] - z = z.Multiply(ys[int64(b1)*int64(math.Pow(float64(2), float64(k0)))+int64(b0)]) + for b1 := 0; b1 < b1Limit; b1++ { + z = z.Multiply(ys[b1*b0Limit+b0]) if z == nil { return nil } } - //x *= pow(z, b0) d := z.Pow(int64(b0)) if d == nil { return nil @@ -216,8 +235,8 @@ func evalOptimized(identity, h *ClassGroup, B *big.Int, T, k, l int, C map[int]* return x } -//generate y = x ^ (2 ^T) and pi -func generateProof(identity, x, y *ClassGroup, T, k, l int, powers map[int]*ClassGroup) *ClassGroup { +// generate y = x ^ (2 ^T) and pi +func generateProof(identity, x, y *ClassGroup, T, k, l int, powers []*ClassGroup) *ClassGroup { //x_s = x.serialize() x_s := x.Serialize() @@ -234,23 +253,11 @@ func generateProof(identity, x, y *ClassGroup, T, k, l int, powers map[int]*Clas func calculateVDF(discriminant *big.Int, x *ClassGroup, iterations, int_size_bits int, stop <-chan struct{}) (y, proof *ClassGroup) { L, k, _ := approximateParameters(iterations) - loopCount := int(math.Ceil(float64(iterations) / float64(k*L))) - powers_to_calculate := make([]int, loopCount+2) - - for i := 0; i < loopCount+1; i++ { - powers_to_calculate[i] = i * k * L - } - - powers_to_calculate[loopCount+1] = iterations - - powers := iterateSquarings(x, powers_to_calculate, stop) - - if powers == nil { + powers, y := iterateSquarings(x, k*L, iterations, stop) + if powers == nil || y == nil { return nil, nil } - y = powers[iterations] - identity := IdentityForDiscriminant(discriminant) proof = generateProof(identity, x, y, iterations, k, L, powers) @@ -287,6 +294,10 @@ func verifyProof(x, y, proof *ClassGroup, T int) bool { } } +func SetTimingEnabled(enabled bool) { + timingEnabled.Store(enabled) +} + func timeTrack(start time.Time) { elapsed := time.Since(start) @@ -297,8 +308,7 @@ func timeTrack(start time.Time) { funcObj := runtime.FuncForPC(pc) // Regex to extract just the function name (and not the module path). - runtimeFunc := regexp.MustCompile(`^.*\.(.*)$`) - name := runtimeFunc.ReplaceAllString(funcObj.Name(), "$1") + name := runtimeFuncNameRE.ReplaceAllString(funcObj.Name(), "$1") - log.Println(fmt.Sprintf("%s took %s", name, elapsed)) + log.Printf("%s took %s", name, elapsed) } diff --git a/src/vdf_go/vdf.go b/src/vdf_go/vdf.go index 9e646f3..5d4410b 100644 --- a/src/vdf_go/vdf.go +++ b/src/vdf_go/vdf.go @@ -7,9 +7,10 @@ type VDF struct { output [516]byte outputChan chan [516]byte finished bool + verifier *verifierState } -//size of long integers in quadratic function group +// size of long integers in quadratic function group const sizeInBits = 2048 // New create a new instance of VDF. @@ -18,6 +19,7 @@ func New(difficulty int, input [32]byte) *VDF { difficulty: difficulty, input: input, outputChan: make(chan [516]byte), + verifier: newVerifierState(input[:], sizeInBits), } } @@ -32,7 +34,13 @@ func (vdf *VDF) GetOutputChannel() chan [516]byte { func (vdf *VDF) Execute() { vdf.finished = false - yBuf, proofBuf := GenerateVDF(vdf.input[:], vdf.difficulty, sizeInBits) + if vdf.verifier == nil { + vdf.verifier = newVerifierState(vdf.input[:], sizeInBits) + } + + y, proof := calculateVDF(vdf.verifier.discriminant, vdf.verifier.base, vdf.difficulty, sizeInBits, nil) + yBuf := y.Serialize() + proofBuf := proof.Serialize() copy(vdf.output[:], yBuf) copy(vdf.output[258:], proofBuf) @@ -47,7 +55,7 @@ func (vdf *VDF) Execute() { // Verify runs the verification of generated proof // currently on i7-6700K, verification takes about 350 ms func (vdf *VDF) Verify(proof [516]byte) bool { - return VerifyVDF(vdf.input[:], proof[:], vdf.difficulty, sizeInBits) + return verifyProofBlob(vdf.verifier, proof[:], vdf.difficulty) } // IsFinished returns whether the vdf execution is finished or not. diff --git a/src/vdf_go/verifier.go b/src/vdf_go/verifier.go new file mode 100644 index 0000000..3864db7 --- /dev/null +++ b/src/vdf_go/verifier.go @@ -0,0 +1,135 @@ +package vdf_go + +import "math/big" + +type verifierState struct { + discriminant *big.Int + base *ClassGroup + baseSerialized []byte + intSize int +} + +func serializeReducedClassGroup(group *ClassGroup, intSize int) []byte { + buf := make([]byte, intSize*2) + fillTwosComplement(buf[:intSize], group.a) + fillTwosComplement(buf[intSize:], group.b) + return buf +} + +func newVerifierState(seed []byte, intSizeBits int) *verifierState { + intSize := (intSizeBits + 16) >> 4 + discriminant := CreateDiscriminant(seed, intSizeBits) + base := NewClassGroupFromAbDiscriminant(big.NewInt(2), big.NewInt(1), discriminant) + + return &verifierState{ + discriminant: discriminant, + base: base, + baseSerialized: serializeReducedClassGroup(base, intSize), + intSize: intSize, + } +} + +func multiExpReduced(a *ClassGroup, eA *big.Int, b *ClassGroup, eB *big.Int) *ClassGroup { + maxBits := eA.BitLen() + if eB.BitLen() > maxBits { + maxBits = eB.BitLen() + } + if maxBits == 0 { + return a.identity() + } + + id := a.identity() + powA := [4]*ClassGroup{id, a} + powB := [4]*ClassGroup{id, b} + + powA[2] = a.Square() + if powA[2] == nil { + return nil + } + powA[3] = powA[2].multiplyReduced(a) + if powA[3] == nil { + return nil + } + + powB[2] = b.Square() + if powB[2] == nil { + return nil + } + powB[3] = powB[2].multiplyReduced(b) + if powB[3] == nil { + return nil + } + + table := [16]*ClassGroup{} + table[1] = powB[1] + table[2] = powB[2] + table[3] = powB[3] + table[4] = powA[1] + table[8] = powA[2] + table[12] = powA[3] + + for ea := 1; ea < 4; ea++ { + for eb := 1; eb < 4; eb++ { + idx := (ea << 2) | eb + table[idx] = powA[ea].multiplyReduced(powB[eb]) + if table[idx] == nil { + return nil + } + } + } + + windows := (maxBits + 1) >> 1 + var acc *ClassGroup + started := false + + for w := windows - 1; w >= 0; w-- { + if started { + acc = acc.Square() + if acc == nil { + return nil + } + acc = acc.Square() + if acc == nil { + return nil + } + } + + bit := w << 1 + eaDigit := (eA.Bit(bit+1) << 1) | eA.Bit(bit) + ebDigit := (eB.Bit(bit+1) << 1) | eB.Bit(bit) + key := (eaDigit << 2) | ebDigit + if key == 0 { + continue + } + + factor := table[key] + if !started { + acc = factor + started = true + continue + } + + acc = acc.multiplyReduced(factor) + if acc == nil { + return nil + } + } + + if !started { + return id + } + + return acc +} + +func verifyProofBlob(state *verifierState, proofBlob []byte, iterations int) bool { + ySerialized := proofBlob[:2*state.intSize] + y, _ := NewClassGroupFromBytesDiscriminant(ySerialized, state.discriminant) + proof, _ := NewClassGroupFromBytesDiscriminant(proofBlob[2*state.intSize:], state.discriminant) + + B := hashPrime(state.baseSerialized, ySerialized) + r := new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(iterations)), B) + + z := multiExpReduced(proof, B, state.base, r) + return z != nil && z.equalReduced(y) +} diff --git a/src/vdf_go/verifier_internal_test.go b/src/vdf_go/verifier_internal_test.go new file mode 100644 index 0000000..e1f09ec --- /dev/null +++ b/src/vdf_go/verifier_internal_test.go @@ -0,0 +1,63 @@ +package vdf_go + +import ( + "math/big" + "testing" +) + +func TestMultiExpReducedMatchesSeparateExponentiation(t *testing.T) { + seed := []byte{ + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + } + + discriminant := CreateDiscriminant(seed, 2048) + base := NewClassGroupFromAbDiscriminant(big.NewInt(2), big.NewInt(1), discriminant) + other := base.Pow(17) + if other == nil { + t.Fatal("failed to build comparison class group") + } + + cases := []struct { + a string + b string + }{ + {"0", "0"}, + {"1", "1"}, + {"123456789abcdef", "fedcba987654321"}, + {"ffffffffffffffffffffffffffffffff", "123456789abcdef123456789abcdef"}, + } + + for _, tc := range cases { + eA, ok := new(big.Int).SetString(tc.a, 16) + if !ok { + t.Fatalf("invalid exponent A: %s", tc.a) + } + eB, ok := new(big.Int).SetString(tc.b, 16) + if !ok { + t.Fatalf("invalid exponent B: %s", tc.b) + } + + got := multiExpReduced(other, eA, base, eB) + if got == nil { + t.Fatalf("multiExpReduced returned nil for %s/%s", tc.a, tc.b) + } + + left := other.bigPowReduced(eA) + right := base.bigPowReduced(eB) + if left == nil || right == nil { + t.Fatalf("separate exponentiation returned nil for %s/%s", tc.a, tc.b) + } + + want := left.multiplyReduced(right) + if want == nil { + t.Fatalf("separate multiplication returned nil for %s/%s", tc.a, tc.b) + } + + if !got.equalReduced(want) { + t.Fatalf("mismatch for exponents %s/%s", tc.a, tc.b) + } + } +}