Skip to content

Commit 884674a

Browse files
committed
fix problem when learner has no more points for BalancingLearner, closes #213
1 parent c602a86 commit 884674a

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

adaptive/learner/balancing_learner.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -119,20 +119,28 @@ def strategy(self, strategy):
119119
' strategy="npoints", or strategy="cycle" is implemented.'
120120
)
121121

122+
def _to_select(self, total_points):
123+
to_select = []
124+
for index, learner in enumerate(self.learners):
125+
# Take the points from the cache
126+
if index not in self._ask_cache:
127+
self._ask_cache[index] = learner.ask(n=1, tell_pending=False)
128+
points, loss_improvements = self._ask_cache[index]
129+
if not points:
130+
# cannot ask for more points
131+
return to_select
132+
to_select.append(
133+
((index, points[0]), (loss_improvements[0], -total_points[index]))
134+
)
135+
return to_select
136+
122137
def _ask_and_tell_based_on_loss_improvements(self, n):
123138
selected = [] # tuples ((learner_index, point), loss_improvement)
124139
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
125140
for _ in range(n):
126-
to_select = []
127-
for index, learner in enumerate(self.learners):
128-
# Take the points from the cache
129-
if index not in self._ask_cache:
130-
self._ask_cache[index] = learner.ask(n=1, tell_pending=False)
131-
points, loss_improvements = self._ask_cache[index]
132-
to_select.append(
133-
((index, points[0]), (loss_improvements[0], -total_points[index]))
134-
)
135-
141+
to_select = self._to_select(total_points)
142+
if not to_select:
143+
break
136144
# Choose the optimal improvement.
137145
(index, point), (loss_improvement, _) = max(to_select, key=itemgetter(1))
138146
total_points[index] += 1

0 commit comments

Comments
 (0)