diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index 39553f7554..ce5656cfca 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -21,6 +21,7 @@ use std::any::Any; use std::collections::HashMap; use std::future::Future; use std::str::FromStr; +use std::sync::Arc; use async_trait::async_trait; use iceberg::io::{self, FileIO}; @@ -38,7 +39,8 @@ use tokio::sync::OnceCell; use typed_builder::TypedBuilder; use crate::client::{ - HttpClient, deserialize_catalog_response, deserialize_unexpected_catalog_error, + CustomAuthenticator, HttpClient, deserialize_catalog_response, + deserialize_unexpected_catalog_error, }; use crate::types::{ CatalogConfig, CommitTableRequest, CommitTableResponse, CreateTableRequest, @@ -67,6 +69,7 @@ impl Default for RestCatalogBuilder { warehouse: None, props: HashMap::new(), client: None, + authenticator: None, }) } } @@ -124,6 +127,24 @@ impl RestCatalogBuilder { self.0.client = Some(client); self } + + /// Set a custom token authenticator. + /// + /// The authenticator will be used to obtain tokens instead of using static tokens + /// or OAuth credentials. + /// + /// # Example + /// ```ignore + /// let authenticator = Arc::new(MyAuthenticator::new()); + /// let catalog = RestCatalogBuilder::default() + /// .with_token_authenticator(authenticator) + /// .load("rest", config) + /// .await?; + /// ``` + pub fn with_token_authenticator(mut self, authenticator: Arc) -> Self { + self.0.authenticator = Some(authenticator); + self + } } /// Rest catalog configuration. @@ -142,6 +163,9 @@ pub(crate) struct RestCatalogConfig { #[builder(default)] client: Option, + + #[builder(default)] + authenticator: Option>, } impl RestCatalogConfig { @@ -349,7 +373,13 @@ impl RestCatalog { async fn context(&self) -> Result<&RestContext> { self.ctx .get_or_try_init(|| async { - let client = HttpClient::new(&self.user_config)?; + let mut client = HttpClient::new(&self.user_config)?; + + // Set authenticator if one was configured + if let Some(authenticator) = &self.user_config.authenticator { + client = client.with_authenticator(authenticator.clone()); + } + let catalog_config = RestCatalog::load_config(&client, &self.user_config).await?; let config = self.user_config.clone().merge_with_config(catalog_config); let client = client.update_with(&config)?; diff --git a/crates/catalog/rest/src/client.rs b/crates/catalog/rest/src/client.rs index 361c036bb6..e8c3307e98 100644 --- a/crates/catalog/rest/src/client.rs +++ b/crates/catalog/rest/src/client.rs @@ -17,6 +17,7 @@ use std::collections::HashMap; use std::fmt::{Debug, Formatter}; +use std::sync::Arc; use http::StatusCode; use iceberg::{Error, ErrorKind, Result}; @@ -28,6 +29,17 @@ use tokio::sync::Mutex; use crate::RestCatalogConfig; use crate::types::{ErrorResponse, TokenResponse}; +/// Trait for custom token authentication. +/// +/// Implement this trait to provide custom token generation/refresh logic +/// instead of using OAuth credentials. +#[async_trait::async_trait] +pub trait CustomAuthenticator: Send + Sync + Debug { + /// Get or refresh the authentication token. + /// Called when the client needs a token for authentication. + async fn get_token(&self) -> Result; +} + pub(crate) struct HttpClient { client: Client, @@ -39,6 +51,8 @@ pub(crate) struct HttpClient { token_endpoint: String, /// The credential to be used for authentication. credential: Option<(Option, String)>, + /// Custom token authenticator (takes precedence over credential/token) + authenticator: Option>, /// Extra headers to be added to each request. extra_headers: HeaderMap, /// Extra oauth parameters to be added to each authentication request. @@ -63,6 +77,7 @@ impl HttpClient { token: Mutex::new(cfg.token()), token_endpoint: cfg.get_token_endpoint(), credential: cfg.credential(), + authenticator: None, extra_headers, extra_oauth_params: cfg.extra_oauth_params(), }) @@ -86,6 +101,7 @@ impl HttpClient { self.token_endpoint }, credential: cfg.credential().or(self.credential), + authenticator: self.authenticator, extra_headers, extra_oauth_params: if !cfg.extra_oauth_params().is_empty() { cfg.extra_oauth_params() @@ -174,6 +190,27 @@ impl HttpClient { Ok(auth_res.access_token) } + /// Set a custom token authenticator. + /// + /// When set, the authenticator will be called to get tokens instead of using + /// static tokens or OAuth credentials. This allows for custom token management + /// such as reading from files, APIs, or other custom sources. + pub fn with_authenticator(mut self, authenticator: Arc) -> Self { + self.authenticator = Some(authenticator); + self + } + + /// Add bearer token to request authorization header. + fn set_bearer_token(req: &mut Request, token: &str, error_msg: &str) -> Result<()> { + req.headers_mut().insert( + http::header::AUTHORIZATION, + format!("Bearer {token}") + .parse() + .map_err(|e| Error::new(ErrorKind::DataInvalid, error_msg).with_source(e))?, + ); + Ok(()) + } + /// Invalidate the current token without generating a new one. On the next request, the client /// will attempt to generate a new token. pub(crate) async fn invalidate_token(&self) -> Result<()> { @@ -195,18 +232,24 @@ impl HttpClient { /// Authenticates the request by adding a bearer token to the authorization header. /// - /// This method supports three authentication modes: + /// This method supports four authentication modes (in order of precedence): /// - /// 1. **No authentication** - Skip authentication when both `credential` and `token` are missing. - /// 2. **Token authentication** - Use the provided `token` directly for authentication. - /// 3. **OAuth authentication** - Exchange `credential` for a token, cache it, then use it for authentication. + /// 1. **Custom authenticator** - If set, use the custom CustomAuthenticator to get tokens. + /// 2. **Token authentication** - Use the provided static `token` directly. + /// 3. **OAuth authentication** - Exchange `credential` for a token, cache it, then use it. + /// 4. **No authentication** - Skip authentication when none of the above are available. /// - /// When both `credential` and `token` are present, `token` takes precedence. - /// - /// # TODO: Support automatic token refreshing. + /// When an authenticator is provided, it takes precedence over static tokens and credentials. async fn authenticate(&self, req: &mut Request) -> Result<()> { + // Try authenticator first (highest priority) + if let Some(authenticator) = &self.authenticator { + let token = authenticator.get_token().await?; + Self::set_bearer_token(req, &token, "Invalid custom token")?; + return Ok(()); + } + // Clone the token from lock without holding the lock for entire function. - let token = self.token.lock().await.clone(); + let token: Option = self.token.lock().await.clone(); if self.credential.is_none() && token.is_none() { return Ok(()); @@ -224,18 +267,7 @@ impl HttpClient { } }; - // Insert token in request. - req.headers_mut().insert( - http::header::AUTHORIZATION, - format!("Bearer {token}").parse().map_err(|e| { - Error::new( - ErrorKind::DataInvalid, - "Invalid token received from catalog server!", - ) - .with_source(e) - })?, - ); - + Self::set_bearer_token(req, &token, "Invalid token received from catalog server!")?; Ok(()) } diff --git a/crates/catalog/rest/src/lib.rs b/crates/catalog/rest/src/lib.rs index 70cdeaabd0..c8e1b98877 100644 --- a/crates/catalog/rest/src/lib.rs +++ b/crates/catalog/rest/src/lib.rs @@ -56,3 +56,4 @@ mod client; mod types; pub use catalog::*; +pub use client::CustomAuthenticator; diff --git a/crates/catalog/rest/tests/rest_catalog_test.rs b/crates/catalog/rest/tests/rest_catalog_test.rs index 59fea0b51f..30be3f2f57 100644 --- a/crates/catalog/rest/tests/rest_catalog_test.rs +++ b/crates/catalog/rest/tests/rest_catalog_test.rs @@ -19,13 +19,19 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::RwLock; +use std::sync::{Arc, Mutex, RwLock}; +use async_trait::async_trait; use ctor::{ctor, dtor}; use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type}; use iceberg::transaction::{ApplyTransactionAction, Transaction}; -use iceberg::{Catalog, CatalogBuilder, Namespace, NamespaceIdent, TableCreation, TableIdent}; -use iceberg_catalog_rest::{REST_CATALOG_PROP_URI, RestCatalog, RestCatalogBuilder}; +use iceberg::{ + Catalog, CatalogBuilder, Namespace, NamespaceIdent, Result as IcebergResult, TableCreation, + TableIdent, +}; +use iceberg_catalog_rest::{ + CustomAuthenticator, REST_CATALOG_PROP_URI, RestCatalog, RestCatalogBuilder, +}; use iceberg_test_utils::docker::DockerCompose; use iceberg_test_utils::{normalize_test_name, set_up}; use port_scanner::scan_port_addr; @@ -449,3 +455,137 @@ async fn test_register_table() { table_registered.identifier().to_string() ); } + +#[derive(Debug)] +struct CountingAuthenticator { + count: Arc>, +} + +#[async_trait] +impl CustomAuthenticator for CountingAuthenticator { + async fn get_token(&self) -> IcebergResult { + let mut c = self.count.lock().unwrap(); + *c += 1; + // Return a unique token each time to ensure dynamic generation + Ok(format!("token_{}", *c)) + } +} + +async fn get_catalog_with_authenticator( + authenticator: Arc, +) -> RestCatalog { + set_up(); + + let rest_catalog_ip = { + let guard = DOCKER_COMPOSE_ENV.read().unwrap(); + let docker_compose = guard.as_ref().unwrap(); + docker_compose.get_container_ip("rest") + }; + + let rest_socket_addr = SocketAddr::new(rest_catalog_ip, REST_CATALOG_PORT); + while !scan_port_addr(rest_socket_addr) { + info!("Waiting for 1s rest catalog to ready..."); + sleep(std::time::Duration::from_millis(1000)).await; + } + + RestCatalogBuilder::default() + .with_token_authenticator(authenticator) + .load( + "rest", + HashMap::from([( + REST_CATALOG_PROP_URI.to_string(), + format!("http://{rest_socket_addr}"), + )]), + ) + .await + .unwrap() +} + +#[tokio::test] +async fn test_authenticator_token_refresh() { + // Track how many times tokens were requested + let token_request_count = Arc::new(Mutex::new(0)); + let token_request_count_clone = token_request_count.clone(); + + let authenticator = Arc::new(CountingAuthenticator { + count: token_request_count_clone, + }); + + let catalog_with_auth = get_catalog_with_authenticator(authenticator).await; + + // Perform multiple operations that should trigger token requests + let ns1 = Namespace::with_properties( + NamespaceIdent::from_strs(["test_refresh_1"]).unwrap(), + HashMap::new(), + ); + catalog_with_auth + .create_namespace(ns1.name(), HashMap::new()) + .await + .unwrap(); + + let ns2 = Namespace::with_properties( + NamespaceIdent::from_strs(["test_refresh_2"]).unwrap(), + HashMap::new(), + ); + catalog_with_auth + .create_namespace(ns2.name(), HashMap::new()) + .await + .unwrap(); + + // Verify authenticator was called multiple times + let count = *token_request_count.lock().unwrap(); + assert!( + count >= 2, + "Authenticator should have been called at least twice, but was called {} times", + count + ); +} + +#[tokio::test] +async fn test_authenticator_persists_across_operations() { + let operation_count = Arc::new(Mutex::new(0)); + let operation_count_clone = operation_count.clone(); + + let authenticator = Arc::new(CountingAuthenticator { + count: operation_count_clone, + }); + + let catalog_with_auth = get_catalog_with_authenticator(authenticator).await; + + // Create a namespace + let ns = Namespace::with_properties( + NamespaceIdent::from_strs(["test_persist", "auth"]).unwrap(), + HashMap::new(), + ); + catalog_with_auth + .create_namespace(ns.name(), HashMap::new()) + .await + .unwrap(); + + let count_after_create = *operation_count.lock().unwrap(); + + // List the namespace children (should use the same authenticator) + // We need to list children of "test_persist" to find "auth" + let list_result = catalog_with_auth + .list_namespaces(Some(&NamespaceIdent::from_strs(["test_persist"]).unwrap())) + .await + .unwrap(); + assert!( + list_result.contains(&NamespaceIdent::from_strs(["test_persist", "auth"]).unwrap()), + "Namespace {:?} not found in list {:?}", + ns.name(), + list_result + ); + + let count_after_list = *operation_count.lock().unwrap(); + + // Verify authenticator was used for both operations + assert!( + count_after_create > 0, + "Authenticator should be used for create" + ); + assert!( + count_after_list > count_after_create, + "Authenticator should be used for list operation too" + ); +}