diff --git a/cmd/syncthing/gui.go b/cmd/syncthing/gui.go index 0c27c213..4e333c47 100644 --- a/cmd/syncthing/gui.go +++ b/cmd/syncthing/gui.go @@ -34,6 +34,7 @@ import ( "github.com/syncthing/syncthing/lib/model" "github.com/syncthing/syncthing/lib/osutil" "github.com/syncthing/syncthing/lib/sync" + "github.com/syncthing/syncthing/lib/tlsutil" "github.com/syncthing/syncthing/lib/upgrade" "github.com/vitrun/qart/qr" "golang.org/x/crypto/bcrypt" @@ -92,7 +93,7 @@ func (s *apiSvc) getListener(cfg config.GUIConfiguration) (net.Listener, error) name = tlsDefaultCommonName } - cert, err = newCertificate(locations[locHTTPSCertFile], locations[locHTTPSKeyFile], name) + cert, err = tlsutil.NewCertificate(locations[locHTTPSCertFile], locations[locHTTPSKeyFile], name, tlsRSABits) } if err != nil { return nil, err @@ -120,7 +121,7 @@ func (s *apiSvc) getListener(cfg config.GUIConfiguration) (net.Listener, error) return nil, err } - listener := &DowngradingListener{rawListener, tlsCfg} + listener := &tlsutil.DowngradingListener{rawListener, tlsCfg} return listener, nil } diff --git a/cmd/syncthing/main.go b/cmd/syncthing/main.go index 4518732e..4dce776c 100644 --- a/cmd/syncthing/main.go +++ b/cmd/syncthing/main.go @@ -36,6 +36,7 @@ import ( "github.com/syncthing/syncthing/lib/osutil" "github.com/syncthing/syncthing/lib/relay" "github.com/syncthing/syncthing/lib/symlinks" + "github.com/syncthing/syncthing/lib/tlsutil" "github.com/syncthing/syncthing/lib/upgrade" "github.com/syndtr/goleveldb/leveldb" @@ -67,8 +68,10 @@ const ( ) const ( - bepProtocolName = "bep/1.0" - pingEventInterval = time.Minute + bepProtocolName = "bep/1.0" + tlsDefaultCommonName = "syncthing" + tlsRSABits = 3072 + pingEventInterval = time.Minute ) var l = logger.DefaultLogger @@ -298,7 +301,7 @@ func main() { l.Warnln("Key exists; will not overwrite.") l.Infoln("Device ID:", protocol.NewDeviceID(cert.Certificate[0])) } else { - cert, err = newCertificate(certFile, keyFile, tlsDefaultCommonName) + cert, err = tlsutil.NewCertificate(certFile, keyFile, tlsDefaultCommonName, tlsRSABits) myID = protocol.NewDeviceID(cert.Certificate[0]) if err != nil { l.Fatalln("load cert:", err) @@ -464,9 +467,10 @@ func syncthingMain() { // Ensure that that we have a certificate and key. cert, err := tls.LoadX509KeyPair(locations[locCertFile], locations[locKeyFile]) if err != nil { - cert, err = newCertificate(locations[locCertFile], locations[locKeyFile], tlsDefaultCommonName) + l.Infof("Generating RSA key and certificate for %s...", tlsDefaultCommonName) + cert, err = tlsutil.NewCertificate(locations[locCertFile], locations[locKeyFile], tlsDefaultCommonName, tlsRSABits) if err != nil { - l.Fatalln("load cert:", err) + l.Fatalln(err) } } diff --git a/cmd/syncthing/tls.go b/lib/tlsutil/tlsutil.go similarity index 79% rename from cmd/syncthing/tls.go rename to lib/tlsutil/tlsutil.go index d8b8be05..e836247c 100644 --- a/cmd/syncthing/tls.go +++ b/lib/tlsutil/tlsutil.go @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -package main +package tlsutil import ( "bufio" @@ -14,6 +14,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "fmt" "io" "math/big" mr "math/rand" @@ -22,17 +23,10 @@ import ( "time" ) -const ( - tlsRSABits = 3072 - tlsDefaultCommonName = "syncthing" -) - -func newCertificate(certFile, keyFile, name string) (tls.Certificate, error) { - l.Infof("Generating RSA key and certificate for %s...", name) - +func NewCertificate(certFile, keyFile, tlsDefaultCommonName string, tlsRSABits int) (tls.Certificate, error) { priv, err := rsa.GenerateKey(rand.Reader, tlsRSABits) if err != nil { - l.Fatalln("generate key:", err) + return tls.Certificate{}, fmt.Errorf("generate key: %s", err) } notBefore := time.Now() @@ -41,7 +35,7 @@ func newCertificate(certFile, keyFile, name string) (tls.Certificate, error) { template := x509.Certificate{ SerialNumber: new(big.Int).SetInt64(mr.Int63()), Subject: pkix.Name{ - CommonName: name, + CommonName: tlsDefaultCommonName, }, NotBefore: notBefore, NotAfter: notAfter, @@ -53,33 +47,33 @@ func newCertificate(certFile, keyFile, name string) (tls.Certificate, error) { derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { - l.Fatalln("create cert:", err) + return tls.Certificate{}, fmt.Errorf("create cert: %s", err) } certOut, err := os.Create(certFile) if err != nil { - l.Fatalln("save cert:", err) + return tls.Certificate{}, fmt.Errorf("save cert: %s", err) } err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) if err != nil { - l.Fatalln("save cert:", err) + return tls.Certificate{}, fmt.Errorf("save cert: %s", err) } err = certOut.Close() if err != nil { - l.Fatalln("save cert:", err) + return tls.Certificate{}, fmt.Errorf("save cert: %s", err) } keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { - l.Fatalln("save key:", err) + return tls.Certificate{}, fmt.Errorf("save key: %s", err) } err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) if err != nil { - l.Fatalln("save key:", err) + return tls.Certificate{}, fmt.Errorf("save key: %s", err) } err = keyOut.Close() if err != nil { - l.Fatalln("save key:", err) + return tls.Certificate{}, fmt.Errorf("save key: %s", err) } return tls.LoadX509KeyPair(certFile, keyFile) @@ -90,11 +84,6 @@ type DowngradingListener struct { TLSConfig *tls.Config } -type WrappedConnection struct { - io.Reader - net.Conn -} - func (l *DowngradingListener) Accept() (net.Conn, error) { conn, err := l.Listener.Accept() if err != nil { @@ -121,6 +110,11 @@ func (l *DowngradingListener) Accept() (net.Conn, error) { return wrapper, nil } +type WrappedConnection struct { + io.Reader + net.Conn +} + func (c *WrappedConnection) Read(b []byte) (n int, err error) { return c.Reader.Read(b) }