Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions rl.install
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ function rl_install() {
'not null' => TRUE,
'description' => 'Module that owns this experiment',
],
'experiment_name' => [
'type' => 'varchar',
'length' => 255,
'not null' => FALSE,
'description' => 'Human-readable experiment name',
],
'registered_at' => [
'type' => 'int',
'unsigned' => TRUE,
Expand Down Expand Up @@ -326,6 +332,12 @@ function rl_schema() {
'not null' => TRUE,
'description' => 'Module that owns this experiment',
],
'experiment_name' => [
'type' => 'varchar',
'length' => 255,
'not null' => FALSE,
'description' => 'Human-readable experiment name',
],
'registered_at' => [
'type' => 'int',
'unsigned' => TRUE,
Expand Down
44 changes: 24 additions & 20 deletions rl.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,62 +11,70 @@
use Drupal\Core\DrupalKernel;
use Symfony\Component\HttpFoundation\Request;

// CRITICAL: Only accept POST requests for security and caching reasons.
$action = filter_input(INPUT_POST, 'action', FILTER_SANITIZE_FULL_SPECIAL_CHARS);
$experiment_uuid = filter_input(INPUT_POST, 'experiment_uuid', FILTER_SANITIZE_FULL_SPECIAL_CHARS);
$arm_id = filter_input(INPUT_POST, 'arm_id', FILTER_SANITIZE_FULL_SPECIAL_CHARS);

// Validate inputs more strictly.
if (!$action || !$experiment_uuid || !in_array($action, ['turn', 'turns', 'reward'])) {
http_response_code(400);
exit();
exit('Invalid request parameters');
}

// Additional validation for experiment_uuid (should be alphanumeric/hash)
if (!preg_match('/^[a-zA-Z0-9]+$/', $experiment_uuid)) {
http_response_code(400);
exit();
exit('Invalid experiment_uuid format');
}

// Catch exceptions when site is not configured or storage fails.
try {
// Assumes module in modules/contrib/rl, so three levels below root.
chdir('../../..');
$levels_up = '../../../';

chdir($levels_up);
$drupal_root = getcwd();
$autoload_path = $drupal_root . '/../vendor/autoload.php';

if (!file_exists($autoload_path)) {
$script_filename = $_SERVER['SCRIPT_FILENAME'] ?? '';
if (!preg_match('/^[a-zA-Z0-9\/_.-]+$/', $script_filename)) {
http_response_code(500);
exit('Invalid script filename');
}

$drupal_root = dirname(dirname(dirname(dirname($script_filename))));
$autoload_path = $drupal_root . '/../vendor/autoload.php';

if (!file_exists($autoload_path)) {
http_response_code(500);
exit('Drupal autoload.php not found');
}
}

$autoloader = require_once $drupal_root . '/autoload.php';
$autoloader = require_once $autoload_path;

$request = Request::createFromGlobals();
$kernel = DrupalKernel::createFromRequest($request, $autoloader, 'prod');
$kernel->boot();
$container = $kernel->getContainer();

// Check if experiment is registered.
$registry = $container->get('rl.experiment_registry');
if (!$registry->isRegistered($experiment_uuid)) {
// Silently ignore unregistered experiments like statistics module.
exit();
}

// Get the experiment data storage service.
$storage = $container->get('rl.experiment_data_storage');

// Handle the different actions.
switch ($action) {
case 'turn':
// Validate arm_id for single turn.
if ($arm_id && preg_match('/^[a-zA-Z0-9_-]+$/', $arm_id)) {
$storage->recordTurn($experiment_uuid, $arm_id);
}
break;

case 'turns':
// Handle multiple turns with better validation.
$arm_ids = filter_input(INPUT_POST, 'arm_ids', FILTER_SANITIZE_FULL_SPECIAL_CHARS);
if ($arm_ids) {
$arm_ids_array = explode(',', $arm_ids);
$arm_ids_array = array_map('trim', $arm_ids_array);

// Validate each arm_id.
$valid_arm_ids = [];
foreach ($arm_ids_array as $aid) {
if (preg_match('/^[a-zA-Z0-9_-]+$/', $aid)) {
Expand All @@ -81,16 +89,12 @@
break;

case 'reward':
// Validate arm_id for reward.
if ($arm_id && preg_match('/^[a-zA-Z0-9_-]+$/', $arm_id)) {
$storage->recordReward($experiment_uuid, $arm_id);
}
break;
}

// Send success response.
http_response_code(200);
}
catch (\Exception $e) {
// Do nothing if there is PDO Exception or other failure.
}
7 changes: 3 additions & 4 deletions src/Controller/ReportsController.php
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public function experimentsOverview() {

// Get all registered experiments with their totals (if any)
$query = $this->database->select('rl_experiment_registry', 'er')
->fields('er', ['uuid', 'module', 'registered_at']);
->fields('er', ['uuid', 'module', 'experiment_name', 'registered_at']);
$query->leftJoin('rl_experiment_totals', 'et', 'er.uuid = et.experiment_uuid');
$query->addField('et', 'total_turns', 'total_turns');
$query->addField('et', 'created', 'totals_created');
Expand Down Expand Up @@ -140,9 +140,8 @@ public function experimentsOverview() {
? $this->dateFormatter->format($last_activity_timestamp, 'short')
: $this->t('Never');

// Get decorated experiment name or fallback to UUID.
$experiment_display = $this->decoratorManager->decorateExperiment($experiment->uuid);
$experiment_name = $experiment_display ? \Drupal::service('renderer')->renderPlain($experiment_display) : $experiment->uuid;
// Use experiment name from registry or fallback to UUID.
$experiment_name = $experiment->experiment_name ?: $experiment->uuid;

$rows[] = [
['data' => ['#markup' => $operations_markup]],
Expand Down
16 changes: 11 additions & 5 deletions src/Registry/ExperimentRegistry.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,21 @@ public function __construct(Connection $database) {
/**
* {@inheritdoc}
*/
public function register(string $uuid, string $module): void {
public function register(string $uuid, string $module, ?string $experiment_name = NULL): void {
try {
// Use merge to handle duplicate registrations gracefully.
$fields = [
'module' => $module,
'registered_at' => \Drupal::time()->getRequestTime(),
];

if ($experiment_name !== NULL) {
$fields['experiment_name'] = $experiment_name;
}

$this->database->merge('rl_experiment_registry')
->key(['uuid' => $uuid])
->fields([
'module' => $module,
'registered_at' => \Drupal::time()->getRequestTime(),
])
->fields($fields)
->execute();
}
catch (\Exception $e) {
Expand Down
4 changes: 3 additions & 1 deletion src/Registry/ExperimentRegistryInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ interface ExperimentRegistryInterface {
* The experiment UUID.
* @param string $module
* The module name that owns this experiment.
* @param string $experiment_name
* Optional human-readable experiment name.
*/
public function register(string $uuid, string $module): void;
public function register(string $uuid, string $module, ?string $experiment_name = NULL): void;

/**
* Check if an experiment UUID is registered.
Expand Down
33 changes: 31 additions & 2 deletions src/Service/ExperimentManager.php
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,40 @@ public function getTotalTurns($experiment_uuid) {
/**
* {@inheritdoc}
*/
public function getThompsonScores($experiment_uuid, $time_window_seconds = NULL) {
public function getThompsonScores($experiment_uuid, $time_window_seconds = NULL, array $requested_arms = []) {
$arms_data = $this->storage->getAllArmsData($experiment_uuid, $time_window_seconds);

// If specific arms are requested, ensure they all have scores.
// New arms get initialized with zero stats for maximum exploration.
if (!empty($requested_arms)) {
foreach ($requested_arms as $arm_id) {
if (!isset($arms_data[$arm_id])) {
// New arm: initialize with zero stats (0 turns, 0 rewards).
// Thompson sampling will give these high exploration scores.
$arms_data[$arm_id] = (object) [
'arm_id' => $arm_id,
'turns' => 0,
'rewards' => 0,
];
}
}
}

// Complete cold start: no arms at all.
if (empty($arms_data)) {
return [];
// If no specific arms requested, we can't generate scores.
if (empty($requested_arms)) {
return [];
}

// If arms were requested, initialize them all as new.
foreach ($requested_arms as $arm_id) {
$arms_data[$arm_id] = (object) [
'arm_id' => $arm_id,
'turns' => 0,
'rewards' => 0,
];
}
}

return $this->tsCalculator->calculateThompsonScores($arms_data);
Expand Down
8 changes: 6 additions & 2 deletions src/Service/ExperimentManagerInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ public function getTotalTurns($experiment_uuid);
* The experiment UUID.
* @param int|null $time_window_seconds
* Optional time window in seconds. Only considers arms active within this timeframe.
* @param array $requested_arms
* Optional array of arm IDs that need scores. New arms will be initialized
* with zero stats (0 turns, 0 rewards) to ensure maximum exploration.
*
* @return array
* Array of Thompson Sampling scores keyed by arm_id.
* Array of Thompson Sampling scores keyed by arm_id. Returns empty array
* only if no arms exist AND no requested_arms were provided.
*/
public function getThompsonScores($experiment_uuid, $time_window_seconds = NULL);
public function getThompsonScores($experiment_uuid, $time_window_seconds = NULL, array $requested_arms = []);

}
7 changes: 4 additions & 3 deletions src/Service/ThompsonCalculator.php
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ public function calculateThompsonScores(array $arms_data): array {
$scores = [];

foreach ($arms_data as $id => $arm) {
// Good ratings + 1.
$alpha = $arm->rewards + 1;
// Bad ratings + 1.
$beta = ($arm->turns - $arm->rewards) + 1;
$scores[$id] = $this->randBeta($alpha, $beta);
$base_score = $this->randBeta($alpha, $beta);

$tie_breaker = mt_rand(1, 999) / 1000000;
$scores[$id] = $base_score + $tie_breaker;
}
return $scores;
}
Expand Down
27 changes: 26 additions & 1 deletion src/Storage/ExperimentDataStorage.php
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,34 @@ public function recordTurn($experiment_uuid, $arm_id) {
* {@inheritdoc}
*/
public function recordTurns($experiment_uuid, array $arm_ids) {
$timestamp = \Drupal::time()->getRequestTime();
$arm_count = count($arm_ids);

// Record a turn for each arm (each arm gets exposure).
foreach ($arm_ids as $arm_id) {
$this->recordTurn($experiment_uuid, $arm_id);
$this->database->merge('rl_arm_data')
->key(['experiment_uuid' => $experiment_uuid, 'arm_id' => $arm_id])
->fields([
'turns' => 1,
'created' => $timestamp,
'updated' => $timestamp,
])
->expression('turns', 'turns + :inc', [':inc' => 1])
->expression('updated', ':timestamp', [':timestamp' => $timestamp])
->execute();
}

// Record total turns = number of arms shown (sum of individual turns).
$this->database->merge('rl_experiment_totals')
->key(['experiment_uuid' => $experiment_uuid])
->fields([
'total_turns' => $arm_count,
'created' => $timestamp,
'updated' => $timestamp,
])
->expression('total_turns', 'total_turns + :inc', [':inc' => $arm_count])
->expression('updated', ':timestamp', [':timestamp' => $timestamp])
->execute();
}

/**
Expand Down