Remove martini, use standard http mux

This commit is contained in:
Jakob Borg
2014-07-05 21:40:29 +02:00
parent 2d272a3cac
commit ee10295d04
29 changed files with 154 additions and 3153 deletions

View File

@@ -18,6 +18,7 @@ import (
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
"time"
@@ -27,7 +28,6 @@ import (
"github.com/calmh/syncthing/config"
"github.com/calmh/syncthing/logger"
"github.com/calmh/syncthing/model"
"github.com/codegangsta/martini"
"github.com/vitrun/qart/qr"
)
@@ -42,6 +42,7 @@ var (
guiErrorsMut sync.Mutex
static func(http.ResponseWriter, *http.Request, *log.Logger)
apiKey string
modt = time.Now().UTC().Format(http.TimeFormat)
)
const (
@@ -81,68 +82,90 @@ func startGUI(cfg config.GUIConfiguration, assetDir string, m *model.Model) erro
}
}
if len(assetDir) > 0 {
static = martini.Static(assetDir).(func(http.ResponseWriter, *http.Request, *log.Logger))
} else {
static = embeddedStatic()
}
router := martini.NewRouter()
router.Get("/", getRoot)
router.Get("/rest/version", restGetVersion)
router.Get("/rest/model", restGetModel)
router.Get("/rest/model/version", restGetModelVersion)
router.Get("/rest/need", restGetNeed)
router.Get("/rest/connections", restGetConnections)
router.Get("/rest/config", restGetConfig)
router.Get("/rest/config/sync", restGetConfigInSync)
router.Get("/rest/system", restGetSystem)
router.Get("/rest/errors", restGetErrors)
router.Get("/rest/discovery", restGetDiscovery)
router.Get("/rest/report", restGetReport)
router.Get("/qr/:text", getQR)
router.Post("/rest/config", restPostConfig)
router.Post("/rest/restart", restPostRestart)
router.Post("/rest/reset", restPostReset)
router.Post("/rest/shutdown", restPostShutdown)
router.Post("/rest/error", restPostError)
router.Post("/rest/error/clear", restClearErrors)
router.Post("/rest/discovery/hint", restPostDiscoveryHint)
router.Post("/rest/model/override", restPostOverride)
mr := martini.New()
mr.Use(csrfMiddleware)
if len(cfg.User) > 0 && len(cfg.Password) > 0 {
mr.Use(basic(cfg.User, cfg.Password))
}
mr.Use(static)
mr.Use(martini.Recovery())
mr.Use(restMiddleware)
mr.Action(router.Handle)
mr.Map(m)
apiKey = cfg.APIKey
loadCsrfTokens()
go http.Serve(listener, mr)
// The GET handlers
getRestMux := http.NewServeMux()
getRestMux.HandleFunc("/rest/version", restGetVersion)
getRestMux.HandleFunc("/rest/model", withModel(m, restGetModel))
getRestMux.HandleFunc("/rest/model/version", withModel(m, restGetModelVersion))
getRestMux.HandleFunc("/rest/need", withModel(m, restGetNeed))
getRestMux.HandleFunc("/rest/connections", withModel(m, restGetConnections))
getRestMux.HandleFunc("/rest/config", restGetConfig)
getRestMux.HandleFunc("/rest/config/sync", restGetConfigInSync)
getRestMux.HandleFunc("/rest/system", restGetSystem)
getRestMux.HandleFunc("/rest/errors", restGetErrors)
getRestMux.HandleFunc("/rest/discovery", restGetDiscovery)
getRestMux.HandleFunc("/rest/report", withModel(m, restGetReport))
// The POST handlers
postRestMux := http.NewServeMux()
postRestMux.HandleFunc("/rest/config", withModel(m, restPostConfig))
postRestMux.HandleFunc("/rest/restart", restPostRestart)
postRestMux.HandleFunc("/rest/reset", restPostReset)
postRestMux.HandleFunc("/rest/shutdown", restPostShutdown)
postRestMux.HandleFunc("/rest/error", restPostError)
postRestMux.HandleFunc("/rest/error/clear", restClearErrors)
postRestMux.HandleFunc("/rest/discovery/hint", restPostDiscoveryHint)
postRestMux.HandleFunc("/rest/model/override", withModel(m, restPostOverride))
// A handler that splits requests between the two above and disables
// caching
restMux := noCacheMiddleware(getPostHandler(getRestMux, postRestMux))
// The main routing handler
mux := http.NewServeMux()
mux.Handle("/rest/", restMux)
mux.HandleFunc("/qr/", getQR)
// Serve compiled in assets unless an asset directory was set (for development)
if len(assetDir) > 0 {
mux.Handle("/", http.FileServer(http.Dir(assetDir)))
} else {
mux.HandleFunc("/", embeddedStatic)
}
// Wrap everything in CSRF protection
handler := csrfMiddleware(mux)
// Wrap everything in basic auth, if user/password is set.
if len(cfg.User) > 0 {
handler = basicAuthMiddleware(cfg.User, cfg.Password, handler)
}
go http.Serve(listener, handler)
return nil
}
func getRoot(w http.ResponseWriter, r *http.Request) {
r.URL.Path = "/index.html"
static(w, r, nil)
func getPostHandler(get, post http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
get.ServeHTTP(w, r)
case "POST":
post.ServeHTTP(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
}
func restMiddleware(w http.ResponseWriter, r *http.Request) {
if len(r.URL.Path) >= 6 && r.URL.Path[:6] == "/rest/" {
func noCacheMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-cache")
h.ServeHTTP(w, r)
})
}
func withModel(m *model.Model, h func(m *model.Model, w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h(m, w, r)
}
}
func restGetVersion() string {
return Version
func restGetVersion(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(Version))
}
func restGetModelVersion(m *model.Model, w http.ResponseWriter, r *http.Request) {
@@ -186,7 +209,7 @@ func restGetModel(m *model.Model, w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(res)
}
func restPostOverride(m *model.Model, r *http.Request) {
func restPostOverride(m *model.Model, w http.ResponseWriter, r *http.Request) {
var qs = r.URL.Query()
var repo = qs.Get("repo")
m.Override(repo)
@@ -202,13 +225,13 @@ func restGetNeed(m *model.Model, w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(files)
}
func restGetConnections(m *model.Model, w http.ResponseWriter) {
func restGetConnections(m *model.Model, w http.ResponseWriter, r *http.Request) {
var res = m.ConnectionStats()
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(res)
}
func restGetConfig(w http.ResponseWriter) {
func restGetConfig(w http.ResponseWriter, r *http.Request) {
encCfg := cfg
if encCfg.GUI.Password != "" {
encCfg.GUI.Password = unchangedPassword
@@ -217,9 +240,9 @@ func restGetConfig(w http.ResponseWriter) {
json.NewEncoder(w).Encode(encCfg)
}
func restPostConfig(req *http.Request, m *model.Model) {
func restPostConfig(m *model.Model, w http.ResponseWriter, r *http.Request) {
var newCfg config.Configuration
err := json.NewDecoder(req.Body).Decode(&newCfg)
err := json.NewDecoder(r.Body).Decode(&newCfg)
if err != nil {
l.Warnln(err)
} else {
@@ -289,26 +312,23 @@ func restPostConfig(req *http.Request, m *model.Model) {
}
}
func restGetConfigInSync(w http.ResponseWriter) {
func restGetConfigInSync(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]bool{"configInSync": configInSync})
}
func restPostRestart(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
func restPostRestart(w http.ResponseWriter, r *http.Request) {
flushResponse(`{"ok": "restarting"}`, w)
go restart()
}
func restPostReset(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
func restPostReset(w http.ResponseWriter, r *http.Request) {
flushResponse(`{"ok": "resetting repos"}`, w)
resetRepositories()
go restart()
}
func restPostShutdown(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
func restPostShutdown(w http.ResponseWriter, r *http.Request) {
flushResponse(`{"ok": "shutting down"}`, w)
go shutdown()
}
@@ -322,7 +342,7 @@ func flushResponse(s string, w http.ResponseWriter) {
var cpuUsagePercent [10]float64 // The last ten seconds
var cpuUsageLock sync.RWMutex
func restGetSystem(w http.ResponseWriter) {
func restGetSystem(w http.ResponseWriter, r *http.Request) {
var m runtime.MemStats
runtime.ReadMemStats(&m)
@@ -347,20 +367,20 @@ func restGetSystem(w http.ResponseWriter) {
json.NewEncoder(w).Encode(res)
}
func restGetErrors(w http.ResponseWriter) {
func restGetErrors(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
guiErrorsMut.Lock()
json.NewEncoder(w).Encode(guiErrors)
guiErrorsMut.Unlock()
}
func restPostError(req *http.Request) {
bs, _ := ioutil.ReadAll(req.Body)
req.Body.Close()
func restPostError(w http.ResponseWriter, r *http.Request) {
bs, _ := ioutil.ReadAll(r.Body)
r.Body.Close()
showGuiError(0, string(bs))
}
func restClearErrors() {
func restClearErrors(w http.ResponseWriter, r *http.Request) {
guiErrorsMut.Lock()
guiErrors = []guiError{}
guiErrorsMut.Unlock()
@@ -375,7 +395,7 @@ func showGuiError(l logger.LogLevel, err string) {
guiErrorsMut.Unlock()
}
func restPostDiscoveryHint(r *http.Request) {
func restPostDiscoveryHint(w http.ResponseWriter, r *http.Request) {
var qs = r.URL.Query()
var node = qs.Get("node")
var addr = qs.Get("addr")
@@ -384,18 +404,19 @@ func restPostDiscoveryHint(r *http.Request) {
}
}
func restGetDiscovery(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
func restGetDiscovery(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(discoverer.All())
}
func restGetReport(w http.ResponseWriter, m *model.Model) {
func restGetReport(m *model.Model, w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(reportData(m))
}
func getQR(w http.ResponseWriter, params martini.Params) {
code, err := qr.Encode(params["text"], qr.M)
func getQR(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
text := r.FormValue("text")
code, err := qr.Encode(text, qr.M)
if err != nil {
http.Error(w, "Invalid", 500)
return
@@ -405,20 +426,21 @@ func getQR(w http.ResponseWriter, params martini.Params) {
w.Write(code.PNG())
}
func basic(username string, passhash string) http.HandlerFunc {
return func(res http.ResponseWriter, req *http.Request) {
if validAPIKey(req.Header.Get("X-API-Key")) {
func basicAuthMiddleware(username string, passhash string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if validAPIKey(r.Header.Get("X-API-Key")) {
next.ServeHTTP(w, r)
return
}
error := func() {
time.Sleep(time.Duration(rand.Intn(100)+100) * time.Millisecond)
res.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"")
http.Error(res, "Not Authorized", http.StatusUnauthorized)
w.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"")
http.Error(w, "Not Authorized", http.StatusUnauthorized)
}
hdr := req.Header.Get("Authorization")
if len(hdr) < len("Basic ") || hdr[:6] != "Basic " {
hdr := r.Header.Get("Authorization")
if !strings.HasPrefix(hdr, "Basic ") {
error()
return
}
@@ -445,35 +467,37 @@ func basic(username string, passhash string) http.HandlerFunc {
error()
return
}
}
next.ServeHTTP(w, r)
})
}
func validAPIKey(k string) bool {
return len(apiKey) > 0 && k == apiKey
}
func embeddedStatic() func(http.ResponseWriter, *http.Request, *log.Logger) {
var modt = time.Now().UTC().Format(http.TimeFormat)
func embeddedStatic(w http.ResponseWriter, r *http.Request) {
file := r.URL.Path
return func(res http.ResponseWriter, req *http.Request, log *log.Logger) {
file := req.URL.Path
if file[0] == '/' {
file = file[1:]
}
bs, ok := auto.Assets[file]
if !ok {
return
}
mtype := mime.TypeByExtension(filepath.Ext(req.URL.Path))
if len(mtype) != 0 {
res.Header().Set("Content-Type", mtype)
}
res.Header().Set("Content-Length", fmt.Sprintf("%d", len(bs)))
res.Header().Set("Last-Modified", modt)
res.Write(bs)
if file[0] == '/' {
file = file[1:]
}
if len(file) == 0 {
file = "index.html"
}
bs, ok := auto.Assets[file]
if !ok {
return
}
mtype := mime.TypeByExtension(filepath.Ext(r.URL.Path))
if len(mtype) != 0 {
w.Header().Set("Content-Type", mtype)
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(bs)))
w.Header().Set("Last-Modified", modt)
w.Write(bs)
}