SCIP Doxygen Documentation
Loading...
Searching...
No Matches
bandit_ucb.c
Go to the documentation of this file.
1/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2/* */
3/* This file is part of the program and library */
4/* SCIP --- Solving Constraint Integer Programs */
5/* */
6/* Copyright (c) 2002-2025 Zuse Institute Berlin (ZIB) */
7/* */
8/* Licensed under the Apache License, Version 2.0 (the "License"); */
9/* you may not use this file except in compliance with the License. */
10/* You may obtain a copy of the License at */
11/* */
12/* http://www.apache.org/licenses/LICENSE-2.0 */
13/* */
14/* Unless required by applicable law or agreed to in writing, software */
15/* distributed under the License is distributed on an "AS IS" BASIS, */
16/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17/* See the License for the specific language governing permissions and */
18/* limitations under the License. */
19/* */
20/* You should have received a copy of the Apache-2.0 license */
21/* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
22/* */
23/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24
25/**@file bandit_ucb.c
26 * @ingroup OTHER_CFILES
27 * @brief methods for UCB bandit selection
28 * @author Gregor Hendel
29 */
30
31/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32
33#include "scip/bandit.h"
34#include "scip/bandit_ucb.h"
35#include "scip/pub_bandit.h"
36#include "scip/pub_message.h"
37#include "scip/pub_misc.h"
38#include "scip/pub_misc_sort.h"
39#include "scip/scip_bandit.h"
40#include "scip/scip_mem.h"
42
43
44#define BANDIT_NAME "ucb"
45#define NUMEPS 1e-6
46
47/*
48 * Data structures
49 */
50
51/** implementation specific data of UCB bandit algorithm */
52struct SCIP_BanditData
53{
54 int nselections; /**< counter for the number of selections */
55 int* counter; /**< array of counters how often every action has been chosen */
56 int* startperm; /**< indices for starting permutation */
57 SCIP_Real* meanscores; /**< array of average scores for the actions */
58 SCIP_Real alpha; /**< parameter to increase confidence width */
59};
60
61
62/*
63 * Local methods
64 */
65
66/** data reset method */
67static
69 BMS_BUFMEM* bufmem, /**< buffer memory */
70 SCIP_BANDIT* ucb, /**< ucb bandit algorithm */
71 SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */
72 SCIP_Real* priorities, /**< priorities for start permutation, or NULL */
73 int nactions /**< number of actions */
74 )
75{
76 int i;
77 SCIP_RANDNUMGEN* rng;
78
79 assert(bufmem != NULL);
80 assert(ucb != NULL);
81 assert(nactions > 0);
82
83 /* clear counters and scores */
84 BMSclearMemoryArray(banditdata->counter, nactions);
85 BMSclearMemoryArray(banditdata->meanscores, nactions);
86 banditdata->nselections = 0;
87
88 rng = SCIPbanditGetRandnumgen(ucb);
89 assert(rng != NULL);
90
91 /* initialize start permutation as identity */
92 for( i = 0; i < nactions; ++i )
93 banditdata->startperm[i] = i;
94
95 /* prepare the start permutation in decreasing order of priority */
96 if( priorities != NULL )
97 {
98 SCIP_Real* prioritycopy;
99
100 SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
101
102 /* randomly wiggle priorities a little bit to make them unique */
103 for( i = 0; i < nactions; ++i )
104 prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
105
106 SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
107
108 BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
109 }
110 else
111 {
112 /* use a random start permutation */
113 SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
114 }
115
116 return SCIP_OKAY;
117}
118
119
120/*
121 * Callback methods of bandit algorithm
122 */
123
124/** callback to free bandit specific data structures */
125SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
126{ /*lint --e{715}*/
127 SCIP_BANDITDATA* banditdata;
128 int nactions;
129 assert(bandit != NULL);
130
131 banditdata = SCIPbanditGetData(bandit);
132 assert(banditdata != NULL);
133 nactions = SCIPbanditGetNActions(bandit);
134
135 BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
136 BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
137 BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
138 BMSfreeBlockMemory(blkmem, &banditdata);
139
140 SCIPbanditSetData(bandit, NULL);
141
142 return SCIP_OKAY;
143}
144
145/** selection callback for bandit selector */
146SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
147{ /*lint --e{715}*/
148 SCIP_BANDITDATA* banditdata;
149 int nactions;
150 int* counter;
151
152 assert(bandit != NULL);
154
155 banditdata = SCIPbanditGetData(bandit);
156 assert(banditdata != NULL);
157 nactions = SCIPbanditGetNActions(bandit);
158
159 counter = banditdata->counter;
160 /* select the next uninitialized action from the start permutation */
161 if( banditdata->nselections < nactions )
162 {
163 *selection = banditdata->startperm[banditdata->nselections];
164 assert(counter[*selection] == 0);
165 }
166 else
167 {
168 /* select the action with the highest upper confidence bound */
169 SCIP_Real* meanscores;
170 SCIP_Real widthfactor;
171 SCIP_Real maxucb;
172 int i;
174 meanscores = banditdata->meanscores;
175
176 assert(rng != NULL);
177 assert(meanscores != NULL);
178
179 /* compute the confidence width factor that is common for all actions */
180 widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
181 widthfactor = sqrt(widthfactor);
182 maxucb = -1.0;
183
184 /* loop over the actions and determine the maximum upper confidence bound.
185 * The upper confidence bound of an action is the sum of its mean score
186 * plus a confidence term that decreases with increasing number of observations of
187 * this action.
188 */
189 for( i = 0; i < nactions; ++i )
190 {
191 SCIP_Real uppercb;
192 SCIP_Real rootcount;
193 assert(counter[i] > 0);
194
195 /* compute the upper confidence bound for action i */
196 uppercb = meanscores[i];
197 rootcount = sqrt((SCIP_Real)counter[i]);
198 uppercb += widthfactor / rootcount;
199 assert(uppercb > 0);
200
201 /* update maximum, breaking ties uniformly at random */
202 if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
203 {
204 maxucb = uppercb;
205 *selection = i;
206 }
207 }
208 }
209
210 assert(*selection >= 0);
211 assert(*selection < nactions);
212
213 return SCIP_OKAY;
214}
215
216/** update callback for bandit algorithm */
217SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
218{ /*lint --e{715}*/
219 SCIP_BANDITDATA* banditdata;
220 SCIP_Real delta;
221
222 assert(bandit != NULL);
223
224 banditdata = SCIPbanditGetData(bandit);
225 assert(banditdata != NULL);
226 assert(selection >= 0);
228
229 /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
230 delta = score - banditdata->meanscores[selection];
231 ++banditdata->counter[selection];
232 banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
233
234 banditdata->nselections++;
235
236 return SCIP_OKAY;
237}
238
239/** reset callback for bandit algorithm */
240SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
241{ /*lint --e{715}*/
242 SCIP_BANDITDATA* banditdata;
243 int nactions;
244
245 assert(bufmem != NULL);
246 assert(bandit != NULL);
247
248 banditdata = SCIPbanditGetData(bandit);
249 assert(banditdata != NULL);
250 nactions = SCIPbanditGetNActions(bandit);
251
252 /* call the data reset for the given priorities */
253 SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
254
255 return SCIP_OKAY;
256}
257
258/*
259 * bandit algorithm specific interface methods
260 */
261
262/** returns the upper confidence bound of a selected action */
264 SCIP_BANDIT* ucb, /**< UCB bandit algorithm */
265 int action /**< index of the queried action */
266 )
267{
268 SCIP_Real uppercb;
269 SCIP_BANDITDATA* banditdata;
270 int nactions;
271
272 assert(ucb != NULL);
273 banditdata = SCIPbanditGetData(ucb);
274 nactions = SCIPbanditGetNActions(ucb);
275 assert(action < nactions);
276
277 /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
278 if( banditdata->nselections < nactions )
279 return 1.0;
280
281 /* the bandit algorithm must have picked every action once */
282 assert(banditdata->counter[action] > 0);
283 uppercb = banditdata->meanscores[action];
284
285 uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
286
287 return uppercb;
288}
289
290/** return start permutation of the UCB bandit algorithm */
292 SCIP_BANDIT* ucb /**< UCB bandit algorithm */
293 )
294{
295 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
296
297 assert(banditdata != NULL);
298
299 return banditdata->startperm;
300}
301
302/** internal method to create and reset UCB bandit algorithm */
304 BMS_BLKMEM* blkmem, /**< block memory */
305 BMS_BUFMEM* bufmem, /**< buffer memory */
306 SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */
307 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
308 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
309 SCIP_Real alpha, /**< parameter to increase confidence width */
310 int nactions, /**< the positive number of actions for this bandit algorithm */
311 unsigned int initseed /**< initial random seed */
312 )
313{
314 SCIP_BANDITDATA* banditdata;
315
316 if( alpha < 0.0 )
317 {
318 SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
319 return SCIP_INVALIDDATA;
320 }
321
322 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
323 assert(banditdata != NULL);
324
325 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
326 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
327 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
328
329 banditdata->alpha = alpha;
330
331 SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
332
333 return SCIP_OKAY;
334}
335
336/** create and reset UCB bandit algorithm */
338 SCIP* scip, /**< SCIP data structure */
339 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
340 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
341 SCIP_Real alpha, /**< parameter to increase confidence width */
342 int nactions, /**< the positive number of actions for this bandit algorithm */
343 unsigned int initseed /**< initial random number seed */
344 )
345{
346 SCIP_BANDITVTABLE* vtable;
347
349 if( vtable == NULL )
350 {
351 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
352 return SCIP_INVALIDDATA;
353 }
354
356 priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
357
358 return SCIP_OKAY;
359}
360
361/** include virtual function table for UCB bandit algorithms */
363 SCIP* scip /**< SCIP data structure */
364 )
365{
366 SCIP_BANDITVTABLE* vtable;
367
369 SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
370 assert(vtable != NULL);
371
372 return SCIP_OKAY;
373}
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition bandit.c:200
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition bandit.c:42
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition bandit.c:190
internal methods for bandit algorithms
#define BANDIT_NAME
static SCIP_RETCODE dataReset(BMS_BUFMEM *bufmem, SCIP_BANDIT *ucb, SCIP_BANDITDATA *banditdata, SCIP_Real *priorities, int nactions)
Definition bandit_ucb.c:68
#define NUMEPS
Definition bandit_ucb.c:45
SCIP_RETCODE SCIPincludeBanditvtableUcb(SCIP *scip)
Definition bandit_ucb.c:362
SCIP_RETCODE SCIPbanditCreateUcb(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition bandit_ucb.c:303
internal methods for UCB bandit algorithm
#define NULL
Definition def.h:262
#define LOG1P(x)
Definition def.h:218
#define SCIP_ALLOC(x)
Definition def.h:380
#define SCIP_Real
Definition def.h:172
#define EPSEQ(x, y, eps)
Definition def.h:197
#define EPSGT(x, y, eps)
Definition def.h:200
#define SCIP_CALL(x)
Definition def.h:369
void SCIPrandomPermuteIntArray(SCIP_RANDNUMGEN *randnumgen, int *array, int begin, int end)
Definition misc.c:10150
int * SCIPgetStartPermutationUcb(SCIP_BANDIT *ucb)
Definition bandit_ucb.c:291
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition bandit.c:303
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition bandit.c:293
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition scip_bandit.c:80
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)),)
Definition scip_bandit.c:48
SCIP_Real SCIPgetConfidenceBoundUcb(SCIP_BANDIT *ucb, int action)
Definition bandit_ucb.c:263
SCIP_RETCODE SCIPcreateBanditUcb(SCIP *scip, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition bandit_ucb.c:337
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition scip_mem.c:72
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition misc.c:10131
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
void SCIPsortDownRealInt(SCIP_Real *realarray, int *intarray, int len)
return SCIP_OKAY
int selection
assert(minobj< SCIPgetCutoffbound(scip))
SCIP_Real alpha
#define BMSfreeBlockMemory(mem, ptr)
Definition memory.h:465
#define BMSduplicateBufferMemoryArray(mem, ptr, source, num)
Definition memory.h:737
#define BMSallocBlockMemory(mem, ptr)
Definition memory.h:451
#define BMSfreeBufferMemoryArray(mem, ptr)
Definition memory.h:742
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition memory.h:454
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition memory.h:467
#define BMSclearMemoryArray(ptr, num)
Definition memory.h:130
struct BMS_BufMem BMS_BUFMEM
Definition memory.h:721
struct BMS_BlkMem BMS_BLKMEM
Definition memory.h:437
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition scip_mem.c:57
public methods for bandit algorithms
public methods for message output
#define SCIPerrorMessage
Definition pub_message.h:64
public data structures and miscellaneous methods
methods for sorting joint arrays of various types
public methods for bandit algorithms
public methods for memory management
public methods for random numbers
#define SCIP_DECL_BANDITUPDATE(x)
Definition type_bandit.h:75
#define SCIP_DECL_BANDITFREE(x)
Definition type_bandit.h:63
struct SCIP_Bandit SCIP_BANDIT
Definition type_bandit.h:50
struct SCIP_BanditData SCIP_BANDITDATA
Definition type_bandit.h:56
#define SCIP_DECL_BANDITSELECT(x)
Definition type_bandit.h:69
struct SCIP_BanditVTable SCIP_BANDITVTABLE
Definition type_bandit.h:53
#define SCIP_DECL_BANDITRESET(x)
Definition type_bandit.h:82
struct SCIP_RandNumGen SCIP_RANDNUMGEN
Definition type_misc.h:126
@ SCIP_INVALIDDATA
enum SCIP_Retcode SCIP_RETCODE
struct Scip SCIP
Definition type_scip.h:39