1. # -*- coding: utf-8 -*- 2. import torch 3. import math 4. 5. 6. class LegendrePolynomial3(torch.autograd.Function): 7. """ 8. We can implement our own custom autograd Functions by subclassing 9. torch.autograd.Function and implementing the forward and backward passes 10. which operate on Tensors. 11. """ 12. 13. @staticmethod 14. def forward(ctx, input): 15. """ 16. In the forward pass we receive a Tensor containing the input and return 17. a Tensor containing the output. ctx is a context object that can be used 18. to stash information for backward computation. You can cache arbitrary 19. objects for use in the backward pass using the ctx.save_for_backward method. 20. """ 21. ctx.save_for_backward(input) 22. return 0.5 * (5 * input ** 3 - 3 * input) 23. 24. @staticmethod 25. def backward(ctx, grad_output): 26. """ 27. In the backward pass we receive a Tensor containing the gradient of the loss 28. with respect to the output, and we need to compute the gradient of the loss 29. with respect to the input. 30. """ 31. input, = ctx.saved_tensors 32. return grad_output * 1.5 * (5 * input ** 2 - 1) 33. 34. 35. dtype = torch.float 36. device = torch.device("cpu") 37. # device = torch.device("cuda:0") # Uncomment this to run on GPU 38. 39. # Create Tensors to hold input and outputs. 40. # By default, requires_grad=False, which indicates that we do not need to 41. # compute gradients with respect to these Tensors during the backward pass. 42. x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) 43. y = torch.sin(x) 44. 45. # Create random Tensors for weights. For this example, we need 46. # 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized 47. # not too far from the correct result to ensure convergence. 48. # Setting requires_grad=True indicates that we want to compute gradients with 49. # respect to these Tensors during the backward pass. 50. a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) 51. b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True) 52. c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) 53. d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True) 54. 55. learning_rate = 5e-6 56. for t in range(2000): 57. # To apply our Function, we use Function.apply method. We alias this as 'P3'. 58. P3 = LegendrePolynomial3.apply 59. 60. # Forward pass: compute predicted y using operations; we compute 61. # P3 using our custom autograd operation. 62. y_pred = a + b * P3(c + d * x) 63. 64. # Compute and print loss 65. loss = (y_pred - y).pow(2).sum() 66. if t % 100 == 99: 67. print(t, loss.item()) 68. 69. # Use autograd to compute the backward pass. 70. loss.backward() 71. 72. # Update weights using gradient descent 73. with torch.no_grad(): 74. a -= learning_rate * a.grad 75. b -= learning_rate * b.grad 76. c -= learning_rate * c.grad 77. d -= learning_rate * d.grad 78. 79. # Manually zero the gradients after updating weights 80. a.grad = None 81. b.grad = None 82. c.grad = None 83. d.grad = None 84. 85. print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)') 86. 87.