@@ -237,16 +237,27 @@ def _numpy(self, data, weights, shape):
237237 self ._checkNPQuantity (q , shape )
238238 self ._checkNPWeights (weights , shape )
239239 weights = self ._makeNPWeights (weights , shape )
240+ newentries = weights .sum ()
241+
242+ subweights = weights .copy ()
243+ subweights [weights < 0.0 ] = 0.0
240244
241- # no possibility of exception from here on out (for rollback)
242- for x , w in zip (q , weights ):
243- if w > 0.0 :
244- if x not in self .bins :
245- self .bins [x ] = self .value .zero ()
246- self .bins [x ].fill (x , w )
245+ import numpy
246+ selection = numpy .empty (q .shape , dtype = numpy .bool )
247+
248+ uniques , inverse = numpy .unique (q , return_inverse = True )
247249
248250 # no possibility of exception from here on out (for rollback)
249- self .entries += float (weights .sum ())
251+ for i , x in enumerate (uniques ):
252+ if x not in self .bins :
253+ self .bins [x ] = self .value .zero ()
254+
255+ numpy .not_equal (inverse , i , selection )
256+ subweights [:] = weights
257+ subweights [selection ] = 0.0
258+ self .bins [x ]._numpy (data , subweights , shape )
259+
260+ self .entries += float (newentries )
250261
251262 def _sparksql (self , jvm , converter ):
252263 return converter .Categorize (self .quantity .asSparkSQL (), self .value ._sparksql (jvm , converter ))
0 commit comments