主要观点总结
本文介绍了如何使用torch.library.custom_op在PyTorch中创建Python自定义运算符,包括其与torch.compile和autograd的协同工作。文章通过示例解释了如何封装Python函数为自定义运算符,如何添加训练支持,以及如何进行测试。
关键观点总结
关键观点1: 使用torch.library.custom_op创建自定义运算符
介绍了如何使用torch.library.custom_op定义新的自定义运算符,将用Python编写的函数封装为类似于PyTorch原生运算符的行为。
关键观点2: 创建自定义运算符的原因
包括将任意Python函数视为不透明的可调用对象,与torch.compile相对应(即防止torch.compile跟踪进入函数);为任意Python函数添加训练支持;以及当操作可以表示为现有PyTorch运算符的组合时,提高效率和方便性。
关键观点3: 如何封装Python函数为自定义运算符
通过示例(如crop功能)详细解释了如何封装Python函数为自定义运算符,包括使用torch.library.custom_op和添加FakeTensor kernel。
关键观点4: 为自定义运算符添加训练支持
介绍了如何使用torch.library.register_autograd为运算符添加训练支持,以及如何通过注册autograd来为自定义运算符指定梯度公式。
关键观点5: 测试Python自定义运算符
阐述了如何使用torch.library.opcheck来测试自定义运算符是否正确注册,以及如何编写单独的测试来验证梯度的正确性。
文章预览
前言 在vllm里面看到flash attention包了一层 @torch.library.custom_op 装饰器(https://github.com/vllm-project/vllm/pull/7536),查阅了一下资料,发现这个是torch 2.4之后的新feature,防止打算torch compile的graph,翻译一下官方教程稍微了解一下这个用法。 来源:https://pytorch.org/tutorials/advanced/python_custom_ops.html Python Custom Operators 教程 这个教程介绍了Python自定义运算符的主题。它列出了我们将从这一教程中学习到的内容,包括如何将用Python编写的自定义运算符与PyTorch集成,以及如何使用torch.library.opcheck来测试自定义运算符。所需的先决条件是安装了PyTorch 2.4或更高版本。 PyTorch提供了大量可以在Tensor上运行的运算符(例如torch.add、torch.sum等)。但是,您可能希望在PyTorch中使用一个新的自定义运算符,可能是由第三方库编写的。本教程展示了如何封装Python函数,使它们的行为类似于PyTor
………………………………