diff --git a/cmd/strelaypoolsrv/main.go b/cmd/strelaypoolsrv/main.go index 6b84f434..5bcdd327 100644 --- a/cmd/strelaypoolsrv/main.go +++ b/cmd/strelaypoolsrv/main.go @@ -25,6 +25,7 @@ import ( type relay struct { URL string `json:"url"` + uri *url.URL } func (r relay) String() string { @@ -80,7 +81,7 @@ func main() { if dir != "" { if debug { - log.Println("Starting TLS listener") + log.Println("Starting TLS listener on", listen) } certFile, keyFile := filepath.Join(dir, "http-cert.pem"), filepath.Join(dir, "http-key.pem") cert, err := tls.LoadX509KeyPair(certFile, keyFile) @@ -109,7 +110,7 @@ func main() { listener, err = tls.Listen("tcp", listen, tlsCfg) } else { if debug { - log.Println("Starting plain listener") + log.Println("Starting plain listener on", listen) } listener, err = net.Listen("tcp", listen) } @@ -176,16 +177,6 @@ func handlePostRequest(w http.ResponseWriter, r *http.Request) { return } - for _, current := range permanentRelays { - if current == newRelay { - if debug { - log.Println("Asked to add a relay", newRelay, "which exists in permanent list") - } - http.Error(w, "Invalid request", 500) - return - } - } - uri, err := url.Parse(newRelay.URL) if err != nil { if debug { @@ -195,6 +186,40 @@ func handlePostRequest(w http.ResponseWriter, r *http.Request) { return } + host, port, err := net.SplitHostPort(uri.Host) + if err != nil { + if debug { + log.Println("Failed to split URI", newRelay.URL) + } + http.Error(w, err.Error(), 500) + return + } + + // The client did not provide an IP address, work it out. + if host == "" { + rhost, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + if debug { + log.Println("Failed to split remote address", r.RemoteAddr) + } + http.Error(w, err.Error(), 500) + return + } + uri.Host = net.JoinHostPort(rhost, port) + newRelay.URL = uri.String() + } + newRelay.uri = uri + + for _, current := range permanentRelays { + if current.uri.Host == newRelay.uri.Host { + if debug { + log.Println("Asked to add a relay", newRelay, "which exists in permanent list") + } + http.Error(w, "Invalid request", 500) + return + } + } + reschan := make(chan result) select { @@ -233,8 +258,18 @@ func loadPermanentRelays() { if len(line) == 0 { continue } + + uri, err := url.Parse(line) + if err != nil { + if debug { + log.Println("Skipping permanent relay", line, "due to parse error", err) + } + continue + } + permanentRelays = append(permanentRelays, relay{ URL: line, + uri: uri, }) if debug { log.Println("Adding permanent relay", line) @@ -271,7 +306,7 @@ func requestProcessor() { } mut.Lock() - timer, ok := evictionTimers[request.relay.URL] + timer, ok := evictionTimers[request.relay.uri.Host] if ok { if debug { log.Println("Stopping existing timer for", request.relay) @@ -280,7 +315,7 @@ func requestProcessor() { } for _, current := range knownRelays { - if current == request.relay { + if current.uri.Host == request.relay.uri.Host { if debug { log.Println("Relay", request.relay, "already exists") } @@ -294,7 +329,7 @@ func requestProcessor() { knownRelays = append(knownRelays, request.relay) found: - evictionTimers[request.relay.URL] = time.AfterFunc(evictionTime, evict(request.relay)) + evictionTimers[request.relay.uri.Host] = time.AfterFunc(evictionTime, evict(request.relay)) mut.Unlock() request.result <- result{nil, evictionTime} } @@ -309,7 +344,7 @@ func evict(relay relay) func() { log.Println("Evicting", relay) } for i, current := range knownRelays { - if current == relay { + if current.uri.Host == relay.uri.Host { if debug { log.Println("Evicted", relay) } @@ -318,6 +353,6 @@ func evict(relay relay) func() { knownRelays = knownRelays[:last] } } - delete(evictionTimers, relay.URL) + delete(evictionTimers, relay.uri.Host) } }