Skip to content

Commit

Permalink
PJRT C API v0.35 (iree-org#15269)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman authored Oct 24, 2023
1 parent bb1efe8 commit 2bfc636
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 10 deletions.
127 changes: 126 additions & 1 deletion integrations/pjrt/src/iree_pjrt/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1817,19 +1817,144 @@ iree_status_t LoadedExecutableInstance::BatchExecute(
return status;
}

static void BindUndefineds(PJRT_Api* api) {
#define _STUB(API) \
api->API = +[](API##_Args* args) -> PJRT_Error* { \
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED, #API)); \
}

_STUB(PJRT_Plugin_Initialize);
_STUB(PJRT_Plugin_Attributes);

_STUB(PJRT_Event_Destroy);
_STUB(PJRT_Event_IsReady);
_STUB(PJRT_Event_Error);
_STUB(PJRT_Event_Await);
_STUB(PJRT_Event_OnReady);

_STUB(PJRT_Client_Create);
_STUB(PJRT_Client_Destroy);
_STUB(PJRT_Client_PlatformName);
_STUB(PJRT_Client_ProcessIndex);
_STUB(PJRT_Client_PlatformVersion);
_STUB(PJRT_Client_Devices);
_STUB(PJRT_Client_AddressableDevices);
_STUB(PJRT_Client_LookupDevice);
_STUB(PJRT_Client_LookupAddressableDevice);
_STUB(PJRT_Client_AddressableMemories);
_STUB(PJRT_Client_Compile);
_STUB(PJRT_Client_DefaultDeviceAssignment);
_STUB(PJRT_Client_BufferFromHostBuffer);

_STUB(PJRT_DeviceDescription_Id);
_STUB(PJRT_DeviceDescription_ProcessIndex);
_STUB(PJRT_DeviceDescription_Attributes);
_STUB(PJRT_DeviceDescription_Kind);
_STUB(PJRT_DeviceDescription_DebugString);
_STUB(PJRT_DeviceDescription_ToString);

_STUB(PJRT_Device_GetDescription);
_STUB(PJRT_Device_IsAddressable);
_STUB(PJRT_Device_LocalHardwareId);
_STUB(PJRT_Device_AddressableMemories);
_STUB(PJRT_Device_DefaultMemory);
_STUB(PJRT_Device_MemoryStats);

_STUB(PJRT_Memory_Id);
_STUB(PJRT_Memory_Kind);
_STUB(PJRT_Memory_DebugString);
_STUB(PJRT_Memory_ToString);
_STUB(PJRT_Memory_AddressableByDevices);

_STUB(PJRT_Executable_Destroy);
_STUB(PJRT_Executable_Name);
_STUB(PJRT_Executable_NumReplicas);
_STUB(PJRT_Executable_NumPartitions);
_STUB(PJRT_Executable_NumOutputs);
_STUB(PJRT_Executable_SizeOfGeneratedCodeInBytes);
_STUB(PJRT_Executable_GetCostAnalysis);
_STUB(PJRT_Executable_OutputMemoryKinds);
_STUB(PJRT_Executable_OptimizedProgram);
_STUB(PJRT_Executable_Serialize);

_STUB(PJRT_LoadedExecutable_Destroy);
_STUB(PJRT_LoadedExecutable_GetExecutable);
_STUB(PJRT_LoadedExecutable_AddressableDevices);
_STUB(PJRT_LoadedExecutable_Delete);
_STUB(PJRT_LoadedExecutable_IsDeleted);
_STUB(PJRT_LoadedExecutable_Execute);
_STUB(PJRT_Executable_DeserializeAndLoad);
_STUB(PJRT_LoadedExecutable_Fingerprint);

_STUB(PJRT_Buffer_Destroy);
_STUB(PJRT_Buffer_ElementType);
_STUB(PJRT_Buffer_Dimensions);
_STUB(PJRT_Buffer_UnpaddedDimensions);
_STUB(PJRT_Buffer_DynamicDimensionIndices);
_STUB(PJRT_Buffer_GetMemoryLayout);
_STUB(PJRT_Buffer_OnDeviceSizeInBytes);
_STUB(PJRT_Buffer_Device);
_STUB(PJRT_Buffer_Memory);
_STUB(PJRT_Buffer_Delete);
_STUB(PJRT_Buffer_IsDeleted);
_STUB(PJRT_Buffer_CopyToDevice);
_STUB(PJRT_Buffer_ToHostBuffer);
_STUB(PJRT_Buffer_IsOnCpu);
_STUB(PJRT_Buffer_ReadyEvent);
_STUB(PJRT_Buffer_UnsafePointer);
_STUB(PJRT_Buffer_IncreaseExternalReferenceCount);
_STUB(PJRT_Buffer_DecreaseExternalReferenceCount);
_STUB(PJRT_Buffer_OpaqueDeviceMemoryDataPointer);

_STUB(PJRT_CopyToDeviceStream_Destroy);
_STUB(PJRT_CopyToDeviceStream_AddChunk);
_STUB(PJRT_CopyToDeviceStream_TotalBytes);
_STUB(PJRT_CopyToDeviceStream_GranuleSize);
_STUB(PJRT_CopyToDeviceStream_CurrentBytes);

_STUB(PJRT_TopologyDescription_Create);
_STUB(PJRT_TopologyDescription_Destroy);
_STUB(PJRT_TopologyDescription_PlatformName);
_STUB(PJRT_TopologyDescription_PlatformVersion);
_STUB(PJRT_TopologyDescription_GetDeviceDescriptions);
_STUB(PJRT_TopologyDescription_Serialize);
_STUB(PJRT_TopologyDescription_Attributes);

_STUB(PJRT_Compile);

// Always add new fields to the end of the struct. Move fields below to their
// corresponding places after each major version bump.
_STUB(PJRT_Executable_OutputElementTypes);
_STUB(PJRT_Executable_OutputDimensions);

_STUB(PJRT_Buffer_CopyToMemory);

_STUB(PJRT_Client_CreateViewOfDeviceBuffer);
}

//===----------------------------------------------------------------------===//
// Top-level API binding.
//===----------------------------------------------------------------------===//

void BindMonomorphicApi(PJRT_Api* api) {
api->struct_size = PJRT_Api_STRUCT_SIZE;
api->extension_start = nullptr;
api->pjrt_api_version.major_version = PJRT_API_MAJOR;
api->pjrt_api_version.minor_version = PJRT_API_MINOR;

// This is a bare implementation throwing UNDEFINED errors. This way new
// functions will not segmentation fault on invocation.
BindUndefineds(api);
ErrorInstance::BindApi(api);

api->PJRT_Plugin_Initialize =
+[](PJRT_Plugin_Initialize_Args* args) -> PJRT_Error* { return nullptr; };

// Bind by object types.
BufferInstance::BindApi(api);
ClientInstance::BindApi(api);
DeviceDescription::BindApi(api);
DeviceInstance::BindApi(api);
ErrorInstance::BindApi(api);
EventInstance::BindApi(api);
ExecutableImage::BindApi(api);
LoadedExecutableInstance::BindApi(api);
Expand Down
94 changes: 85 additions & 9 deletions integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ extern "C" {
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 31
#define PJRT_API_MINOR 35

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -718,6 +718,43 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_BufferFromHostBuffer_Args, buffer);
typedef PJRT_Error* PJRT_Client_BufferFromHostBuffer(
PJRT_Client_BufferFromHostBuffer_Args* args);

struct PJRT_Client_CreateViewOfDeviceBuffer_Args {
size_t struct_size;
void* priv;
PJRT_Client* client;
// A pointer to a non-owned device buffer. A PJRT_Buffer that is a non-owned
// view of this device buffer will be created.
void* device_buffer_ptr;
const int64_t* dims;
size_t num_dims;
PJRT_Buffer_Type element_type;
PJRT_Buffer_MemoryLayout* layout;
// The device that `device_buffer_ptr` is on.
PJRT_Device* device;
// A callback to be performed when the PJRT_Buffer is done with the on-device
// buffer. This callback is optional and can be a nullptr.
void (*on_delete_callback)(void* device_buffer_ptr, void* user_arg);
// `on_delete_callback_arg` will be passed to `on_delete_callback` as
// `user_arg` argument.
void* on_delete_callback_arg;
// A platform-specific stream handle that should contain the work or events
// needed to materialize the on-device buffer. It is optional and can be
// casted from a nullptr. PJRT_Client_CreateViewOfDeviceBuffer_Args will
// append an event to `stream` that indicates when the returned buffer is
// ready to use. This is intended to support dlpack on GPU and is not expected
// to be supported on all hardware platforms.
intptr_t stream;
PJRT_Buffer* buffer; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateViewOfDeviceBuffer_Args, buffer);

// Creates a PJRT buffer that is a non-owned view of an on-device buffer
// (typically allocated by another library). The buffer may be mutated,
// for example, if the buffer is donated to an Execute operation. This method is
// not required on all hardware platforms.
typedef PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer(
PJRT_Client_CreateViewOfDeviceBuffer_Args* args);

// -------------------------- Device Descriptions ------------------------------

// Device descriptions may be associated with an actual device
Expand Down Expand Up @@ -1278,6 +1315,24 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_SizeOfGeneratedCodeInBytes_Args,
typedef PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes(
PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args);

struct PJRT_Executable_Fingerprint_Args {
size_t struct_size;
void* priv;
PJRT_Executable* executable;
// Has the lifetime of `executable`
const char* executable_fingerprint; // out
size_t executable_fingerprint_size; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Fingerprint_Args,
executable_fingerprint_size);

// A unique fingerprint for `executable`. Two executables that were produced by
// compiling with identical inputs (same program, compile options, compiler
// version, etc.) should have the same fingerprint. May not be implemented by
// all platforms.
typedef PJRT_Error* PJRT_Executable_Fingerprint(
PJRT_Executable_Fingerprint_Args* args);

struct PJRT_Executable_GetCostAnalysis_Args {
size_t struct_size;
void* priv;
Expand Down Expand Up @@ -1397,10 +1452,11 @@ struct PJRT_LoadedExecutable_Fingerprint_Args {
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Fingerprint_Args,
executable_fingerprint_size);
// A unique fingerprint for `executable`. Two executables that were produced by
// compiling with identical inputs (same program, compile options, compiler
// version, etc.) should have the same fingerprint. May not be implemented by
// all platforms.
// DEPRECATED. Will be removed in PJRT version 2.0. Please use
// PJRT_Executable_Fingerprint instead. A unique fingerprint for `executable`.
// Two executables that were produced by compiling with identical inputs (same
// program, compile options, compiler version, etc.) should have the same
// fingerprint. May not be implemented by all platforms.
typedef PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
PJRT_LoadedExecutable_Fingerprint_Args* args);

Expand Down Expand Up @@ -1565,12 +1621,27 @@ struct PJRT_Buffer_CopyToDevice_Args {
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_CopyToDevice_Args, dst_buffer);

// Copies the buffer to device `dst_device`. Caller is responsible for freeing
// returned `dst_buffer` with PJRT_Buffer_Destroy. Returns an error if the
// buffer is already on `dst_device`.
// Copies the buffer to device `dst_device` within the same client. Caller is
// responsible for freeing returned `dst_buffer` with PJRT_Buffer_Destroy.
// Returns an error if the buffer is already on `dst_device`.
typedef PJRT_Error* PJRT_Buffer_CopyToDevice(
PJRT_Buffer_CopyToDevice_Args* args);

struct PJRT_Buffer_CopyToMemory_Args {
size_t struct_size;
void* priv;
PJRT_Buffer* buffer;
PJRT_Memory* dst_memory;
PJRT_Buffer* dst_buffer; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_CopyToMemory_Args, dst_buffer);

// Copies the buffer to memory `dst_memory` within the same client. Caller is
// responsible for freeing returned `dst_buffer` with PJRT_Buffer_Destroy.
// Returns an error if the buffer is already on `dst_memory`.
typedef PJRT_Error* PJRT_Buffer_CopyToMemory(
PJRT_Buffer_CopyToMemory_Args* args);

struct PJRT_Buffer_IsOnCpu_Args {
size_t struct_size;
void* priv;
Expand Down Expand Up @@ -1905,6 +1976,7 @@ typedef PJRT_Error* PJRT_Compile(PJRT_Compile_Args* args);

typedef enum {
PJRT_Structure_Type_Gpu_Custom_Call = 0,
PJRT_Structure_Type_Profiler,
} PJRT_Structure_Type;

// PJRT_Structure_Base contains a type and a pointer to next
Expand Down Expand Up @@ -2033,10 +2105,14 @@ typedef struct {
// corresponding places after each major version bump.
_PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputElementTypes);
_PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputDimensions);

_PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToMemory);
_PJRT_API_STRUCT_FIELD(PJRT_Client_CreateViewOfDeviceBuffer);
_PJRT_API_STRUCT_FIELD(PJRT_Executable_Fingerprint);
} PJRT_Api;

const size_t PJRT_Api_STRUCT_SIZE =
PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Executable_OutputDimensions);
PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_CreateViewOfDeviceBuffer);

#undef _PJRT_API_STRUCT_FIELD

Expand Down

0 comments on commit 2bfc636

Please sign in to comment.