cmd/stdiscosrv: New discovery server (fixes #4618)

This is a new revision of the discovery server. Relevant changes and
non-changes:

- Protocol towards clients is unchanged.

- Recommended large scale design is still to be deployed nehind nginx (I
  tested, and it's still a lot faster at terminating TLS).

- Database backend is leveldb again, only. It scales enough, is easy to
  setup, and we don't need any backend to take care of.

- Server supports replication. This is a simple TCP channel - protect it
  with a firewall when deploying over the internet. (We deploy this within
  the same datacenter, and with firewall.) Any incoming client announces
  are sent over the replication channel(s) to other peer discosrvs.
  Incoming replication changes are applied to the database as if they came
  from clients, but without the TLS/certificate overhead.

- Metrics are exposed using the prometheus library, when enabled.

- The database values and replication protocol is protobuf, because JSON
  was quite CPU intensive when I tried that and benchmarked it.

- The "Retry-After" value for failed lookups gets slowly increased from
  a default of 120 seconds, by 5 seconds for each failed lookup,
  independently by each discosrv. This lowers the query load over time for
  clients that are never seen. The Retry-After maxes out at 3600 after a
  couple of weeks of this increase. The number of failed lookups is
  stored in the database, now and then (avoiding making each lookup a
  database put).

All in all this means clients can be pointed towards a cluster using
just multiple A / AAAA records to gain both load sharing and redundancy
(if one is down, clients will talk to the remaining ones).

GitHub-Pull-Request: https://github.com/syncthing/syncthing/pull/4648
This commit is contained in:
Jakob Borg
2018-01-14 08:52:31 +00:00
parent 341b9691a7
commit 916ec63af6
864 changed files with 216825 additions and 64540 deletions

View File

@@ -1,19 +0,0 @@
Copyright (C) 2014-2015 The Discosrv Authors
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
- The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -6,33 +6,4 @@ This is the global discovery server for the `syncthing` project.
Usage
-----
The discovery server supports `ql` and `postgres` backends.
Specify the backend via `-db-backend` and the database DSN via `-db-dsn`.
By default it will use in-memory `ql` backend. If you wish to persist the
information on disk between restarts in `ql`, specify a file DSN:
```bash
$ stdiscosrv -db-dsn="file:///var/run/stdiscosrv.db"
```
For `postgres`, you will need to create a database and a user with permissions
to create tables in it, then start the stdiscosrv as follows:
```bash
$ export STDISCOSRV_DB_DSN="postgres://user:password@localhost/databasename"
$ stdiscosrv -db-backend="postgres"
```
You can pass the DSN as command line option, but the value what you pass in will
be visible in most process managers, potentially exposing the database password
to other users.
In all cases, the appropriate tables and indexes will be created at first
startup. If it doesn't exit with an error, you're fine.
See `stdiscosrv -help` for other options.
##### Third-party attribution
[cznic/lldb](https://github.com/cznic/lldb), Copyright (C) 2014 The lldb Authors.
https://docs.syncthing.net/users/stdiscosrv.html

394
cmd/stdiscosrv/apisrv.go Normal file
View File

@@ -0,0 +1,394 @@
// Copyright (C) 2018 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
package main
import (
"bytes"
"crypto/tls"
"encoding/json"
"encoding/pem"
"fmt"
"log"
"math/rand"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
"github.com/syncthing/syncthing/lib/protocol"
"golang.org/x/net/context"
)
// announcement is the format received from and sent to clients
type announcement struct {
Seen time.Time `json:"seen"`
Addresses []string `json:"addresses"`
}
type apiSrv struct {
addr string
cert tls.Certificate
db database
listener net.Listener
repl replicator // optional
useHTTP bool
mapsMut sync.Mutex
misses map[string]int32
}
type requestID int64
func (i requestID) String() string {
return fmt.Sprintf("%016x", int64(i))
}
type contextKey int
const idKey contextKey = iota
func newAPISrv(addr string, cert tls.Certificate, db database, repl replicator, useHTTP bool) *apiSrv {
return &apiSrv{
addr: addr,
cert: cert,
db: db,
repl: repl,
useHTTP: useHTTP,
misses: make(map[string]int32),
}
}
func (s *apiSrv) Serve() {
if s.useHTTP {
listener, err := net.Listen("tcp", s.addr)
if err != nil {
log.Println("Listen:", err)
return
}
s.listener = listener
} else {
tlsCfg := &tls.Config{
Certificates: []tls.Certificate{s.cert},
ClientAuth: tls.RequestClientCert,
SessionTicketsDisabled: true,
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
},
}
tlsListener, err := tls.Listen("tcp", s.addr, tlsCfg)
if err != nil {
log.Println("Listen:", err)
return
}
s.listener = tlsListener
}
http.HandleFunc("/", s.handler)
http.HandleFunc("/ping", handlePing)
srv := &http.Server{
ReadTimeout: httpReadTimeout,
WriteTimeout: httpWriteTimeout,
MaxHeaderBytes: httpMaxHeaderBytes,
}
if err := srv.Serve(s.listener); err != nil {
log.Println("Serve:", err)
}
}
var topCtx = context.Background()
func (s *apiSrv) handler(w http.ResponseWriter, req *http.Request) {
t0 := time.Now()
lw := NewLoggingResponseWriter(w)
defer func() {
diff := time.Since(t0)
apiRequestsSeconds.WithLabelValues(req.Method).Observe(diff.Seconds())
apiRequestsTotal.WithLabelValues(req.Method, strconv.Itoa(lw.statusCode)).Inc()
}()
reqID := requestID(rand.Int63())
ctx := context.WithValue(topCtx, idKey, reqID)
if debug {
log.Println(reqID, req.Method, req.URL)
}
var remoteIP net.IP
if s.useHTTP {
remoteIP = net.ParseIP(req.Header.Get("X-Forwarded-For"))
} else {
addr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr)
if err != nil {
log.Println("remoteAddr:", err)
lw.Header().Set("Retry-After", errorRetryAfterString())
http.Error(lw, "Internal Server Error", http.StatusInternalServerError)
apiRequestsTotal.WithLabelValues("no_remote_addr").Inc()
return
}
remoteIP = addr.IP
}
switch req.Method {
case "GET":
s.handleGET(ctx, lw, req)
case "POST":
s.handlePOST(ctx, remoteIP, lw, req)
default:
http.Error(lw, "Method Not Allowed", http.StatusMethodNotAllowed)
}
}
func (s *apiSrv) handleGET(ctx context.Context, w http.ResponseWriter, req *http.Request) {
reqID := ctx.Value(idKey).(requestID)
deviceID, err := protocol.DeviceIDFromString(req.URL.Query().Get("device"))
if err != nil {
if debug {
log.Println(reqID, "bad device param")
}
lookupRequestsTotal.WithLabelValues("bad_request").Inc()
w.Header().Set("Retry-After", errorRetryAfterString())
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
key := deviceID.String()
rec, err := s.db.get(key)
if err != nil {
// some sort of internal error
lookupRequestsTotal.WithLabelValues("internal_error").Inc()
w.Header().Set("Retry-After", errorRetryAfterString())
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if len(rec.Addresses) == 0 {
lookupRequestsTotal.WithLabelValues("not_found").Inc()
s.mapsMut.Lock()
misses := s.misses[key]
if misses < rec.Misses {
misses = rec.Misses + 1
} else {
misses++
}
s.misses[key] = misses
s.mapsMut.Unlock()
if misses%notFoundMissesWriteInterval == 0 {
rec.Misses = misses
rec.Addresses = nil
// rec.Seen retained from get
s.db.put(key, rec)
}
w.Header().Set("Retry-After", notFoundRetryAfterString(int(misses)))
http.Error(w, "Not Found", http.StatusNotFound)
return
}
lookupRequestsTotal.WithLabelValues("success").Inc()
bs, _ := json.Marshal(announcement{
Seen: time.Unix(0, rec.Seen),
Addresses: addressStrs(rec.Addresses),
})
w.Header().Set("Content-Type", "application/json")
w.Write(bs)
}
func (s *apiSrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.ResponseWriter, req *http.Request) {
reqID := ctx.Value(idKey).(requestID)
rawCert := certificateBytes(req)
if rawCert == nil {
if debug {
log.Println(reqID, "no certificates")
}
announceRequestsTotal.WithLabelValues("no_certificate").Inc()
w.Header().Set("Retry-After", errorRetryAfterString())
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
var ann announcement
if err := json.NewDecoder(req.Body).Decode(&ann); err != nil {
if debug {
log.Println(reqID, "decode:", err)
}
announceRequestsTotal.WithLabelValues("bad_request").Inc()
w.Header().Set("Retry-After", errorRetryAfterString())
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
deviceID := protocol.NewDeviceID(rawCert)
addresses := fixupAddresses(remoteIP, ann.Addresses)
if len(addresses) == 0 {
announceRequestsTotal.WithLabelValues("bad_request").Inc()
w.Header().Set("Retry-After", errorRetryAfterString())
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
if err := s.handleAnnounce(remoteIP, deviceID, addresses); err != nil {
announceRequestsTotal.WithLabelValues("internal_error").Inc()
w.Header().Set("Retry-After", errorRetryAfterString())
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
announceRequestsTotal.WithLabelValues("success").Inc()
w.Header().Set("Reannounce-After", reannounceAfterString())
w.WriteHeader(http.StatusNoContent)
}
func (s *apiSrv) Stop() {
s.listener.Close()
}
func (s *apiSrv) handleAnnounce(remote net.IP, deviceID protocol.DeviceID, addresses []string) error {
key := deviceID.String()
now := time.Now()
expire := now.Add(addressExpiryTime).UnixNano()
dbAddrs := make([]DatabaseAddress, len(addresses))
for i := range addresses {
dbAddrs[i].Address = addresses[i]
dbAddrs[i].Expires = expire
}
seen := now.UnixNano()
if s.repl != nil {
s.repl.send(key, dbAddrs, seen)
}
return s.db.merge(key, dbAddrs, seen)
}
func handlePing(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(204)
}
func certificateBytes(req *http.Request) []byte {
if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 {
return req.TLS.PeerCertificates[0].Raw
}
if hdr := req.Header.Get("X-SSL-Cert"); hdr != "" {
bs := []byte(hdr)
// The certificate is in PEM format but with spaces for newlines. We
// need to reinstate the newlines for the PEM decoder. But we need to
// leave the spaces in the BEGIN and END lines - the first and last
// space - alone.
firstSpace := bytes.Index(bs, []byte(" "))
lastSpace := bytes.LastIndex(bs, []byte(" "))
for i := firstSpace + 1; i < lastSpace; i++ {
if bs[i] == ' ' {
bs[i] = '\n'
}
}
block, _ := pem.Decode(bs)
if block == nil {
// Decoding failed
return nil
}
return block.Bytes
}
return nil
}
// fixupAddresses checks the list of addresses, removing invalid ones and
// replacing unspecified IPs with the given remote IP.
func fixupAddresses(remote net.IP, addresses []string) []string {
fixed := make([]string, 0, len(addresses))
for _, annAddr := range addresses {
uri, err := url.Parse(annAddr)
if err != nil {
continue
}
host, port, err := net.SplitHostPort(uri.Host)
if err != nil {
continue
}
ip := net.ParseIP(host)
if host == "" || ip.IsUnspecified() {
// Do not use IPv6 remote address if requested scheme is tcp4
if uri.Scheme == "tcp4" && remote.To4() == nil {
continue
}
// Do not use IPv4 remote address if requested scheme is tcp6
if uri.Scheme == "tcp6" && remote.To4() != nil {
continue
}
host = remote.String()
}
uri.Host = net.JoinHostPort(host, port)
fixed = append(fixed, uri.String())
}
return fixed
}
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
}
func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
return &loggingResponseWriter{w, http.StatusOK}
}
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
func addressStrs(dbAddrs []DatabaseAddress) []string {
res := make([]string, len(dbAddrs))
for i, a := range dbAddrs {
res[i] = a.Address
}
return res
}
func errorRetryAfterString() string {
return strconv.Itoa(errorRetryAfterSeconds + rand.Intn(errorRetryFuzzSeconds))
}
func notFoundRetryAfterString(misses int) string {
retryAfterS := notFoundRetryMinSeconds + notFoundRetryIncSeconds*misses
if retryAfterS > notFoundRetryMaxSeconds {
retryAfterS = notFoundRetryMaxSeconds
}
retryAfterS += rand.Intn(notFoundRetryFuzzSeconds)
return strconv.Itoa(retryAfterS)
}
func reannounceAfterString() string {
return strconv.Itoa(reannounceAfterSeconds + rand.Intn(reannounzeFuzzSeconds))
}

View File

@@ -1,75 +0,0 @@
// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
package main
import (
"database/sql"
"log"
"time"
)
type cleansrv struct {
intv time.Duration
db *sql.DB
prep map[string]*sql.Stmt
}
func (s *cleansrv) Serve() {
for {
time.Sleep(next(s.intv))
err := s.cleanOldEntries()
if err != nil {
log.Println("Clean:", err)
}
}
}
func (s *cleansrv) Stop() {
panic("stop unimplemented")
}
func (s *cleansrv) cleanOldEntries() (err error) {
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
} else {
tx.Rollback()
}
}()
res, err := tx.Stmt(s.prep["cleanAddress"]).Exec()
if err != nil {
return err
}
if rows, _ := res.RowsAffected(); rows > 0 {
log.Printf("Clean: %d old addresses", rows)
}
res, err = tx.Stmt(s.prep["cleanDevice"]).Exec()
if err != nil {
return err
}
if rows, _ := res.RowsAffected(); rows > 0 {
log.Printf("Clean: %d old devices", rows)
}
var devs, addrs int
row := tx.Stmt(s.prep["countDevice"]).QueryRow()
if err = row.Scan(&devs); err != nil {
return err
}
row = tx.Stmt(s.prep["countAddress"]).QueryRow()
if err = row.Scan(&addrs); err != nil {
return err
}
log.Printf("Database: %d devices, %d addresses", devs, addrs)
return nil
}

336
cmd/stdiscosrv/database.go Normal file
View File

@@ -0,0 +1,336 @@
// Copyright (C) 2018 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
//go:generate go run ../../script/protofmt.go database.proto
//go:generate protoc -I ../../../../../ -I ../../vendor/ -I ../../vendor/github.com/gogo/protobuf/protobuf -I . --gogofast_out=. database.proto
package main
import (
"sort"
"time"
"github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/util"
)
type clock interface {
Now() time.Time
}
type defaultClock struct{}
func (defaultClock) Now() time.Time {
return time.Now()
}
type database interface {
put(key string, rec DatabaseRecord) error
merge(key string, addrs []DatabaseAddress, seen int64) error
get(key string) (DatabaseRecord, error)
}
type levelDBStore struct {
db *leveldb.DB
inbox chan func()
stop chan struct{}
clock clock
marshalBuf []byte
}
func newLevelDBStore(dir string) (*levelDBStore, error) {
db, err := leveldb.OpenFile(dir, levelDBOptions)
if err != nil {
return nil, err
}
return &levelDBStore{
db: db,
inbox: make(chan func(), 16),
stop: make(chan struct{}),
clock: defaultClock{},
}, nil
}
func (s *levelDBStore) put(key string, rec DatabaseRecord) error {
t0 := time.Now()
defer func() {
databaseOperationSeconds.WithLabelValues(dbOpPut).Observe(time.Since(t0).Seconds())
}()
rc := make(chan error)
s.inbox <- func() {
size := rec.Size()
if len(s.marshalBuf) < size {
s.marshalBuf = make([]byte, size)
}
n, _ := rec.MarshalTo(s.marshalBuf)
rc <- s.db.Put([]byte(key), s.marshalBuf[:n], nil)
}
err := <-rc
if err != nil {
databaseOperations.WithLabelValues(dbOpPut, dbResError).Inc()
} else {
databaseOperations.WithLabelValues(dbOpPut, dbResSuccess).Inc()
}
return err
}
func (s *levelDBStore) merge(key string, addrs []DatabaseAddress, seen int64) error {
t0 := time.Now()
defer func() {
databaseOperationSeconds.WithLabelValues(dbOpMerge).Observe(time.Since(t0).Seconds())
}()
rc := make(chan error)
newRec := DatabaseRecord{
Addresses: addrs,
Seen: seen,
}
s.inbox <- func() {
// grab the existing record
oldRec, err := s.get(key)
if err != nil {
// "not found" is not an error from get, so this is serious
// stuff only
rc <- err
return
}
newRec = merge(newRec, oldRec)
// We replicate s.put() functionality here ourselves instead of
// calling it because we want to serialize our get above together
// with the put in the same function.
size := newRec.Size()
if len(s.marshalBuf) < size {
s.marshalBuf = make([]byte, size)
}
n, _ := newRec.MarshalTo(s.marshalBuf)
rc <- s.db.Put([]byte(key), s.marshalBuf[:n], nil)
}
err := <-rc
if err != nil {
databaseOperations.WithLabelValues(dbOpMerge, dbResError).Inc()
} else {
databaseOperations.WithLabelValues(dbOpMerge, dbResSuccess).Inc()
}
return err
}
func (s *levelDBStore) get(key string) (DatabaseRecord, error) {
t0 := time.Now()
defer func() {
databaseOperationSeconds.WithLabelValues(dbOpGet).Observe(time.Since(t0).Seconds())
}()
keyBs := []byte(key)
val, err := s.db.Get(keyBs, nil)
if err == leveldb.ErrNotFound {
databaseOperations.WithLabelValues(dbOpGet, dbResNotFound).Inc()
return DatabaseRecord{}, nil
}
if err != nil {
databaseOperations.WithLabelValues(dbOpGet, dbResError).Inc()
return DatabaseRecord{}, err
}
var rec DatabaseRecord
if err := rec.Unmarshal(val); err != nil {
databaseOperations.WithLabelValues(dbOpGet, dbResUnmarshalError).Inc()
return DatabaseRecord{}, nil
}
rec.Addresses = expire(rec.Addresses, s.clock.Now().UnixNano())
databaseOperations.WithLabelValues(dbOpGet, dbResSuccess).Inc()
return rec, nil
}
func (s *levelDBStore) Serve() {
t := time.NewTimer(0)
defer t.Stop()
defer s.db.Close()
// Start the statistics serve routine. It will exit with us when
// statisticsTrigger is closed.
statisticsTrigger := make(chan struct{})
defer close(statisticsTrigger)
statisticsDone := make(chan struct{})
go s.statisticsServe(statisticsTrigger, statisticsDone)
for {
select {
case fn := <-s.inbox:
// Run function in serialized order.
fn()
case <-t.C:
// Trigger the statistics routine to do its thing in the
// background.
statisticsTrigger <- struct{}{}
case <-statisticsDone:
// The statistics routine is done with one iteratation, schedule
// the next.
t.Reset(databaseStatisticsInterval)
case <-s.stop:
// We're done.
return
}
}
}
func (s *levelDBStore) statisticsServe(trigger <-chan struct{}, done chan<- struct{}) {
for range trigger {
t0 := time.Now()
nowNanos := t0.UnixNano()
cutoff24h := t0.Add(-24 * time.Hour).UnixNano()
cutoff1w := t0.Add(-7 * 24 * time.Hour).UnixNano()
current, last24h, last1w, inactive, errors := 0, 0, 0, 0, 0
iter := s.db.NewIterator(&util.Range{}, nil)
for iter.Next() {
// Attempt to unmarshal the record and count the
// failure if there's something wrong with it.
var rec DatabaseRecord
if err := rec.Unmarshal(iter.Value()); err != nil {
errors++
continue
}
// If there are addresses that have not expired it's a current
// record, otherwise account it based on when it was last seen
// (last 24 hours or last week) or finally as inactice.
switch {
case len(expire(rec.Addresses, nowNanos)) > 0:
current++
case rec.Seen > cutoff24h:
last24h++
case rec.Seen > cutoff1w:
last1w++
default:
inactive++
}
}
iter.Release()
databaseKeys.WithLabelValues("current").Set(float64(current))
databaseKeys.WithLabelValues("last24h").Set(float64(last24h))
databaseKeys.WithLabelValues("last1w").Set(float64(last1w))
databaseKeys.WithLabelValues("inactive").Set(float64(inactive))
databaseKeys.WithLabelValues("error").Set(float64(errors))
databaseStatisticsSeconds.Set(time.Since(t0).Seconds())
// Signal that we are done and can be scheduled again.
done <- struct{}{}
}
}
func (s *levelDBStore) Stop() {
close(s.stop)
}
// merge returns the merged result of the two database records a and b. The
// result is the union of the two address sets, with the newer expiry time
// chosen for any duplicates.
func merge(a, b DatabaseRecord) DatabaseRecord {
// Both lists must be sorted for this to work.
sort.Slice(a.Addresses, func(i, j int) bool {
return a.Addresses[i].Address < a.Addresses[j].Address
})
sort.Slice(b.Addresses, func(i, j int) bool {
return b.Addresses[i].Address < b.Addresses[j].Address
})
res := DatabaseRecord{
Addresses: make([]DatabaseAddress, 0, len(a.Addresses)+len(b.Addresses)),
Seen: a.Seen,
}
if b.Seen > a.Seen {
res.Seen = b.Seen
}
aIdx := 0
bIdx := 0
aAddrs := a.Addresses
bAddrs := b.Addresses
loop:
for {
switch {
case aIdx == len(aAddrs) && bIdx == len(bAddrs):
// both lists are exhausted, we are done
break loop
case aIdx == len(aAddrs):
// a is exhausted, pick from b and continue
res.Addresses = append(res.Addresses, bAddrs[bIdx])
bIdx++
continue
case bIdx == len(bAddrs):
// b is exhausted, pick from a and continue
res.Addresses = append(res.Addresses, aAddrs[aIdx])
aIdx++
continue
}
// We have values left on both sides.
aVal := aAddrs[aIdx]
bVal := bAddrs[bIdx]
switch {
case aVal.Address == bVal.Address:
// update for same address, pick newer
if aVal.Expires > bVal.Expires {
res.Addresses = append(res.Addresses, aVal)
} else {
res.Addresses = append(res.Addresses, bVal)
}
aIdx++
bIdx++
case aVal.Address < bVal.Address:
// a is smallest, pick it and continue
res.Addresses = append(res.Addresses, aVal)
aIdx++
default:
// b is smallest, pick it and continue
res.Addresses = append(res.Addresses, bVal)
bIdx++
}
}
return res
}
// expire returns the list of addresses after removing expired entries.
// Expiration happen in place, so the slice given as the parameter is
// destroyed. Internal order is not preserved.
func expire(addrs []DatabaseAddress, now int64) []DatabaseAddress {
i := 0
for i < len(addrs) {
if addrs[i].Expires < now {
// This item is expired. Replace it with the last in the list
// (noop if we are at the last item).
addrs[i] = addrs[len(addrs)-1]
// Wipe the last item of the list to release references to
// strings and stuff.
addrs[len(addrs)-1] = DatabaseAddress{}
// Shorten the slice.
addrs = addrs[:len(addrs)-1]
continue
}
i++
}
return addrs
}

View File

@@ -0,0 +1,743 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: database.proto
/*
Package main is a generated protocol buffer package.
It is generated from these files:
database.proto
It has these top-level messages:
DatabaseRecord
ReplicationRecord
DatabaseAddress
*/
package main
import proto "github.com/gogo/protobuf/proto"
import fmt "fmt"
import math "math"
import _ "github.com/gogo/protobuf/gogoproto"
import io "io"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
type DatabaseRecord struct {
Addresses []DatabaseAddress `protobuf:"bytes,1,rep,name=addresses" json:"addresses"`
Misses int32 `protobuf:"varint,2,opt,name=misses,proto3" json:"misses,omitempty"`
Seen int64 `protobuf:"varint,3,opt,name=seen,proto3" json:"seen,omitempty"`
}
func (m *DatabaseRecord) Reset() { *m = DatabaseRecord{} }
func (m *DatabaseRecord) String() string { return proto.CompactTextString(m) }
func (*DatabaseRecord) ProtoMessage() {}
func (*DatabaseRecord) Descriptor() ([]byte, []int) { return fileDescriptorDatabase, []int{0} }
type ReplicationRecord struct {
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
Addresses []DatabaseAddress `protobuf:"bytes,2,rep,name=addresses" json:"addresses"`
Seen int64 `protobuf:"varint,3,opt,name=seen,proto3" json:"seen,omitempty"`
}
func (m *ReplicationRecord) Reset() { *m = ReplicationRecord{} }
func (m *ReplicationRecord) String() string { return proto.CompactTextString(m) }
func (*ReplicationRecord) ProtoMessage() {}
func (*ReplicationRecord) Descriptor() ([]byte, []int) { return fileDescriptorDatabase, []int{1} }
type DatabaseAddress struct {
Address string `protobuf:"bytes,1,opt,name=address,proto3" json:"address,omitempty"`
Expires int64 `protobuf:"varint,2,opt,name=expires,proto3" json:"expires,omitempty"`
}
func (m *DatabaseAddress) Reset() { *m = DatabaseAddress{} }
func (m *DatabaseAddress) String() string { return proto.CompactTextString(m) }
func (*DatabaseAddress) ProtoMessage() {}
func (*DatabaseAddress) Descriptor() ([]byte, []int) { return fileDescriptorDatabase, []int{2} }
func init() {
proto.RegisterType((*DatabaseRecord)(nil), "main.DatabaseRecord")
proto.RegisterType((*ReplicationRecord)(nil), "main.ReplicationRecord")
proto.RegisterType((*DatabaseAddress)(nil), "main.DatabaseAddress")
}
func (m *DatabaseRecord) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalTo(dAtA)
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *DatabaseRecord) MarshalTo(dAtA []byte) (int, error) {
var i int
_ = i
var l int
_ = l
if len(m.Addresses) > 0 {
for _, msg := range m.Addresses {
dAtA[i] = 0xa
i++
i = encodeVarintDatabase(dAtA, i, uint64(msg.Size()))
n, err := msg.MarshalTo(dAtA[i:])
if err != nil {
return 0, err
}
i += n
}
}
if m.Misses != 0 {
dAtA[i] = 0x10
i++
i = encodeVarintDatabase(dAtA, i, uint64(m.Misses))
}
if m.Seen != 0 {
dAtA[i] = 0x18
i++
i = encodeVarintDatabase(dAtA, i, uint64(m.Seen))
}
return i, nil
}
func (m *ReplicationRecord) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalTo(dAtA)
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *ReplicationRecord) MarshalTo(dAtA []byte) (int, error) {
var i int
_ = i
var l int
_ = l
if len(m.Key) > 0 {
dAtA[i] = 0xa
i++
i = encodeVarintDatabase(dAtA, i, uint64(len(m.Key)))
i += copy(dAtA[i:], m.Key)
}
if len(m.Addresses) > 0 {
for _, msg := range m.Addresses {
dAtA[i] = 0x12
i++
i = encodeVarintDatabase(dAtA, i, uint64(msg.Size()))
n, err := msg.MarshalTo(dAtA[i:])
if err != nil {
return 0, err
}
i += n
}
}
if m.Seen != 0 {
dAtA[i] = 0x18
i++
i = encodeVarintDatabase(dAtA, i, uint64(m.Seen))
}
return i, nil
}
func (m *DatabaseAddress) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalTo(dAtA)
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *DatabaseAddress) MarshalTo(dAtA []byte) (int, error) {
var i int
_ = i
var l int
_ = l
if len(m.Address) > 0 {
dAtA[i] = 0xa
i++
i = encodeVarintDatabase(dAtA, i, uint64(len(m.Address)))
i += copy(dAtA[i:], m.Address)
}
if m.Expires != 0 {
dAtA[i] = 0x10
i++
i = encodeVarintDatabase(dAtA, i, uint64(m.Expires))
}
return i, nil
}
func encodeFixed64Database(dAtA []byte, offset int, v uint64) int {
dAtA[offset] = uint8(v)
dAtA[offset+1] = uint8(v >> 8)
dAtA[offset+2] = uint8(v >> 16)
dAtA[offset+3] = uint8(v >> 24)
dAtA[offset+4] = uint8(v >> 32)
dAtA[offset+5] = uint8(v >> 40)
dAtA[offset+6] = uint8(v >> 48)
dAtA[offset+7] = uint8(v >> 56)
return offset + 8
}
func encodeFixed32Database(dAtA []byte, offset int, v uint32) int {
dAtA[offset] = uint8(v)
dAtA[offset+1] = uint8(v >> 8)
dAtA[offset+2] = uint8(v >> 16)
dAtA[offset+3] = uint8(v >> 24)
return offset + 4
}
func encodeVarintDatabase(dAtA []byte, offset int, v uint64) int {
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return offset + 1
}
func (m *DatabaseRecord) Size() (n int) {
var l int
_ = l
if len(m.Addresses) > 0 {
for _, e := range m.Addresses {
l = e.Size()
n += 1 + l + sovDatabase(uint64(l))
}
}
if m.Misses != 0 {
n += 1 + sovDatabase(uint64(m.Misses))
}
if m.Seen != 0 {
n += 1 + sovDatabase(uint64(m.Seen))
}
return n
}
func (m *ReplicationRecord) Size() (n int) {
var l int
_ = l
l = len(m.Key)
if l > 0 {
n += 1 + l + sovDatabase(uint64(l))
}
if len(m.Addresses) > 0 {
for _, e := range m.Addresses {
l = e.Size()
n += 1 + l + sovDatabase(uint64(l))
}
}
if m.Seen != 0 {
n += 1 + sovDatabase(uint64(m.Seen))
}
return n
}
func (m *DatabaseAddress) Size() (n int) {
var l int
_ = l
l = len(m.Address)
if l > 0 {
n += 1 + l + sovDatabase(uint64(l))
}
if m.Expires != 0 {
n += 1 + sovDatabase(uint64(m.Expires))
}
return n
}
func sovDatabase(x uint64) (n int) {
for {
n++
x >>= 7
if x == 0 {
break
}
}
return n
}
func sozDatabase(x uint64) (n int) {
return sovDatabase(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *DatabaseRecord) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: DatabaseRecord: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: DatabaseRecord: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Addresses", wireType)
}
var msglen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
msglen |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if msglen < 0 {
return ErrInvalidLengthDatabase
}
postIndex := iNdEx + msglen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Addresses = append(m.Addresses, DatabaseAddress{})
if err := m.Addresses[len(m.Addresses)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil {
return err
}
iNdEx = postIndex
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Misses", wireType)
}
m.Misses = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Misses |= (int32(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
case 3:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Seen", wireType)
}
m.Seen = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Seen |= (int64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipDatabase(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthDatabase
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *ReplicationRecord) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: ReplicationRecord: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: ReplicationRecord: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthDatabase
}
postIndex := iNdEx + intStringLen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Key = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Addresses", wireType)
}
var msglen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
msglen |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if msglen < 0 {
return ErrInvalidLengthDatabase
}
postIndex := iNdEx + msglen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Addresses = append(m.Addresses, DatabaseAddress{})
if err := m.Addresses[len(m.Addresses)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil {
return err
}
iNdEx = postIndex
case 3:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Seen", wireType)
}
m.Seen = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Seen |= (int64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipDatabase(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthDatabase
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *DatabaseAddress) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: DatabaseAddress: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: DatabaseAddress: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Address", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthDatabase
}
postIndex := iNdEx + intStringLen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Address = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Expires", wireType)
}
m.Expires = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowDatabase
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Expires |= (int64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipDatabase(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthDatabase
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipDatabase(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowDatabase
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowDatabase
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
return iNdEx, nil
case 1:
iNdEx += 8
return iNdEx, nil
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowDatabase
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
iNdEx += length
if length < 0 {
return 0, ErrInvalidLengthDatabase
}
return iNdEx, nil
case 3:
for {
var innerWire uint64
var start int = iNdEx
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowDatabase
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
innerWire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
innerWireType := int(innerWire & 0x7)
if innerWireType == 4 {
break
}
next, err := skipDatabase(dAtA[start:])
if err != nil {
return 0, err
}
iNdEx = start + next
}
return iNdEx, nil
case 4:
return iNdEx, nil
case 5:
iNdEx += 4
return iNdEx, nil
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
}
panic("unreachable")
}
var (
ErrInvalidLengthDatabase = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowDatabase = fmt.Errorf("proto: integer overflow")
)
func init() { proto.RegisterFile("database.proto", fileDescriptorDatabase) }
var fileDescriptorDatabase = []byte{
// 254 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4b, 0x49, 0x2c, 0x49,
0x4c, 0x4a, 0x2c, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0xc9, 0x4d, 0xcc, 0xcc,
0x93, 0xd2, 0x4d, 0xcf, 0x2c, 0xc9, 0x28, 0x4d, 0xd2, 0x4b, 0xce, 0xcf, 0xd5, 0x4f, 0xcf, 0x4f,
0xcf, 0xd7, 0x07, 0x4b, 0x26, 0x95, 0xa6, 0x81, 0x79, 0x60, 0x0e, 0x98, 0x05, 0xd1, 0xa4, 0x54,
0xce, 0xc5, 0xe7, 0x02, 0x35, 0x26, 0x28, 0x35, 0x39, 0xbf, 0x28, 0x45, 0xc8, 0x92, 0x8b, 0x33,
0x31, 0x25, 0xa5, 0x28, 0xb5, 0xb8, 0x38, 0xb5, 0x58, 0x82, 0x51, 0x81, 0x59, 0x83, 0xdb, 0x48,
0x54, 0x0f, 0x64, 0xb4, 0x1e, 0x4c, 0xa1, 0x23, 0x44, 0xda, 0x89, 0xe5, 0xc4, 0x3d, 0x79, 0x86,
0x20, 0x84, 0x6a, 0x21, 0x31, 0x2e, 0xb6, 0xdc, 0x4c, 0xb0, 0x3e, 0x26, 0x05, 0x46, 0x0d, 0xd6,
0x20, 0x28, 0x4f, 0x48, 0x88, 0x8b, 0xa5, 0x38, 0x35, 0x35, 0x4f, 0x82, 0x59, 0x81, 0x51, 0x83,
0x39, 0x08, 0xcc, 0x56, 0x2a, 0xe1, 0x12, 0x0c, 0x4a, 0x2d, 0xc8, 0xc9, 0x4c, 0x4e, 0x2c, 0xc9,
0xcc, 0xcf, 0x83, 0xda, 0x2d, 0xc0, 0xc5, 0x9c, 0x9d, 0x5a, 0x29, 0xc1, 0xa8, 0xc0, 0xa8, 0xc1,
0x19, 0x04, 0x62, 0xa2, 0xba, 0x86, 0x89, 0x24, 0xd7, 0x60, 0xb3, 0xd5, 0x95, 0x8b, 0x1f, 0x4d,
0x9f, 0x90, 0x04, 0x17, 0x3b, 0x54, 0x0f, 0xd4, 0x5e, 0x18, 0x17, 0x24, 0x93, 0x5a, 0x51, 0x90,
0x59, 0x04, 0xf5, 0x0f, 0x73, 0x10, 0x8c, 0xeb, 0x24, 0x70, 0xe2, 0xa1, 0x1c, 0xc3, 0x89, 0x47,
0x72, 0x8c, 0x17, 0x1e, 0xc9, 0x31, 0x3e, 0x78, 0x24, 0xc7, 0x98, 0xc4, 0x06, 0x0e, 0x4e, 0x63,
0x40, 0x00, 0x00, 0x00, 0xff, 0xff, 0x9e, 0x45, 0x60, 0x7e, 0x95, 0x01, 0x00, 0x00,
}

View File

@@ -0,0 +1,30 @@
// Copyright (C) 2018 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
syntax = "proto3";
package main;
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
option (gogoproto.goproto_getters_all) = false;
message DatabaseRecord {
repeated DatabaseAddress addresses = 1 [(gogoproto.nullable) = false];
int32 misses = 2; // Number of lookups without hits
int64 seen = 3; // Unix nanos, last device announce
}
message ReplicationRecord {
string key = 1;
repeated DatabaseAddress addresses = 2 [(gogoproto.nullable) = false];
int64 seen = 3; // Unix nanos, last device announce
}
message DatabaseAddress {
string address = 1;
int64 expires = 2; // Unix nanos
}

View File

@@ -0,0 +1,211 @@
// Copyright (C) 2018 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
package main
import (
"fmt"
"os"
"testing"
"time"
)
func TestDatabaseGetSet(t *testing.T) {
os.RemoveAll("_database")
defer os.RemoveAll("_database")
db, err := newLevelDBStore("_database")
if err != nil {
t.Fatal(err)
}
go db.Serve()
defer db.Stop()
// Check missing record
rec, err := db.get("abcd")
if err != nil {
t.Error("not found should not be an error")
}
if len(rec.Addresses) != 0 {
t.Error("addresses should be empty")
}
if rec.Misses != 0 {
t.Error("missing should be zero")
}
// Set up a clock
now := time.Now()
tc := &testClock{now}
db.clock = tc
// Put a record
rec.Addresses = []DatabaseAddress{
{Address: "tcp://1.2.3.4:5", Expires: tc.Now().Add(time.Minute).UnixNano()},
}
if err := db.put("abcd", rec); err != nil {
t.Fatal(err)
}
// Verify it
rec, err = db.get("abcd")
if err != nil {
t.Fatal(err)
}
if len(rec.Addresses) != 1 {
t.Log(rec.Addresses)
t.Fatal("should have one address")
}
if rec.Addresses[0].Address != "tcp://1.2.3.4:5" {
t.Log(rec.Addresses)
t.Error("incorrect address")
}
// Wind the clock one half expiry, and merge in a new address
tc.wind(30 * time.Second)
addrs := []DatabaseAddress{
{Address: "tcp://6.7.8.9:0", Expires: tc.Now().Add(time.Minute).UnixNano()},
}
if err := db.merge("abcd", addrs, tc.Now().UnixNano()); err != nil {
t.Fatal(err)
}
// Verify it
rec, err = db.get("abcd")
if err != nil {
t.Fatal(err)
}
if len(rec.Addresses) != 2 {
t.Log(rec.Addresses)
t.Fatal("should have two addresses")
}
if rec.Addresses[0].Address != "tcp://1.2.3.4:5" {
t.Log(rec.Addresses)
t.Error("incorrect address[0]")
}
if rec.Addresses[1].Address != "tcp://6.7.8.9:0" {
t.Log(rec.Addresses)
t.Error("incorrect address[1]")
}
// Pass the first expiry time
tc.wind(45 * time.Second)
// Verify it
rec, err = db.get("abcd")
if err != nil {
t.Fatal(err)
}
if len(rec.Addresses) != 1 {
t.Log(rec.Addresses)
t.Fatal("should have one address")
}
if rec.Addresses[0].Address != "tcp://6.7.8.9:0" {
t.Log(rec.Addresses)
t.Error("incorrect address")
}
// Put a record with misses
rec = DatabaseRecord{Misses: 42}
if err := db.put("efgh", rec); err != nil {
t.Fatal(err)
}
// Verify it
rec, err = db.get("efgh")
if err != nil {
t.Fatal(err)
}
if len(rec.Addresses) != 0 {
t.Log(rec.Addresses)
t.Fatal("should have no addresses")
}
if rec.Misses != 42 {
t.Log(rec.Misses)
t.Error("incorrect misses")
}
// Set an address
addrs = []DatabaseAddress{
{Address: "tcp://6.7.8.9:0", Expires: tc.Now().Add(time.Minute).UnixNano()},
}
if err := db.merge("efgh", addrs, tc.Now().UnixNano()); err != nil {
t.Fatal(err)
}
// Verify it
rec, err = db.get("efgh")
if err != nil {
t.Fatal(err)
}
if len(rec.Addresses) != 1 {
t.Log(rec.Addresses)
t.Fatal("should have one addres")
}
if rec.Misses != 0 {
t.Log(rec.Misses)
t.Error("should have no misses")
}
}
func TestFilter(t *testing.T) {
// all cases are expired with t=10
cases := []struct {
a []DatabaseAddress
b []DatabaseAddress
}{
{
a: nil,
b: nil,
},
{
a: []DatabaseAddress{{Address: "a", Expires: 9}, {Address: "b", Expires: 9}, {Address: "c", Expires: 9}},
b: []DatabaseAddress{},
},
{
a: []DatabaseAddress{{Address: "a", Expires: 10}},
b: []DatabaseAddress{{Address: "a", Expires: 10}},
},
{
a: []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 10}, {Address: "c", Expires: 10}},
b: []DatabaseAddress{{Address: "a", Expires: 10}, {Address: "b", Expires: 10}, {Address: "c", Expires: 10}},
},
{
a: []DatabaseAddress{{Address: "a", Expires: 5}, {Address: "b", Expires: 15}, {Address: "c", Expires: 5}, {Address: "d", Expires: 15}, {Address: "e", Expires: 5}},
b: []DatabaseAddress{{Address: "d", Expires: 15}, {Address: "b", Expires: 15}}, // gets reordered
},
}
for _, tc := range cases {
res := expire(tc.a, 10)
if fmt.Sprint(res) != fmt.Sprint(tc.b) {
t.Errorf("Incorrect result %v, expected %v", res, tc.b)
}
}
}
type testClock struct {
now time.Time
}
func (t *testClock) wind(d time.Duration) {
t.now = t.now.Add(d)
}
func (t *testClock) Now() time.Time {
return t.now
}

