Issue
I was watching a tutorial on PyTorch and coding along and got stuck on function torch.randint
According to the documentation:
torch.randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
here, size
is :
size (tuple) – a tuple defining the shape of the output tensor.
The YouTuber wrote
random_idx = torch.randint(0, len(train_data), size=[1]).item()
But [1]
is not a tuple, it is a list. How is this possible?
I also tested it with a tuple and it worked just fine and every usage of randint()
i found on the internet provides a tuple for size
. E.g. size = (1,2)
or size = (1,1)
.
I searched the source code for torch.randint
but could not find it. I searched GitHub, PyTorch docs and even tried to find it in a local PyTorch library.
Solution
The documentation states that this should be a tuple
however in practice, the definition of randint()
is:
def randint(low: _int, high: _int, size: _size, ... )
Where _size
is defined as:
(type alias) _size: Type[Size] | Type[List[int]] | Type[Tuple[int, ...]]
So in practice, the requirement is for the size
parameter to be of type Size
, List
of int or Tuple
of int, which will pretty much behave the same in this case
EDIT: As stated above, indeed, typing is only an indicative in Python, so if you use any type of variable, there won't be any issue if the function itself doesn't raise an error. For the question of why the function acts accordingly and returns what is expected, this is because of the first part of the answer :)
Answered By - PlainRavioli Answer Checked By - Pedro (PHPFixing Volunteer)
0 Comments:
Post a Comment
Note: Only a member of this blog may post a comment.