Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions crates/catalog/rest/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::any::Any;
use std::collections::HashMap;
use std::future::Future;
use std::str::FromStr;
use std::sync::Arc;

use async_trait::async_trait;
use iceberg::io::{self, FileIO};
Expand All @@ -38,7 +39,8 @@ use tokio::sync::OnceCell;
use typed_builder::TypedBuilder;

use crate::client::{
HttpClient, deserialize_catalog_response, deserialize_unexpected_catalog_error,
CustomAuthenticator, HttpClient, deserialize_catalog_response,
deserialize_unexpected_catalog_error,
};
use crate::types::{
CatalogConfig, CommitTableRequest, CommitTableResponse, CreateTableRequest,
Expand Down Expand Up @@ -67,6 +69,7 @@ impl Default for RestCatalogBuilder {
warehouse: None,
props: HashMap::new(),
client: None,
authenticator: None,
})
}
}
Expand Down Expand Up @@ -124,6 +127,24 @@ impl RestCatalogBuilder {
self.0.client = Some(client);
self
}

/// Set a custom token authenticator.
///
/// The authenticator will be used to obtain tokens instead of using static tokens
/// or OAuth credentials.
///
/// # Example
/// ```ignore
/// let authenticator = Arc::new(MyAuthenticator::new());
/// let catalog = RestCatalogBuilder::default()
/// .with_token_authenticator(authenticator)
/// .load("rest", config)
/// .await?;
/// ```
pub fn with_token_authenticator(mut self, authenticator: Arc<dyn CustomAuthenticator>) -> Self {
self.0.authenticator = Some(authenticator);
self
}
}

/// Rest catalog configuration.
Expand All @@ -142,6 +163,9 @@ pub(crate) struct RestCatalogConfig {

#[builder(default)]
client: Option<Client>,

#[builder(default)]
authenticator: Option<Arc<dyn CustomAuthenticator>>,
}