View File

@@ -1,32 +0,0 @@
// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
package main
import (
"database/sql"
"fmt"
)
type setupFunc func(db *sql.DB) error
type compileFunc func(db *sql.DB) (map[string]*sql.Stmt, error)
var (
setupFuncs = make(map[string]setupFunc)
compileFuncs = make(map[string]compileFunc)
)
func register(name string, setup setupFunc, compile compileFunc) {
setupFuncs[name] = setup
compileFuncs[name] = compile
}
func setup(backend string, db *sql.DB) (map[string]*sql.Stmt, error) {
setup, ok := setupFuncs[backend]
if !ok {
return nil, fmt.Errorf("Unsupported backend")
}
if err := setup(db); err != nil {
return nil, err
}
return compileFuncs[backend](db)
}

View File

@@ -1,29 +1,70 @@
// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
// Copyright (C) 2018 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
package main
import (
"crypto/tls"
"database/sql"
"flag"
"fmt"
"log"
"net"
"net/http"
"os"
"runtime"
"strconv"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/tlsutil"
"github.com/syndtr/goleveldb/leveldb/opt"
"github.com/thejerf/suture"
)
const (
minNegCache = 60 // seconds
maxNegCache = 3600 // seconds
maxDeviceAge = 7 * 86400 // one week, in seconds
addressExpiryTime = 2 * time.Hour
databaseStatisticsInterval = 5 * time.Minute
// Reannounce-After is set to reannounceAfterSeconds +
// random(reannounzeFuzzSeconds), similar for Retry-After
reannounceAfterSeconds = 3300
reannounzeFuzzSeconds = 300
errorRetryAfterSeconds = 1500
errorRetryFuzzSeconds = 300
// Retry for not found is minSeconds + failures * incSeconds +
// random(fuzz), where failures is the number of consecutive lookups
// with no answer, up to maxSeconds. The fuzz is applied after capping
// to maxSeconds.
notFoundRetryMinSeconds = 60
notFoundRetryMaxSeconds = 3540
notFoundRetryIncSeconds = 10
notFoundRetryFuzzSeconds = 60
// How often (in requests) we serialize the missed counter to database.
notFoundMissesWriteInterval = 10
httpReadTimeout = 5 * time.Second
httpWriteTimeout = 5 * time.Second
httpMaxHeaderBytes = 1 << 10
// Size of the replication outbox channel
replicationOutboxSize = 10000
)
// These options make the database a little more optimized for writes, at
// the expense of some memory usage and risk of losing writes in a (system)
// crash.
var levelDBOptions = &opt.Options{
NoSync: true,
WriteBuffer: 32 << 20, // default 4<<20
}
var (
Version string
BuildStamp string
@@ -43,17 +84,7 @@ func init() {
}
var (
lruSize = 10240
limitAvg = 5
limitBurst = 20
globalStats stats
statsFile string
backend = "ql"
dsn = getEnvDefault("STDISCOSRV_DB_DSN", "memory://stdiscosrv")
certFile = "cert.pem"
keyFile = "key.pem"
debug = false
useHTTP = false
debug = false
)
func main() {
@@ -63,84 +94,112 @@ func main() {
)
var listen string
var dir string
var metricsListen string
var replicationListen string
var replicationPeers string
var certFile string
var keyFile string
var useHTTP bool
log.SetOutput(os.Stdout)
log.SetFlags(0)
flag.StringVar(&certFile, "cert", "./cert.pem", "Certificate file")
flag.StringVar(&dir, "db-dir", "./discovery.db", "Database directory")
flag.BoolVar(&debug, "debug", false, "Print debug output")
flag.BoolVar(&useHTTP, "http", false, "Listen on HTTP (behind an HTTPS proxy)")
flag.StringVar(&listen, "listen", ":8443", "Listen address")
flag.IntVar(&lruSize, "limit-cache", lruSize, "Limiter cache entries")
flag.IntVar(&limitAvg, "limit-avg", limitAvg, "Allowed average package rate, per 10 s")
flag.IntVar(&limitBurst, "limit-burst", limitBurst, "Allowed burst size, packets")
flag.StringVar(&statsFile, "stats-file", statsFile, "File to write periodic operation stats to")
flag.StringVar(&backend, "db-backend", backend, "Database backend to use")
flag.StringVar(&dsn, "db-dsn", dsn, "Database DSN")
flag.StringVar(&certFile, "cert", certFile, "Certificate file")
flag.StringVar(&keyFile, "key", keyFile, "Key file")
flag.BoolVar(&debug, "debug", debug, "Debug")
flag.BoolVar(&useHTTP, "http", useHTTP, "Listen on HTTP (behind an HTTPS proxy)")
flag.StringVar(&keyFile, "key", "./key.pem", "Key file")
flag.StringVar(&metricsListen, "metrics-listen", "", "Metrics listen address")
flag.StringVar(&replicationPeers, "replicate", "", "Replication peers, id@address, comma separated")
flag.StringVar(&replicationListen, "replication-listen", ":19200", "Replication listen address")
flag.Parse()
log.Println(LongVersion)
var cert tls.Certificate
var err error
if !useHTTP {
cert, err = tls.LoadX509KeyPair(certFile, keyFile)
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
log.Println("Failed to load keypair. Generating one, this might take a while...")
cert, err = tlsutil.NewCertificate(certFile, keyFile, "stdiscosrv", 0)
if err != nil {
log.Println("Failed to load keypair. Generating one, this might take a while...")
cert, err = tlsutil.NewCertificate(certFile, keyFile, "stdiscosrv", 3072)
if err != nil {
log.Fatalln("Failed to generate X509 key pair:", err)
}
log.Fatalln("Failed to generate X509 key pair:", err)
}
devID := protocol.NewDeviceID(cert.Certificate[0])
log.Println("Server device ID is", devID)
}
db, err := sql.Open(backend, dsn)
if err != nil {
log.Fatalln("sql.Open:", err)
}
prep, err := setup(backend, db)
if err != nil {
log.Fatalln("Setup:", err)
devID := protocol.NewDeviceID(cert.Certificate[0])
log.Println("Server device ID is", devID)
// Parse the replication specs, if any.
var allowedReplicationPeers []protocol.DeviceID
var replicationDestinations []string
parts := strings.Split(replicationPeers, ",")
for _, part := range parts {
fields := strings.Split(part, "@")
switch len(fields) {
case 2:
// This is an id@address specification. Grab the address for the
// destination list. Try to resolve it once to catch obvious
// syntax errors here rather than having the sender service fail
// repeatedly later.
_, err := net.ResolveTCPAddr("tcp", fields[1])
if err != nil {
log.Fatalln("Resolving address:", err)
}
replicationDestinations = append(replicationDestinations, fields[1])
fallthrough // N.B.
case 1:
// The first part is always a device ID.
id, err := protocol.DeviceIDFromString(fields[0])
if err != nil {
log.Fatalln("Parsing device ID:", err)
}
allowedReplicationPeers = append(allowedReplicationPeers, id)
default:
log.Fatalln("Unrecognized replication spec:", part)
}
}
// Root of the service tree.
main := suture.NewSimple("main")
main.Add(&querysrv{
addr: listen,
cert: cert,
db: db,
prep: prep,
})
// Start the database.
db, err := newLevelDBStore(dir)
if err != nil {
log.Fatalln("Open database:", err)
}
main.Add(db)
main.Add(&cleansrv{
intv: cleanIntv,
db: db,
prep: prep,
})
// Start any replication senders.
var repl replicationMultiplexer
for _, dst := range replicationDestinations {
rs := newReplicationSender(dst, cert, allowedReplicationPeers)
main.Add(rs)
repl = append(repl, rs)
}
main.Add(&statssrv{
intv: statsIntv,
file: statsFile,
db: db,
})
// If we have replication configured, start the replication listener.
if len(allowedReplicationPeers) > 0 {
rl := newReplicationListener(replicationListen, cert, allowedReplicationPeers, db)
main.Add(rl)
}
globalStats.Reset()
// Start the main API server.
qs := newAPISrv(listen, cert, db, repl, useHTTP)
main.Add(qs)
// If we have a metrics port configured, start a metrics handler.
if metricsListen != "" {
go func() {
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.Handler())
log.Fatal(http.ListenAndServe(metricsListen, mux))
}()
}
// Engage!
main.Serve()
}
func getEnvDefault(key, def string) string {
if val := os.Getenv(key); val != "" {
return val
}
return def
}
func next(intv time.Duration) time.Duration {
t0 := time.Now()
t1 := t0.Add(intv).Truncate(intv)
return t1.Sub(t0)
}

View File

@@ -1,98 +0,0 @@
// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
package main
import (
"database/sql"
"fmt"
_ "github.com/lib/pq"
)
func init() {
register("postgres", postgresSetup, postgresCompile)
}
func postgresSetup(db *sql.DB) error {
var err error
db.SetMaxIdleConns(4)
db.SetMaxOpenConns(8)
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS Devices (
DeviceID CHAR(63) NOT NULL PRIMARY KEY,
Seen TIMESTAMP NOT NULL
)`)
if err != nil {
return err
}
var tmp string
row := db.QueryRow(`SELECT 'DevicesDeviceIDIndex'::regclass`)
if err = row.Scan(&tmp); err != nil {
_, err = db.Exec(`CREATE INDEX DevicesDeviceIDIndex ON Devices (DeviceID)`)
}
if err != nil {
return err
}
row = db.QueryRow(`SELECT 'DevicesSeenIndex'::regclass`)
if err = row.Scan(&tmp); err != nil {
_, err = db.Exec(`CREATE INDEX DevicesSeenIndex ON Devices (Seen)`)
}
if err != nil {
return err
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS Addresses (
DeviceID CHAR(63) NOT NULL,
Seen TIMESTAMP NOT NULL,
Address VARCHAR(2048) NOT NULL
)`)
if err != nil {
return err
}
row = db.QueryRow(`SELECT 'AddressesDeviceIDSeenIndex'::regclass`)
if err = row.Scan(&tmp); err != nil {
_, err = db.Exec(`CREATE INDEX AddressesDeviceIDSeenIndex ON Addresses (DeviceID, Seen)`)
}
if err != nil {
return err
}
row = db.QueryRow(`SELECT 'AddressesDeviceIDAddressIndex'::regclass`)
if err = row.Scan(&tmp); err != nil {
_, err = db.Exec(`CREATE INDEX AddressesDeviceIDAddressIndex ON Addresses (DeviceID, Address)`)
}
if err != nil {
return err
}
return nil
}
func postgresCompile(db *sql.DB) (map[string]*sql.Stmt, error) {
stmts := map[string]string{
"cleanAddress": "DELETE FROM Addresses WHERE Seen < now() - '2 hour'::INTERVAL",
"cleanDevice": fmt.Sprintf("DELETE FROM Devices WHERE Seen < now() - '%d hour'::INTERVAL", maxDeviceAge/3600),
"countAddress": "SELECT count(*) FROM Addresses",
"countDevice": "SELECT count(*) FROM Devices",
"insertAddress": "INSERT INTO Addresses (DeviceID, Seen, Address) VALUES ($1, now(), $2)",
"insertDevice": "INSERT INTO Devices (DeviceID, Seen) VALUES ($1, now())",
"selectAddress": "SELECT Address FROM Addresses WHERE DeviceID=$1 AND Seen > now() - '1 hour'::INTERVAL ORDER BY random() LIMIT 16",
"selectDevice": "SELECT Seen FROM Devices WHERE DeviceID=$1",
"updateAddress": "UPDATE Addresses SET Seen=now() WHERE DeviceID=$1 AND Address=$2",
"updateDevice": "UPDATE Devices SET Seen=now() WHERE DeviceID=$1",
}
res := make(map[string]*sql.Stmt, len(stmts))
for key, stmt := range stmts {
prep, err := db.Prepare(stmt)
if err != nil {
return nil, err
}
res[key] = prep
}
return res, nil
}

