diff --git a/gen/nvml/generateapi.go b/gen/nvml/generateapi.go index c99fd78..5ec4a07 100644 --- a/gen/nvml/generateapi.go +++ b/gen/nvml/generateapi.go @@ -43,7 +43,7 @@ var GeneratableInterfaces = []GeneratableInterfacePoperties{ { Type: "library", Interface: "Interface", - Exclude: []string{"Lookup"}, + Exclude: []string{"LookupSymbol"}, PackageMethodsAliasedFrom: "libnvml", }, { diff --git a/pkg/nvml/api.go b/pkg/nvml/api.go index 9ede936..fdf27bd 100644 --- a/pkg/nvml/api.go +++ b/pkg/nvml/api.go @@ -16,6 +16,17 @@ package nvml +// ExtendedInterface defines a set of extensions to the core NVML API. +// +// TODO: For now the list of methods in this interface need to be kept in sync +// with the list of excluded methods for the Interface type in +// gen/nvml/generateapi.go. In the future we should automate this. +// +//go:generate moq -out mock/extendedinterface.go -pkg mock . ExtendedInterface:ExtendedInterface +type ExtendedInterface interface { + LookupSymbol(string) error +} + // libraryOptions hold the paramaters than can be set by a LibraryOption type libraryOptions struct { path string @@ -25,11 +36,6 @@ type libraryOptions struct { // LibraryOption represents a functional option to configure the underlying NVML library type LibraryOption func(*libraryOptions) -// Library defines a set of functions defined on the underlying dynamic library. -type Library interface { - Lookup(string) error -} - // WithLibraryPath provides an option to set the library name to be used by the NVML library. func WithLibraryPath(path string) LibraryOption { return func(o *libraryOptions) { diff --git a/pkg/nvml/lib.go b/pkg/nvml/lib.go index ed4f469..4d26531 100644 --- a/pkg/nvml/lib.go +++ b/pkg/nvml/lib.go @@ -85,13 +85,13 @@ func (l *library) init(opts ...LibraryOption) { l.dl = dl.New(o.path, o.flags) } -func (l *library) GetLibrary() Library { +func (l *library) Extensions() ExtendedInterface { return l } -// Lookup checks whether the specified library symbol exists in the library. +// LookupSymbol checks whether the specified library symbol exists in the library. // Note that this requires that the library be loaded. -func (l *library) Lookup(name string) error { +func (l *library) LookupSymbol(name string) error { if l == nil || l.refcount == 0 { return fmt.Errorf("error looking up %s: %w", name, errLibraryNotLoaded) } @@ -198,93 +198,93 @@ func (pis ProcessInfo_v2Slice) ToProcessInfoSlice() []ProcessInfo { // When new versioned symbols are added, these would have to be initialized above and have // corresponding checks and subsequent assignments added below. func (l *library) updateVersionedSymbols() { - err := l.Lookup("nvmlInit_v2") + err := l.LookupSymbol("nvmlInit_v2") if err == nil { nvmlInit = nvmlInit_v2 } - err = l.Lookup("nvmlDeviceGetPciInfo_v2") + err = l.LookupSymbol("nvmlDeviceGetPciInfo_v2") if err == nil { nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v2 } - err = l.Lookup("nvmlDeviceGetPciInfo_v3") + err = l.LookupSymbol("nvmlDeviceGetPciInfo_v3") if err == nil { nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v3 } - err = l.Lookup("nvmlDeviceGetCount_v2") + err = l.LookupSymbol("nvmlDeviceGetCount_v2") if err == nil { nvmlDeviceGetCount = nvmlDeviceGetCount_v2 } - err = l.Lookup("nvmlDeviceGetHandleByIndex_v2") + err = l.LookupSymbol("nvmlDeviceGetHandleByIndex_v2") if err == nil { nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v2 } - err = l.Lookup("nvmlDeviceGetHandleByPciBusId_v2") + err = l.LookupSymbol("nvmlDeviceGetHandleByPciBusId_v2") if err == nil { nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v2 } - err = l.Lookup("nvmlDeviceGetNvLinkRemotePciInfo_v2") + err = l.LookupSymbol("nvmlDeviceGetNvLinkRemotePciInfo_v2") if err == nil { nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v2 } // Unable to overwrite nvmlDeviceRemoveGpu() because the v2 function takes // a different set of parameters than the v1 function. - //err = l.Lookup("nvmlDeviceRemoveGpu_v2") + //err = l.LookupSymbol("nvmlDeviceRemoveGpu_v2") //if err == nil { // nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v2 //} - err = l.Lookup("nvmlDeviceGetGridLicensableFeatures_v2") + err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v2") if err == nil { nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v2 } - err = l.Lookup("nvmlDeviceGetGridLicensableFeatures_v3") + err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v3") if err == nil { nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v3 } - err = l.Lookup("nvmlDeviceGetGridLicensableFeatures_v4") + err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v4") if err == nil { nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v4 } - err = l.Lookup("nvmlEventSetWait_v2") + err = l.LookupSymbol("nvmlEventSetWait_v2") if err == nil { nvmlEventSetWait = nvmlEventSetWait_v2 } - err = l.Lookup("nvmlDeviceGetAttributes_v2") + err = l.LookupSymbol("nvmlDeviceGetAttributes_v2") if err == nil { nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v2 } - err = l.Lookup("nvmlComputeInstanceGetInfo_v2") + err = l.LookupSymbol("nvmlComputeInstanceGetInfo_v2") if err == nil { nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v2 } - err = l.Lookup("nvmlDeviceGetComputeRunningProcesses_v2") + err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v2") if err == nil { deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v2 } - err = l.Lookup("nvmlDeviceGetComputeRunningProcesses_v3") + err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v3") if err == nil { deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v3 } - err = l.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v2") + err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v2") if err == nil { deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v2 } - err = l.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v3") + err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v3") if err == nil { deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v3 } - err = l.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v2") + err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v2") if err == nil { deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v2 } - err = l.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v3") + err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v3") if err == nil { deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v3 } - err = l.Lookup("nvmlDeviceGetGpuInstancePossiblePlacements_v2") + err = l.LookupSymbol("nvmlDeviceGetGpuInstancePossiblePlacements_v2") if err == nil { nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v2 } - err = l.Lookup("nvmlVgpuInstanceGetLicenseInfo_v2") + err = l.LookupSymbol("nvmlVgpuInstanceGetLicenseInfo_v2") if err == nil { nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v2 } diff --git a/pkg/nvml/lib_test.go b/pkg/nvml/lib_test.go index 460ed1d..5e233e8 100644 --- a/pkg/nvml/lib_test.go +++ b/pkg/nvml/lib_test.go @@ -125,7 +125,7 @@ func TestLookupFromDefault(t *testing.T) { if !tc.skipLoadLibrary { require.ErrorIs(t, l.load(), tc.expectedLoadError) } - require.ErrorIs(t, l.Lookup("symbol"), tc.expectedLookupErrror) + require.ErrorIs(t, l.LookupSymbol("symbol"), tc.expectedLookupErrror) require.ErrorIs(t, l.close(), tc.expectedCloseError) if tc.expectedCloseError == nil { require.Equal(t, 0, int(l.refcount)) diff --git a/pkg/nvml/mock/extendedinterface.go b/pkg/nvml/mock/extendedinterface.go new file mode 100644 index 0000000..71634bf --- /dev/null +++ b/pkg/nvml/mock/extendedinterface.go @@ -0,0 +1,75 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mock + +import ( + "github.com/NVIDIA/go-nvml/pkg/nvml" + "sync" +) + +// Ensure, that ExtendedInterface does implement nvml.ExtendedInterface. +// If this is not the case, regenerate this file with moq. +var _ nvml.ExtendedInterface = &ExtendedInterface{} + +// ExtendedInterface is a mock implementation of nvml.ExtendedInterface. +// +// func TestSomethingThatUsesExtendedInterface(t *testing.T) { +// +// // make and configure a mocked nvml.ExtendedInterface +// mockedExtendedInterface := &ExtendedInterface{ +// LookupSymbolFunc: func(s string) error { +// panic("mock out the LookupSymbol method") +// }, +// } +// +// // use mockedExtendedInterface in code that requires nvml.ExtendedInterface +// // and then make assertions. +// +// } +type ExtendedInterface struct { + // LookupSymbolFunc mocks the LookupSymbol method. + LookupSymbolFunc func(s string) error + + // calls tracks calls to the methods. + calls struct { + // LookupSymbol holds details about calls to the LookupSymbol method. + LookupSymbol []struct { + // S is the s argument value. + S string + } + } + lockLookupSymbol sync.RWMutex +} + +// LookupSymbol calls LookupSymbolFunc. +func (mock *ExtendedInterface) LookupSymbol(s string) error { + if mock.LookupSymbolFunc == nil { + panic("ExtendedInterface.LookupSymbolFunc: method is nil but ExtendedInterface.LookupSymbol was just called") + } + callInfo := struct { + S string + }{ + S: s, + } + mock.lockLookupSymbol.Lock() + mock.calls.LookupSymbol = append(mock.calls.LookupSymbol, callInfo) + mock.lockLookupSymbol.Unlock() + return mock.LookupSymbolFunc(s) +} + +// LookupSymbolCalls gets all the calls that were made to LookupSymbol. +// Check the length with: +// +// len(mockedExtendedInterface.LookupSymbolCalls()) +func (mock *ExtendedInterface) LookupSymbolCalls() []struct { + S string +} { + var calls []struct { + S string + } + mock.lockLookupSymbol.RLock() + calls = mock.calls.LookupSymbol + mock.lockLookupSymbol.RUnlock() + return calls +} diff --git a/pkg/nvml/mock/interface.go b/pkg/nvml/mock/interface.go index 6e2f3d5..96739dd 100644 --- a/pkg/nvml/mock/interface.go +++ b/pkg/nvml/mock/interface.go @@ -654,15 +654,15 @@ var _ nvml.Interface = &Interface{} // EventSetWaitFunc: func(eventSet nvml.EventSet, v uint32) (nvml.EventData, nvml.Return) { // panic("mock out the EventSetWait method") // }, +// ExtensionsFunc: func() nvml.ExtendedInterface { +// panic("mock out the Extensions method") +// }, // GetExcludedDeviceCountFunc: func() (int, nvml.Return) { // panic("mock out the GetExcludedDeviceCount method") // }, // GetExcludedDeviceInfoByIndexFunc: func(n int) (nvml.ExcludedDeviceInfo, nvml.Return) { // panic("mock out the GetExcludedDeviceInfoByIndex method") // }, -// GetLibraryFunc: func() nvml.Library { -// panic("mock out the GetLibrary method") -// }, // GetVgpuCompatibilityFunc: func(vgpuMetadata *nvml.VgpuMetadata, vgpuPgpuMetadata *nvml.VgpuPgpuMetadata) (nvml.VgpuPgpuCompatibility, nvml.Return) { // panic("mock out the GetVgpuCompatibility method") // }, @@ -1534,15 +1534,15 @@ type Interface struct { // EventSetWaitFunc mocks the EventSetWait method. EventSetWaitFunc func(eventSet nvml.EventSet, v uint32) (nvml.EventData, nvml.Return) + // ExtensionsFunc mocks the Extensions method. + ExtensionsFunc func() nvml.ExtendedInterface + // GetExcludedDeviceCountFunc mocks the GetExcludedDeviceCount method. GetExcludedDeviceCountFunc func() (int, nvml.Return) // GetExcludedDeviceInfoByIndexFunc mocks the GetExcludedDeviceInfoByIndex method. GetExcludedDeviceInfoByIndexFunc func(n int) (nvml.ExcludedDeviceInfo, nvml.Return) - // GetLibraryFunc mocks the GetLibrary method. - GetLibraryFunc func() nvml.Library - // GetVgpuCompatibilityFunc mocks the GetVgpuCompatibility method. GetVgpuCompatibilityFunc func(vgpuMetadata *nvml.VgpuMetadata, vgpuPgpuMetadata *nvml.VgpuPgpuMetadata) (nvml.VgpuPgpuCompatibility, nvml.Return) @@ -3069,6 +3069,9 @@ type Interface struct { // V is the v argument value. V uint32 } + // Extensions holds details about calls to the Extensions method. + Extensions []struct { + } // GetExcludedDeviceCount holds details about calls to the GetExcludedDeviceCount method. GetExcludedDeviceCount []struct { } @@ -3077,9 +3080,6 @@ type Interface struct { // N is the n argument value. N int } - // GetLibrary holds details about calls to the GetLibrary method. - GetLibrary []struct { - } // GetVgpuCompatibility holds details about calls to the GetVgpuCompatibility method. GetVgpuCompatibility []struct { // VgpuMetadata is the vgpuMetadata argument value. @@ -3697,9 +3697,9 @@ type Interface struct { lockEventSetCreate sync.RWMutex lockEventSetFree sync.RWMutex lockEventSetWait sync.RWMutex + lockExtensions sync.RWMutex lockGetExcludedDeviceCount sync.RWMutex lockGetExcludedDeviceInfoByIndex sync.RWMutex - lockGetLibrary sync.RWMutex lockGetVgpuCompatibility sync.RWMutex lockGetVgpuDriverCapabilities sync.RWMutex lockGetVgpuVersion sync.RWMutex @@ -11031,6 +11031,33 @@ func (mock *Interface) EventSetWaitCalls() []struct { return calls } +// Extensions calls ExtensionsFunc. +func (mock *Interface) Extensions() nvml.ExtendedInterface { + if mock.ExtensionsFunc == nil { + panic("Interface.ExtensionsFunc: method is nil but Interface.Extensions was just called") + } + callInfo := struct { + }{} + mock.lockExtensions.Lock() + mock.calls.Extensions = append(mock.calls.Extensions, callInfo) + mock.lockExtensions.Unlock() + return mock.ExtensionsFunc() +} + +// ExtensionsCalls gets all the calls that were made to Extensions. +// Check the length with: +// +// len(mockedInterface.ExtensionsCalls()) +func (mock *Interface) ExtensionsCalls() []struct { +} { + var calls []struct { + } + mock.lockExtensions.RLock() + calls = mock.calls.Extensions + mock.lockExtensions.RUnlock() + return calls +} + // GetExcludedDeviceCount calls GetExcludedDeviceCountFunc. func (mock *Interface) GetExcludedDeviceCount() (int, nvml.Return) { if mock.GetExcludedDeviceCountFunc == nil { @@ -11090,33 +11117,6 @@ func (mock *Interface) GetExcludedDeviceInfoByIndexCalls() []struct { return calls } -// GetLibrary calls GetLibraryFunc. -func (mock *Interface) GetLibrary() nvml.Library { - if mock.GetLibraryFunc == nil { - panic("Interface.GetLibraryFunc: method is nil but Interface.GetLibrary was just called") - } - callInfo := struct { - }{} - mock.lockGetLibrary.Lock() - mock.calls.GetLibrary = append(mock.calls.GetLibrary, callInfo) - mock.lockGetLibrary.Unlock() - return mock.GetLibraryFunc() -} - -// GetLibraryCalls gets all the calls that were made to GetLibrary. -// Check the length with: -// -// len(mockedInterface.GetLibraryCalls()) -func (mock *Interface) GetLibraryCalls() []struct { -} { - var calls []struct { - } - mock.lockGetLibrary.RLock() - calls = mock.calls.GetLibrary - mock.lockGetLibrary.RUnlock() - return calls -} - // GetVgpuCompatibility calls GetVgpuCompatibilityFunc. func (mock *Interface) GetVgpuCompatibility(vgpuMetadata *nvml.VgpuMetadata, vgpuPgpuMetadata *nvml.VgpuPgpuMetadata) (nvml.VgpuPgpuCompatibility, nvml.Return) { if mock.GetVgpuCompatibilityFunc == nil { diff --git a/pkg/nvml/zz_generated.api.go b/pkg/nvml/zz_generated.api.go index 76bf64b..9997a27 100644 --- a/pkg/nvml/zz_generated.api.go +++ b/pkg/nvml/zz_generated.api.go @@ -232,9 +232,9 @@ var ( EventSetCreate = libnvml.EventSetCreate EventSetFree = libnvml.EventSetFree EventSetWait = libnvml.EventSetWait + Extensions = libnvml.Extensions GetExcludedDeviceCount = libnvml.GetExcludedDeviceCount GetExcludedDeviceInfoByIndex = libnvml.GetExcludedDeviceInfoByIndex - GetLibrary = libnvml.GetLibrary GetVgpuCompatibility = libnvml.GetVgpuCompatibility GetVgpuDriverCapabilities = libnvml.GetVgpuDriverCapabilities GetVgpuVersion = libnvml.GetVgpuVersion @@ -529,9 +529,9 @@ type Interface interface { EventSetCreate() (EventSet, Return) EventSetFree(EventSet) Return EventSetWait(EventSet, uint32) (EventData, Return) + Extensions() ExtendedInterface GetExcludedDeviceCount() (int, Return) GetExcludedDeviceInfoByIndex(int) (ExcludedDeviceInfo, Return) - GetLibrary() Library GetVgpuCompatibility(*VgpuMetadata, *VgpuPgpuMetadata) (VgpuPgpuCompatibility, Return) GetVgpuDriverCapabilities(VgpuDriverCapability) (bool, Return) GetVgpuVersion() (VgpuVersion, VgpuVersion, Return)