Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
use Scheb\TwoFactorBundle\Security\Authentication\Token\TwoFactorTokenInterface;
use Scheb\TwoFactorBundle\Security\TwoFactor\AuthenticationContextInterface;
use Scheb\TwoFactorBundle\Security\TwoFactor\Provider\Exception\UnknownTwoFactorProviderException;
use function array_walk;
use function count;

/**
* @final
Expand All @@ -21,12 +23,9 @@ public function __construct(
) {
}

/**
* @return string[]
*/
private function getActiveTwoFactorProviders(AuthenticationContextInterface $context): array
public function beginTwoFactorAuthentication(AuthenticationContextInterface $context): TwoFactorTokenInterface|null
{
$activeTwoFactorProviders = [];
$activeTwoFactorProviders = $statelessProviders = [];

// Iterate over two-factor providers and begin the two-factor authentication process.
foreach ($this->providerRegistry->getAllProviders() as $providerName => $provider) {
Expand All @@ -35,32 +34,32 @@ private function getActiveTwoFactorProviders(AuthenticationContextInterface $con
}

$activeTwoFactorProviders[] = $providerName;
}
if ($provider->needsPreparation()) {
continue;
}

return $activeTwoFactorProviders;
}
$statelessProviders[] = $providerName;
}

public function beginTwoFactorAuthentication(AuthenticationContextInterface $context): TwoFactorTokenInterface|null
{
$activeTwoFactorProviders = $this->getActiveTwoFactorProviders($context);
if (0 === count($activeTwoFactorProviders)) {
return null;
}

$authenticatedToken = $context->getToken();
if ($activeTwoFactorProviders) {
$twoFactorToken = $this->twoFactorTokenFactory->create($authenticatedToken, $context->getFirewallName(), $activeTwoFactorProviders);
$twoFactorToken = $this->twoFactorTokenFactory->create($authenticatedToken, $context->getFirewallName(), $activeTwoFactorProviders);

$preferredProvider = $this->twoFactorProviderDecider->getPreferredTwoFactorProvider($activeTwoFactorProviders, $twoFactorToken, $context);
array_walk($statelessProviders, static fn (string $providerName) => $twoFactorToken->setTwoFactorProviderPrepared($providerName));

if (null !== $preferredProvider) {
try {
$twoFactorToken->preferTwoFactorProvider($preferredProvider);
} catch (UnknownTwoFactorProviderException) {
// Bad user input
}
}
$preferredProvider = $this->twoFactorProviderDecider->getPreferredTwoFactorProvider($activeTwoFactorProviders, $twoFactorToken, $context);

return $twoFactorToken;
if (null !== $preferredProvider) {
try {
$twoFactorToken->preferTwoFactorProvider($preferredProvider);
} catch (UnknownTwoFactorProviderException) {
// Bad user input
}
}

return null;
return $twoFactorToken;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ interface TwoFactorProviderInterface
*/
public function beginAuthentication(AuthenticationContextInterface $context): bool;

/**
* Determine whether this Provider needs to be prepared (if the prepareAuthentication method needs to be called).
*/
public function needsPreparation(): bool;

/**
* Do all steps necessary to prepare authentication, e.g. generate & send a code.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ public function beginAuthentication(AuthenticationContextInterface $context): bo
return $user instanceof TwoFactorInterface && $user->isEmailAuthEnabled();
}

public function needsPreparation(): bool
{
return true;
}

public function prepareAuthentication(object $user): void
{
if (!($user instanceof TwoFactorInterface)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ public function beginAuthentication(AuthenticationContextInterface $context): bo
return true;
}

public function needsPreparation(): bool
{
return false;
}

public function prepareAuthentication(object $user): void
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ public function beginAuthentication(AuthenticationContextInterface $context): bo
return true;
}

public function needsPreparation(): bool
{
return false;
}

public function prepareAuthentication(object $user): void
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class TwoFactorProviderInitiatorTest extends AbstractAuthenticationContextTestCa
protected function setUp(): void
{
$this->provider1 = $this->createMock(TwoFactorProviderInterface::class);
$this->provider1->method('needsPreparation')->willReturn(true);
$this->provider2 = $this->createMock(TwoFactorProviderInterface::class);
$this->provider2->method('needsPreparation')->willReturn(false);

$providerRegistry = $this->createMock(TwoFactorProviderRegistry::class);
$providerRegistry
Expand All @@ -36,6 +38,16 @@ protected function setUp(): void
'test1' => $this->provider1,
'test2' => $this->provider2,
]);
$providerRegistry
->expects($this->any())
->method('getProvider')
->willReturnCallback(function (string $name) {
return match ($name) {
'test1' => $this->provider1,
'test2' => $this->provider2,
default => null,
};
});

$this->twoFactorTokenFactory = $this->createMock(TwoFactorTokenFactory::class);

Expand Down Expand Up @@ -155,4 +167,21 @@ public function beginAuthentication_hasPreferredProvider_setThatProviderPreferre

$this->initiator->beginTwoFactorAuthentication($context);
}

#[Test]
public function beginAuthentication_statelessProviderPrepared_setThatProviderIsPrepared(): void
{
$originalToken = $this->createToken();
$context = $this->createAuthenticationContext(null, $originalToken);
$this->stubProvidersReturn(true, true);

$twoFactorToken = $this->createTwoFactorToken();
$this->stubTwoFactorTokenFactoryReturns($twoFactorToken);
$twoFactorToken
->expects($this->once())
->method('setTwoFactorProviderPrepared')
->with('test2');

$this->initiator->beginTwoFactorAuthentication($context);
}
}