View File

@@ -1,81 +0,0 @@
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
package main
import (
"database/sql"
"fmt"
"log"
"github.com/cznic/ql"
)
func init() {
ql.RegisterDriver()
register("ql", qlSetup, qlCompile)
}
func qlSetup(db *sql.DB) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
defer func() {
if err == nil {
err = tx.Commit()
} else {
tx.Rollback()
}
}()
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS Devices (
DeviceID STRING NOT NULL,
Seen TIME NOT NULL
)`)
if err != nil {
return
}
if _, err = tx.Exec(`CREATE INDEX IF NOT EXISTS DevicesDeviceIDIndex ON Devices (DeviceID)`); err != nil {
return
}
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS Addresses (
DeviceID STRING NOT NULL,
Seen TIME NOT NULL,
Address STRING NOT NULL,
)`)
if err != nil {
return
}
_, err = tx.Exec(`CREATE INDEX IF NOT EXISTS AddressesDeviceIDAddressIndex ON Addresses (DeviceID, Address)`)
return
}
func qlCompile(db *sql.DB) (map[string]*sql.Stmt, error) {
stmts := map[string]string{
"cleanAddress": `DELETE FROM Addresses WHERE Seen < now() - duration("2h")`,
"cleanDevice": fmt.Sprintf(`DELETE FROM Devices WHERE Seen < now() - duration("%dh")`, maxDeviceAge/3600),
"countAddress": "SELECT count(*) FROM Addresses",
"countDevice": "SELECT count(*) FROM Devices",
"insertAddress": "INSERT INTO Addresses (DeviceID, Seen, Address) VALUES ($1, now(), $2)",
"insertDevice": "INSERT INTO Devices (DeviceID, Seen) VALUES ($1, now())",
"selectAddress": `SELECT Address from Addresses WHERE DeviceID==$1 AND Seen > now() - duration("1h") LIMIT 16`,
"selectDevice": "SELECT Seen FROM Devices WHERE DeviceID==$1",
"updateAddress": "UPDATE Addresses Seen=now() WHERE DeviceID==$1 AND Address==$2",
"updateDevice": "UPDATE Devices Seen=now() WHERE DeviceID==$1",
}
res := make(map[string]*sql.Stmt, len(stmts))
for key, stmt := range stmts {
prep, err := db.Prepare(stmt)
if err != nil {
log.Println("Failed to compile", stmt)
return nil, err
}
res[key] = prep
}
return res, nil
}

