diff --git a/lib/tlsutil/tlsutil.go b/lib/tlsutil/tlsutil.go index e836247c..2c169b6b 100644 --- a/lib/tlsutil/tlsutil.go +++ b/lib/tlsutil/tlsutil.go @@ -23,6 +23,10 @@ import ( "time" ) +var ( + ErrIdentificationFailed = fmt.Errorf("failed to identify socket type") +) + func NewCertificate(certFile, keyFile, tlsDefaultCommonName string, tlsRSABits int) (tls.Certificate, error) { priv, err := rsa.GenerateKey(rand.Reader, tlsRSABits) if err != nil { @@ -85,9 +89,28 @@ type DowngradingListener struct { } func (l *DowngradingListener) Accept() (net.Conn, error) { + conn, isTLS, err := l.AcceptNoWrap() + + // We failed to identify the socket type, pretend that everything is fine, + // and pass it to the underlying handler, and let them deal with it. + if err == ErrIdentificationFailed { + return conn, nil + } + + if err != nil { + return conn, err + } + + if isTLS { + return tls.Server(conn, l.TLSConfig), nil + } + return conn, nil +} + +func (l *DowngradingListener) AcceptNoWrap() (net.Conn, bool, error) { conn, err := l.Listener.Accept() if err != nil { - return nil, err + return nil, false, err } br := bufio.NewReader(conn) @@ -96,18 +119,12 @@ func (l *DowngradingListener) Accept() (net.Conn, error) { conn.SetReadDeadline(time.Time{}) if err != nil { // We hit a read error here, but the Accept() call succeeded so we must not return an error. - // We return the connection as is and let whoever tries to use it deal with the error. - return conn, nil + // We return the connection as is with a special error which handles this + // special case in Accept(). + return conn, false, ErrIdentificationFailed } - wrapper := &WrappedConnection{br, conn} - - // 0x16 is the first byte of a TLS handshake - if bs[0] == 0x16 { - return tls.Server(wrapper, l.TLSConfig), nil - } - - return wrapper, nil + return &WrappedConnection{br, conn}, bs[0] == 0x16, nil } type WrappedConnection struct {