PyTorch: purpose of addmm function

2019-08-16 01:41发布

问题:

What is the purpose of the following PyTorch function (doc):

torch.addmm(beta=1, mat, alpha=1, mat1, mat2, out=None)

More specifically, is there any reason to prefer this function instead of just using

beta * mat + alpha * (mat1 @ mat2)

回答1:

The addmm function is an optimized version of the equation beta*mat + alpha*(mat1 @ mat2). I ran some tests and timed their execution.

  • If beta=1, alpha=1, then the execution of both the statements (addmm and manual) is approximately the same (addmm is just a little faster), regardless of the matrices size.

  • If beta and alpha are not 1, then addmm is two times faster than the manual execution for smaller matrices (with total elements in order of 105). But, if matrices are large (in order of 106), the speedup seems negligible (39ms v/s 41ms)



标签: pytorch