-
Notifications
You must be signed in to change notification settings - Fork 739
Description
In the Model Optimizations chapter, in the Pruning section, there is a mistake in Listing 1:
The results section states that 4 out of 9 weights should remain. This is not the case, due to the mask.
torch.abs(weights) takes the absolute value of each weight, so negative numbers like -0.9 become 0.9.
'>=threshold' then keeps any value whose magnitude is greater than or equal to 0.1, regardless of sign, meaning both the positive and negative large numbers stay. In short, the code leaves 6 out of 9 weights remaining, not 4.
Here is a correction:
The corrected code ensures negative weights are removed. This leaves the user with 3 out of 9 weights remaining (33% sparsity instead of 44%). I am aware the listing states this code example should use magnitude-based pruning, but as the results were incorrect I am providing this correction in which negative signs are accounted for.
Hope this helps!
-Tess