Thanks for your perfect work! But in the paper, the input channels will be re-weighted by the SE block firstly. And then select the top-k subset to the normal convolution to get the output of each head. But your code just pass the whole re-weighted input channels to the normal convolution, whose shape is (C_out // num_heads, C_in, k, k). If so, the amount of calculation and parameters will not decrease. Therefore, I don't notice the select progress. Could you please explain to me?