diff --git a/internal/upgrade/upgrade_supported.go b/internal/upgrade/upgrade_supported.go index fe44a734..b80c3fe8 100644 --- a/internal/upgrade/upgrade_supported.go +++ b/internal/upgrade/upgrade_supported.go @@ -13,13 +13,16 @@ // You should have received a copy of the GNU General Public License along // with this program. If not, see . -// +build !windows,!noupgrade +// +build !noupgrade package upgrade import ( "archive/tar" + "archive/zip" + "bytes" "compress/gzip" + "crypto/md5" "encoding/json" "fmt" "io" @@ -28,43 +31,10 @@ import ( "os" "path" "path/filepath" + "runtime" "strings" ) -// Upgrade to the given release, saving the previous binary with a ".old" extension. -func upgradeTo(path string, rel Release) error { - expectedRelease := releaseName(rel.Tag) - if debug { - l.Debugf("expected release asset %q", expectedRelease) - } - for _, asset := range rel.Assets { - if debug { - l.Debugln("considering release", asset) - } - if strings.HasPrefix(asset.Name, expectedRelease) { - if strings.HasSuffix(asset.Name, ".tar.gz") { - fname, err := readTarGZ(asset.URL, filepath.Dir(path)) - if err != nil { - return err - } - - old := path + ".old" - err = os.Rename(path, old) - if err != nil { - return err - } - err = os.Rename(fname, path) - if err != nil { - return err - } - return nil - } - } - } - - return ErrVersionUnknown -} - // Returns the latest release, including prereleases or not depending on the argument func LatestRelease(prerelease bool) (Release, error) { resp, err := http.Get("https://api.github.com/repos/syncthing/syncthing/releases?per_page=10") @@ -97,7 +67,42 @@ func LatestRelease(prerelease bool) (Release, error) { return Release{}, ErrVersionUnknown } -func readTarGZ(url string, dir string) (string, error) { +// Upgrade to the given release, saving the previous binary with a ".old" extension. +func upgradeTo(binary string, rel Release) error { + expectedRelease := releaseName(rel.Tag) + if debug { + l.Debugf("expected release asset %q", expectedRelease) + } + for _, asset := range rel.Assets { + assetName := path.Base(asset.Name) + if debug { + l.Debugln("considering release", assetName) + } + + if strings.HasPrefix(assetName, expectedRelease) { + fname, err := readRelease(filepath.Dir(binary), asset.URL) + if err != nil { + return err + } + + old := binary + ".old" + _ = os.Remove(old) + err = os.Rename(binary, old) + if err != nil { + return err + } + err = os.Rename(fname, binary) + if err != nil { + return err + } + return nil + } + } + + return ErrVersionUnknown +} + +func readRelease(dir, url string) (string, error) { if debug { l.Debugf("loading %q", url) } @@ -114,14 +119,26 @@ func readTarGZ(url string, dir string) (string, error) { } defer resp.Body.Close() - gr, err := gzip.NewReader(resp.Body) + switch runtime.GOOS { + case "windows": + return readZip(dir, resp.Body) + default: + return readTarGz(dir, resp.Body) + } +} + +func readTarGz(dir string, r io.Reader) (string, error) { + gr, err := gzip.NewReader(r) if err != nil { return "", err } tr := tar.NewReader(gr) + var tempName, actualMD5, expectedMD5 string + // Iterate through the files in the archive. +fileLoop: for { hdr, err := tr.Next() if err == io.EOF { @@ -131,37 +148,177 @@ func readTarGZ(url string, dir string) (string, error) { if err != nil { return "", err } + + shortName := path.Base(hdr.Name) + if debug { - l.Debugf("considering file %q", hdr.Name) + l.Debugf("considering file %q", shortName) } - if path.Base(hdr.Name) == "syncthing" { - of, err := ioutil.TempFile(dir, "syncthing") + switch shortName { + case "syncthing": + if debug { + l.Debugln("writing and hashing binary") + } + tempName, actualMD5, err = writeBinary(dir, tr) if err != nil { return "", err } - _, err = io.Copy(of, tr) + if expectedMD5 != "" { + // We're done + break fileLoop + } + + case "syncthing.md5": + bs, err := ioutil.ReadAll(tr) if err != nil { - os.Remove(of.Name()) return "", err } - err = of.Close() - if err != nil { - os.Remove(of.Name()) - return "", err + expectedMD5 = strings.TrimSpace(string(bs)) + if debug { + l.Debugln("expected md5 is", actualMD5) } - err = os.Chmod(of.Name(), os.FileMode(hdr.Mode)) - if err != nil { - os.Remove(of.Name()) - return "", err + if actualMD5 != "" { + // We're done + break fileLoop } - - return of.Name(), nil } } + if tempName != "" && actualMD5 != "" { + // We found and saved something to disk. + if expectedMD5 == "" { + if debug { + l.Debugln("there is no md5 to compare with") + } + } else if actualMD5 != expectedMD5 { + // There was an md5 file included in the archive, and it doesn't + // match what we just wrote to disk. + return "", fmt.Errorf("incorrect MD5 checksum") + } + return tempName, nil + } + + return "", fmt.Errorf("no upgrade found") +} + +func readZip(dir string, r io.Reader) (string, error) { + body, err := ioutil.ReadAll(r) + if err != nil { + return "", err + } + + archive, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) + if err != nil { + return "", err + } + + var tempName, actualMD5, expectedMD5 string + + // Iterate through the files in the archive. +fileLoop: + for _, file := range archive.File { + shortName := path.Base(file.Name) + + if debug { + l.Debugf("considering file %q", shortName) + } + + switch shortName { + case "syncthing.exe": + if debug { + l.Debugln("writing and hashing binary") + } + + inFile, err := file.Open() + if err != nil { + return "", err + } + tempName, actualMD5, err = writeBinary(dir, inFile) + if err != nil { + return "", err + } + + if expectedMD5 != "" { + // We're done + break fileLoop + } + + case "syncthing.exe.md5": + inFile, err := file.Open() + if err != nil { + return "", err + } + bs, err := ioutil.ReadAll(inFile) + if err != nil { + return "", err + } + + expectedMD5 = strings.TrimSpace(string(bs)) + if debug { + l.Debugln("expected md5 is", actualMD5) + } + + if actualMD5 != "" { + // We're done + break fileLoop + } + } + } + + if tempName != "" && actualMD5 != "" { + // We found and saved something to disk. + if expectedMD5 == "" { + if debug { + l.Debugln("there is no md5 to compare with") + } + } else if actualMD5 != expectedMD5 { + // There was an md5 file included in the archive, and it doesn't + // match what we just wrote to disk. + return "", fmt.Errorf("incorrect MD5 checksum") + } + return tempName, nil + } + return "", fmt.Errorf("No upgrade found") } + +func writeBinary(dir string, inFile io.Reader) (filename, md5sum string, err error) { + outFile, err := ioutil.TempFile(dir, "syncthing") + if err != nil { + return "", "", err + } + + // Write the binary both a temporary file and to the MD5 hasher. + + h := md5.New() + mw := io.MultiWriter(h, outFile) + + _, err = io.Copy(mw, inFile) + if err != nil { + os.Remove(outFile.Name()) + return "", "", err + } + + err = outFile.Close() + if err != nil { + os.Remove(outFile.Name()) + return "", "", err + } + + err = os.Chmod(outFile.Name(), os.FileMode(0755)) + if err != nil { + os.Remove(outFile.Name()) + return "", "", err + } + + actualMD5 := fmt.Sprintf("%x", h.Sum(nil)) + if debug { + l.Debugln("actual md5 is", actualMD5) + } + + return outFile.Name(), actualMD5, nil +} diff --git a/internal/upgrade/upgrade_windows.go b/internal/upgrade/upgrade_windows.go deleted file mode 100644 index 9627f7cf..00000000 --- a/internal/upgrade/upgrade_windows.go +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (C) 2014 The Syncthing Authors. -// -// This program is free software: you can redistribute it and/or modify it -// under the terms of the GNU General Public License as published by the Free -// Software Foundation, either version 3 of the License, or (at your option) -// any later version. -// -// This program is distributed in the hope that it will be useful, but WITHOUT -// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or -// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for -// more details. -// -// You should have received a copy of the GNU General Public License along -// with this program. If not, see . - -// +build windows,!noupgrade - -package upgrade - -import ( - "archive/zip" - "bytes" - "encoding/json" - "fmt" - "io" - "io/ioutil" - "net/http" - "os" - "path" - "path/filepath" - "strings" -) - -// Upgrade to the given release, saving the previous binary with a ".old" extension. -func upgradeTo(path string, rel Release) error { - expectedRelease := releaseName(rel.Tag) - if debug { - l.Debugf("expected release asset %q", expectedRelease) - } - for _, asset := range rel.Assets { - if debug { - l.Debugln("considering release", asset) - } - if strings.HasPrefix(asset.Name, expectedRelease) { - if strings.HasSuffix(asset.Name, ".zip") { - fname, err := readZip(asset.URL, filepath.Dir(path)) - if err != nil { - return err - } - - old := path + ".old" - - os.Remove(old) - err = os.Rename(path, old) - if err != nil { - return err - } - err = os.Rename(fname, path) - if err != nil { - return err - } - return nil - } - } - } - - return ErrVersionUnknown -} - -// Returns the latest release, including prereleases or not depending on the argument -func LatestRelease(prerelease bool) (Release, error) { - resp, err := http.Get("https://api.github.com/repos/syncthing/syncthing/releases?per_page=10") - if err != nil { - return Release{}, err - } - if resp.StatusCode > 299 { - return Release{}, fmt.Errorf("API call returned HTTP error: %s", resp.Status) - } - - var rels []Release - json.NewDecoder(resp.Body).Decode(&rels) - resp.Body.Close() - - if len(rels) == 0 { - return Release{}, ErrVersionUnknown - } - - if prerelease { - // We are a beta version. Use the latest. - return rels[0], nil - } else { - // We are a regular release. Only consider non-prerelease versions for upgrade. - for _, rel := range rels { - if !rel.Prerelease { - return rel, nil - } - } - return Release{}, ErrVersionUnknown - } -} - -func readZip(url, dir string) (string, error) { - if debug { - l.Debugf("loading %q", url) - } - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return "", err - } - - req.Header.Add("Accept", "application/octet-stream") - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return "", err - } - - archive, err := zip.NewReader(bytes.NewReader(body), resp.ContentLength) - if err != nil { - return "", err - } - - // Iterate through the files in the archive. - for _, file := range archive.File { - - if debug { - l.Debugf("considering file %q", file.Name) - } - - if path.Base(file.Name) == "syncthing.exe" { - infile, err := file.Open() - if err != nil { - return "", err - } - - outfile, err := ioutil.TempFile(dir, "syncthing") - if err != nil { - return "", err - } - - _, err = io.Copy(outfile, infile) - if err != nil { - return "", err - } - - err = infile.Close() - if err != nil { - return "", err - } - - err = outfile.Close() - if err != nil { - os.Remove(outfile.Name()) - return "", err - } - - os.Chmod(outfile.Name(), file.Mode()) - return outfile.Name(), nil - } - } - - return "", fmt.Errorf("No upgrade found") -}