Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (ce *Client) Serve(ctx context.Context) {
}

for ind, entry := range entries {
if entry.Static.Hex() == ctx.Value("dmsgServer").(string) {
if dmsgServer, ok := ctx.Value("dmsgServer").(string); ok && entry.Static.Hex() == dmsgServer {
entries = entries[ind : ind+1]
}
}
Expand Down Expand Up @@ -585,12 +585,20 @@ func (ce *Client) dialSession(ctx context.Context, entry *disc.Entry) (cs Client
}

go func() {
defer func() {
if r := recover(); r != nil {
ce.log.Warnf("recovered panic in session serve goroutine: %v", r)
}
}()
ce.log.WithField("remote_pk", dSes.RemotePK()).Debug("Serving session.")
err := dSes.serve()
if !isClosed(ce.done) {
// We should only report an error when client is not closed.
// Also, when the client is closed, it will automatically delete all sessions.
ce.errCh <- fmt.Errorf("failed to serve dialed session to %s: %v", dSes.RemotePK(), err)
ce.sesMx.Lock()
select {
case ce.errCh <- fmt.Errorf("failed to serve dialed session to %s: %v", dSes.RemotePK(), err):
default:
}
ce.sesMx.Unlock()
ce.delSession(ctx, dSes.RemotePK())
}

Expand Down
18 changes: 8 additions & 10 deletions pkg/dmsg/server_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,12 @@ func (ss *ServerSession) Serve() {
for {
sStr, err := ss.sm.smux.AcceptStream()
if err != nil {
switch err {
case io.EOF:
if err == io.EOF || err == smux.ErrInvalidProtocol || ss.sm.smux.IsClosed() {
ss.log.WithError(err).Info("Stopping session...")
default:
ss.log.WithError(err).Warn("Failed to accept stream, stopping session...")
return
}
return
ss.log.WithError(err).Warn("Failed to accept smux stream, continuing...")
continue
}

log := ss.log.WithField("smux_id", sStr.ID())
Expand All @@ -75,13 +74,12 @@ func (ss *ServerSession) Serve() {
for {
yStr, err := ss.sm.yamux.AcceptStream()
if err != nil {
switch err {
case yamux.ErrSessionShutdown, io.EOF:
if err == yamux.ErrSessionShutdown || err == io.EOF || ss.sm.yamux.IsClosed() {
ss.log.WithError(err).Info("Stopping session...")
default:
ss.log.WithError(err).Warn("Failed to accept stream, stopping session...")
return
}
return
ss.log.WithError(err).Warn("Failed to accept yamux stream, continuing...")
continue
}

log := ss.log.WithField("yamux_id", yStr.StreamID())
Expand Down
47 changes: 27 additions & 20 deletions pkg/dmsg/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dmsg

import (
"context"
"fmt"
"net"
"time"

Expand Down Expand Up @@ -82,16 +83,18 @@ func (s *Stream) writeRequest(rAddr Addr) (req StreamRequest, err error) {
// Reserve stream in porter.
var lPort uint16
if lPort, s.close, err = s.ses.porter.ReserveEphemeral(context.Background(), s); err != nil {
return
return req, err
}

// Prepare fields.
s.prepareFields(true, Addr{PK: s.ses.LocalPK(), Port: lPort}, rAddr)
if err = s.prepareFields(true, Addr{PK: s.ses.LocalPK(), Port: lPort}, rAddr); err != nil {
return req, err
}

// Prepare request.
var nsMsg []byte
if nsMsg, err = s.ns.MakeHandshakeMessage(); err != nil {
return
return req, err
}
req = StreamRequest{
Timestamp: time.Now().UnixNano(),
Expand All @@ -104,21 +107,23 @@ func (s *Stream) writeRequest(rAddr Addr) (req StreamRequest, err error) {
// Write request.
if s.sStr != nil {
err = s.ses.writeObject(s.sStr, obj)
return
return req, err
}
err = s.ses.writeObject(s.yStr, obj)
return
return req, err
}

func (s *Stream) writeIPRequest(rAddr Addr) (req StreamRequest, err error) {
// Reserve stream in porter.
var lPort uint16
if lPort, s.close, err = s.ses.porter.ReserveEphemeral(context.Background(), s); err != nil {
return
return req, err
}

// Prepare fields.
s.prepareFields(true, Addr{PK: s.ses.LocalPK(), Port: lPort}, rAddr)
if err = s.prepareFields(true, Addr{PK: s.ses.LocalPK(), Port: lPort}, rAddr); err != nil {
return req, err
}

req = StreamRequest{
Timestamp: time.Now().UnixNano(),
Expand All @@ -131,42 +136,43 @@ func (s *Stream) writeIPRequest(rAddr Addr) (req StreamRequest, err error) {
// Write request.
if s.sStr != nil {
err = s.ses.writeObject(s.sStr, obj)
return
return req, err
}
err = s.ses.writeObject(s.yStr, obj)
return
return req, err
}

func (s *Stream) readRequest() (req StreamRequest, err error) {
var obj SignedObject
if s.sStr != nil {
if obj, err = s.ses.readObject(s.sStr); err != nil {
return
return req, err
}
} else {
if obj, err = s.ses.readObject(s.yStr); err != nil {
return
return req, err
}
}

if req, err = obj.ObtainStreamRequest(); err != nil {
return
return req, err
}
if err = req.Verify(0); err != nil {
return
return req, err
}
if req.DstAddr.PK != s.ses.LocalPK() {
err = ErrReqInvalidDstPK
return
return req, ErrReqInvalidDstPK
}

// Prepare fields.
s.prepareFields(false, req.DstAddr, req.SrcAddr)
if err = s.prepareFields(false, req.DstAddr, req.SrcAddr); err != nil {
return req, err
}

if err = s.ns.ProcessHandshakeMessage(req.NoiseMsg); err != nil {
return
return req, err
}
return
return req, nil
}

func (s *Stream) writeResponse(reqHash cipher.SHA256) error {
Expand Down Expand Up @@ -254,15 +260,15 @@ func (s *Stream) readIPResponse(req StreamRequest) (net.IP, error) {
return resp.IP, nil
}

func (s *Stream) prepareFields(init bool, lAddr, rAddr Addr) {
func (s *Stream) prepareFields(init bool, lAddr, rAddr Addr) error {
ns, err := noise.New(noise.HandshakeKK, noise.Config{
LocalPK: s.ses.LocalPK(),
LocalSK: s.ses.localSK(),
RemotePK: rAddr.PK,
Initiator: init,
})
if err != nil {
s.log.WithError(err).Panic("Failed to prepare stream noise object.")
return fmt.Errorf("failed to prepare stream noise object: %w", err)
}

s.lAddr = lAddr
Expand All @@ -274,6 +280,7 @@ func (s *Stream) prepareFields(init bool, lAddr, rAddr Addr) {
s.nsConn = noise.NewReadWriter(s.yStr, s.ns)
}
s.log = s.ses.log.WithField("stream", s.lAddr.ShortString()+"->"+s.rAddr.ShortString())
return nil
}

// LocalAddr returns the local address of the dmsg stream.
Expand Down
11 changes: 10 additions & 1 deletion pkg/dmsgctrl/serve_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@ package dmsgctrl

import (
"net"
"strings"

"github.com/skycoin/skywire/pkg/skywire-utilities/pkg/logging"
)

var log = logging.MustGetLogger("dmsgctrl")

// ServeListener serves a listener with dmsgctrl.Control.
// It returns a channel for incoming Controls.
func ServeListener(l net.Listener, chanLen int) <-chan *Control {
Expand All @@ -16,7 +21,11 @@ func ServeListener(l net.Listener, chanLen int) <-chan *Control {
for {
conn, err := l.Accept()
if err != nil {
return
if strings.Contains(err.Error(), "use of closed") {
return
}
log.Warnf("Failed to accept dmsgctrl conn, continuing: %v", err)
continue
}
if ctrl := ControlStream(conn); ch != nil && len(ch) < cap(ch) {
ch <- ctrl
Expand Down
2 changes: 1 addition & 1 deletion pkg/dmsgpty/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (h *Host) serveConn(ctx context.Context, log logrus.FieldLogger, mux *hostM
func (h *Host) authorize(log logrus.FieldLogger, rPK cipher.PubKey) bool {
ok, err := h.wl.Get(rPK)
if err != nil {
log.WithError(err).Panic("dmsgpty.Whitelist error.")
log.WithError(err).Error("dmsgpty.Whitelist error.")
return false
}
if !ok {
Expand Down
2 changes: 1 addition & 1 deletion pkg/dmsgpty/ui.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (ui *UI) Handler(customCommands map[string][]string) http.HandlerFunc {
go func() {
// Buffer PTY output and flush periodically to reduce WebSocket message count
bw := newBufferedWSWriter(wsConn, 16*time.Millisecond)
defer bw.Close()
defer bw.Close() //nolint:errcheck
_, _ = io.Copy(bw, ptyC) //nolint:errcheck
closeDone()
}()
Expand Down
Loading
Loading