@@ -14,6 +14,7 @@ import (
1414 "golang.org/x/net/proxy"
1515 "google.golang.org/grpc"
1616 "google.golang.org/grpc/codes"
17+ "google.golang.org/grpc/connectivity"
1718 "google.golang.org/grpc/credentials"
1819 "google.golang.org/grpc/credentials/insecure"
1920 "google.golang.org/grpc/keepalive"
@@ -30,73 +31,55 @@ func LoggerRecoveryHandler(log *logrus.Entry) recovery.RecoveryHandlerFunc {
3031 }
3132}
3233
33- // BlockingDial is a helper method to dial the given address, using optional TLS credentials,
34+ // BlockingNewClient is a helper method to dial the given address, using optional TLS credentials,
3435// and blocking until the returned connection is ready. If the given credentials are nil, the
3536// connection will be insecure (plain-text).
3637// Lifted from: https://github.com/fullstorydev/grpcurl/blob/master/grpcurl.go
37- func BlockingDial (ctx context.Context , network , address string , creds credentials.TransportCredentials , opts ... grpc.DialOption ) (* grpc.ClientConn , error ) {
38- // grpc.Dial doesn't provide any information on permanent connection errors (like
39- // TLS handshake failures). So in order to provide good error messages, we need a
40- // custom dialer that can provide that info. That means we manage the TLS handshake.
41- result := make (chan any , 1 )
42- writeResult := func (res any ) {
43- // non-blocking write: we only need the first result
44- select {
45- case result <- res :
46- default :
47- }
38+ func BlockingNewClient (ctx context.Context , network , address string , creds credentials.TransportCredentials , opts ... grpc.DialOption ) (* grpc.ClientConn , error ) {
39+ proxyDialer := proxy .FromEnvironment ()
40+ rawConn , err := proxyDialer .Dial (network , address )
41+ if err != nil {
42+ return nil , fmt .Errorf ("error dial proxy: %w" , err )
4843 }
4944
50- dialer := func (ctx context.Context , address string ) (net.Conn , error ) {
51- proxyDialer := proxy .FromEnvironment ()
52- conn , err := proxyDialer .Dial (network , address )
45+ if creds != nil {
46+ rawConn , _ , err = creds .ClientHandshake (ctx , address , rawConn )
5347 if err != nil {
54- writeResult (err )
55- return nil , fmt .Errorf ("error dial proxy: %w" , err )
56- }
57- if creds != nil {
58- conn , _ , err = creds .ClientHandshake (ctx , address , conn )
59- if err != nil {
60- writeResult (err )
61- return nil , fmt .Errorf ("error creating connection: %w" , err )
62- }
48+ return nil , fmt .Errorf ("error creating connection: %w" , err )
6349 }
64- return conn , nil
50+ }
51+ customDialer := func (_ context.Context , _ string ) (net.Conn , error ) {
52+ return rawConn , nil
6553 }
6654
67- // Even with grpc.FailOnNonTempDialError, this call will usually timeout in
68- // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
69- // know when we're done. So we run it in a goroutine and then use result
70- // channel to either get the channel or fail-fast.
71- go func () {
72- opts = append (opts ,
73- //nolint:staticcheck
74- grpc .WithBlock (),
75- //nolint:staticcheck
76- grpc .FailOnNonTempDialError (true ),
77- grpc .WithContextDialer (dialer ),
78- grpc .WithTransportCredentials (insecure .NewCredentials ()), // we are handling TLS, so tell grpc not to
79- grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : common .GetGRPCKeepAliveTime ()}),
80- )
81- //nolint:staticcheck
82- conn , err := grpc .DialContext (ctx , address , opts ... )
83- var res any
84- if err != nil {
85- res = err
86- } else {
87- res = conn
88- }
89- writeResult (res )
90- }()
55+ opts = append (opts ,
56+ grpc .WithContextDialer (customDialer ),
57+ grpc .WithTransportCredentials (insecure .NewCredentials ()),
58+ grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : common .GetGRPCKeepAliveTime ()}),
59+ )
9160
92- select {
93- case res := <- result :
94- if conn , ok := res .(* grpc.ClientConn ); ok {
95- return conn , nil
61+ conn , err := grpc .NewClient ("passthrough:" + address , opts ... )
62+ if err != nil {
63+ return nil , fmt .Errorf ("grpc.NewClient failed: %w" , err )
64+ }
65+
66+ conn .Connect ()
67+ if err := waitForReady (ctx , conn ); err != nil {
68+ return nil , fmt .Errorf ("gRPC connection not ready: %w" , err )
69+ }
70+
71+ return conn , nil
72+ }
73+
74+ func waitForReady (ctx context.Context , conn * grpc.ClientConn ) error {
75+ for {
76+ state := conn .GetState ()
77+ if state == connectivity .Ready {
78+ return nil
79+ }
80+ if ! conn .WaitForStateChange (ctx , state ) {
81+ return ctx .Err () // context timeout or cancellation
9682 }
97- return nil , res .(error )
98- case <- ctx .Done ():
99- return nil , ctx .Err ()
10083 }
10184}
10285
@@ -120,15 +103,15 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
120103 ctx , cancel := context .WithTimeout (context .Background (), dialTime )
121104 defer cancel ()
122105
123- conn , err := BlockingDial (ctx , "tcp" , address , creds )
106+ conn , err := BlockingNewClient (ctx , "tcp" , address , creds )
124107 if err == nil {
125108 _ = conn .Close ()
126109 testResult .TLS = true
127110 creds := credentials .NewTLS (& tls.Config {})
128111 ctx , cancel := context .WithTimeout (context .Background (), dialTime )
129112 defer cancel ()
130113
131- conn , err := BlockingDial (ctx , "tcp" , address , creds )
114+ conn , err := BlockingNewClient (ctx , "tcp" , address , creds )
132115 if err == nil {
133116 _ = conn .Close ()
134117 } else {
@@ -143,7 +126,7 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
143126 // refused). Test if server accepts plain-text connections
144127 ctx , cancel = context .WithTimeout (context .Background (), dialTime )
145128 defer cancel ()
146- conn , err = BlockingDial (ctx , "tcp" , address , nil )
129+ conn , err = BlockingNewClient (ctx , "tcp" , address , nil )
147130 if err == nil {
148131 _ = conn .Close ()
149132 testResult .TLS = false
0 commit comments