diff --git a/cmd/discosrv/LICENSE b/cmd/discosrv/LICENSE new file mode 100644 index 00000000..9dd6db41 --- /dev/null +++ b/cmd/discosrv/LICENSE @@ -0,0 +1,19 @@ +Copyright (C) 2014-2015 The Discosrv Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +- The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/cmd/discosrv/README.md b/cmd/discosrv/README.md new file mode 100644 index 00000000..8269c607 --- /dev/null +++ b/cmd/discosrv/README.md @@ -0,0 +1,40 @@ +discosrv +======== + +[![Latest Build](http://img.shields.io/jenkins/s/http/build.syncthing.net/discosrv.svg?style=flat-square)](http://build.syncthing.net/job/discosrv/lastBuild/) + +This is the global discovery server for the `syncthing` project. + +To get it, run `go get github.com/syncthing/discosrv` or download the +[latest build](http://build.syncthing.net/job/discosrv/lastSuccessfulBuild/artifact/) +from the build server. + +Usage +----- + +The discovery server supports `ql` and `postgres` backends. +Specify the backend via `-db-backend` and the database DSN via `-db-dsn`. + +By default it will use in-memory `ql` backend. If you wish to persist the +information on disk between restarts in `ql`, specify a file DSN: + +```bash +$ discosrv -db-dsn="file:///var/run/discosrv.db" +``` + +For `postgres`, you will need to create a database and a user with permissions +to create tables in it, then start the discosrv as follows: + +```bash +$ export DISCOSRV_DB_DSN="postgres://user:password@localhost/databasename" +$ discosrv -db-backend="postgres" +``` + +You can pass the DSN as command line option, but the value what you pass in will +be visible in most process managers, potentially exposing the database password +to other users. + +In all cases, the appropriate tables and indexes will be created at first +startup. If it doesn't exit with an error, you're fine. + +See `discosrv -help` for other options. diff --git a/cmd/discosrv/clean.go b/cmd/discosrv/clean.go new file mode 100644 index 00000000..962f773f --- /dev/null +++ b/cmd/discosrv/clean.go @@ -0,0 +1,75 @@ +// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "database/sql" + "log" + "time" +) + +type cleansrv struct { + intv time.Duration + db *sql.DB + prep map[string]*sql.Stmt +} + +func (s *cleansrv) Serve() { + for { + time.Sleep(next(s.intv)) + + err := s.cleanOldEntries() + if err != nil { + log.Println("Clean:", err) + } + } +} + +func (s *cleansrv) Stop() { + panic("stop unimplemented") +} + +func (s *cleansrv) cleanOldEntries() (err error) { + var tx *sql.Tx + tx, err = s.db.Begin() + if err != nil { + return err + } + + defer func() { + if err == nil { + err = tx.Commit() + } else { + tx.Rollback() + } + }() + + res, err := tx.Stmt(s.prep["cleanAddress"]).Exec() + if err != nil { + return err + } + if rows, _ := res.RowsAffected(); rows > 0 { + log.Printf("Clean: %d old addresses", rows) + } + + res, err = tx.Stmt(s.prep["cleanDevice"]).Exec() + if err != nil { + return err + } + if rows, _ := res.RowsAffected(); rows > 0 { + log.Printf("Clean: %d old devices", rows) + } + + var devs, addrs int + row := tx.Stmt(s.prep["countDevice"]).QueryRow() + if err = row.Scan(&devs); err != nil { + return err + } + row = tx.Stmt(s.prep["countAddress"]).QueryRow() + if err = row.Scan(&addrs); err != nil { + return err + } + + log.Printf("Database: %d devices, %d addresses", devs, addrs) + return nil +} diff --git a/cmd/discosrv/db.go b/cmd/discosrv/db.go new file mode 100644 index 00000000..34162d58 --- /dev/null +++ b/cmd/discosrv/db.go @@ -0,0 +1,32 @@ +// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "database/sql" + "fmt" +) + +type setupFunc func(db *sql.DB) error +type compileFunc func(db *sql.DB) (map[string]*sql.Stmt, error) + +var ( + setupFuncs = make(map[string]setupFunc) + compileFuncs = make(map[string]compileFunc) +) + +func register(name string, setup setupFunc, compile compileFunc) { + setupFuncs[name] = setup + compileFuncs[name] = compile +} + +func setup(backend string, db *sql.DB) (map[string]*sql.Stmt, error) { + setup, ok := setupFuncs[backend] + if !ok { + return nil, fmt.Errorf("Unsupported backend") + } + if err := setup(db); err != nil { + return nil, err + } + return compileFuncs[backend](db) +} diff --git a/cmd/discosrv/main.go b/cmd/discosrv/main.go new file mode 100644 index 00000000..ae3d3719 --- /dev/null +++ b/cmd/discosrv/main.go @@ -0,0 +1,118 @@ +// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "crypto/tls" + "database/sql" + "flag" + "log" + "os" + "time" + + "github.com/syncthing/syncthing/lib/protocol" + "github.com/thejerf/suture" +) + +const ( + minNegCache = 60 // seconds + maxNegCache = 3600 // seconds + maxDeviceAge = 7 * 86400 // one week, in seconds +) + +var ( + lruSize = 10240 + limitAvg = 5 + limitBurst = 20 + globalStats stats + statsFile string + backend = "ql" + dsn = getEnvDefault("DISCOSRV_DB_DSN", "memory://discosrv") + certFile = "cert.pem" + keyFile = "key.pem" + debug = false + useHttp = false +) + +func main() { + const ( + cleanIntv = 1 * time.Hour + statsIntv = 5 * time.Minute + ) + + var listen string + + log.SetOutput(os.Stdout) + log.SetFlags(0) + + flag.StringVar(&listen, "listen", ":8443", "Listen address") + flag.IntVar(&lruSize, "limit-cache", lruSize, "Limiter cache entries") + flag.IntVar(&limitAvg, "limit-avg", limitAvg, "Allowed average package rate, per 10 s") + flag.IntVar(&limitBurst, "limit-burst", limitBurst, "Allowed burst size, packets") + flag.StringVar(&statsFile, "stats-file", statsFile, "File to write periodic operation stats to") + flag.StringVar(&backend, "db-backend", backend, "Database backend to use") + flag.StringVar(&dsn, "db-dsn", dsn, "Database DSN") + flag.StringVar(&certFile, "cert", certFile, "Certificate file") + flag.StringVar(&keyFile, "key", keyFile, "Key file") + flag.BoolVar(&debug, "debug", debug, "Debug") + flag.BoolVar(&useHttp, "http", useHttp, "Listen on HTTP (behind an HTTPS proxy)") + flag.Parse() + + var cert tls.Certificate + var err error + if !useHttp { + cert, err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalln("Failed to load X509 key pair:", err) + } + + devID := protocol.NewDeviceID(cert.Certificate[0]) + log.Println("Server device ID is", devID) + } + + db, err := sql.Open(backend, dsn) + if err != nil { + log.Fatalln("sql.Open:", err) + } + prep, err := setup(backend, db) + if err != nil { + log.Fatalln("Setup:", err) + } + + main := suture.NewSimple("main") + + main.Add(&querysrv{ + addr: listen, + cert: cert, + db: db, + prep: prep, + }) + + main.Add(&cleansrv{ + intv: cleanIntv, + db: db, + prep: prep, + }) + + main.Add(&statssrv{ + intv: statsIntv, + file: statsFile, + db: db, + }) + + globalStats.Reset() + main.Serve() +} + +func getEnvDefault(key, def string) string { + if val := os.Getenv(key); val != "" { + return val + } + return def +} + +func next(intv time.Duration) time.Duration { + t0 := time.Now() + t1 := t0.Add(intv).Truncate(intv) + return t1.Sub(t0) +} diff --git a/cmd/discosrv/psql.go b/cmd/discosrv/psql.go new file mode 100644 index 00000000..d4af7fe6 --- /dev/null +++ b/cmd/discosrv/psql.go @@ -0,0 +1,97 @@ +// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "database/sql" + "fmt" + + _ "github.com/lib/pq" +) + +func init() { + register("postgres", postgresSetup, postgresCompile) +} + +func postgresSetup(db *sql.DB) error { + var err error + + db.SetMaxIdleConns(4) + db.SetMaxOpenConns(8) + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS Devices ( + DeviceID CHAR(63) NOT NULL PRIMARY KEY, + Seen TIMESTAMP NOT NULL + )`) + if err != nil { + return err + } + + row := db.QueryRow(`SELECT 'DevicesDeviceIDIndex'::regclass`) + if err := row.Scan(nil); err != nil { + _, err = db.Exec(`CREATE INDEX DevicesDeviceIDIndex ON Devices (DeviceID)`) + } + if err != nil { + return err + } + + row = db.QueryRow(`SELECT 'DevicesSeenIndex'::regclass`) + if err := row.Scan(nil); err != nil { + _, err = db.Exec(`CREATE INDEX DevicesSeenIndex ON Devices (Seen)`) + } + if err != nil { + return err + } + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS Addresses ( + DeviceID CHAR(63) NOT NULL, + Seen TIMESTAMP NOT NULL, + Address VARCHAR(256) NOT NULL + )`) + if err != nil { + return err + } + + row = db.QueryRow(`SELECT 'AddressesDeviceIDSeenIndex'::regclass`) + if err := row.Scan(nil); err != nil { + _, err = db.Exec(`CREATE INDEX AddressesDeviceIDSeenIndex ON Addresses (DeviceID, Seen)`) + } + if err != nil { + return err + } + + row = db.QueryRow(`SELECT 'AddressesDeviceIDAddressIndex'::regclass`) + if err := row.Scan(nil); err != nil { + _, err = db.Exec(`CREATE INDEX AddressesDeviceIDAddressIndex ON Addresses (DeviceID, Address)`) + } + if err != nil { + return err + } + + return nil +} + +func postgresCompile(db *sql.DB) (map[string]*sql.Stmt, error) { + stmts := map[string]string{ + "cleanAddress": "DELETE FROM Addresses WHERE Seen < now() - '2 hour'::INTERVAL", + "cleanDevice": fmt.Sprintf("DELETE FROM Devices WHERE Seen < now() - '%d hour'::INTERVAL", maxDeviceAge/3600), + "countAddress": "SELECT count(*) FROM Addresses", + "countDevice": "SELECT count(*) FROM Devices", + "insertAddress": "INSERT INTO Addresses (DeviceID, Seen, Address) VALUES ($1, now(), $2)", + "insertDevice": "INSERT INTO Devices (DeviceID, Seen) VALUES ($1, now())", + "selectAddress": "SELECT Address FROM Addresses WHERE DeviceID=$1 AND Seen > now() - '1 hour'::INTERVAL ORDER BY random() LIMIT 16", + "selectDevice": "SELECT Seen FROM Devices WHERE DeviceID=$1", + "updateAddress": "UPDATE Addresses SET Seen=now() WHERE DeviceID=$1 AND Address=$2", + "updateDevice": "UPDATE Devices SET Seen=now() WHERE DeviceID=$1", + } + + res := make(map[string]*sql.Stmt, len(stmts)) + for key, stmt := range stmts { + prep, err := db.Prepare(stmt) + if err != nil { + return nil, err + } + res[key] = prep + } + return res, nil +} diff --git a/cmd/discosrv/ql.go b/cmd/discosrv/ql.go new file mode 100644 index 00000000..971ee6f2 --- /dev/null +++ b/cmd/discosrv/ql.go @@ -0,0 +1,81 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "database/sql" + "fmt" + "log" + + "github.com/cznic/ql" +) + +func init() { + ql.RegisterDriver() + register("ql", qlSetup, qlCompile) +} + +func qlSetup(db *sql.DB) (err error) { + tx, err := db.Begin() + if err != nil { + return + } + + defer func() { + if err == nil { + err = tx.Commit() + } else { + tx.Rollback() + } + }() + + _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS Devices ( + DeviceID STRING NOT NULL, + Seen TIME NOT NULL + )`) + if err != nil { + return + } + + if _, err = tx.Exec(`CREATE INDEX IF NOT EXISTS DevicesDeviceIDIndex ON Devices (DeviceID)`); err != nil { + return + } + + _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS Addresses ( + DeviceID STRING NOT NULL, + Seen TIME NOT NULL, + Address STRING NOT NULL, + )`) + if err != nil { + return + } + + _, err = tx.Exec(`CREATE INDEX IF NOT EXISTS AddressesDeviceIDAddressIndex ON Addresses (DeviceID, Address)`) + return +} + +func qlCompile(db *sql.DB) (map[string]*sql.Stmt, error) { + stmts := map[string]string{ + "cleanAddress": `DELETE FROM Addresses WHERE Seen < now() - duration("2h")`, + "cleanDevice": fmt.Sprintf(`DELETE FROM Devices WHERE Seen < now() - duration("%dh")`, maxDeviceAge/3600), + "countAddress": "SELECT count(*) FROM Addresses", + "countDevice": "SELECT count(*) FROM Devices", + "insertAddress": "INSERT INTO Addresses (DeviceID, Seen, Address) VALUES ($1, now(), $2)", + "insertDevice": "INSERT INTO Devices (DeviceID, Seen) VALUES ($1, now())", + "selectAddress": `SELECT Address from Addresses WHERE DeviceID==$1 AND Seen > now() - duration("1h") LIMIT 16`, + "selectDevice": "SELECT Seen FROM Devices WHERE DeviceID==$1", + "updateAddress": "UPDATE Addresses Seen=now() WHERE DeviceID==$1 AND Address==$2", + "updateDevice": "UPDATE Devices Seen=now() WHERE DeviceID==$1", + } + + res := make(map[string]*sql.Stmt, len(stmts)) + for key, stmt := range stmts { + prep, err := db.Prepare(stmt) + if err != nil { + log.Println("Failed to compile", stmt) + return nil, err + } + res[key] = prep + } + return res, nil +} diff --git a/cmd/discosrv/querysrv.go b/cmd/discosrv/querysrv.go new file mode 100644 index 00000000..9cd5a891 --- /dev/null +++ b/cmd/discosrv/querysrv.go @@ -0,0 +1,476 @@ +// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "bytes" + "crypto/tls" + "database/sql" + "encoding/json" + "encoding/pem" + "fmt" + "log" + "math/rand" + "net" + "net/http" + "net/url" + "strconv" + "sync" + "time" + + "github.com/golang/groupcache/lru" + "github.com/juju/ratelimit" + "github.com/syncthing/syncthing/lib/protocol" + "golang.org/x/net/context" +) + +type querysrv struct { + addr string + db *sql.DB + prep map[string]*sql.Stmt + limiter *safeCache + cert tls.Certificate + listener net.Listener +} + +type announcement struct { + Seen time.Time `json:"seen"` + Addresses []string `json:"addresses"` +} + +type safeCache struct { + *lru.Cache + mut sync.Mutex +} + +func (s *safeCache) Get(key string) (val interface{}, ok bool) { + s.mut.Lock() + val, ok = s.Cache.Get(key) + s.mut.Unlock() + return +} + +func (s *safeCache) Add(key string, val interface{}) { + s.mut.Lock() + s.Cache.Add(key, val) + s.mut.Unlock() +} + +type requestID int64 + +func (i requestID) String() string { + return fmt.Sprintf("%016x", int64(i)) +} + +func negCacheFor(lastSeen time.Time) int { + since := time.Since(lastSeen).Seconds() + if since >= maxDeviceAge { + return maxNegCache + } + if since < 0 { + // That's weird + return minNegCache + } + + // Return a value linearly scaled from minNegCache (at zero seconds ago) + // to maxNegCache (at maxDeviceAge seconds ago). + r := since / maxDeviceAge + return int(minNegCache + r*(maxNegCache-minNegCache)) +} + +func (s *querysrv) Serve() { + s.limiter = &safeCache{ + Cache: lru.New(lruSize), + } + + if useHttp { + listener, err := net.Listen("tcp", s.addr) + if err != nil { + log.Println("Listen:", err) + return + } + s.listener = listener + } else { + tlsCfg := &tls.Config{ + Certificates: []tls.Certificate{s.cert}, + ClientAuth: tls.RequestClientCert, + SessionTicketsDisabled: true, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + }, + } + + tlsListener, err := tls.Listen("tcp", s.addr, tlsCfg) + if err != nil { + log.Println("Listen:", err) + return + } + s.listener = tlsListener + } + + http.HandleFunc("/v2/", s.handler) + http.HandleFunc("/ping", handlePing) + + srv := &http.Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + MaxHeaderBytes: 1 << 10, + } + + if err := srv.Serve(s.listener); err != nil { + log.Println("Serve:", err) + } +} + +var topCtx = context.Background() + +func (s *querysrv) handler(w http.ResponseWriter, req *http.Request) { + reqID := requestID(rand.Int63()) + ctx := context.WithValue(topCtx, "id", reqID) + + if debug { + log.Println(reqID, req.Method, req.URL) + } + + t0 := time.Now() + defer func() { + diff := time.Since(t0) + var comment string + if diff > time.Second { + comment = "(very slow request)" + } else if diff > 100*time.Millisecond { + comment = "(slow request)" + } + if comment != "" || debug { + log.Println(reqID, req.Method, req.URL, "completed in", diff, comment) + } + }() + + var remoteIP net.IP + if useHttp { + remoteIP = net.ParseIP(req.Header.Get("X-Forwarded-For")) + } else { + addr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr) + if err != nil { + log.Println("remoteAddr:", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + remoteIP = addr.IP + } + + if s.limit(remoteIP) { + if debug { + log.Println(remoteIP, "is limited") + } + w.Header().Set("Retry-After", "60") + http.Error(w, "Too Many Requests", 429) + return + } + + switch req.Method { + case "GET": + s.handleGET(ctx, w, req) + case "POST": + s.handlePOST(ctx, remoteIP, w, req) + default: + globalStats.Error() + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + } +} + +func (s *querysrv) handleGET(ctx context.Context, w http.ResponseWriter, req *http.Request) { + reqID := ctx.Value("id").(requestID) + + deviceID, err := protocol.DeviceIDFromString(req.URL.Query().Get("device")) + if err != nil { + if debug { + log.Println(reqID, "bad device param") + } + globalStats.Error() + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + var ann announcement + + ann.Seen, err = s.getDeviceSeen(deviceID) + negCache := strconv.Itoa(negCacheFor(ann.Seen)) + w.Header().Set("Retry-After", negCache) + w.Header().Set("Cache-Control", "public, max-age="+negCache) + + if err != nil { + // The device is not in the database. + globalStats.Query() + http.Error(w, "Not Found", http.StatusNotFound) + return + } + + t0 := time.Now() + ann.Addresses, err = s.getAddresses(ctx, deviceID) + if err != nil { + log.Println(reqID, "getAddresses:", err) + globalStats.Error() + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + if debug { + log.Println(reqID, "getAddresses in", time.Since(t0)) + } + + globalStats.Query() + + if len(ann.Addresses) == 0 { + http.Error(w, "Not Found", http.StatusNotFound) + return + } + + globalStats.Answer() + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(ann) +} + +func (s *querysrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.ResponseWriter, req *http.Request) { + reqID := ctx.Value("id").(requestID) + + rawCert := certificateBytes(req) + if rawCert == nil { + if debug { + log.Println(reqID, "no certificates") + } + globalStats.Error() + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + var ann announcement + if err := json.NewDecoder(req.Body).Decode(&ann); err != nil { + if debug { + log.Println(reqID, "decode:", err) + } + globalStats.Error() + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + deviceID := protocol.NewDeviceID(rawCert) + + // handleAnnounce returns *two* errors. The first indicates a problem with + // something the client posted to us. We should return a 400 Bad Request + // and not worry about it. The second indicates that the request was fine, + // but something internal messed up. We should log it and respond with a + // more apologetic 500 Internal Server Error. + userErr, internalErr := s.handleAnnounce(ctx, remoteIP, deviceID, ann.Addresses) + if userErr != nil { + if debug { + log.Println(reqID, "handleAnnounce:", userErr) + } + globalStats.Error() + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + if internalErr != nil { + log.Println(reqID, "handleAnnounce:", internalErr) + globalStats.Error() + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + globalStats.Announce() + + // TODO: Slowly increase this for stable clients + w.Header().Set("Reannounce-After", "1800") + + // We could return the lookup result here, but it's kind of unnecessarily + // expensive to go query the database again so we let the client decide to + // do a lookup if they really care. + w.WriteHeader(http.StatusNoContent) +} + +func (s *querysrv) Stop() { + s.listener.Close() +} + +func (s *querysrv) handleAnnounce(ctx context.Context, remote net.IP, deviceID protocol.DeviceID, addresses []string) (userErr, internalErr error) { + reqID := ctx.Value("id").(requestID) + + tx, err := s.db.Begin() + if err != nil { + internalErr = err + return + } + + defer func() { + // Since we return from a bunch of different places, we handle + // rollback in the defer. + if internalErr != nil || userErr != nil { + tx.Rollback() + } + }() + + for _, annAddr := range addresses { + uri, err := url.Parse(annAddr) + if err != nil { + userErr = err + return + } + + host, port, err := net.SplitHostPort(uri.Host) + if err != nil { + userErr = err + return + } + + ip := net.ParseIP(host) + if len(ip) == 0 || ip.IsUnspecified() { + uri.Host = net.JoinHostPort(remote.String(), port) + } + + if err := s.updateAddress(ctx, tx, deviceID, uri.String()); err != nil { + internalErr = err + return + } + } + + if err := s.updateDevice(ctx, tx, deviceID); err != nil { + internalErr = err + return + } + + t0 := time.Now() + internalErr = tx.Commit() + if debug { + log.Println(reqID, "commit in", time.Since(t0)) + } + return +} + +func (s *querysrv) limit(remote net.IP) bool { + key := remote.String() + + bkt, ok := s.limiter.Get(key) + if ok { + bkt := bkt.(*ratelimit.Bucket) + if bkt.TakeAvailable(1) != 1 { + // Rate limit exceeded; ignore packet + return true + } + } else { + // One packet per ten seconds average rate, burst ten packets + s.limiter.Add(key, ratelimit.NewBucket(10*time.Second/time.Duration(limitAvg), int64(limitBurst))) + } + + return false +} + +func (s *querysrv) updateDevice(ctx context.Context, tx *sql.Tx, device protocol.DeviceID) error { + reqID := ctx.Value("id").(requestID) + t0 := time.Now() + res, err := tx.Stmt(s.prep["updateDevice"]).Exec(device.String()) + if err != nil { + return err + } + if debug { + log.Println(reqID, "updateDevice in", time.Since(t0)) + } + + if rows, _ := res.RowsAffected(); rows == 0 { + t0 = time.Now() + _, err := tx.Stmt(s.prep["insertDevice"]).Exec(device.String()) + if err != nil { + return err + } + if debug { + log.Println(reqID, "insertDevice in", time.Since(t0)) + } + } + + return nil +} + +func (s *querysrv) updateAddress(ctx context.Context, tx *sql.Tx, device protocol.DeviceID, uri string) error { + res, err := tx.Stmt(s.prep["updateAddress"]).Exec(device.String(), uri) + if err != nil { + return err + } + + if rows, _ := res.RowsAffected(); rows == 0 { + _, err := tx.Stmt(s.prep["insertAddress"]).Exec(device.String(), uri) + if err != nil { + return err + } + } + + return nil +} + +func (s *querysrv) getAddresses(ctx context.Context, device protocol.DeviceID) ([]string, error) { + rows, err := s.prep["selectAddress"].Query(device.String()) + if err != nil { + return nil, err + } + defer rows.Close() + + var res []string + for rows.Next() { + var addr string + + err := rows.Scan(&addr) + if err != nil { + log.Println("Scan:", err) + continue + } + res = append(res, addr) + } + + return res, nil +} + +func (s *querysrv) getDeviceSeen(device protocol.DeviceID) (time.Time, error) { + row := s.prep["selectDevice"].QueryRow(device.String()) + var seen time.Time + if err := row.Scan(&seen); err != nil { + return time.Time{}, err + } + return seen, nil +} + +func handlePing(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(204) +} + +func certificateBytes(req *http.Request) []byte { + if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 { + return req.TLS.PeerCertificates[0].Raw + } + + if hdr := req.Header.Get("X-SSL-Cert"); hdr != "" { + bs := []byte(hdr) + // The certificate is in PEM format but with spaces for newlines. We + // need to reinstate the newlines for the PEM decoder. But we need to + // leave the spaces in the BEGIN and END lines - the first and last + // space - alone. + firstSpace := bytes.Index(bs, []byte(" ")) + lastSpace := bytes.LastIndex(bs, []byte(" ")) + for i := firstSpace + 1; i < lastSpace; i++ { + if bs[i] == ' ' { + bs[i] = '\n' + } + } + block, _ := pem.Decode(bs) + if block == nil { + // Decoding failed + return nil + } + return block.Bytes + } + + return nil +} diff --git a/cmd/discosrv/stats.go b/cmd/discosrv/stats.go new file mode 100644 index 00000000..cdc8c301 --- /dev/null +++ b/cmd/discosrv/stats.go @@ -0,0 +1,141 @@ +// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "bytes" + "database/sql" + "fmt" + "io/ioutil" + "log" + "os" + "sync/atomic" + "time" +) + +type stats struct { + // Incremented atomically + announces int64 + queries int64 + answers int64 + errors int64 +} + +func (s *stats) Announce() { + atomic.AddInt64(&s.announces, 1) +} + +func (s *stats) Query() { + atomic.AddInt64(&s.queries, 1) +} + +func (s *stats) Answer() { + atomic.AddInt64(&s.answers, 1) +} + +func (s *stats) Error() { + atomic.AddInt64(&s.errors, 1) +} + +// Reset returns a copy of the current stats and resets the counters to +// zero. +func (s *stats) Reset() stats { + // Create a copy of the stats using atomic reads + copy := stats{ + announces: atomic.LoadInt64(&s.announces), + queries: atomic.LoadInt64(&s.queries), + answers: atomic.LoadInt64(&s.answers), + errors: atomic.LoadInt64(&s.errors), + } + + // Reset the stats by subtracting the values that we copied + atomic.AddInt64(&s.announces, -copy.announces) + atomic.AddInt64(&s.queries, -copy.queries) + atomic.AddInt64(&s.answers, -copy.answers) + atomic.AddInt64(&s.errors, -copy.errors) + + return copy +} + +type statssrv struct { + intv time.Duration + file string + db *sql.DB +} + +func (s *statssrv) Serve() { + lastReset := time.Now() + for { + time.Sleep(next(s.intv)) + + stats := globalStats.Reset() + d := time.Since(lastReset).Seconds() + lastReset = time.Now() + + log.Printf("Stats: %.02f announces/s, %.02f queries/s, %.02f answers/s, %.02f errors/s", + float64(stats.announces)/d, float64(stats.queries)/d, float64(stats.answers)/d, float64(stats.errors)/d) + + if s.file != "" { + s.writeToFile(stats, d) + } + } +} + +func (s *statssrv) Stop() { + panic("stop unimplemented") +} + +func (s *statssrv) writeToFile(stats stats, secs float64) { + newLine := []byte("\n") + + var addrs int + row := s.db.QueryRow("SELECT COUNT(*) FROM Addresses") + if err := row.Scan(&addrs); err != nil { + log.Println("stats query:", err) + return + } + + fd, err := os.OpenFile(s.file, os.O_RDWR|os.O_CREATE, 0666) + if err != nil { + log.Println("stats file:", err) + return + } + defer func() { + err = fd.Close() + if err != nil { + log.Println("stats file:", err) + } + }() + + bs, err := ioutil.ReadAll(fd) + if err != nil { + log.Println("stats file:", err) + return + } + lines := bytes.Split(bytes.TrimSpace(bs), newLine) + if len(lines) > 12 { + lines = lines[len(lines)-12:] + } + + latest := fmt.Sprintf("%v: %6d addresses, %8.02f announces/s, %8.02f queries/s, %8.02f answers/s, %8.02f errors/s\n", + time.Now().UTC().Format(time.RFC3339), addrs, + float64(stats.announces)/secs, float64(stats.queries)/secs, float64(stats.answers)/secs, float64(stats.errors)/secs) + lines = append(lines, []byte(latest)) + + _, err = fd.Seek(0, 0) + if err != nil { + log.Println("stats file:", err) + return + } + err = fd.Truncate(0) + if err != nil { + log.Println("stats file:", err) + return + } + + _, err = fd.Write(bytes.Join(lines, newLine)) + if err != nil { + log.Println("stats file:", err) + return + } +}