View File

@@ -1,492 +0,0 @@
// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
package main
import (
"bytes"
"crypto/tls"
"database/sql"
"encoding/json"
"encoding/pem"
"fmt"
"log"
"math/rand"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
"github.com/golang/groupcache/lru"
"github.com/syncthing/syncthing/lib/protocol"
"golang.org/x/net/context"
"golang.org/x/time/rate"
)
type querysrv struct {
addr string
db *sql.DB
prep map[string]*sql.Stmt
limiter *safeCache
cert tls.Certificate
listener net.Listener
}
type announcement struct {
Seen time.Time `json:"seen"`
Addresses []string `json:"addresses"`
}
type safeCache struct {
*lru.Cache
mut sync.Mutex
}
func (s *safeCache) Get(key string) (val interface{}, ok bool) {
s.mut.Lock()
val, ok = s.Cache.Get(key)
s.mut.Unlock()
return
}
func (s *safeCache) Add(key string, val interface{}) {
s.mut.Lock()
s.Cache.Add(key, val)
s.mut.Unlock()
}
type requestID int64
func (i requestID) String() string {
return fmt.Sprintf("%016x", int64(i))
}
type contextKey int
const idKey contextKey = iota
func negCacheFor(lastSeen time.Time) int {
since := time.Since(lastSeen).Seconds()
if since >= maxDeviceAge {
return maxNegCache
}
if since < 0 {
// That's weird
return minNegCache
}
// Return a value linearly scaled from minNegCache (at zero seconds ago)
// to maxNegCache (at maxDeviceAge seconds ago).
r := since / maxDeviceAge
return int(minNegCache + r*(maxNegCache-minNegCache))
}
func (s *querysrv) Serve() {
s.limiter = &safeCache{
Cache: lru.New(lruSize),
}
if useHTTP {
listener, err := net.Listen("tcp", s.addr)
if err != nil {
log.Println("Listen:", err)
return
}
s.listener = listener
} else {
tlsCfg := &tls.Config{
Certificates: []tls.Certificate{s.cert},
ClientAuth: tls.RequestClientCert,
SessionTicketsDisabled: true,
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
},
}
tlsListener, err := tls.Listen("tcp", s.addr, tlsCfg)
if err != nil {
log.Println("Listen:", err)
return
}
s.listener = tlsListener
}
http.HandleFunc("/v2/", s.handler)
http.HandleFunc("/ping", handlePing)
srv := &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
MaxHeaderBytes: 1 << 10,
}
if err := srv.Serve(s.listener); err != nil {
log.Println("Serve:", err)
}
}
var topCtx = context.Background()
func (s *querysrv) handler(w http.ResponseWriter, req *http.Request) {
reqID := requestID(rand.Int63())
ctx := context.WithValue(topCtx, idKey, reqID)
if debug {
log.Println(reqID, req.Method, req.URL)
}
t0 := time.Now()
defer func() {
diff := time.Since(t0)
var comment string
if diff > time.Second {
comment = "(very slow request)"
} else if diff > 100*time.Millisecond {
comment = "(slow request)"
}
if comment != "" || debug {
log.Println(reqID, req.Method, req.URL, "completed in", diff, comment)
}
}()
var remoteIP net.IP
if useHTTP {
remoteIP = net.ParseIP(req.Header.Get("X-Forwarded-For"))
} else {
addr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr)
if err != nil {
log.Println("remoteAddr:", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
remoteIP = addr.IP
}
if s.limit(remoteIP) {
if debug {
log.Println(remoteIP, "is limited")
}
w.Header().Set("Retry-After", "60")
http.Error(w, "Too Many Requests", 429)
return
}
switch req.Method {
case "GET":
s.handleGET(ctx, w, req)
case "POST":
s.handlePOST(ctx, remoteIP, w, req)
default:
globalStats.Error()
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
}
}
func (s *querysrv) handleGET(ctx context.Context, w http.ResponseWriter, req *http.Request) {
reqID := ctx.Value(idKey).(requestID)
deviceID, err := protocol.DeviceIDFromString(req.URL.Query().Get("device"))
if err != nil {
if debug {
log.Println(reqID, "bad device param")
}
globalStats.Error()
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
var ann announcement
ann.Seen, err = s.getDeviceSeen(deviceID)
negCache := strconv.Itoa(negCacheFor(ann.Seen))
w.Header().Set("Retry-After", negCache)
w.Header().Set("Cache-Control", "public, max-age="+negCache)
if err != nil {
// The device is not in the database.
globalStats.Query()
http.Error(w, "Not Found", http.StatusNotFound)
return
}
t0 := time.Now()
ann.Addresses, err = s.getAddresses(ctx, deviceID)
if err != nil {
log.Println(reqID, "getAddresses:", err)
globalStats.Error()
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if debug {
log.Println(reqID, "getAddresses in", time.Since(t0))
}
globalStats.Query()
if len(ann.Addresses) == 0 {
http.Error(w, "Not Found", http.StatusNotFound)
return
}
globalStats.Answer()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(ann)
}
func (s *querysrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.ResponseWriter, req *http.Request) {
reqID := ctx.Value(idKey).(requestID)
rawCert := certificateBytes(req)
if rawCert == nil {
if debug {
log.Println(reqID, "no certificates")
}
globalStats.Error()
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
var ann announcement
if err := json.NewDecoder(req.Body).Decode(&ann); err != nil {
if debug {
log.Println(reqID, "decode:", err)
}
globalStats.Error()
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
deviceID := protocol.NewDeviceID(rawCert)
// handleAnnounce returns *two* errors. The first indicates a problem with
// something the client posted to us. We should return a 400 Bad Request
// and not worry about it. The second indicates that the request was fine,
// but something internal messed up. We should log it and respond with a
// more apologetic 500 Internal Server Error.
userErr, internalErr := s.handleAnnounce(ctx, remoteIP, deviceID, ann.Addresses)
if userErr != nil {
if debug {
log.Println(reqID, "handleAnnounce:", userErr)
}
globalStats.Error()
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
if internalErr != nil {
log.Println(reqID, "handleAnnounce:", internalErr)
globalStats.Error()
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
globalStats.Announce()
// TODO: Slowly increase this for stable clients
w.Header().Set("Reannounce-After", "1800")
// We could return the lookup result here, but it's kind of unnecessarily
// expensive to go query the database again so we let the client decide to
// do a lookup if they really care.
w.WriteHeader(http.StatusNoContent)
}
func (s *querysrv) Stop() {
s.listener.Close()
}
func (s *querysrv) handleAnnounce(ctx context.Context, remote net.IP, deviceID protocol.DeviceID, addresses []string) (userErr, internalErr error) {
reqID := ctx.Value(idKey).(requestID)
tx, err := s.db.Begin()
if err != nil {
internalErr = err
return
}
defer func() {
// Since we return from a bunch of different places, we handle
// rollback in the defer.
if internalErr != nil || userErr != nil {
tx.Rollback()
}
}()
for _, annAddr := range addresses {
uri, err := url.Parse(annAddr)
if err != nil {
userErr = err
return
}
host, port, err := net.SplitHostPort(uri.Host)
if err != nil {
userErr = err
return
}
ip := net.ParseIP(host)
if host == "" || ip.IsUnspecified() {
// Do not use IPv6 remote address if requested scheme is tcp4
if uri.Scheme == "tcp4" && remote.To4() == nil {
continue
}
// Do not use IPv4 remote address if requested scheme is tcp6
if uri.Scheme == "tcp6" && remote.To4() != nil {
continue
}
host = remote.String()
}
uri.Host = net.JoinHostPort(host, port)
if err := s.updateAddress(ctx, tx, deviceID, uri.String()); err != nil {
internalErr = err
return
}
}
if err := s.updateDevice(ctx, tx, deviceID); err != nil {
internalErr = err
return
}
t0 := time.Now()
internalErr = tx.Commit()
if debug {
log.Println(reqID, "commit in", time.Since(t0))
}
return
}
func (s *querysrv) limit(remote net.IP) bool {
key := remote.String()
bkt, ok := s.limiter.Get(key)
if ok {
bkt := bkt.(*rate.Limiter)
if !bkt.Allow() {
// Rate limit exceeded; ignore packet
return true
}
} else {
// limitAvg is in packets per ten seconds.
s.limiter.Add(key, rate.NewLimiter(rate.Limit(limitAvg)/10, limitBurst))
}
return false
}
func (s *querysrv) updateDevice(ctx context.Context, tx *sql.Tx, device protocol.DeviceID) error {
reqID := ctx.Value(idKey).(requestID)
t0 := time.Now()
res, err := tx.Stmt(s.prep["updateDevice"]).Exec(device.String())
if err != nil {
return err
}
if debug {
log.Println(reqID, "updateDevice in", time.Since(t0))
}
if rows, _ := res.RowsAffected(); rows == 0 {
t0 = time.Now()
_, err := tx.Stmt(s.prep["insertDevice"]).Exec(device.String())
if err != nil {
return err
}
if debug {
log.Println(reqID, "insertDevice in", time.Since(t0))
}
}
return nil
}
func (s *querysrv) updateAddress(ctx context.Context, tx *sql.Tx, device protocol.DeviceID, uri string) error {
res, err := tx.Stmt(s.prep["updateAddress"]).Exec(device.String(), uri)
if err != nil {
return err
}
if rows, _ := res.RowsAffected(); rows == 0 {
_, err := tx.Stmt(s.prep["insertAddress"]).Exec(device.String(), uri)
if err != nil {
return err
}
}
return nil
}
func (s *querysrv) getAddresses(ctx context.Context, device protocol.DeviceID) ([]string, error) {
rows, err := s.prep["selectAddress"].Query(device.String())
if err != nil {
return nil, err
}
defer rows.Close()
var res []string
for rows.Next() {
var addr string
err := rows.Scan(&addr)
if err != nil {
log.Println("Scan:", err)
continue
}
res = append(res, addr)
}
return res, nil
}
func (s *querysrv) getDeviceSeen(device protocol.DeviceID) (time.Time, error) {
row := s.prep["selectDevice"].QueryRow(device.String())
var seen time.Time
if err := row.Scan(&seen); err != nil {
return time.Time{}, err
}
return seen.In(time.UTC), nil
}
func handlePing(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(204)
}
func certificateBytes(req *http.Request) []byte {
if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 {
return req.TLS.PeerCertificates[0].Raw
}
if hdr := req.Header.Get("X-SSL-Cert"); hdr != "" {
bs := []byte(hdr)
// The certificate is in PEM format but with spaces for newlines. We
// need to reinstate the newlines for the PEM decoder. But we need to
// leave the spaces in the BEGIN and END lines - the first and last
// space - alone.
firstSpace := bytes.Index(bs, []byte(" "))
lastSpace := bytes.LastIndex(bs, []byte(" "))
for i := firstSpace + 1; i < lastSpace; i++ {
if bs[i] == ' ' {
bs[i] = '\n'
}
}
block, _ := pem.Decode(bs)
if block == nil {
// Decoding failed
return nil
}
return block.Bytes
}
return nil
}

View File

@@ -0,0 +1,304 @@
// Copyright (C) 2018 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
package main
import (
"crypto/tls"
"encoding/binary"
"fmt"
io "io"
"log"
"net"
"time"
"github.com/syncthing/syncthing/lib/protocol"
)
type replicator interface {
send(key string, addrs []DatabaseAddress, seen int64)
}
// a replicationSender tries to connect to the remote address and provide
// them with a feed of replication updates.
type replicationSender struct {
dst string
cert tls.Certificate // our certificate
allowedIDs []protocol.DeviceID
outbox chan ReplicationRecord
stop chan struct{}
}
func newReplicationSender(dst string, cert tls.Certificate, allowedIDs []protocol.DeviceID) *replicationSender {
return &replicationSender{
dst: dst,
cert: cert,
allowedIDs: allowedIDs,
outbox: make(chan ReplicationRecord, replicationOutboxSize),
stop: make(chan struct{}),
}
}
func (s *replicationSender) Serve() {
// Sleep a little at startup. Peers often restart at the same time, and
// this avoid the service failing and entering backoff state
// unnecessarily, while also reducing the reconnect rate to something
// reasonable by default.
time.Sleep(2 * time.Second)
tlsCfg := &tls.Config{
Certificates: []tls.Certificate{s.cert},
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: true,
}
// Dial the TLS connection.
conn, err := tls.Dial("tcp", s.dst, tlsCfg)
if err != nil {
log.Println("Replication connect:", err)
return
}
defer func() {
conn.SetWriteDeadline(time.Now().Add(time.Second))
conn.Close()
}()
// Get the other side device ID.
remoteID, err := deviceID(conn)
if err != nil {
log.Println("Replication connect:", err)
return
}
// Verify it's in the set of allowed device IDs.
if !deviceIDIn(remoteID, s.allowedIDs) {
log.Println("Replication connect: unexpected device ID:", remoteID)
return
}
// Send records.
buf := make([]byte, 1024)
for {
select {
case rec := <-s.outbox:
// Buffer must hold record plus four bytes for size
size := rec.Size()
if len(buf) < size+4 {
buf = make([]byte, size+4)
}
// Record comes after the four bytes size
n, err := rec.MarshalTo(buf[4:])
if err != nil {
// odd to get an error here, but we haven't sent anything
// yet so it's not fatal
replicationSendsTotal.WithLabelValues("error").Inc()
log.Println("Replication marshal:", err)
continue
}
binary.BigEndian.PutUint32(buf, uint32(n))
// Send
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
if _, err := conn.Write(buf[:4+n]); err != nil {
replicationSendsTotal.WithLabelValues("error").Inc()
log.Println("Replication write:", err)
return
}
replicationSendsTotal.WithLabelValues("success").Inc()
case <-s.stop:
return
}
}
}
func (s *replicationSender) Stop() {
close(s.stop)
}
func (s *replicationSender) String() string {
return fmt.Sprintf("replicationSender(%q)", s.dst)
}
func (s *replicationSender) send(key string, ps []DatabaseAddress, seen int64) {
item := ReplicationRecord{
Key: key,
Addresses: ps,
}
// The send should never block. The inbox is suitably buffered for at
// least a few seconds of stalls, which shouldn't happen in practice.
select {
case s.outbox <- item:
default:
replicationSendsTotal.WithLabelValues("drop").Inc()
}
}
// a replicationMultiplexer sends to multiple replicators
type replicationMultiplexer []replicator
func (m replicationMultiplexer) send(key string, ps []DatabaseAddress, seen int64) {
for _, s := range m {
// each send is nonblocking
s.send(key, ps, seen)
}
}
// replicationListener acceptes incoming connections and reads replication
// items from them. Incoming items are applied to the KV store.
type replicationListener struct {
addr string
cert tls.Certificate
allowedIDs []protocol.DeviceID
db database
stop chan struct{}
}
func newReplicationListener(addr string, cert tls.Certificate, allowedIDs []protocol.DeviceID, db database) *replicationListener {
return &replicationListener{
addr: addr,
cert: cert,
allowedIDs: allowedIDs,
db: db,
stop: make(chan struct{}),
}
}
func (l *replicationListener) Serve() {
tlsCfg := &tls.Config{
Certificates: []tls.Certificate{l.cert},
ClientAuth: tls.RequestClientCert,
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: true,
}
lst, err := tls.Listen("tcp", l.addr, tlsCfg)
if err != nil {
log.Println("Replication listen:", err)
return
}
defer lst.Close()
for {
select {
case <-l.stop:
return
default:
}
// Accept a connection
conn, err := lst.Accept()
if err != nil {
log.Println("Replication accept:", err)
return
}
// Figure out the other side device ID
remoteID, err := deviceID(conn.(*tls.Conn))
if err != nil {
log.Println("Replication accept:", err)
conn.SetWriteDeadline(time.Now().Add(time.Second))
conn.Close()
continue
}
// Verify it is in the set of allowed device IDs
if !deviceIDIn(remoteID, l.allowedIDs) {
log.Println("Replication accept: unexpected device ID:", remoteID)
conn.SetWriteDeadline(time.Now().Add(time.Second))
conn.Close()
continue
}
go l.handle(conn)
}
}
func (l *replicationListener) Stop() {
close(l.stop)
}
func (l *replicationListener) String() string {
return fmt.Sprintf("replicationListener(%q)", l.addr)
}
func (l *replicationListener) handle(conn net.Conn) {
defer func() {
conn.SetWriteDeadline(time.Now().Add(time.Second))
conn.Close()
}()
buf := make([]byte, 1024)
for {
select {
case <-l.stop:
return
default:
}
conn.SetReadDeadline(time.Now().Add(time.Minute))
// First four bytes are the size
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
log.Println("Replication read size:", err)
replicationRecvsTotal.WithLabelValues("error").Inc()
return
}
// Read the rest of the record
size := int(binary.BigEndian.Uint32(buf[:4]))
if len(buf) < size {
buf = make([]byte, size)
}
if _, err := io.ReadFull(conn, buf[:size]); err != nil {
log.Println("Replication read record:", err)
replicationRecvsTotal.WithLabelValues("error").Inc()
return
}
// Unmarshal
var rec ReplicationRecord
if err := rec.Unmarshal(buf[:size]); err != nil {
log.Println("Replication unmarshal:", err)
replicationRecvsTotal.WithLabelValues("error").Inc()
continue
}
// Store
l.db.merge(rec.Key, rec.Addresses, rec.Seen)
replicationRecvsTotal.WithLabelValues("success").Inc()
}
}
func deviceID(conn *tls.Conn) (protocol.DeviceID, error) {
// Handshake may not be complete on the server side yet, which we need
// to get the client certificate.
if !conn.ConnectionState().HandshakeComplete {
if err := conn.Handshake(); err != nil {
return protocol.DeviceID{}, err
}
}
// We expect exactly one certificate.
certs := conn.ConnectionState().PeerCertificates
if len(certs) != 1 {
return protocol.DeviceID{}, fmt.Errorf("unexpected number of certificates (%d != 1)", len(certs))
}
return protocol.NewDeviceID(certs[0].Raw), nil
}
func deviceIDIn(id protocol.DeviceID, ids []protocol.DeviceID) bool {
for _, candidate := range ids {
if id == candidate {
return true
}
}
return false
}

View File

@@ -1,141 +1,108 @@
// Copyright (C) 2014-2015 Jakob Borg and Contributors (see the CONTRIBUTORS file).
// Copyright (C) 2018 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
package main
import (
"bytes"
"database/sql"
"fmt"
"io/ioutil"
"log"
"os"
"sync/atomic"
"time"
"github.com/prometheus/client_golang/prometheus"
)
type stats struct {
// Incremented atomically
announces int64
queries int64
answers int64
errors int64
}
func (s *stats) Announce() {
atomic.AddInt64(&s.announces, 1)
}
func (s *stats) Query() {
atomic.AddInt64(&s.queries, 1)
}
func (s *stats) Answer() {
atomic.AddInt64(&s.answers, 1)
}
func (s *stats) Error() {
atomic.AddInt64(&s.errors, 1)
}
// Reset returns a copy of the current stats and resets the counters to
// zero.
func (s *stats) Reset() stats {
// Create a copy of the stats using atomic reads
copy := stats{
announces: atomic.LoadInt64(&s.announces),
queries: atomic.LoadInt64(&s.queries),
answers: atomic.LoadInt64(&s.answers),
errors: atomic.LoadInt64(&s.errors),
}
// Reset the stats by subtracting the values that we copied
atomic.AddInt64(&s.announces, -copy.announces)
atomic.AddInt64(&s.queries, -copy.queries)
atomic.AddInt64(&s.answers, -copy.answers)
atomic.AddInt64(&s.errors, -copy.errors)
return copy
}
type statssrv struct {
intv time.Duration
file string
db *sql.DB
}
func (s *statssrv) Serve() {
lastReset := time.Now()
for {
time.Sleep(next(s.intv))
stats := globalStats.Reset()
d := time.Since(lastReset).Seconds()
lastReset = time.Now()
log.Printf("Stats: %.02f announces/s, %.02f queries/s, %.02f answers/s, %.02f errors/s",
float64(stats.announces)/d, float64(stats.queries)/d, float64(stats.answers)/d, float64(stats.errors)/d)
if s.file != "" {
s.writeToFile(stats, d)
}
}
}
func (s *statssrv) Stop() {
panic("stop unimplemented")
}
func (s *statssrv) writeToFile(stats stats, secs float64) {
newLine := []byte("\n")
var addrs int
row := s.db.QueryRow("SELECT COUNT(*) FROM Addresses")
if err := row.Scan(&addrs); err != nil {
log.Println("stats query:", err)
return
}
fd, err := os.OpenFile(s.file, os.O_RDWR|os.O_CREATE, 0666)
if err != nil {
log.Println("stats file:", err)
return
}
defer func() {
err = fd.Close()
if err != nil {
log.Println("stats file:", err)
}
}()
bs, err := ioutil.ReadAll(fd)
if err != nil {
log.Println("stats file:", err)
return
}
lines := bytes.Split(bytes.TrimSpace(bs), newLine)
if len(lines) > 12 {
lines = lines[len(lines)-12:]
}
latest := fmt.Sprintf("%v: %6d addresses, %8.02f announces/s, %8.02f queries/s, %8.02f answers/s, %8.02f errors/s\n",
time.Now().UTC().Format(time.RFC3339), addrs,
float64(stats.announces)/secs, float64(stats.queries)/secs, float64(stats.answers)/secs, float64(stats.errors)/secs)
lines = append(lines, []byte(latest))
_, err = fd.Seek(0, 0)
if err != nil {
log.Println("stats file:", err)
return
}
err = fd.Truncate(0)
if err != nil {
log.Println("stats file:", err)
return
}
_, err = fd.Write(bytes.Join(lines, newLine))
if err != nil {
log.Println("stats file:", err)
return
}
var (
apiRequestsTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: "syncthing",
Subsystem: "discovery",
Name: "api_requests_total",
Help: "Number of API requests.",
}, []string{"type", "result"})
apiRequestsSeconds = prometheus.NewSummaryVec(
prometheus.SummaryOpts{
Namespace: "syncthing",
Subsystem: "discovery",
Name: "api_requests_seconds",
Help: "Latency of API requests.",
Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001},
}, []string{"type"})
lookupRequestsTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: "syncthing",
Subsystem: "discovery",
Name: "lookup_requests_total",
Help: "Number of lookup requests.",
}, []string{"result"})
announceRequestsTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: "syncthing",
Subsystem: "discovery",
Name: "announcement_requests_total",
Help: "Number of announcement requests.",
}, []string{"result"})
replicationSendsTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: "syncthing",
Subsystem: "discovery",
Name: "replication_sends_total",
Help: "Number of replication sends.",
}, []string{"result"})
replicationRecvsTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: "syncthing",
Subsystem: "discovery",
Name: "replication_recvs_total",
Help: "Number of replication receives.",
}, []string{"result"})
databaseKeys = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: "syncthing",
Subsystem: "discovery",
Name: "database_keys",
Help: "Number of database keys at last count.",
}, []string{"category"})
databaseStatisticsSeconds = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "syncthing",
Subsystem: "discovery",
Name: "database_statistics_seconds",
Help: "Time spent running the statistics routine.",
})
databaseOperations = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: "syncthing",
Subsystem: "discovery",
Name: "database_operations_total",
Help: "Number of database operations.",
}, []string{"operation", "result"})
databaseOperationSeconds = prometheus.NewSummaryVec(
prometheus.SummaryOpts{
Namespace: "syncthing",
Subsystem: "discovery",
Name: "database_operation_seconds",
Help: "Latency of database operations.",
Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001},
}, []string{"operation"})
)
const (
dbOpGet = "get"
dbOpPut = "put"
dbOpMerge = "merge"
dbResSuccess = "success"
dbResNotFound = "not_found"
dbResError = "error"
dbResUnmarshalError = "unmarsh_err"
)
func init() {
prometheus.MustRegister(apiRequestsTotal, apiRequestsSeconds,
lookupRequestsTotal, announceRequestsTotal,
replicationSendsTotal, replicationRecvsTotal,
databaseKeys, databaseStatisticsSeconds,
databaseOperations, databaseOperationSeconds)
}