diff --git a/cmd/wsl2host/pkg/service/service.go b/cmd/wsl2host/pkg/service/service.go index 0d15297..1edd5d5 100644 --- a/cmd/wsl2host/pkg/service/service.go +++ b/cmd/wsl2host/pkg/service/service.go @@ -11,6 +11,7 @@ import ( "golang.org/x/sys/windows/svc/debug" "github.com/shayne/go-wsl2-host/pkg/hostsapi" + "github.com/shayne/go-wsl2-host/pkg/hypervapi" "github.com/shayne/go-wsl2-host/pkg/wslapi" ) @@ -30,6 +31,22 @@ func distroNameToHostname(distroname string) string { // Run main entry point to service logic func Run(elog debug.Log) error { + err := RunHyperVMCheckAndUpdate(elog) + if err != nil { + fmt.Println("update hyper-v vm with err", err) + elog.Error(1, fmt.Sprintf("failed to update hyper-v vm: %v", err)) + return err + } + err = RunWSLCheckAndUpdate(elog) + if err != nil { + fmt.Println("update wsl with err", err) + elog.Error(1, fmt.Sprintf("failed to update wsl: %v", err)) + return err + } + return nil +} + +func RunWSLCheckAndUpdate(elog debug.Log) error { // Then get all wsl info. and run them with config. infos, err := wslapi.GetAllInfo() if err != nil { @@ -57,6 +74,65 @@ func Run(elog debug.Log) error { return nil } +func RunHyperVMCheckAndUpdate(elog debug.Log) error { + infos, err := hypervapi.NewHyperVManager().GetRunningVMs() + if err != nil { + elog.Error(1, fmt.Sprintf("failed to get infos: %v", err)) + return fmt.Errorf("failed to get infos: %w", err) + } + err = updateHostIPOfVm(elog, infos) + if err != nil { + elog.Error(1, fmt.Sprintf("failed to update host IP info: %s", err)) + return err + } + return nil +} + +func updateHostIPOfVm(elog debug.Log, vmInfos []*hypervapi.VMInfo) error { + // update the ip to the vm + hapi, err := hostsapi.CreateAPI("hyper-vm") // filtere only managed host entries + if err != nil { + elog.Error(1, fmt.Sprintf("failed to create hosts api: %v", err)) + return fmt.Errorf("failed to create hosts api: %w", err) + } + + fmt.Printf("old hapi entry info:%v\n", hapi.Entries()) + + updated := false + + // update the vm ip to host + for _, info := range vmInfos { + hostname := info.GeDefaulttDomainName() + for _, ip := range info.GetIPV4() { + // update IPs of running distros + // add running distros not present + isUpsert := hapi.IsUpsertEntry(&hostsapi.HostEntry{ + Hostname: hostname, + IP: ip, + Comment: info.GetComent(), + }) + if isUpsert { + updated = true + } + } + } + + if updated { + err = hapi.Write() + if err != nil { + elog.Error(1, fmt.Sprintf("failed to write hosts file: %v", err)) + return fmt.Errorf("failed to write hosts file: %w", err) + } + + // restart the IP Helper service (iphlpsvc) for port forwarding + exec.Command("C:\\Windows\\System32\\cmd.exe", "/C net stop iphlpsvc").Run() + exec.Command("C:\\Windows\\System32\\cmd.exe", "/C net start iphlpsvc").Run() + } + + return nil + +} + func updateHostIP(elog debug.Log, distros []*wslapi.DistroInfo) error { // update the ip to the wsl hapi, err := hostsapi.CreateAPI("wsl2-host") // filtere only managed host entries @@ -149,7 +225,7 @@ func updateHostIP(elog debug.Log, distros []*wslapi.DistroInfo) error { } } - hostIP, err := hostsapi.GetHostIP() + hostIP, err := hostsapi.GetHostIPV2() if err == nil { hostname, err := os.Hostname() @@ -182,7 +258,7 @@ func updateHostIP(elog debug.Log, distros []*wslapi.DistroInfo) error { /// Write all other distro and host into the hosts file for each distro. func updateDistroIP(elog debug.Log, distros []*wslapi.DistroInfo, distro string) error { - host_ip, err := hostsapi.GetHostIP() + host_ip, err := hostsapi.GetHostIPV2() if err != nil { return err } diff --git a/go.sum b/go.sum index b209985..184d9e1 100644 --- a/go.sum +++ b/go.sum @@ -2,7 +2,6 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= diff --git a/pkg/hostsapi/hostsapi.go b/pkg/hostsapi/hostsapi.go index 79b307f..6d0ea38 100644 --- a/pkg/hostsapi/hostsapi.go +++ b/pkg/hostsapi/hostsapi.go @@ -21,6 +21,10 @@ type HostEntry struct { Comment string } +func (e *HostEntry) String() string { + return fmt.Sprintf("id:%d, ip:%s, hostname:%s, comment:%s", e.idx, e.IP, e.Hostname, e.Comment) +} + // HostsAPI data structure type HostsAPI struct { filter string @@ -154,6 +158,18 @@ func (h *HostsAPI) AddEntry(entry *HostEntry) error { return nil } +func (h *HostsAPI) IsUpsertEntry(entry *HostEntry) bool { + if oldEntry, exists := h.entries[entry.Hostname]; exists { + if oldEntry == entry { + fmt.Printf("entry is same, no need to update, entry:%s\n", entry) + return false + } + } + fmt.Printf("upsert entry:%s\n", entry) + h.entries[entry.Hostname] = entry + return true +} + // Write func (h *HostsAPI) Write() error { var outbuf bytes.Buffer @@ -211,3 +227,22 @@ func GetHostIP() (string, error) { } return ipString[1], nil } + +func GetHostIPV2() (string, error) { + cmd := exec.Command("powershell", "-Command", "(Get-NetIPAddress | Where-Object {$_.InterfaceAlias -like '*WSL*' -and $_.AddressFamily -eq 'IPv4'}).IPAddress") + + out, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("failed to execute PowerShell command with output: %v", string(out)) + } + + // Trim any whitespace or newline characters from the output + ip := strings.TrimSpace(string(out)) + + if ip == "" { + return "", errors.New("no WSL IP address found") + } + + return ip, nil +} + diff --git a/pkg/hypervapi/hypervapi.go b/pkg/hypervapi/hypervapi.go new file mode 100644 index 0000000..be44bf9 --- /dev/null +++ b/pkg/hypervapi/hypervapi.go @@ -0,0 +1,144 @@ +package hypervapi + +import ( + "bytes" + "fmt" + "net" + "os/exec" + "strings" +) + +const ( + defaultDomainNameSuffix = ".exmaple.com" + defaultComment = "managed by api - hyper-vm" +) + +type VMManager interface { + GetRunningVMs() ([]VMInfo, error) + GetVMIPs(vmName string) ([]string, error) +} + +type IP struct { + IPv4List []string + IPv6List []string +} + +func (ip *IP) String() string { + if ip == nil { + return "" + } + return fmt.Sprintf("ipv4:%s, ipv6:%s", ip.IPv4List, ip.IPv6List) +} + +type VMInfo struct { + Name string + IPInfo *IP +} + +func (v *VMInfo) GeDefaulttDomainName() string { + return v.Name + defaultDomainNameSuffix +} + +func (v *VMInfo) GetIP() []string { + if v.IPInfo == nil { + return []string{} + } + // 预分配足够的容量以避免多次分配 + ips := make([]string, 0, len(v.IPInfo.IPv4List)+len(v.IPInfo.IPv6List)) + ips = append(ips, v.IPInfo.IPv4List...) + ips = append(ips, v.IPInfo.IPv6List...) + return ips +} + +func (v *VMInfo) GetIPV4() []string { + if v.IPInfo == nil { + return []string{} + } + // 预分配足够的容量以避免多次分配 + ips := make([]string, 0, len(v.IPInfo.IPv4List)) + ips = append(ips, v.IPInfo.IPv4List...) + return ips +} + +func (v *VMInfo) GetComent() string { + return defaultComment +} + +type HyperVManager struct{} + +func NewHyperVManager() *HyperVManager { + return &HyperVManager{} +} + +func (h *HyperVManager) GetRunningVMs() ([]*VMInfo, error) { + vmNames, err := h.GetRunningVMNames() + if err != nil { + return nil, fmt.Errorf("get VM names failed: %w", err) + } + + vmInfos := make([]*VMInfo, 0, len(vmNames)) + var errs []string + + for _, vmName := range vmNames { + vmInfo := &VMInfo{Name: vmName} + vmIPs, err := h.GetVMIPByVMName(vmName) + if err != nil { + errs = append(errs, fmt.Sprintf("get IP for VM %q failed: %v", vmName, err)) + continue + } + vmInfo.IPInfo = vmIPs + fmt.Printf("vm:%v, ipInfo:%s\n", vmName, vmIPs) + vmInfos = append(vmInfos, vmInfo) + } + + if len(errs) > 0 { + return vmInfos, fmt.Errorf("errors occurred while getting VM information: %s", strings.Join(errs, "; ")) + } + + return vmInfos, nil +} + +func (h *HyperVManager) GetRunningVMNames() ([]string, error) { + // Implementation to get running VMs using PowerShell commands + cmd := exec.Command("powershell", "Get-VM | Where-Object {$_.State -eq 'Running'} | Select-Object -ExpandProperty Name") + var out bytes.Buffer + var errStr bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &errStr + err := cmd.Run() + if err != nil { + fmt.Printf("Running powershell 'powershell Get-VM | Where-Object {$_.State -eq 'Running'} | Select-Object -ExpandProperty Name' failed with err:%v\n, output:%v\n", errStr.String(), out.String()) + return nil, err + } + vms := strings.Split(strings.TrimSpace(out.String()), "\r\n") + fmt.Println("Get running vm names", vms) + return vms, nil +} + +func (h *HyperVManager) GetVMIPByVMName(vmName string) (*IP, error) { + // Implementation to get IPs of a VM using PowerShell commands + cmd := exec.Command("powershell", fmt.Sprintf("Get-VMNetworkAdapter -VMName %s | Select-Object -ExpandProperty IPAddresses", vmName)) + var out bytes.Buffer + var errStr bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &errStr + err := cmd.Run() + if err != nil { + fmt.Printf("Running powershell 'powershell Get-VMNetworkAdapter -VMName %s | Select-Object -ExpandProperty IPAddresses failed with err:%v\n, output:%v\n", vmName, errStr.String(), out.String()) + return nil, err + } + ipInfo := &IP{} + ipList := strings.Split(strings.TrimSpace(out.String()), "\r\n") + for _, ip := range ipList { + ip = strings.TrimSpace(ip) + if parsedIP := net.ParseIP(ip); parsedIP != nil { + if parsedIP.To4() != nil { + ipInfo.IPv4List = append(ipInfo.IPv4List, ip) + } else { + ipInfo.IPv6List = append(ipInfo.IPv6List, ip) + } + } + } + + return ipInfo, nil +}