Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plugins support #4

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 59 additions & 18 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package main

import (
"context"
"errors"
"flag"
"fmt"
"log"
"math/rand"
"net/url"
"os"
"os/exec"
"os/signal"
"syscall"
"time"
Expand All @@ -21,6 +23,15 @@ type opts struct {
checkFrequency string
destinationDir string
skipWaiting bool

plugin string
pluginArgs []string
}

type plugin struct {
binPath string
name string
args []string
}

const (
Expand All @@ -33,15 +44,22 @@ func main() {

log.Print("gokrazy's selfupdate service starting up..")

gokrazy.WaitForClock()

var o opts

flag.StringVar(&o.gusServer, "gus_server", "", "the HTTP/S endpoint of the GUS (gokrazy Update System) server (required)")
flag.StringVar(&o.checkFrequency, "check_frequency", "1h", "the time frequency for checks to the update service. default: 1h")
flag.StringVar(&o.destinationDir, "destination_dir", "/tmp/selfupdate", "the destination directory for the fetched update file. default: /tmp/selfupdate")
flag.BoolVar(&o.skipWaiting, "skip_waiting", false, "skips the time frequency check and jitter waits, and immediately performs an update check. default: false")
flag.BoolVar(&o.skipWaiting, "skip_waiting", false, "for the first update check it skips the time frequency check and jitter, useful for testing. default: false")
flag.StringVar(&o.plugin, "plugin", "", "name of the desired plugin to be loaded (this will be used when needed). default: ''")

flag.Parse()

// Gather args after flag parsing termination "--".
// They will be directly passed to the plugin binary.
o.pluginArgs = flag.Args()

if err := logic(ctx, o); err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -79,44 +97,46 @@ func logic(ctx context.Context, o opts) error {
return fmt.Errorf("error joining gus server url: %w", err)
}

plugins := make(map[string]plugin)
if err := loadPlugin(plugins, o.plugin, o.pluginArgs); err != nil {
return fmt.Errorf("error loading plugin %s: %w", o.plugin, err)
}

gusCfg := gusapi.NewConfiguration()
gusCfg.BasePath = gusBasePath
gusCli := gusapi.NewAPIClient(gusCfg)

if o.skipWaiting {
log.Print("skipping waiting, performing an immediate updateProcess")
if err := updateProcess(ctx, gusCli, machineID, o.gusServer, sbomHash, o.destinationDir, httpPassword, httpPort); err != nil {
log.Print("skipping waiting, performing an immediate update check")
if err := updateProcess(ctx, gusCli, plugins, machineID, o.gusServer, sbomHash, o.destinationDir, httpPassword, httpPort); err != nil {
// If the updateProcess fails we exit with an error
// so that gokrazy supervisor will restart the process.
return fmt.Errorf("error performing updateProcess: %v", err)
}

// If the updateProcess doesn't error
// we happily return to terminate the process.
return nil
}

log.Print("entering update checking loop")
ticker := time.NewTicker(frequency)

for {
for c := time.Tick(frequency); ; {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this so then we get an immediate first check without waiting an extra frequency interval.

Also changing the skipWaiting to only skip the Jitter once, the first time.

These two changes combined still achieve the same initially defined behaviour for skipWaiting, but improve the first check also in cases where skipWaiting is set to false

select {
case <-ctx.Done():
log.Print("stopping update checking")
log.Print("shutting down...")
return nil
case <-c:
if o.skipWaiting {
// Re-introduce jitter after first run skip.
o.skipWaiting = false
jitter := time.Duration(rand.Int63n(250)) * time.Second
time.Sleep(jitter)
}

case <-ticker.C:
jitter := time.Duration(rand.Int63n(250)) * time.Second
time.Sleep(jitter)
if err := updateProcess(ctx, gusCli, machineID, sbomHash, o.gusServer, o.destinationDir, httpPassword, httpPort); err != nil {
if err := updateProcess(ctx, gusCli, plugins, machineID, o.gusServer, sbomHash, o.destinationDir, httpPassword, httpPort); err != nil {
log.Printf("error performing updateProcess: %v", err)
continue
}
}
}
}

func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, machineID, gusServer, sbomHash, destinationDir, httpPassword, httpPort string) error {
func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, plugins map[string]plugin, machineID, gusServer, sbomHash, destinationDir, httpPassword, httpPort string) error {
response, err := checkForUpdates(ctx, gusCli, machineID)
if err != nil {
return fmt.Errorf("unable to check for updates: %w", err)
Expand All @@ -129,7 +149,7 @@ func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, machineID, gus
}

// The SBOMHash differs, start the selfupdate procedure.
if err := selfupdate(ctx, gusCli, gusServer, machineID, destinationDir, response, httpPassword, httpPort); err != nil {
if err := selfupdate(ctx, gusCli, plugins, gusServer, machineID, destinationDir, response, httpPassword, httpPort); err != nil {
return fmt.Errorf("unable to perform the selfupdate procedure: %w", err)
}

Expand All @@ -141,3 +161,24 @@ func updateProcess(ctx context.Context, gusCli *gusapi.APIClient, machineID, gus

return nil
}

func loadPlugin(plugins map[string]plugin, pluginName string, pluginArgs []string) error {
var binPath string

// Try to find the plugin binary in PATH.
fullPluginName := fmt.Sprintf("gokplugin-%s", pluginName)
if p, err := exec.LookPath(fullPluginName); err == nil {
binPath = p
} else {
// The binary can't be found in PATH.
// Fall back to checking in the well known gokrazy's /user/ path.
fallbackPath := fmt.Sprintf("/user/%s", fullPluginName)
if _, err := os.Stat(fallbackPath); errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("unable to find %s", fullPluginName)
}
binPath = fallbackPath
}
plugins[pluginName] = plugin{binPath: binPath, name: pluginName, args: pluginArgs}

return nil
}
21 changes: 17 additions & 4 deletions selfupdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func shouldUpdate(response gusapi.UpdateResponse, sbomHash string) bool {
return true
}

func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, gusServer, machineID, destinationDir string, response gusapi.UpdateResponse, httpPassword, httpPort string) error {
func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, plugins map[string]plugin, gusServer, machineID, destinationDir string, response gusapi.UpdateResponse, httpPassword, httpPort string) error {
log.Print("starting self-update procedure")

if _, _, err := gusCli.UpdateApi.Attempt(ctx, &gusapi.UpdateApiAttemptOpts{
Expand All @@ -52,12 +52,19 @@ func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, gusServer, machin

switch response.RegistryType {
case "http", "localdisk":
readClosers, err = httpFetcher(response, gusServer, destinationDir)
readClosers, err = httpUpdateFetch(response, gusServer, destinationDir)
if err != nil {
return fmt.Errorf("error fetching %q update from link %q: %w", response.RegistryType, response.DownloadLink, err)
}
default:
return fmt.Errorf("unrecognized registry type %q", response.RegistryType)
if _, ok := plugins[response.RegistryType]; !ok {
return fmt.Errorf("error %q is not a loaded plugin", response.RegistryType)
}

readClosers, err = pluginFetchUpdate(ctx, plugins[response.RegistryType], destinationDir, response.DownloadLink)
if err != nil {
return fmt.Errorf("error fetching %q update from link %q: %w", response.RegistryType, response.DownloadLink, err)
}
}

uri := fmt.Sprintf("http://gokrazy:%s@localhost:%s/", httpPassword, httpPort)
Expand Down Expand Up @@ -98,7 +105,13 @@ func selfupdate(ctx context.Context, gusCli *gusapi.APIClient, gusServer, machin
return fmt.Errorf("switching to non-active partition: %v", err)
}

log.Print("reboot")
log.Print("requesting reboot")
// TODO: change call from target.Reboot to something like target.ScheduleReboot
// which can asyncronouly perform the task instead of waiting
// for all the services to shutdown and then ack the reboot,
// othewise this causes a deadlock as the selfupdate service won't SIGTERM cleanly
// until the reboot ack is received, but won't receive the ack until all services shut down,
// causing a delayed reboot until SIGKILL kicks in.
if err := target.Reboot(); err != nil {
return fmt.Errorf("reboot: %v", err)
}
Comment on lines +108 to 117
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one will need some thought

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why will selfupdate not SIGTERM cleanly in this state? I would have expected it to. SIGINT and SIGTERM are hooked up to context cancellation in selfupdate’s main function.

Expand Down
63 changes: 59 additions & 4 deletions update_handlers.go → update_fetchers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package main

import (
"archive/zip"
"context"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"strconv"
"syscall"
Expand All @@ -28,8 +30,8 @@ type rcs struct {
root io.ReadCloser
}

// httpFetcher handles a http update link.
func httpFetcher(response gusapi.UpdateResponse, gusServer, destinationDir string) (rcs, error) {
// httpUpdateFetch fetches the update payload via HTTP.
func httpUpdateFetch(response gusapi.UpdateResponse, gusServer, destinationDir string) (rcs, error) {
// The link may be a relative url if the server's backend registry is its local disk.
// Ensure we have an absolute url by adding the base (gusServer) url
// when necessary.
Expand All @@ -43,7 +45,7 @@ func httpFetcher(response gusapi.UpdateResponse, gusServer, destinationDir strin
return rcs{}, fmt.Errorf("error ensuring destination directory exists: %w", err)
}

log.Printf("downloading update file from registry %q with url: %s", response.RegistryType, link)
log.Printf("downloading update from %q registry with url: %s", response.RegistryType, link)

filePath := filepath.Join(destinationDir, "disk.gaf")
if err := httpDownloadFile(destinationDir, filePath, link); err != nil {
Expand Down Expand Up @@ -129,6 +131,7 @@ func httpDownloadFile(destinationDir, filePath string, url string) error {
return nil
}

// ensureAbsoluteHTTPLink ensures an HTTP link is absolute.
func ensureAbsoluteHTTPLink(baseURL string, link string) (string, error) {
base, err := url.Parse(baseURL)
if err != nil {
Expand All @@ -143,7 +146,7 @@ func ensureAbsoluteHTTPLink(baseURL string, link string) (string, error) {
return u.String(), nil
}

// Function to get available disk space for path.
// diskSpaceAvailable gets available disk space for path.
func diskSpaceAvailable(path string) (uint64, error) {
fs := syscall.Statfs_t{}
err := syscall.Statfs(path, &fs)
Expand All @@ -152,3 +155,55 @@ func diskSpaceAvailable(path string) (uint64, error) {
}
return fs.Bfree * uint64(fs.Bsize), nil
}

// pluginFetchUpdate fetches the update payload via plugin.
func pluginFetchUpdate(ctx context.Context, p plugin, destinationDir string, link string) (rcs, error) {
if err := os.MkdirAll(destinationDir, 0755); err != nil {
return rcs{}, fmt.Errorf("error ensuring destination directory exists: %w", err)
}

log.Printf("downloading update from %q registry with url: %q", p.name, link)

filePath := filepath.Join(destinationDir, "disk.gaf")

// TODO: add update size gather + check against available disk space.

args := append(p.args, []string{"--url", link, "--output", destinationDir}...)

cmd := exec.CommandContext(ctx, p.binPath, args...)

if err := cmd.Run(); err != nil {
return rcs{}, fmt.Errorf("error running plugin command %q: %w", p.binPath, err)
}

log.Print("finished downloading update file")
log.Print("loading disk partitions from update file")

r, err := zip.OpenReader(filePath)
if err != nil {
return rcs{}, fmt.Errorf("error opening downloaded file %q: %w", filePath, err)
}

var mbrReader, bootReader, rootReader io.ReadCloser
for _, f := range r.File {
switch f.Name {
case mbrPartitionName:
mbrReader, err = f.Open()
if err != nil {
return rcs{}, fmt.Errorf("error reading %s within update file: %w", mbrPartitionName, err)
}
case bootPartitionName:
bootReader, err = f.Open()
if err != nil {
return rcs{}, fmt.Errorf("error reading %s within update file: %w", bootPartitionName, err)
}
case rootPartitionName:
rootReader, err = f.Open()
if err != nil {
return rcs{}, fmt.Errorf("error reading %s within update file: %w", rootPartitionName, err)
}
}
}

return rcs{r, mbrReader, bootReader, rootReader}, nil
}