impl RestCatalogConfig {
Expand Down Expand Up @@ -349,7 +373,13 @@ impl RestCatalog {
async fn context(&self) -> Result<&RestContext> {
self.ctx
.get_or_try_init(|| async {
let client = HttpClient::new(&self.user_config)?;
let mut client = HttpClient::new(&self.user_config)?;

// Set authenticator if one was configured
if let Some(authenticator) = &self.user_config.authenticator {
client = client.with_authenticator(authenticator.clone());
}

let catalog_config = RestCatalog::load_config(&client, &self.user_config).await?;
let config = self.user_config.clone().merge_with_config(catalog_config);
let client = client.update_with(&config)?;
Expand Down
72 changes: 52 additions & 20 deletions crates/catalog/rest/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use http::StatusCode;
use iceberg::{Error, ErrorKind, Result};
Expand All @@ -28,6 +29,17 @@ use tokio::sync::Mutex;
use crate::RestCatalogConfig;
use crate::types::{ErrorResponse, TokenResponse};

/// Trait for custom token authentication.
///
/// Implement this trait to provide custom token generation/refresh logic
/// instead of using OAuth credentials.
#[async_trait::async_trait]
pub trait CustomAuthenticator: Send + Sync + Debug {
/// Get or refresh the authentication token.
/// Called when the client needs a token for authentication.
async fn get_token(&self) -> Result<String>;
}

pub(crate) struct HttpClient {
client: Client,

Expand All @@ -39,6 +51,8 @@ pub(crate) struct HttpClient {
token_endpoint: String,
/// The credential to be used for authentication.
credential: Option<(Option<String>, String)>,
/// Custom token authenticator (takes precedence over credential/token)
authenticator: Option<Arc<dyn CustomAuthenticator>>,
/// Extra headers to be added to each request.
extra_headers: HeaderMap,
/// Extra oauth parameters to be added to each authentication request.
Expand All @@ -63,6 +77,7 @@ impl HttpClient {
token: Mutex::new(cfg.token()),
token_endpoint: cfg.get_token_endpoint(),
credential: cfg.credential(),
authenticator: None,
extra_headers,
extra_oauth_params: cfg.extra_oauth_params(),
})
Expand All @@ -86,6 +101,7 @@ impl HttpClient {
self.token_endpoint
},
credential: cfg.credential().or(self.credential),
authenticator: self.authenticator,
extra_headers,
extra_oauth_params: if !cfg.extra_oauth_params().is_empty() {
cfg.extra_oauth_params()
Expand Down Expand Up @@ -174,6 +190,27 @@ impl HttpClient {
Ok(auth_res.access_token)
}

/// Set a custom token authenticator.
///
/// When set, the authenticator will be called to get tokens instead of using
/// static tokens or OAuth credentials. This allows for custom token management
/// such as reading from files, APIs, or other custom sources.
pub fn with_authenticator(mut self, authenticator: Arc<dyn CustomAuthenticator>) -> Self {
self.authenticator = Some(authenticator);
self
}

/// Add bearer token to request authorization header.
fn set_bearer_token(req: &mut Request, token: &str, error_msg: &str) -> Result<()> {
req.headers_mut().insert(
http::header::AUTHORIZATION,
format!("Bearer {token}")
.parse()
.map_err(|e| Error::new(ErrorKind::DataInvalid, error_msg).with_source(e))?,
);
Ok(())
}

/// Invalidate the current token without generating a new one. On the next request, the client
/// will attempt to generate a new token.
pub(crate) async fn invalidate_token(&self) -> Result<()> {
Expand All @@ -195,18 +232,24 @@ impl HttpClient {

/// Authenticates the request by adding a bearer token to the authorization header.
///
/// This method supports three authentication modes:
/// This method supports four authentication modes (in order of precedence):
///
/// 1. **No authentication** - Skip authentication when both `credential` and `token` are missing.
/// 2. **Token authentication** - Use the provided `token` directly for authentication.
/// 3. **OAuth authentication** - Exchange `credential` for a token, cache it, then use it for authentication.
/// 1. **Custom authenticator** - If set, use the custom CustomAuthenticator to get tokens.
/// 2. **Token authentication** - Use the provided static `token` directly.
/// 3. **OAuth authentication** - Exchange `credential` for a token, cache it, then use it.
/// 4. **No authentication** - Skip authentication when none of the above are available.
///
/// When both `credential` and `token` are present, `token` takes precedence.
///
/// # TODO: Support automatic token refreshing.
/// When an authenticator is provided, it takes precedence over static tokens and credentials.
async fn authenticate(&self, req: &mut Request) -> Result<()> {
// Try authenticator first (highest priority)
if let Some(authenticator) = &self.authenticator {
let token = authenticator.get_token().await?;
Self::set_bearer_token(req, &token, "Invalid custom token")?;
return Ok(());
}

// Clone the token from lock without holding the lock for entire function.
let token = self.token.lock().await.clone();
let token: Option<String> = self.token.lock().await.clone();

if self.credential.is_none() && token.is_none() {
return Ok(());
Expand All @@ -224,18 +267,7 @@ impl HttpClient {
}
};

// Insert token in request.
req.headers_mut().insert(
http::header::AUTHORIZATION,
format!("Bearer {token}").parse().map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
"Invalid token received from catalog server!",
)
.with_source(e)
})?,
);

Self::set_bearer_token(req, &token, "Invalid token received from catalog server!")?;
Ok(())
}

Expand Down
1 change: 1 addition & 0 deletions crates/catalog/rest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ mod client;
mod types;

pub use catalog::*;
pub use client::CustomAuthenticator;
146 changes: 143 additions & 3 deletions crates/catalog/rest/tests/rest_catalog_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@

use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::RwLock;
use std::sync::{Arc, Mutex, RwLock};

use async_trait::async_trait;
use ctor::{ctor, dtor};
use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
use iceberg::transaction::{ApplyTransactionAction, Transaction};
use iceberg::{Catalog, CatalogBuilder, Namespace, NamespaceIdent, TableCreation, TableIdent};
use iceberg_catalog_rest::{REST_CATALOG_PROP_URI, RestCatalog, RestCatalogBuilder};
use iceberg::{
Catalog, CatalogBuilder, Namespace, NamespaceIdent, Result as IcebergResult, TableCreation,
TableIdent,
};
use iceberg_catalog_rest::{
CustomAuthenticator, REST_CATALOG_PROP_URI, RestCatalog, RestCatalogBuilder,
};
use iceberg_test_utils::docker::DockerCompose;
use iceberg_test_utils::{normalize_test_name, set_up};
use port_scanner::scan_port_addr;
Expand Down Expand Up @@ -449,3 +455,137 @@ async fn test_register_table() {
table_registered.identifier().to_string()
);
}

#[derive(Debug)]
struct CountingAuthenticator {
count: Arc<Mutex<usize>>,
}

#[async_trait]
impl CustomAuthenticator for CountingAuthenticator {
async fn get_token(&self) -> IcebergResult<String> {
let mut c = self.count.lock().unwrap();
*c += 1;
// Return a unique token each time to ensure dynamic generation
Ok(format!("token_{}", *c))
}
}

async fn get_catalog_with_authenticator(
authenticator: Arc<dyn CustomAuthenticator>,
) -> RestCatalog {
set_up();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of duplicating this code, let's reuse get_catalog. Maybe create get_catalog_with_authenticator, that wraps get_catalog. Or have private method that both call.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or some other solution, but we should hide all these details like get_catalog does.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a custom get_catalog method

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to convey that it should reuse code with fn get_catalog though, would that be possible?


let rest_catalog_ip = {
let guard = DOCKER_COMPOSE_ENV.read().unwrap();
let docker_compose = guard.as_ref().unwrap();
docker_compose.get_container_ip("rest")
};

let rest_socket_addr = SocketAddr::new(rest_catalog_ip, REST_CATALOG_PORT);
while !scan_port_addr(rest_socket_addr) {
info!("Waiting for 1s rest catalog to ready...");
sleep(std::time::Duration::from_millis(1000)).await;
}

RestCatalogBuilder::default()
.with_token_authenticator(authenticator)
.load(
"rest",
HashMap::from([(
REST_CATALOG_PROP_URI.to_string(),
format!("http://{rest_socket_addr}"),
)]),
)
.await
.unwrap()
}

#[tokio::test]
async fn test_authenticator_token_refresh() {
// Track how many times tokens were requested
let token_request_count = Arc::new(Mutex::new(0));
let token_request_count_clone = token_request_count.clone();

let authenticator = Arc::new(CountingAuthenticator {
count: token_request_count_clone,
});

let catalog_with_auth = get_catalog_with_authenticator(authenticator).await;

// Perform multiple operations that should trigger token requests
let ns1 = Namespace::with_properties(
NamespaceIdent::from_strs(["test_refresh_1"]).unwrap(),
HashMap::new(),
);
catalog_with_auth
.create_namespace(ns1.name(), HashMap::new())
.await
.unwrap();

let ns2 = Namespace::with_properties(
NamespaceIdent::from_strs(["test_refresh_2"]).unwrap(),
HashMap::new(),
);
catalog_with_auth
.create_namespace(ns2.name(), HashMap::new())
.await
.unwrap();

// Verify authenticator was called multiple times
let count = *token_request_count.lock().unwrap();
assert!(
count >= 2,
"Authenticator should have been called at least twice, but was called {} times",
count
);
}

#[tokio::test]
async fn test_authenticator_persists_across_operations() {
let operation_count = Arc::new(Mutex::new(0));
let operation_count_clone = operation_count.clone();

let authenticator = Arc::new(CountingAuthenticator {
count: operation_count_clone,
});

let catalog_with_auth = get_catalog_with_authenticator(authenticator).await;

// Create a namespace
let ns = Namespace::with_properties(
NamespaceIdent::from_strs(["test_persist", "auth"]).unwrap(),
HashMap::new(),
);
catalog_with_auth
.create_namespace(ns.name(), HashMap::new())
.await
.unwrap();

let count_after_create = *operation_count.lock().unwrap();

// List the namespace children (should use the same authenticator)
// We need to list children of "test_persist" to find "auth"
let list_result = catalog_with_auth
.list_namespaces(Some(&NamespaceIdent::from_strs(["test_persist"]).unwrap()))
.await
.unwrap();
assert!(
list_result.contains(&NamespaceIdent::from_strs(["test_persist", "auth"]).unwrap()),
"Namespace {:?} not found in list {:?}",
ns.name(),
list_result
);

let count_after_list = *operation_count.lock().unwrap();

// Verify authenticator was used for both operations
assert!(
count_after_create > 0,
"Authenticator should be used for create"
);
assert!(
count_after_list > count_after_create,
"Authenticator should be used for list operation too"
);
}
Loading