diff --git a/cmd/strelaypoolsrv/main.go b/cmd/strelaypoolsrv/main.go index 98e3e702..db4ac81a 100644 --- a/cmd/strelaypoolsrv/main.go +++ b/cmd/strelaypoolsrv/main.go @@ -19,7 +19,6 @@ import ( "github.com/golang/groupcache/lru" "github.com/juju/ratelimit" - "github.com/kardianos/osext" "github.com/syncthing/relaysrv/client" "github.com/syncthing/syncthing/lib/sync" @@ -48,7 +47,7 @@ type result struct { var ( binDir string - testCert []tls.Certificate + testCert tls.Certificate listen string = ":80" dir string = "" evictionTime time.Duration = time.Hour @@ -61,6 +60,7 @@ var ( postLimitAvg = 1 getLimit time.Duration postLimit time.Duration + permRelaysFile string getMut sync.RWMutex = sync.NewRWMutex() getLRUCache *lru.Cache @@ -87,6 +87,7 @@ func main() { flag.IntVar(&postLRUSize, "post-limit-cache", postLRUSize, "Post request limiter cache size") flag.IntVar(&postLimitAvg, "post-limit-avg", 2, "Allowed average post request rate, per minute") flag.Int64Var(&postLimitBurst, "post-limit-burst", postLimitBurst, "Allowed burst post requests") + flag.StringVar(&permRelaysFile, "perm-relays", "", "Path to list of permanent relays") flag.Parse() @@ -99,13 +100,12 @@ func main() { var listener net.Listener var err error - binDir, err = osext.ExecutableFolder() - if err != nil { - log.Fatalln("Failed to locate executable directory") + if permRelaysFile != "" { + loadPermanentRelays(permRelaysFile) } - loadPermanentRelays() - loadOrCreateTestCertificate() + testCert = createTestCertificate() + go requestProcessor() if dir != "" { @@ -292,7 +292,7 @@ func requestProcessor() { if debug { log.Println("Request for", request.relay) } - if !client.TestRelay(request.uri, testCert, 250*time.Millisecond, 4) { + if !client.TestRelay(request.uri, []tls.Certificate{testCert}, 250*time.Millisecond, 4) { if debug { log.Println("Test for relay", request.relay, "failed") } @@ -375,16 +375,10 @@ func limit(addr string, cache *lru.Cache, lock sync.RWMutex, rate time.Duration, return false } -func loadPermanentRelays() { - path, err := osext.ExecutableFolder() +func loadPermanentRelays(file string) { + content, err := ioutil.ReadFile(file) if err != nil { - log.Println("Failed to locate executable directory") - return - } - - content, err := ioutil.ReadFile(filepath.Join(path, "relays")) - if err != nil { - return + log.Fatal(err) } for _, line := range strings.Split(string(content), "\n") { @@ -398,6 +392,7 @@ func loadPermanentRelays() { log.Println("Skipping permanent relay", line, "due to parse error", err) } continue + } permanentRelays = append(permanentRelays, relay{ @@ -410,17 +405,17 @@ func loadPermanentRelays() { } } -func loadOrCreateTestCertificate() { - certFile, keyFile := filepath.Join(binDir, "cert.pem"), filepath.Join(binDir, "key.pem") - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err == nil { - testCert = []tls.Certificate{cert} - return +func createTestCertificate() tls.Certificate { + tmpDir, err := ioutil.TempDir("", "relaypoolsrv") + if err != nil { + log.Fatal(err) } - cert, err = tlsutil.NewCertificate(certFile, keyFile, "relaypoolsrv", 3072) + certFile, keyFile := filepath.Join(tmpDir, "cert.pem"), filepath.Join(tmpDir, "key.pem") + cert, err := tlsutil.NewCertificate(certFile, keyFile, "relaypoolsrv", 3072) if err != nil { log.Fatalln("Failed to create test X509 key pair:", err) } - testCert = []tls.Certificate{cert} + + return cert }