@@ -14,6 +14,7 @@ import (
1414 "golang.org/x/net/proxy"
1515 "google.golang.org/grpc"
1616 "google.golang.org/grpc/codes"
17+ "google.golang.org/grpc/connectivity"
1718 "google.golang.org/grpc/credentials"
1819 "google.golang.org/grpc/credentials/insecure"
1920 "google.golang.org/grpc/keepalive"
@@ -30,73 +31,56 @@ func LoggerRecoveryHandler(log *logrus.Entry) recovery.RecoveryHandlerFunc {
3031 }
3132}
3233
33- // BlockingDial is a helper method to dial the given address, using optional TLS credentials,
34+ // BlockingNewClient is a helper method to dial the given address, using optional TLS credentials,
3435// and blocking until the returned connection is ready. If the given credentials are nil, the
3536// connection will be insecure (plain-text).
3637// Lifted from: https://github.com/fullstorydev/grpcurl/blob/master/grpcurl.go
37- func BlockingDial (ctx context.Context , network , address string , creds credentials.TransportCredentials , opts ... grpc.DialOption ) (* grpc.ClientConn , error ) {
38- // grpc.Dial doesn't provide any information on permanent connection errors (like
39- // TLS handshake failures). So in order to provide good error messages, we need a
40- // custom dialer that can provide that info. That means we manage the TLS handshake.
41- result := make (chan any , 1 )
42- writeResult := func (res any ) {
43- // non-blocking write: we only need the first result
44- select {
45- case result <- res :
46- default :
47- }
48- }
49-
50- dialer := func (ctx context.Context , address string ) (net.Conn , error ) {
38+ func BlockingNewClient (ctx context.Context , network , address string , creds credentials.TransportCredentials , opts ... grpc.DialOption ) (* grpc.ClientConn , error ) {
39+ customDialer := func (ctx context.Context , address string ) (net.Conn , error ) {
5140 proxyDialer := proxy .FromEnvironment ()
5241 conn , err := proxyDialer .Dial (network , address )
5342 if err != nil {
54- writeResult (err )
5543 return nil , fmt .Errorf ("error dial proxy: %w" , err )
5644 }
45+
5746 if creds != nil {
5847 conn , _ , err = creds .ClientHandshake (ctx , address , conn )
5948 if err != nil {
60- writeResult (err )
6149 return nil , fmt .Errorf ("error creating connection: %w" , err )
6250 }
6351 }
52+
6453 return conn , nil
6554 }
6655
67- // Even with grpc.FailOnNonTempDialError, this call will usually timeout in
68- // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
69- // know when we're done. So we run it in a goroutine and then use result
70- // channel to either get the channel or fail-fast.
71- go func () {
72- opts = append (opts ,
73- //nolint:staticcheck
74- grpc .WithBlock (),
75- //nolint:staticcheck
76- grpc .FailOnNonTempDialError (true ),
77- grpc .WithContextDialer (dialer ),
78- grpc .WithTransportCredentials (insecure .NewCredentials ()), // we are handling TLS, so tell grpc not to
79- grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : common .GetGRPCKeepAliveTime ()}),
80- )
81- //nolint:staticcheck
82- conn , err := grpc .DialContext (ctx , address , opts ... )
83- var res any
84- if err != nil {
85- res = err
86- } else {
87- res = conn
88- }
89- writeResult (res )
90- }()
56+ opts = append (opts ,
57+ grpc .WithContextDialer (customDialer ),
58+ grpc .WithTransportCredentials (insecure .NewCredentials ()),
59+ grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : common .GetGRPCKeepAliveTime ()}),
60+ )
9161
92- select {
93- case res := <- result :
94- if conn , ok := res .(* grpc.ClientConn ); ok {
95- return conn , nil
62+ cc , err := grpc .NewClient ("passthrough:" + address , opts ... )
63+ if err != nil {
64+ return nil , fmt .Errorf ("grpc.NewClient failed: %w" , err )
65+ }
66+
67+ cc .Connect ()
68+ if err := waitForReady (ctx , cc ); err != nil {
69+ return nil , fmt .Errorf ("gRPC connection not ready: %w" , err )
70+ }
71+
72+ return cc , nil
73+ }
74+
75+ func waitForReady (ctx context.Context , conn * grpc.ClientConn ) error {
76+ for {
77+ state := conn .GetState ()
78+ if state == connectivity .Ready {
79+ return nil
80+ }
81+ if ! conn .WaitForStateChange (ctx , state ) {
82+ return ctx .Err () // context timeout or cancellation
9683 }
97- return nil , res .(error )
98- case <- ctx .Done ():
99- return nil , ctx .Err ()
10084 }
10185}
10286
@@ -120,15 +104,15 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
120104 ctx , cancel := context .WithTimeout (context .Background (), dialTime )
121105 defer cancel ()
122106
123- conn , err := BlockingDial (ctx , "tcp" , address , creds )
107+ conn , err := BlockingNewClient (ctx , "tcp" , address , creds )
124108 if err == nil {
125109 _ = conn .Close ()
126110 testResult .TLS = true
127111 creds := credentials .NewTLS (& tls.Config {})
128112 ctx , cancel := context .WithTimeout (context .Background (), dialTime )
129113 defer cancel ()
130114
131- conn , err := BlockingDial (ctx , "tcp" , address , creds )
115+ conn , err := BlockingNewClient (ctx , "tcp" , address , creds )
132116 if err == nil {
133117 _ = conn .Close ()
134118 } else {
@@ -143,7 +127,7 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
143127 // refused). Test if server accepts plain-text connections
144128 ctx , cancel = context .WithTimeout (context .Background (), dialTime )
145129 defer cancel ()
146- conn , err = BlockingDial (ctx , "tcp" , address , nil )
130+ conn , err = BlockingNewClient (ctx , "tcp" , address , nil )
147131 if err == nil {
148132 _ = conn .Close ()
149133 testResult .TLS = false
0 commit comments