diff --git a/main.go b/main.go index b08af6f..564877c 100644 --- a/main.go +++ b/main.go @@ -2,12 +2,14 @@ package main import ( "context" + "errors" "flag" "fmt" "log" "math/rand" "net/url" "os" + "os/exec" "os/signal" "syscall" "time" @@ -21,6 +23,15 @@ type opts struct { checkFrequency string destinationDir string skipWaiting bool + + plugin string + pluginArgs []string +} + +type plugin struct { + binPath string + name string + args []string } const ( @@ -33,15 +44,22 @@ func main() { log.Print("gokrazy's selfupdate service starting up..") + gokrazy.WaitForClock() + var o opts flag.StringVar(&o.gusServer, "gus_server", "", "the HTTP/S endpoint of the GUS (gokrazy Update System) server (required)") flag.StringVar(&o.checkFrequency, "check_frequency", "1h", "the time frequency for checks to the update service. default: 1h") flag.StringVar(&o.destinationDir, "destination_dir", "/tmp/selfupdate", "the destination directory for the fetched update file. default: /tmp/selfupdate") - flag.BoolVar(&o.skipWaiting, "skip_waiting", false, "skips the time frequency check and jitter waits, and immediately performs an update check. default: false") + flag.BoolVar(&o.skipWaiting, "skip_waiting", false, "for the first update check it skips the time frequency check and jitter, useful for testing. default: false") + flag.StringVar(&o.plugin, "plugin", "", "name of the desired plugin to be loaded (this will be used when needed). default: ''") flag.Parse() + // Gather args after flag parsing termination "--". + // They will be directly passed to the plugin binary. + o.pluginArgs = flag.Args() + if err := logic(ctx, o); err != nil { log.Fatal(err) } @@ -79,36 +97,38 @@ func logic(ctx context.Context, o opts) error { return fmt.Errorf("error joining gus server url: %w", err) } + plugins := make(map[string]plugin) + if err := loadPlugin(plugins, o.plugin, o.pluginArgs); err != nil { + return fmt.Errorf("error loading plugin %s: %w", o.plugin, err) + } + gusCfg := gusapi.NewConfiguration() gusCfg.BasePath = gusBasePath gusCli := gusapi.NewAPIClient(gusCfg) if o.skipWaiting { - log.Print("skipping waiting, performing an immediate updateProcess") - if err := updateProcess(ctx, gusCli, machineID, o.gusServer, sbomHash, o.destinationDir, httpPassword, httpPort); err != nil { + log.Print("skipping waiting, performing an immediate update check") + if err := updateProcess(ctx, gusCli, plugins, machineID, o.gusServer, sbomHash, o.destinationDir, httpPassword, httpPort); err != nil { // If the updateProcess fails we exit with an error // so that gokrazy supervisor will restart the process. return fmt.Errorf("error performing updateProcess: %v", err) } - - // If the updateProcess doesn't error - // we happily return to terminate the process. - return nil } - log.Print("entering update checking loop") - ticker := time.NewTicker(frequency) - - for { + for c := time.Tick(frequency); ; { select { case <-ctx.Done(): - log.Print("stopping update checking") + log.Print("shutting down...") return nil + case <-c: + if o.skipWaiting { + // Re-introduce jitter after first run skip. + o.skipWaiting = false + jitter := time.Duration(rand.Int63n(250)) * time.Second + time.Sleep(jitter) + } - case <-ticker.C: - jitter := time.Duration(rand.Int63n(250)) * time.Second - time.Sleep(jitter) - if err := updateProcess(ctx, gusCli, machineID, sbomHash, o.gusServer, o.destinationDir, httpPassword, httpPort); err != nil { + if err := updateProcess(ctx, gusCli, plugins, machineID, o.gusServer, sbomHash, o.destinationDir, httpPassword, httpPort); err != nil { log.Printf("error performing updateProcess: %v", err) continue } @@ -116,7 +136,7 @@ func logic(ctx context.Context, o opts) error { } } -func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, machineID, gusServer, sbomHash, destinationDir, httpPassword, httpPort string) error { +func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, plugins map[string]plugin, machineID, gusServer, sbomHash, destinationDir, httpPassword, httpPort string) error { response, err := checkForUpdates(ctx, gusCli, machineID) if err != nil { return fmt.Errorf("unable to check for updates: %w", err) @@ -129,7 +149,7 @@ func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, machineID, gus } // The SBOMHash differs, start the selfupdate procedure. - if err := selfupdate(ctx, gusCli, gusServer, machineID, destinationDir, response, httpPassword, httpPort); err != nil { + if err := selfupdate(ctx, gusCli, plugins, gusServer, machineID, destinationDir, response, httpPassword, httpPort); err != nil { return fmt.Errorf("unable to perform the selfupdate procedure: %w", err) } @@ -141,3 +161,24 @@ func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, machineID, gus return nil } + +func loadPlugin(plugins map[string]plugin, pluginName string, pluginArgs []string) error { + var binPath string + + // Try to find the plugin binary in PATH. + fullPluginName := fmt.Sprintf("gokplugin-%s", pluginName) + if p, err := exec.LookPath(fullPluginName); err == nil { + binPath = p + } else { + // The binary can't be found in PATH. + // Fall back to checking in the well known gokrazy's /user/ path. + fallbackPath := fmt.Sprintf("/user/%s", fullPluginName) + if _, err := os.Stat(fallbackPath); errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("unable to find %s", fullPluginName) + } + binPath = fallbackPath + } + plugins[pluginName] = plugin{binPath: binPath, name: pluginName, args: pluginArgs} + + return nil +} diff --git a/selfupdate.go b/selfupdate.go index 8975ac3..c5e6ac6 100644 --- a/selfupdate.go +++ b/selfupdate.go @@ -35,7 +35,7 @@ func shouldUpdate(response gusapi.UpdateResponse, sbomHash string) bool { return true } -func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, gusServer, machineID, destinationDir string, response gusapi.UpdateResponse, httpPassword, httpPort string) error { +func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, plugins map[string]plugin, gusServer, machineID, destinationDir string, response gusapi.UpdateResponse, httpPassword, httpPort string) error { log.Print("starting self-update procedure") if _, _, err := gusCli.UpdateApi.Attempt(ctx, &gusapi.UpdateApiAttemptOpts{ @@ -52,12 +52,19 @@ func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, gusServer, machin switch response.RegistryType { case "http", "localdisk": - readClosers, err = httpFetcher(response, gusServer, destinationDir) + readClosers, err = httpUpdateFetch(response, gusServer, destinationDir) if err != nil { return fmt.Errorf("error fetching %q update from link %q: %w", response.RegistryType, response.DownloadLink, err) } default: - return fmt.Errorf("unrecognized registry type %q", response.RegistryType) + if _, ok := plugins[response.RegistryType]; !ok { + return fmt.Errorf("error %q is not a loaded plugin", response.RegistryType) + } + + readClosers, err = pluginFetchUpdate(ctx, plugins[response.RegistryType], destinationDir, response.DownloadLink) + if err != nil { + return fmt.Errorf("error fetching %q update from link %q: %w", response.RegistryType, response.DownloadLink, err) + } } uri := fmt.Sprintf("http://gokrazy:%s@localhost:%s/", httpPassword, httpPort) @@ -98,7 +105,13 @@ func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, gusServer, machin return fmt.Errorf("switching to non-active partition: %v", err) } - log.Print("reboot") + log.Print("requesting reboot") + // TODO: change call from target.Reboot to something like target.ScheduleReboot + // which can asyncronouly perform the task instead of waiting + // for all the services to shutdown and then ack the reboot, + // othewise this causes a deadlock as the selfupdate service won't SIGTERM cleanly + // until the reboot ack is received, but won't receive the ack until all services shut down, + // causing a delayed reboot until SIGKILL kicks in. if err := target.Reboot(); err != nil { return fmt.Errorf("reboot: %v", err) } diff --git a/update_handlers.go b/update_fetchers.go similarity index 64% rename from update_handlers.go rename to update_fetchers.go index 4faaeff..f98be01 100644 --- a/update_handlers.go +++ b/update_fetchers.go @@ -2,12 +2,14 @@ package main import ( "archive/zip" + "context" "fmt" "io" "log" "net/http" "net/url" "os" + "os/exec" "path/filepath" "strconv" "syscall" @@ -28,8 +30,8 @@ type rcs struct { root io.ReadCloser } -// httpFetcher handles a http update link. -func httpFetcher(response gusapi.UpdateResponse, gusServer, destinationDir string) (rcs, error) { +// httpUpdateFetch fetches the update payload via HTTP. +func httpUpdateFetch(response gusapi.UpdateResponse, gusServer, destinationDir string) (rcs, error) { // The link may be a relative url if the server's backend registry is its local disk. // Ensure we have an absolute url by adding the base (gusServer) url // when necessary. @@ -43,7 +45,7 @@ func httpFetcher(response gusapi.UpdateResponse, gusServer, destinationDir strin return rcs{}, fmt.Errorf("error ensuring destination directory exists: %w", err) } - log.Printf("downloading update file from registry %q with url: %s", response.RegistryType, link) + log.Printf("downloading update from %q registry with url: %s", response.RegistryType, link) filePath := filepath.Join(destinationDir, "disk.gaf") if err := httpDownloadFile(destinationDir, filePath, link); err != nil { @@ -129,6 +131,7 @@ func httpDownloadFile(destinationDir, filePath string, url string) error { return nil } +// ensureAbsoluteHTTPLink ensures an HTTP link is absolute. func ensureAbsoluteHTTPLink(baseURL string, link string) (string, error) { base, err := url.Parse(baseURL) if err != nil { @@ -143,7 +146,7 @@ func ensureAbsoluteHTTPLink(baseURL string, link string) (string, error) { return u.String(), nil } -// Function to get available disk space for path. +// diskSpaceAvailable gets available disk space for path. func diskSpaceAvailable(path string) (uint64, error) { fs := syscall.Statfs_t{} err := syscall.Statfs(path, &fs) @@ -152,3 +155,55 @@ func diskSpaceAvailable(path string) (uint64, error) { } return fs.Bfree * uint64(fs.Bsize), nil } + +// pluginFetchUpdate fetches the update payload via plugin. +func pluginFetchUpdate(ctx context.Context, p plugin, destinationDir string, link string) (rcs, error) { + if err := os.MkdirAll(destinationDir, 0755); err != nil { + return rcs{}, fmt.Errorf("error ensuring destination directory exists: %w", err) + } + + log.Printf("downloading update from %q registry with url: %q", p.name, link) + + filePath := filepath.Join(destinationDir, "disk.gaf") + + // TODO: add update size gather + check against available disk space. + + args := append(p.args, []string{"--url", link, "--output", destinationDir}...) + + cmd := exec.CommandContext(ctx, p.binPath, args...) + + if err := cmd.Run(); err != nil { + return rcs{}, fmt.Errorf("error running plugin command %q: %w", p.binPath, err) + } + + log.Print("finished downloading update file") + log.Print("loading disk partitions from update file") + + r, err := zip.OpenReader(filePath) + if err != nil { + return rcs{}, fmt.Errorf("error opening downloaded file %q: %w", filePath, err) + } + + var mbrReader, bootReader, rootReader io.ReadCloser + for _, f := range r.File { + switch f.Name { + case mbrPartitionName: + mbrReader, err = f.Open() + if err != nil { + return rcs{}, fmt.Errorf("error reading %s within update file: %w", mbrPartitionName, err) + } + case bootPartitionName: + bootReader, err = f.Open() + if err != nil { + return rcs{}, fmt.Errorf("error reading %s within update file: %w", bootPartitionName, err) + } + case rootPartitionName: + rootReader, err = f.Open() + if err != nil { + return rcs{}, fmt.Errorf("error reading %s within update file: %w", rootPartitionName, err) + } + } + } + + return rcs{r, mbrReader, bootReader, rootReader}, nil +}