@@ -119,20 +119,28 @@ def strategy(self, strategy):
119
119
' strategy="npoints", or strategy="cycle" is implemented.'
120
120
)
121
121
122
+ def _to_select (self , total_points ):
123
+ to_select = []
124
+ for index , learner in enumerate (self .learners ):
125
+ # Take the points from the cache
126
+ if index not in self ._ask_cache :
127
+ self ._ask_cache [index ] = learner .ask (n = 1 , tell_pending = False )
128
+ points , loss_improvements = self ._ask_cache [index ]
129
+ if not points :
130
+ # cannot ask for more points
131
+ return to_select
132
+ to_select .append (
133
+ ((index , points [0 ]), (loss_improvements [0 ], - total_points [index ]))
134
+ )
135
+ return to_select
136
+
122
137
def _ask_and_tell_based_on_loss_improvements (self , n ):
123
138
selected = [] # tuples ((learner_index, point), loss_improvement)
124
139
total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
125
140
for _ in range (n ):
126
- to_select = []
127
- for index , learner in enumerate (self .learners ):
128
- # Take the points from the cache
129
- if index not in self ._ask_cache :
130
- self ._ask_cache [index ] = learner .ask (n = 1 , tell_pending = False )
131
- points , loss_improvements = self ._ask_cache [index ]
132
- to_select .append (
133
- ((index , points [0 ]), (loss_improvements [0 ], - total_points [index ]))
134
- )
135
-
141
+ to_select = self ._to_select (total_points )
142
+ if not to_select :
143
+ break
136
144
# Choose the optimal improvement.
137
145
(index , point ), (loss_improvement , _ ) = max (to_select , key = itemgetter (1 ))
138
146
total_points [index ] += 1
0 commit comments