From c4f645589c23da8cab9dddfa6ed7bb90ce182f52 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 8 May 2024 18:13:27 +0200 Subject: [PATCH] Add internal method to get device handle Signed-off-by: Evan Lezar --- gen/nvml/generateapi.go | 11 +++++++---- pkg/nvml/device.go | 12 +++++++++++- pkg/nvml/zz_generated.api.go | 1 + 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/gen/nvml/generateapi.go b/gen/nvml/generateapi.go index 5ec4a07..f12c4fe 100644 --- a/gen/nvml/generateapi.go +++ b/gen/nvml/generateapi.go @@ -36,6 +36,7 @@ type GeneratableInterfacePoperties struct { Type string Interface string Exclude []string + IncludePrivate []string PackageMethodsAliasedFrom string } @@ -47,8 +48,9 @@ var GeneratableInterfaces = []GeneratableInterfacePoperties{ PackageMethodsAliasedFrom: "libnvml", }, { - Type: "nvmlDevice", - Interface: "Device", + Type: "nvmlDevice", + Interface: "Device", + IncludePrivate: []string{"nvmlDeviceHandle"}, }, { Type: "nvmlGpuInstance", @@ -340,8 +342,9 @@ func extractMethods(sourceFile string, sourceContent []byte, input GeneratableIn continue } - // Ignore non-public methods - if !isPublic(funcDecl.Name.Name) { + // Ignore non-public methods unless these are forced. + forced := slices.Contains(input.Include, funcDecl.Name.Name) + if !forced && !isPublic(funcDecl.Name.Name) { continue } diff --git a/pkg/nvml/device.go b/pkg/nvml/device.go index 7ee5e55..3883164 100644 --- a/pkg/nvml/device.go +++ b/pkg/nvml/device.go @@ -18,6 +18,11 @@ import ( "unsafe" ) +// nvmlDeviceHandle provides an explicit function to return the underlying nvmlDevice handle. +func (d nvmlDevice) nvmlDeviceHandle() *nvmlDevice { + return &d +} + // EccBitType type EccBitType = MemoryErrorType @@ -219,8 +224,13 @@ func (l *library) DeviceGetTopologyCommonAncestor(device1 Device, device2 Device } func (device1 nvmlDevice) GetTopologyCommonAncestor(device2 Device) (GpuTopologyLevel, Return) { + other := device2.nvmlDeviceHandle() + if other == nil { + return 0, ERROR_INVALID_ARGUMENT + } + var pathInfo GpuTopologyLevel - ret := nvmlDeviceGetTopologyCommonAncestor(device1, device2.(nvmlDevice), &pathInfo) + ret := nvmlDeviceGetTopologyCommonAncestor(device1, other, &pathInfo) return pathInfo, ret } diff --git a/pkg/nvml/zz_generated.api.go b/pkg/nvml/zz_generated.api.go index 9997a27..cd81423 100644 --- a/pkg/nvml/zz_generated.api.go +++ b/pkg/nvml/zz_generated.api.go @@ -815,6 +815,7 @@ type Device interface { SetVirtualizationMode(GpuVirtualizationMode) Return ValidateInforom() Return VgpuTypeGetMaxInstances(VgpuTypeId) (int, Return) + nvmlDeviceHandle() *nvmlDevice } // GpuInstance represents the interface for the nvmlGpuInstance type.