Skip to content

Commit

Permalink
Revert pyi.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Apr 1, 2024
1 parent be08b14 commit 652306c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
45 changes: 38 additions & 7 deletions candle-pyo3/py_src/candle/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ class i64(DType):
pass

@staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
def ones(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor filled with ones.
"""
pass

@staticmethod
def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
def rand(*shape: Shape, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor with random values.
"""
pass

@staticmethod
def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
def randn(*shape: Shape, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor with random values from a normal distribution.
"""
Expand All @@ -67,7 +67,7 @@ class u8(DType):
pass

@staticmethod
def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
def zeros(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor filled with zeros.
"""
Expand Down Expand Up @@ -208,6 +208,12 @@ class Tensor:
"""
pass

def abs(self) -> Tensor:
"""
Performs the `abs` operation on the tensor.
"""
pass

def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
Expand All @@ -226,7 +232,7 @@ class Tensor:
"""
pass

def broadcast_as(self, shape: Sequence[int]) -> Tensor:
def broadcast_as(self, *shape: Shape) -> Tensor:
"""
Broadcasts the tensor to the given shape.
"""
Expand All @@ -238,7 +244,7 @@ class Tensor:
"""
pass

def broadcast_left(self, shape: Sequence[int]) -> Tensor:
def broadcast_left(self, *shape: Shape) -> Tensor:
"""
Broadcasts the tensor to the given shape, adding new dimensions on the left.
"""
Expand Down Expand Up @@ -318,6 +324,12 @@ class Tensor:
"""
pass

def gather(self, index, dim):
"""
Gathers values along an axis specified by dim.
"""
pass

def get(self, index: int) -> Tensor:
"""
Gets the value at the specified index.
Expand Down Expand Up @@ -385,6 +397,13 @@ class Tensor:
"""
pass

@property
def nelement(self) -> int:
"""
Gets the tensor's element count.
"""
pass

def powf(self, p: float) -> Tensor:
"""
Performs the `pow` operation on the tensor with the given exponent.
Expand All @@ -410,7 +429,7 @@ class Tensor:
"""
pass

def reshape(self, shape: Sequence[int]) -> Tensor:
def reshape(self, *shape: Shape) -> Tensor:
"""
Reshapes the tensor to the given shape.
"""
Expand Down Expand Up @@ -472,6 +491,12 @@ class Tensor:
"""
pass

def to(self, *args, **kwargs) -> Tensor:
"""
Performs Tensor dtype and/or device conversion.
"""
pass

def to_device(self, device: Union[str, Device]) -> Tensor:
"""
Move the tensor to a new device.
Expand All @@ -484,6 +509,12 @@ class Tensor:
"""
pass

def to_torch(self) -> torch.Tensor:
"""
Converts candle's tensor to pytorch's tensor
"""
pass

def transpose(self, dim1: int, dim2: int) -> Tensor:
"""
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
Expand Down
8 changes: 6 additions & 2 deletions candle-pyo3/py_src/candle/utils/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ def has_mkl() -> bool:
pass

@staticmethod
def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
def load_ggml(
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
"""
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
"""
pass

@staticmethod
def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
def load_gguf(
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
"""
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values.
Expand Down

0 comments on commit 652306c

Please sign in to comment.