diff --git a/Cargo.lock b/Cargo.lock index 1ca62b090..d8312dca8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4433,6 +4433,7 @@ name = "key-server" version = "0.5.11" dependencies = [ "anyhow", + "async-trait", "axum 0.8.4", "bcs", "chrono", diff --git a/crates/key-server/Cargo.toml b/crates/key-server/Cargo.toml index 6b201c78e..f5734875a 100644 --- a/crates/key-server/Cargo.toml +++ b/crates/key-server/Cargo.toml @@ -24,6 +24,7 @@ shared_crypto.workspace = true move-core-types.workspace = true mvr_types = { git = "https://github.com/MystenLabs/mvr", rev = "1993d7188f62564b05f0ccab46bbfb24b0eea326", package = "mvr-types" } +async-trait = "0.1.89" tokio = { version = "1.46.1", features = ["full"] } axum = { version = "0.8", features = ["macros"] } tower-http = { version = "0.6.6", features = ["cors", "limit"] } diff --git a/crates/key-server/src/externals.rs b/crates/key-server/src/externals.rs index 05bd792bc..905cdb6cf 100644 --- a/crates/key-server/src/externals.rs +++ b/crates/key-server/src/externals.rs @@ -4,7 +4,7 @@ use crate::cache::default_lru_cache; use crate::errors::InternalError; use crate::key_server_options::KeyServerOptions; -use crate::sui_rpc_client::SuiRpcClient; +use crate::sui_rpc_client::{RpcClient, SuiRpcClient}; use crate::{mvr_forward_resolution, Timestamp}; use moka::sync::Cache; use once_cell::sync::Lazy; @@ -27,9 +27,9 @@ pub(crate) fn add_upgraded_package(pkg_id: ObjectID, new_pkg_id: ObjectID) { CACHE.insert(new_pkg_id, pkg_id); } -pub(crate) async fn check_mvr_package_id( +pub(crate) async fn check_mvr_package_id( mvr_name: &Option, - sui_rpc_client: &SuiRpcClient, + sui_rpc_client: &SuiRpcClient, key_server_options: &KeyServerOptions, first_pkg_id: ObjectID, req_id: Option<&str>, @@ -63,9 +63,9 @@ pub(crate) async fn check_mvr_package_id( Ok(()) } -pub(crate) async fn fetch_first_pkg_id( +pub(crate) async fn fetch_first_pkg_id( pkg_id: &ObjectID, - sui_rpc_client: &SuiRpcClient, + sui_rpc_client: &SuiRpcClient, ) -> Result { match CACHE.get(pkg_id) { Some(first) => Ok(first), @@ -103,8 +103,8 @@ pub(crate) fn get_mvr_cache(mvr_name: &str) -> Option { } /// Returns the timestamp for the latest checkpoint. -pub(crate) async fn get_latest_checkpoint_timestamp( - sui_rpc_client: SuiRpcClient, +pub(crate) async fn get_latest_checkpoint_timestamp( + sui_rpc_client: SuiRpcClient, ) -> SuiRpcResult { let latest_checkpoint_sequence_number = sui_rpc_client .get_latest_checkpoint_sequence_number() @@ -117,7 +117,9 @@ pub(crate) async fn get_latest_checkpoint_timestamp( Ok(checkpoint.timestamp_ms) } -pub(crate) async fn get_reference_gas_price(sui_rpc_client: SuiRpcClient) -> SuiRpcResult { +pub(crate) async fn get_reference_gas_price( + sui_rpc_client: SuiRpcClient, +) -> SuiRpcResult { let rgp = sui_rpc_client .get_reference_gas_price() .await diff --git a/crates/key-server/src/lib.rs b/crates/key-server/src/lib.rs new file mode 100644 index 000000000..cc9a2914a --- /dev/null +++ b/crates/key-server/src/lib.rs @@ -0,0 +1,926 @@ +// Copyright (c), Mysten Labs, Inc. +// SPDX-License-Identifier: Apache-2.0 +use crate::errors::InternalError::{ + DeprecatedSDKVersion, InvalidSDKVersion, MissingRequiredHeader, +}; +use crate::externals::get_reference_gas_price; +use crate::key_server_options::ServerMode; +use crate::metrics::{call_with_duration, observation_callback, status_callback, Metrics}; +use crate::metrics_push::create_push_client; +use crate::mvr::mvr_forward_resolution; +use crate::periodic_updater::spawn_periodic_updater; +use crate::signed_message::signed_request; +use crate::sui_rpc_client::{verify_personal_message_signature, RpcClient}; +use crate::time::checked_duration_since; +use crate::time::from_mins; +use crate::time::{duration_since_as_f64, saturating_duration_since}; +use crate::types::{MasterKeyPOP, Network}; +use anyhow::Context; +use axum::extract::{Query, Request}; +use axum::http::{HeaderMap, HeaderValue}; +use axum::middleware::{from_fn_with_state, map_response, Next}; +use axum::response::Response; +use axum::routing::{get, post}; +use axum::{extract::State, Json, Router}; +use core::time::Duration; +use crypto::elgamal::encrypt; +use crypto::ibe; +use crypto::ibe::create_proof_of_possession; +use crypto::prefixed_hex::PrefixedHex; +use errors::InternalError; +use externals::get_latest_checkpoint_timestamp; +use fastcrypto::ed25519::{Ed25519PublicKey, Ed25519Signature}; +use fastcrypto::encoding::{Encoding, Hex}; +use fastcrypto::traits::VerifyingKey; +use futures::future::pending; +use jsonrpsee::core::ClientError; +use jsonrpsee::types::error::{INVALID_PARAMS_CODE, METHOD_NOT_FOUND_CODE}; +use key_server_options::KeyServerOptions; +use master_keys::MasterKeys; +use metrics::metrics_middleware; +use mysten_service::get_mysten_service; +use mysten_service::metrics::start_prometheus_server; +use mysten_service::package_name; +use mysten_service::package_version; +use prometheus::Registry; +use rand::thread_rng; +use seal_sdk::types::{DecryptionKey, ElGamalPublicKey, ElgamalVerificationKey, KeyId}; +use seal_sdk::{signed_message, FetchKeyResponse}; +use semver::Version; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::collections::HashMap; +use std::env; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use sui_rpc_client::SuiRpcClient; +use sui_sdk::error::Error; +use sui_sdk::rpc_types::{SuiExecutionStatus, SuiTransactionBlockEffectsAPI}; +use sui_sdk::types::base_types::{ObjectID, SuiAddress}; +use sui_sdk::types::signature::GenericSignature; +use sui_sdk::types::transaction::{ProgrammableTransaction, TransactionData, TransactionKind}; +use sui_sdk::SuiClientBuilder; +use tap::tap::TapFallible; +use tokio::sync::watch::Receiver; +use tokio::task::JoinHandle; +use tower_http::cors::{Any, CorsLayer}; +use tower_http::limit::RequestBodyLimitLayer; +use tracing::{debug, error, info, warn}; +use valid_ptb::ValidPtb; + +pub mod cache; +pub mod errors; +pub mod externals; +pub mod signed_message; +pub mod sui_rpc_client; +pub mod types; +pub mod utils; +pub mod valid_ptb; + +pub mod key_server_options; +pub mod master_keys; +pub mod metrics; +pub mod metrics_push; +pub mod mvr; +pub mod periodic_updater; +#[cfg(test)] +pub mod tests; +pub mod time; + +const GAS_BUDGET: u64 = 500_000_000; +const GIT_VERSION: &str = utils::git_version!(); + +// Transaction size limit: 128KB + 33% for base64 + some extra room for other parameters +const MAX_REQUEST_SIZE: usize = 180 * 1024; + +/// Default encoding used for master and public keys for the key server. +type DefaultEncoding = PrefixedHex; + +// TODO: Remove legacy once key-server crate uses sui-sdk-types. +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Certificate { + pub user: SuiAddress, + pub session_vk: Ed25519PublicKey, + pub creation_time: u64, + pub ttl_min: u16, + pub signature: GenericSignature, + pub mvr_name: Option, +} + +// TODO: Remove legacy once key-server crate uses sui-sdk-types. +#[derive(Serialize, Deserialize)] +pub struct FetchKeyRequest { + // Next fields must be signed to prevent others from sending requests on behalf of the user and + // being able to fetch the key + pub ptb: String, // must adhere specific structure, see ValidPtb + // We don't want to rely on https only for restricting the response to this user, since in the + // case of multiple services, one service can do a replay attack to get the key from other + // services. + pub enc_key: ElGamalPublicKey, + pub enc_verification_key: ElgamalVerificationKey, + pub request_signature: Ed25519Signature, + + pub certificate: Certificate, +} + +/// UNIX timestamp in milliseconds. +type Timestamp = u64; + +#[derive(Clone)] +pub struct Server { + sui_rpc_client: SuiRpcClient, + master_keys: MasterKeys, + key_server_oid_to_pop: HashMap, + options: KeyServerOptions, +} + +impl Server { + async fn new( + sui_rpc_client: SuiRpcClient, + options: KeyServerOptions, + master_keys: MasterKeys, + ) -> Self { + info!("Server started with network: {:?}", options.network); + + let key_server_oid_to_pop = options + .get_supported_key_server_object_ids() + .into_iter() + .map(|ks_oid| { + let key = master_keys + .get_key_for_key_server(&ks_oid) + .expect("checked already"); + let pop = create_proof_of_possession(key, &ks_oid.into_bytes()); + (ks_oid, pop) + }) + .collect(); + + Server { + sui_rpc_client, + master_keys, + key_server_oid_to_pop, + options, + } + } + + pub fn get_pop(&self, service_id: ObjectID) -> Result { + self.key_server_oid_to_pop + .get(&service_id) + .copied() + .ok_or(InternalError::InvalidServiceId) + } + + #[allow(clippy::too_many_arguments)] + async fn check_signature( + &self, + ptb: &ProgrammableTransaction, + enc_key: &ElGamalPublicKey, + enc_verification_key: &ElgamalVerificationKey, + session_sig: &Ed25519Signature, + cert: &Certificate, + package_name: String, + req_id: Option<&str>, + ) -> Result<(), InternalError> { + // Check certificate + + // TTL of the session key must be smaller than the allowed max + let ttl = from_mins(cert.ttl_min); + if ttl > self.options.session_key_ttl_max { + debug!( + "Certificate has invalid time-to-live (req_id: {:?})", + req_id + ); + return Err(InternalError::InvalidCertificate); + } + + // Check that the creation time is not in the future and that the certificate has not expired + match checked_duration_since(cert.creation_time) { + None => { + debug!( + "Certificate has invalid creation time (req_id: {:?})", + req_id + ); + return Err(InternalError::InvalidCertificate); + } + Some(duration) => { + if duration > ttl { + debug!("Certificate has expired (req_id: {:?})", req_id); + return Err(InternalError::InvalidCertificate); + } + } + } + + let msg = signed_message( + package_name, + &cert.session_vk, + cert.creation_time, + cert.ttl_min, + ); + debug!( + "Checking signature on message: {:?} (req_id: {:?})", + msg, req_id + ); + verify_personal_message_signature( + cert.signature.clone(), + msg.as_bytes(), + cert.user, + Some(self.sui_rpc_client.sui_client().clone()), + ) + .await + .tap_err(|e| { + debug!( + "Signature verification failed: {:?} (req_id: {:?})", + e, req_id + ); + }) + .map_err(|_| InternalError::InvalidSignature)?; + + // Check session signature + let signed_msg = signed_request(ptb, enc_key, enc_verification_key); + cert.session_vk + .verify(&signed_msg, session_sig) + .map_err(|_| { + debug!( + "Session signature verification failed (req_id: {:?})", + req_id + ); + InternalError::InvalidSessionSignature + }) + } + + async fn check_policy( + &self, + sender: SuiAddress, + vptb: &ValidPtb, + gas_price: u64, + req_id: Option<&str>, + metrics: Option<&Metrics>, + ) -> Result<(), InternalError> { + debug!( + "Checking policy for ptb: {:?} (req_id: {:?})", + vptb.ptb(), + req_id + ); + // Evaluate the `seal_approve*` function + let tx_data = TransactionData::new_with_gas_coins( + TransactionKind::ProgrammableTransaction(vptb.ptb().clone()), + sender, + vec![], // Empty gas payment for dry run + GAS_BUDGET, + gas_price, + ); + let dry_run_res = self + .sui_rpc_client + .dry_run_transaction_block(tx_data.clone()) + .await + .map_err(|e| { + if let Error::RpcError(ClientError::Call(ref e)) = e { + match e.code() { + INVALID_PARAMS_CODE => { + // This error is generic and happens when one of the parameters of the Move call in the PTB is invalid. + // One reason is that one of the parameters does not exist, in which case it could be a newly created object that the FN has not yet seen. + // There are other possible reasons, so we return the entire message to the user to allow debugging. + // Note that the message is a message from the JSON RPC API, so it is already formatted and does not contain any sensitive information. + debug!("Invalid parameter: {}", e.message()); + return InternalError::InvalidParameter(e.message().to_string()); + } + METHOD_NOT_FOUND_CODE => { + // This means that the seal_approve function is not found on the given module. + debug!("Function not found: {:?}", e); + return InternalError::InvalidPTB( + "The seal_approve function was not found on the module".to_string(), + ); + } + _ => {} + } + } + InternalError::Failure(format!( + "Dry run execution failed ({:?}) (req_id: {:?})", + e, req_id + )) + })?; + + // Record the gas cost. Only do this in permissioned mode to avoid high cardinality metrics in public mode. + if let Some(m) = metrics { + if matches!( + self.options.server_mode, + ServerMode::Permissioned { client_configs: _ } + ) { + let package = vptb.pkg_id().to_hex_uncompressed(); + m.dry_run_gas_cost_per_package + .with_label_values(&[&package]) + .observe(dry_run_res.effects.gas_cost_summary().computation_cost as f64); + } + } + + debug!("Dry run response: {:?} (req_id: {:?})", dry_run_res, req_id); + if let SuiExecutionStatus::Failure { error } = dry_run_res.effects.status() { + debug!( + "Dry run execution asserted (req_id: {:?}) {:?}", + req_id, error + ); + return Err(InternalError::NoAccess(error.clone())); + } + + // all good! + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + async fn check_request( + &self, + valid_ptb: &ValidPtb, + enc_key: &ElGamalPublicKey, + enc_verification_key: &ElgamalVerificationKey, + request_signature: &Ed25519Signature, + certificate: &Certificate, + gas_price: u64, + metrics: Option<&Metrics>, + req_id: Option<&str>, + mvr_name: Option, + ) -> Result<(ObjectID, Vec), InternalError> { + // Handle package upgrades: Use the first as the namespace + let first_pkg_id = + call_with_duration(metrics.map(|m| &m.fetch_pkg_ids_duration), || async { + externals::fetch_first_pkg_id(&valid_ptb.pkg_id(), &self.sui_rpc_client).await + }) + .await?; + + // Make sure that the package is supported. + self.master_keys.has_key_for_package(&first_pkg_id)?; + + // Check if the package id that MVR name points matches the first package ID, if provided. + externals::check_mvr_package_id( + &mvr_name, + &self.sui_rpc_client, + &self.options, + first_pkg_id, + req_id, + ) + .await?; + + // Check all conditions + self.check_signature( + valid_ptb.ptb(), + enc_key, + enc_verification_key, + request_signature, + certificate, + mvr_name.unwrap_or(first_pkg_id.to_hex_uncompressed()), + req_id, + ) + .await?; + + call_with_duration(metrics.map(|m| &m.check_policy_duration), || async { + self.check_policy(certificate.user, valid_ptb, gas_price, req_id, metrics) + .await + }) + .await?; + + // return the full id with the first package id as prefix + Ok((first_pkg_id, valid_ptb.full_ids(&first_pkg_id))) + } + + pub(crate) fn create_response( + &self, + first_pkg_id: ObjectID, + ids: &[KeyId], + enc_key: &ElGamalPublicKey, + ) -> FetchKeyResponse { + debug!( + "Creating response for ids: {:?}", + ids.iter().map(Hex::encode).collect::>() + ); + let master_key = self + .master_keys + .get_key_for_package(&first_pkg_id) + .expect("checked already"); + let decryption_keys = ids + .iter() + .map(|id| { + // Requested key + let key = ibe::extract(master_key, id); + // ElGamal encryption of key under the user's public key + let encrypted_key = encrypt(&mut thread_rng(), &key, enc_key); + DecryptionKey { + id: id.to_owned(), + encrypted_key, + } + }) + .collect(); + FetchKeyResponse { decryption_keys } + } + + /// Spawns a thread that fetches the latest checkpoint timestamp and sends it to a [Receiver] once per `update_interval`. + /// Returns the [Receiver]. + async fn spawn_latest_checkpoint_timestamp_updater( + &self, + metrics: Option<&Metrics>, + ) -> (Receiver, JoinHandle<()>) { + spawn_periodic_updater( + &self.sui_rpc_client, + self.options.checkpoint_update_interval, + get_latest_checkpoint_timestamp, + "latest checkpoint timestamp", + metrics.map(|m| { + observation_callback(&m.checkpoint_timestamp_delay, |ts| { + let duration = duration_since_as_f64(ts); + debug!("Latest checkpoint timestamp delay is {duration} ms"); + duration + }) + }), + metrics.map(|m| { + observation_callback(&m.get_checkpoint_timestamp_duration, |d: Duration| { + d.as_millis() as f64 + }) + }), + metrics.map(|m| status_callback(&m.get_checkpoint_timestamp_status)), + ) + .await + } + + /// Spawns a thread that fetches RGP and sends it to a [Receiver] once per `update_interval`. + /// Returns the [Receiver]. + async fn spawn_reference_gas_price_updater( + &self, + metrics: Option<&Metrics>, + ) -> (Receiver, JoinHandle<()>) { + spawn_periodic_updater( + &self.sui_rpc_client, + self.options.rgp_update_interval, + get_reference_gas_price, + "RGP", + None::, + None::, + metrics.map(|m| status_callback(&m.get_reference_gas_price_status)), + ) + .await + } + + /// Spawn a metrics push background jobs that push metrics to seal-proxy + fn spawn_metrics_push_job(&self, registry: prometheus::Registry) -> JoinHandle<()> { + let push_config = self.options.metrics_push_config.clone(); + if let Some(push_config) = push_config { + tokio::spawn(async move { + let mut interval = tokio::time::interval(push_config.push_interval); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + let mut client = create_push_client(); + tracing::info!("starting metrics push to '{}'", &push_config.push_url); + loop { + tokio::select! { + _ = interval.tick() => { + if let Err(error) = metrics_push::push_metrics( + push_config.clone(), + &client, + ®istry, + ).await { + tracing::warn!(?error, "unable to push metrics"); + client = create_push_client(); + } + } + } + } + }) + } else { + tokio::spawn(async move { + warn!("No metrics push config is found"); + pending().await + }) + } + } +} + +#[allow(clippy::single_match)] +pub async fn fetch_key( + server: Arc>, + payload: &FetchKeyRequest, + valid_ptb: ValidPtb, + req_id: Option<&str>, + sdk_version: &str, + gas_price: u64, + metrics: Option<&Metrics>, +) -> Result { + // Report the number of id's in the request to the metrics. + let check_request_result = server + .check_request( + &valid_ptb, + &payload.enc_key, + &payload.enc_verification_key, + &payload.request_signature, + &payload.certificate, + gas_price, + metrics, + req_id, + payload.certificate.mvr_name.clone(), + ) + .await; + + let request_info = json!({ "user": payload.certificate.user, "package_id": valid_ptb.pkg_id(), "req_id": req_id, "sdk_version": sdk_version }); + match check_request_result { + Ok((first_pkg_id, full_ids)) => { + info!("Valid request: {request_info}"); + + let response = server.create_response(first_pkg_id, &full_ids, &payload.enc_key); + + Ok(response) + } + Err(InternalError::Failure(s)) => { + warn!("Check request failed with debug message '{s}': {request_info}"); + + Err(InternalError::Failure(s)) + } + Err(error) => Err(error), + } +} + +#[allow(clippy::single_match)] +async fn handle_fetch_key_internal( + app_state: &MyState, + payload: &FetchKeyRequest, + req_id: Option<&str>, + sdk_version: &str, +) -> Result { + app_state.check_full_node_is_fresh()?; + + let valid_ptb = ValidPtb::try_from_base64(&payload.ptb)?; + + // Report the number of id's in the request to the metrics. + app_state + .metrics + .requests_per_number_of_ids + .observe(valid_ptb.inner_ids().len() as f64); + + fetch_key( + app_state.server.clone(), + payload, + valid_ptb, + req_id, + sdk_version, + app_state.reference_gas_price(), + Some(&app_state.metrics), + ) + .await +} + +async fn handle_fetch_key( + State(app_state): State>, + headers: HeaderMap, + Json(payload): Json, +) -> Result, InternalError> { + let req_id = headers + .get("Request-Id") + .map(|v| v.to_str().unwrap_or_default()); + let sdk_version = headers + .get("Client-Sdk-Version") + .and_then(|v| v.to_str().ok()) + .unwrap_or_default(); + + app_state.metrics.requests.inc(); + + debug!( + "Checking request for ptb: {:?}, cert {:?} (req_id: {:?})", + payload.ptb, payload.certificate, req_id + ); + + handle_fetch_key_internal(&app_state, &payload, req_id, sdk_version) + .await + .tap_err(|e| app_state.metrics.observe_error(e.as_str())) + .map(Json) +} + +#[derive(Serialize, Deserialize)] +struct GetServiceResponse { + service_id: ObjectID, + pop: MasterKeyPOP, +} + +async fn handle_get_service( + State(app_state): State>, + Query(params): Query>, +) -> Result, InternalError> { + app_state.metrics.service_requests.inc(); + + let service_id = params + .get("service_id") + .ok_or(InternalError::InvalidServiceId) + .and_then(|id| { + ObjectID::from_hex_literal(id).map_err(|_| InternalError::InvalidServiceId) + })?; + + let pop = app_state.server.get_pop(service_id)?; + + Ok(Json(GetServiceResponse { service_id, pop })) +} + +#[derive(Clone)] +struct MyState { + metrics: Arc, + server: Arc>, + latest_checkpoint_timestamp_receiver: Receiver, + reference_gas_price_receiver: Receiver, +} + +impl MyState { + fn check_full_node_is_fresh(&self) -> Result<(), InternalError> { + // Compute the staleness of the latest checkpoint timestamp. + let staleness = + saturating_duration_since(*self.latest_checkpoint_timestamp_receiver.borrow()); + if staleness > self.server.options.allowed_staleness { + return Err(InternalError::Failure(format!( + "Full node is stale. Latest checkpoint is {} ms old.", + staleness.as_millis() + ))); + } + Ok(()) + } + + fn reference_gas_price(&self) -> u64 { + *self.reference_gas_price_receiver.borrow() + } + + fn validate_sdk_version(&self, version_string: &str) -> Result<(), InternalError> { + let version = Version::parse(version_string).map_err(|_| InvalidSDKVersion)?; + if !self + .server + .options + .sdk_version_requirement + .matches(&version) + { + return Err(DeprecatedSDKVersion); + } + Ok(()) + } +} + +/// Middleware to validate the SDK version. +async fn handle_request_headers( + state: State>, + request: Request, + next: Next, +) -> Result { + // Log the request id and SDK version + let version = request.headers().get("Client-Sdk-Version"); + + info!( + "Request id: {:?}, SDK version: {:?}, SDK type: {:?}, Target API version: {:?}", + request + .headers() + .get("Request-Id") + .map(|v| v.to_str().unwrap_or_default()), + version, + request.headers().get("Client-Sdk-Type"), + request.headers().get("Client-Target-Api-Version") + ); + + version + .ok_or(MissingRequiredHeader("Client-Sdk-Version".to_string())) + .and_then(|v| v.to_str().map_err(|_| InvalidSDKVersion)) + .and_then(|v| state.validate_sdk_version(v)) + .tap_err(|e| { + debug!("Invalid SDK version: {:?}", e); + state.metrics.observe_error(e.as_str()); + })?; + Ok(next.run(request).await) +} + +/// Middleware to add headers to all responses. +async fn add_response_headers(mut response: Response) -> Response { + let headers = response.headers_mut(); + headers.insert( + "X-KeyServer-Version", + HeaderValue::from_static(package_version!()), + ); + headers.insert( + "X-KeyServer-GitVersion", + HeaderValue::from_static(GIT_VERSION), + ); + response +} + +/// Creates a [prometheus::core::Collector] that tracks the uptime of the server. +fn uptime_metric(version: &str) -> Box { + let opts = prometheus::opts!("uptime", "uptime of the key server in seconds") + .variable_label("version"); + + let start_time = std::time::Instant::now(); + let uptime = move || start_time.elapsed().as_secs(); + let metric = prometheus_closure_metric::ClosureMetric::new( + opts, + prometheus_closure_metric::ValueType::Counter, + uptime, + &[version], + ) + .unwrap(); + + Box::new(metric) +} + +/// Spawn server's background tasks: +/// - background checkpoint downloader +/// - reference gas price updater. +/// - optional metrics pusher (if configured). +/// +/// The returned JoinHandle can be used to catch any tasks error or panic. +async fn start_server_background_tasks( + server: Arc>, + metrics: Arc, + registry: prometheus::Registry, +) -> ( + Receiver, + Receiver, + JoinHandle>, +) { + // Spawn background checkpoint timestamp updater. + let (latest_checkpoint_timestamp_receiver, latest_checkpoint_timestamp_handle) = server + .spawn_latest_checkpoint_timestamp_updater(Some(&metrics)) + .await; + + // Spawn background reference gas price updater. + let (reference_gas_price_receiver, reference_gas_price_handle) = server + .spawn_reference_gas_price_updater(Some(&metrics)) + .await; + + // Spawn metrics push task + let metrics_push_handle = server.spawn_metrics_push_job(registry); + + // Spawn a monitor task that will exit the program if any updater task panics + let handle: JoinHandle> = tokio::spawn(async move { + tokio::select! { + result = latest_checkpoint_timestamp_handle => { + if let Err(e) = result { + error!("Latest checkpoint timestamp updater panicked: {:?}", e); + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } + return Err(e.into()); + } + } + result = reference_gas_price_handle => { + if let Err(e) = result { + error!("Reference gas price updater panicked: {:?}", e); + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } + return Err(e.into()); + } + } + result = metrics_push_handle => { + if let Err(e) = result { + error!("Metrics push task panicked: {:?}", e); + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } + return Err(e.into()); + } + } + } + + unreachable!("One of the background tasks should have returned an error"); + }); + + ( + latest_checkpoint_timestamp_receiver, + reference_gas_price_receiver, + handle, + ) +} + +pub async fn app( + options: KeyServerOptions, + master_keys: MasterKeys, +) -> anyhow::Result<(JoinHandle>, Router)> { + let (server, metrics, registry) = get_server(options, master_keys).await?; + + let (latest_checkpoint_timestamp_receiver, reference_gas_price_receiver, monitor_handle) = + start_server_background_tasks(server.clone(), metrics.clone(), registry.clone()).await; + + let state = MyState { + metrics, + server, + latest_checkpoint_timestamp_receiver, + reference_gas_price_receiver, + }; + + let cors = CorsLayer::new() + .allow_methods(Any) + .allow_origin(Any) + .allow_headers(Any) + .expose_headers(Any); + + let app = get_mysten_service::>(package_name!(), package_version!()) + .merge( + axum::Router::new() + .route("/v1/fetch_key", post(handle_fetch_key)) + .route("/v1/service", get(handle_get_service)) + .layer(from_fn_with_state(state.clone(), handle_request_headers)) + .layer(map_response(add_response_headers)) + // Outside most middlewares that tracks metrics for HTTP requests and response + // status. + .layer(from_fn_with_state( + state.metrics.clone(), + metrics_middleware, + )), + ) + .with_state(state) + // Global body size limit + .layer(RequestBodyLimitLayer::new(MAX_REQUEST_SIZE)) + .layer(cors); + Ok((monitor_handle, app)) +} + +pub async fn get_server( + options: KeyServerOptions, + master_keys: MasterKeys, +) -> anyhow::Result<(Arc>, Arc, Registry)> { + info!("Setting up metrics"); + let registry = start_prometheus_server(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), + options.metrics_host_port, + )) + .default_registry(); + + // Tracks the uptime of the server. + let registry_clone = registry.clone(); + tokio::task::spawn(async move { + registry_clone + .register(uptime_metric( + format!("{}-{}", package_version!(), GIT_VERSION).as_str(), + )) + .expect("metrics defined at compile time must be valid"); + }); + + // hook up custom application metrics + let metrics = Arc::new(Metrics::new(®istry)); + + let sui_rpc_client = SuiRpcClient::new( + Client::new_from_builder( + SuiClientBuilder::default() + .request_timeout(options.rpc_config.timeout) + .build(&options.network.node_url()), + ) + .await + .expect("SuiClientBuilder should not failed unless provided with invalid network url"), + options.rpc_config.retry_config.clone(), + Some(metrics.clone()), + ); + + info!( + "Starting server, version {}", + format!("{}-{}", package_version!(), GIT_VERSION).as_str() + ); + options.validate()?; + + let server = Arc::new(Server::new(sui_rpc_client, options, master_keys).await); + + Ok((server, metrics, registry)) +} + +pub fn get_server_options_from_env() -> anyhow::Result { + match env::var("CONFIG_PATH") { + Ok(config_path) => { + info!("Loading config file: {}", config_path); + let mut opts: KeyServerOptions = serde_yaml::from_reader( + std::fs::File::open(&config_path) + .context(format!("Cannot open configuration file {config_path}"))?, + ) + .expect("Failed to parse configuration file"); + + // Handle Custom network NODE_URL configuration + if let Network::Custom { + ref mut node_url, .. + } = opts.network + { + let env_node_url = env::var("NODE_URL").ok(); + + match (node_url.as_ref(), env_node_url.as_ref()) { + (Some(_), Some(_)) => { + panic!("NODE_URL cannot be provided in both config file and environment variable. Please use only one source."); + } + (None, Some(url)) => { + info!("Using NODE_URL from environment variable: {}", url); + *node_url = Some(url.clone()); + } + (Some(url), None) => { + info!("Using NODE_URL from config file: {}", url); + } + (None, None) => { + panic!("Custom network requires NODE_URL to be set either in config file or as environment variable"); + } + } + } + + Ok(opts) + } + Err(_) => { + info!("Using local environment variables for configuration, should only be used for testing"); + let network = env::var("NETWORK") + .map(|n| Network::from_str_unchecked(&n)) + .unwrap_or(Network::Testnet); + let options = KeyServerOptions::new_open_server_with_default_values( + network, + utils::decode_object_id("KEY_SERVER_OBJECT_ID")?, + ); + + Ok(options) + } + } +} diff --git a/crates/key-server/src/master_keys.rs b/crates/key-server/src/master_keys.rs index 72628d55e..105f562da 100644 --- a/crates/key-server/src/master_keys.rs +++ b/crates/key-server/src/master_keys.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use crate::errors::InternalError; -use crate::key_server_options::{ClientConfig, ClientKeyType, KeyServerOptions, ServerMode}; +use crate::key_server_options::{ClientConfig, ClientKeyType, ServerMode}; use crate::types::IbeMasterKey; use crate::utils::{decode_byte_array, decode_master_key}; use crate::DefaultEncoding; @@ -12,6 +12,7 @@ use crypto::ibe::SEED_LENGTH; use fastcrypto::encoding::{Base64, Encoding}; use fastcrypto::serde_helpers::ToFromByteArray; use std::collections::HashMap; +use std::env; use sui_types::base_types::ObjectID; use tracing::info; @@ -30,29 +31,40 @@ pub enum MasterKeys { } impl MasterKeys { - pub(crate) fn load(options: &KeyServerOptions) -> anyhow::Result { + pub fn load_from_env(server_mode: &ServerMode) -> anyhow::Result { + let master_keys_hex_string = env::var(MASTER_KEY_ENV_VAR) + .map_err(|_| anyhow!("Environment variable {} must be set", MASTER_KEY_ENV_VAR))?; + + Self::load(server_mode, &master_keys_hex_string) + } + pub fn load(server_mode: &ServerMode, master_key_hex_string: &str) -> anyhow::Result { info!("Loading keys from env variables"); - match &options.server_mode { + match &server_mode { ServerMode::Open { .. } => { - let master_key = match decode_master_key::(MASTER_KEY_ENV_VAR) { + let master_key = match decode_master_key::(master_key_hex_string) { Ok(master_key) => master_key, // TODO: Fallback to Base64 encoding for backward compatibility. - Err(_) => crate::utils::decode_master_key::(MASTER_KEY_ENV_VAR)?, + Err(_) => crate::utils::decode_master_key::(master_key_hex_string)?, }; Ok(MasterKeys::Open { master_key }) } ServerMode::Permissioned { client_configs } => { let mut pkg_id_to_key = HashMap::new(); let mut key_server_oid_to_key = HashMap::new(); - let seed = decode_byte_array::(MASTER_KEY_ENV_VAR)?; + let seed = + decode_byte_array::(master_key_hex_string)?; for config in client_configs { let master_key = match &config.client_master_key { ClientKeyType::Derived { derivation_index } => { ibe::derive_master_key(&seed, *derivation_index) } ClientKeyType::Imported { env_var } => { - decode_master_key::(env_var)? + let env = env::var(env_var).map_err(|_| { + anyhow!("Environment variable {} must be set", env_var) + })?; + + decode_master_key::(&env)? } ClientKeyType::Exported { .. } => continue, }; @@ -161,13 +173,13 @@ fn test_master_keys_open_mode() { ); with_vars([("MASTER_KEY", None::<&str>)], || { - assert!(MasterKeys::load(&options).is_err()); + assert!(MasterKeys::load_from_env(&options.server_mode).is_err()); }); let sk = IbeMasterKey::generator(); let sk_as_bytes = DefaultEncoding::encode(bcs::to_bytes(&sk).unwrap()); with_vars([("MASTER_KEY", Some(sk_as_bytes))], || { - let mk = MasterKeys::load(&options); + let mk = MasterKeys::load_from_env(&options.server_mode); assert_eq!( mk.unwrap() .get_key_for_package(&ObjectID::from_hex_literal("0x1").unwrap()) @@ -179,7 +191,7 @@ fn test_master_keys_open_mode() { #[test] fn test_master_keys_permissioned_mode() { - use crate::key_server_options::ClientConfig; + use crate::key_server_options::{ClientConfig, KeyServerOptions}; use crate::types::Network; use fastcrypto::encoding::Encoding; use fastcrypto::groups::GroupElement; @@ -226,7 +238,7 @@ fn test_master_keys_permissioned_mode() { ("ALICE_KEY", Some(DefaultEncoding::encode(seed))), ], || { - let mk = MasterKeys::load(&options).unwrap(); + let mk = MasterKeys::load_from_env(&options.server_mode).unwrap(); let k1 = mk.get_key_for_key_server(&ObjectID::from_hex_literal("0x4").unwrap()); let k2 = mk.get_key_for_key_server(&ObjectID::from_hex_literal("0x6").unwrap()); assert!(k1.is_ok()); @@ -239,7 +251,7 @@ fn test_master_keys_permissioned_mode() { ("ALICE_KEY", Some(&DefaultEncoding::encode(seed))), ], || { - assert!(MasterKeys::load(&options).is_err()); + assert!(MasterKeys::load_from_env(&options.server_mode).is_err()); }, ); with_vars( @@ -248,7 +260,7 @@ fn test_master_keys_permissioned_mode() { ("ALICE_KEY", None::<&String>), ], || { - assert!(MasterKeys::load(&options).is_err()); + assert!(MasterKeys::load_from_env(&options.server_mode).is_err()); }, ); } diff --git a/crates/key-server/src/metrics.rs b/crates/key-server/src/metrics.rs index acc076069..722ba304f 100644 --- a/crates/key-server/src/metrics.rs +++ b/crates/key-server/src/metrics.rs @@ -12,7 +12,7 @@ use std::sync::Arc; use std::time::Instant; #[derive(Debug)] -pub(crate) struct Metrics { +pub struct Metrics { /// Total number of requests received pub requests: IntCounter, diff --git a/crates/key-server/src/mvr.rs b/crates/key-server/src/mvr.rs index 1c1615451..44838683e 100644 --- a/crates/key-server/src/mvr.rs +++ b/crates/key-server/src/mvr.rs @@ -16,7 +16,7 @@ use crate::errors::InternalError; use crate::errors::InternalError::{Failure, InvalidMVRName, InvalidPackage}; use crate::key_server_options::KeyServerOptions; -use crate::sui_rpc_client::SuiRpcClient; +use crate::sui_rpc_client::{RpcClient, SuiRpcClient}; use crate::types::Network; use move_core_types::account_address::AccountAddress; use move_core_types::identifier::Identifier; @@ -91,8 +91,8 @@ impl From> for HashMap { } /// Given an MVR name, look up the package it points to. -pub(crate) async fn mvr_forward_resolution( - sui_rpc_client: &SuiRpcClient, +pub(crate) async fn mvr_forward_resolution( + sui_rpc_client: &SuiRpcClient, mvr_name: &str, key_server_options: &KeyServerOptions, ) -> Result { @@ -171,9 +171,9 @@ pub(crate) fn resolve_network(network: &Network) -> Result( mvr_name: &str, - mainnet_sui_rpc_client: &SuiRpcClient, + mainnet_sui_rpc_client: &SuiRpcClient, ) -> Result, InternalError> { let dynamic_field_name = dynamic_field_name(mvr_name)?; let record_id = mainnet_sui_rpc_client @@ -213,10 +213,14 @@ fn dynamic_field_name(mvr_name: &str) -> Result }) } -async fn get_object Deserialize<'a>>( +async fn get_object( object_id: ObjectID, - sui_rpc_client: &SuiRpcClient, -) -> Result { + sui_rpc_client: &SuiRpcClient, +) -> Result +where + T: for<'a> Deserialize<'a>, + Client: RpcClient, +{ bcs::from_bytes( sui_rpc_client .get_object_with_options(object_id, SuiObjectDataOptions::new().with_bcs()) diff --git a/crates/key-server/src/periodic_updater.rs b/crates/key-server/src/periodic_updater.rs index 26f42302b..c0fe297b5 100644 --- a/crates/key-server/src/periodic_updater.rs +++ b/crates/key-server/src/periodic_updater.rs @@ -1,7 +1,7 @@ // Copyright (c), Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use crate::sui_rpc_client::SuiRpcClient; +use crate::sui_rpc_client::{RpcClient, SuiRpcClient}; use std::time::{Duration, Instant}; use sui_sdk::error::SuiRpcResult; use tokio::sync::watch::{channel, Receiver}; @@ -12,8 +12,8 @@ use tracing::debug; /// If a subscriber is provided, it will be called when the value is updated. /// If a duration_callback is provided, it will be called with the duration of each fetch operation. /// Returns the [Receiver]. -pub async fn spawn_periodic_updater( - client: &SuiRpcClient, +pub async fn spawn_periodic_updater( + client: &SuiRpcClient, update_interval: Duration, fetch_fn: F, value_name: &'static str, @@ -22,7 +22,8 @@ pub async fn spawn_periodic_updater( success_callback: Option, ) -> (Receiver, JoinHandle<()>) where - F: Fn(SuiRpcClient) -> Fut + Send + 'static, + Client: RpcClient, + F: Fn(SuiRpcClient) -> Fut + Send + 'static, Fut: Future> + Send, G: Fn(u64) + Send + 'static, H: Fn(Duration) + Send + 'static, diff --git a/crates/key-server/src/server.rs b/crates/key-server/src/server.rs index 7d9a15c72..8604177c6 100644 --- a/crates/key-server/src/server.rs +++ b/crates/key-server/src/server.rs @@ -1,776 +1,21 @@ // Copyright (c), Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use crate::errors::InternalError::{ - DeprecatedSDKVersion, InvalidSDKVersion, MissingRequiredHeader, -}; -use crate::externals::get_reference_gas_price; -use crate::key_server_options::ServerMode; -use crate::metrics::{call_with_duration, observation_callback, status_callback, Metrics}; -use crate::metrics_push::create_push_client; -use crate::mvr::mvr_forward_resolution; -use crate::periodic_updater::spawn_periodic_updater; -use crate::signed_message::signed_request; -use crate::time::checked_duration_since; -use crate::time::from_mins; -use crate::time::{duration_since_as_f64, saturating_duration_since}; -use crate::types::{MasterKeyPOP, Network}; -use anyhow::{Context, Result}; -use axum::extract::{Query, Request}; -use axum::http::{HeaderMap, HeaderValue}; -use axum::middleware::{from_fn_with_state, map_response, Next}; -use axum::response::Response; -use axum::routing::{get, post}; -use axum::{extract::State, Json, Router}; -use core::time::Duration; -use crypto::elgamal::encrypt; -use crypto::ibe; -use crypto::ibe::create_proof_of_possession; -use crypto::prefixed_hex::PrefixedHex; -use errors::InternalError; -use externals::get_latest_checkpoint_timestamp; -use fastcrypto::ed25519::{Ed25519PublicKey, Ed25519Signature}; -use fastcrypto::encoding::{Encoding, Hex}; -use fastcrypto::traits::VerifyingKey; -use futures::future::pending; -use jsonrpsee::core::ClientError; -use jsonrpsee::types::error::{INVALID_PARAMS_CODE, METHOD_NOT_FOUND_CODE}; -use key_server_options::KeyServerOptions; -use master_keys::MasterKeys; -use metrics::metrics_middleware; -use mysten_service::get_mysten_service; -use mysten_service::metrics::start_prometheus_server; -use mysten_service::package_name; -use mysten_service::package_version; -use mysten_service::serve; -use rand::thread_rng; -use seal_sdk::types::{DecryptionKey, ElGamalPublicKey, ElgamalVerificationKey, KeyId}; -use seal_sdk::{signed_message, FetchKeyResponse}; -use semver::Version; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use std::collections::HashMap; -use std::env; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::sync::Arc; -use sui_rpc_client::SuiRpcClient; -use sui_sdk::error::Error; -use sui_sdk::rpc_types::{SuiExecutionStatus, SuiTransactionBlockEffectsAPI}; -use sui_sdk::types::base_types::{ObjectID, SuiAddress}; -use sui_sdk::types::signature::GenericSignature; -use sui_sdk::types::transaction::{ProgrammableTransaction, TransactionData, TransactionKind}; -use sui_sdk::verify_personal_message_signature::verify_personal_message_signature; -use sui_sdk::SuiClientBuilder; -use tap::tap::TapFallible; -use tap::Tap; -use tokio::sync::watch::Receiver; -use tokio::task::JoinHandle; -use tower_http::cors::{Any, CorsLayer}; -use tower_http::limit::RequestBodyLimitLayer; -use tracing::{debug, error, info, warn}; -use valid_ptb::ValidPtb; - -mod cache; -mod errors; -mod externals; -mod signed_message; -mod sui_rpc_client; -mod types; -mod utils; -mod valid_ptb; - -mod key_server_options; -mod master_keys; -mod metrics; -mod metrics_push; -mod mvr; -mod periodic_updater; -#[cfg(test)] -pub mod tests; -mod time; - -const GAS_BUDGET: u64 = 500_000_000; -const GIT_VERSION: &str = utils::git_version!(); - -// Transaction size limit: 128KB + 33% for base64 + some extra room for other parameters -const MAX_REQUEST_SIZE: usize = 180 * 1024; - -/// Default encoding used for master and public keys for the key server. -type DefaultEncoding = PrefixedHex; - -// TODO: Remove legacy once key-server crate uses sui-sdk-types. -#[derive(Clone, Serialize, Deserialize, Debug)] -struct Certificate { - pub user: SuiAddress, - pub session_vk: Ed25519PublicKey, - pub creation_time: u64, - pub ttl_min: u16, - pub signature: GenericSignature, - pub mvr_name: Option, -} - -// TODO: Remove legacy once key-server crate uses sui-sdk-types. -#[derive(Serialize, Deserialize)] -struct FetchKeyRequest { - // Next fields must be signed to prevent others from sending requests on behalf of the user and - // being able to fetch the key - ptb: String, // must adhere specific structure, see ValidPtb - // We don't want to rely on https only for restricting the response to this user, since in the - // case of multiple services, one service can do a replay attack to get the key from other - // services. - enc_key: ElGamalPublicKey, - enc_verification_key: ElgamalVerificationKey, - request_signature: Ed25519Signature, - - certificate: Certificate, -} - -/// UNIX timestamp in milliseconds. -type Timestamp = u64; - -#[derive(Clone)] -struct Server { - sui_rpc_client: SuiRpcClient, - master_keys: MasterKeys, - key_server_oid_to_pop: HashMap, - options: KeyServerOptions, -} - -impl Server { - async fn new(options: KeyServerOptions, metrics: Option>) -> Self { - let sui_rpc_client = SuiRpcClient::new( - SuiClientBuilder::default() - .request_timeout(options.rpc_config.timeout) - .build(&options.network.node_url()) - .await - .expect( - "SuiClientBuilder should not failed unless provided with invalid network url", - ), - options.rpc_config.retry_config.clone(), - metrics, - ); - info!("Server started with network: {:?}", options.network); - let master_keys = MasterKeys::load(&options).unwrap_or_else(|e| { - panic!("Failed to load master keys: {}", e); - }); - - let key_server_oid_to_pop = options - .get_supported_key_server_object_ids() - .into_iter() - .map(|ks_oid| { - let key = master_keys - .get_key_for_key_server(&ks_oid) - .expect("checked already"); - let pop = create_proof_of_possession(key, &ks_oid.into_bytes()); - (ks_oid, pop) - }) - .collect(); - - Server { - sui_rpc_client, - master_keys, - key_server_oid_to_pop, - options, - } - } - - #[allow(clippy::too_many_arguments)] - async fn check_signature( - &self, - ptb: &ProgrammableTransaction, - enc_key: &ElGamalPublicKey, - enc_verification_key: &ElgamalVerificationKey, - session_sig: &Ed25519Signature, - cert: &Certificate, - package_name: String, - req_id: Option<&str>, - ) -> Result<(), InternalError> { - // Check certificate - - // TTL of the session key must be smaller than the allowed max - let ttl = from_mins(cert.ttl_min); - if ttl > self.options.session_key_ttl_max { - debug!( - "Certificate has invalid time-to-live (req_id: {:?})", - req_id - ); - return Err(InternalError::InvalidCertificate); - } - - // Check that the creation time is not in the future and that the certificate has not expired - match checked_duration_since(cert.creation_time) { - None => { - debug!( - "Certificate has invalid creation time (req_id: {:?})", - req_id - ); - return Err(InternalError::InvalidCertificate); - } - Some(duration) => { - if duration > ttl { - debug!("Certificate has expired (req_id: {:?})", req_id); - return Err(InternalError::InvalidCertificate); - } - } - } - - let msg = signed_message( - package_name, - &cert.session_vk, - cert.creation_time, - cert.ttl_min, - ); - debug!( - "Checking signature on message: {:?} (req_id: {:?})", - msg, req_id - ); - verify_personal_message_signature( - cert.signature.clone(), - msg.as_bytes(), - cert.user, - Some(self.sui_rpc_client.sui_client().clone()), - ) - .await - .tap_err(|e| { - debug!( - "Signature verification failed: {:?} (req_id: {:?})", - e, req_id - ); - }) - .map_err(|_| InternalError::InvalidSignature)?; - - // Check session signature - let signed_msg = signed_request(ptb, enc_key, enc_verification_key); - cert.session_vk - .verify(&signed_msg, session_sig) - .map_err(|_| { - debug!( - "Session signature verification failed (req_id: {:?})", - req_id - ); - InternalError::InvalidSessionSignature - }) - } - - async fn check_policy( - &self, - sender: SuiAddress, - vptb: &ValidPtb, - gas_price: u64, - req_id: Option<&str>, - metrics: Option<&Metrics>, - ) -> Result<(), InternalError> { - debug!( - "Checking policy for ptb: {:?} (req_id: {:?})", - vptb.ptb(), - req_id - ); - // Evaluate the `seal_approve*` function - let tx_data = TransactionData::new_with_gas_coins( - TransactionKind::ProgrammableTransaction(vptb.ptb().clone()), - sender, - vec![], // Empty gas payment for dry run - GAS_BUDGET, - gas_price, - ); - let dry_run_res = self - .sui_rpc_client - .dry_run_transaction_block(tx_data.clone()) - .await - .map_err(|e| { - if let Error::RpcError(ClientError::Call(ref e)) = e { - match e.code() { - INVALID_PARAMS_CODE => { - // This error is generic and happens when one of the parameters of the Move call in the PTB is invalid. - // One reason is that one of the parameters does not exist, in which case it could be a newly created object that the FN has not yet seen. - // There are other possible reasons, so we return the entire message to the user to allow debugging. - // Note that the message is a message from the JSON RPC API, so it is already formatted and does not contain any sensitive information. - debug!("Invalid parameter: {}", e.message()); - return InternalError::InvalidParameter(e.message().to_string()); - } - METHOD_NOT_FOUND_CODE => { - // This means that the seal_approve function is not found on the given module. - debug!("Function not found: {:?}", e); - return InternalError::InvalidPTB( - "The seal_approve function was not found on the module".to_string(), - ); - } - _ => {} - } - } - InternalError::Failure(format!( - "Dry run execution failed ({:?}) (req_id: {:?})", - e, req_id - )) - })?; - - // Record the gas cost. Only do this in permissioned mode to avoid high cardinality metrics in public mode. - if let Some(m) = metrics { - if matches!( - self.options.server_mode, - ServerMode::Permissioned { client_configs: _ } - ) { - let package = vptb.pkg_id().to_hex_uncompressed(); - m.dry_run_gas_cost_per_package - .with_label_values(&[&package]) - .observe(dry_run_res.effects.gas_cost_summary().computation_cost as f64); - } - } - - debug!("Dry run response: {:?} (req_id: {:?})", dry_run_res, req_id); - if let SuiExecutionStatus::Failure { error } = dry_run_res.effects.status() { - debug!( - "Dry run execution asserted (req_id: {:?}) {:?}", - req_id, error - ); - return Err(InternalError::NoAccess(error.clone())); - } - - // all good! - Ok(()) - } - - #[allow(clippy::too_many_arguments)] - async fn check_request( - &self, - valid_ptb: &ValidPtb, - enc_key: &ElGamalPublicKey, - enc_verification_key: &ElgamalVerificationKey, - request_signature: &Ed25519Signature, - certificate: &Certificate, - gas_price: u64, - metrics: Option<&Metrics>, - req_id: Option<&str>, - mvr_name: Option, - ) -> Result<(ObjectID, Vec), InternalError> { - // Handle package upgrades: Use the first as the namespace - let first_pkg_id = - call_with_duration(metrics.map(|m| &m.fetch_pkg_ids_duration), || async { - externals::fetch_first_pkg_id(&valid_ptb.pkg_id(), &self.sui_rpc_client).await - }) - .await?; - - // Make sure that the package is supported. - self.master_keys.has_key_for_package(&first_pkg_id)?; - - // Check if the package id that MVR name points matches the first package ID, if provided. - externals::check_mvr_package_id( - &mvr_name, - &self.sui_rpc_client, - &self.options, - first_pkg_id, - req_id, - ) - .await?; - - // Check all conditions - self.check_signature( - valid_ptb.ptb(), - enc_key, - enc_verification_key, - request_signature, - certificate, - mvr_name.unwrap_or(first_pkg_id.to_hex_uncompressed()), - req_id, - ) - .await?; - - call_with_duration(metrics.map(|m| &m.check_policy_duration), || async { - self.check_policy(certificate.user, valid_ptb, gas_price, req_id, metrics) - .await - }) - .await?; - - // return the full id with the first package id as prefix - Ok((first_pkg_id, valid_ptb.full_ids(&first_pkg_id))) - } - - fn create_response( - &self, - first_pkg_id: ObjectID, - ids: &[KeyId], - enc_key: &ElGamalPublicKey, - ) -> FetchKeyResponse { - debug!( - "Creating response for ids: {:?}", - ids.iter().map(Hex::encode).collect::>() - ); - let master_key = self - .master_keys - .get_key_for_package(&first_pkg_id) - .expect("checked already"); - let decryption_keys = ids - .iter() - .map(|id| { - // Requested key - let key = ibe::extract(master_key, id); - // ElGamal encryption of key under the user's public key - let encrypted_key = encrypt(&mut thread_rng(), &key, enc_key); - DecryptionKey { - id: id.to_owned(), - encrypted_key, - } - }) - .collect(); - FetchKeyResponse { decryption_keys } - } - - /// Spawns a thread that fetches the latest checkpoint timestamp and sends it to a [Receiver] once per `update_interval`. - /// Returns the [Receiver]. - async fn spawn_latest_checkpoint_timestamp_updater( - &self, - metrics: Option<&Metrics>, - ) -> (Receiver, JoinHandle<()>) { - spawn_periodic_updater( - &self.sui_rpc_client, - self.options.checkpoint_update_interval, - get_latest_checkpoint_timestamp, - "latest checkpoint timestamp", - metrics.map(|m| { - observation_callback(&m.checkpoint_timestamp_delay, |ts| { - let duration = duration_since_as_f64(ts); - debug!("Latest checkpoint timestamp delay is {duration} ms"); - duration - }) - }), - metrics.map(|m| { - observation_callback(&m.get_checkpoint_timestamp_duration, |d: Duration| { - d.as_millis() as f64 - }) - }), - metrics.map(|m| status_callback(&m.get_checkpoint_timestamp_status)), - ) - .await - } - - /// Spawns a thread that fetches RGP and sends it to a [Receiver] once per `update_interval`. - /// Returns the [Receiver]. - async fn spawn_reference_gas_price_updater( - &self, - metrics: Option<&Metrics>, - ) -> (Receiver, JoinHandle<()>) { - spawn_periodic_updater( - &self.sui_rpc_client, - self.options.rgp_update_interval, - get_reference_gas_price, - "RGP", - None::, - None::, - metrics.map(|m| status_callback(&m.get_reference_gas_price_status)), - ) - .await - } - - /// Spawn a metrics push background jobs that push metrics to seal-proxy - fn spawn_metrics_push_job(&self, registry: prometheus::Registry) -> JoinHandle<()> { - let push_config = self.options.metrics_push_config.clone(); - if let Some(push_config) = push_config { - tokio::spawn(async move { - let mut interval = tokio::time::interval(push_config.push_interval); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - let mut client = create_push_client(); - tracing::info!("starting metrics push to '{}'", &push_config.push_url); - loop { - tokio::select! { - _ = interval.tick() => { - if let Err(error) = metrics_push::push_metrics( - push_config.clone(), - &client, - ®istry, - ).await { - tracing::warn!(?error, "unable to push metrics"); - client = create_push_client(); - } - } - } - } - }) - } else { - tokio::spawn(async move { - warn!("No metrics push config is found"); - pending().await - }) - } - } -} - -#[allow(clippy::single_match)] -async fn handle_fetch_key_internal( - app_state: &MyState, - payload: &FetchKeyRequest, - req_id: Option<&str>, - sdk_version: &str, -) -> Result<(ObjectID, Vec), InternalError> { - app_state.check_full_node_is_fresh()?; - - let valid_ptb = ValidPtb::try_from_base64(&payload.ptb)?; - - // Report the number of id's in the request to the metrics. - app_state - .metrics - .requests_per_number_of_ids - .observe(valid_ptb.inner_ids().len() as f64); - - app_state - .server - .check_request( - &valid_ptb, - &payload.enc_key, - &payload.enc_verification_key, - &payload.request_signature, - &payload.certificate, - app_state.reference_gas_price(), - Some(&app_state.metrics), - req_id, - payload.certificate.mvr_name.clone(), - ) - .await - .tap(|r| { - let request_info = json!({ "user": payload.certificate.user, "package_id": valid_ptb.pkg_id(), "req_id": req_id, "sdk_version": sdk_version }); - match r { - Ok(_) => info!("Valid request: {request_info}"), - Err(InternalError::Failure(s)) => warn!("Check request failed with debug message '{s}': {request_info}"), - _ => {}, - } - }) -} - -async fn handle_fetch_key( - State(app_state): State, - headers: HeaderMap, - Json(payload): Json, -) -> Result, InternalError> { - let req_id = headers - .get("Request-Id") - .map(|v| v.to_str().unwrap_or_default()); - let sdk_version = headers - .get("Client-Sdk-Version") - .and_then(|v| v.to_str().ok()) - .unwrap_or_default(); - app_state.metrics.requests.inc(); - - debug!( - "Checking request for ptb: {:?}, cert {:?} (req_id: {:?})", - payload.ptb, payload.certificate, req_id - ); - - handle_fetch_key_internal(&app_state, &payload, req_id, sdk_version) - .await - .tap_err(|e| app_state.metrics.observe_error(e.as_str())) - .map(|(first_pkg_id, full_ids)| { - Json( - app_state - .server - .create_response(first_pkg_id, &full_ids, &payload.enc_key), - ) - }) -} - -#[derive(Serialize, Deserialize)] -struct GetServiceResponse { - service_id: ObjectID, - pop: MasterKeyPOP, -} - -async fn handle_get_service( - State(app_state): State, - Query(params): Query>, -) -> Result, InternalError> { - app_state.metrics.service_requests.inc(); - - let service_id = params - .get("service_id") - .ok_or(InternalError::InvalidServiceId) - .and_then(|id| { - ObjectID::from_hex_literal(id).map_err(|_| InternalError::InvalidServiceId) - })?; - - let pop = *app_state - .server - .key_server_oid_to_pop - .get(&service_id) - .ok_or(InternalError::InvalidServiceId)?; - - Ok(Json(GetServiceResponse { service_id, pop })) -} - -#[derive(Clone)] -struct MyState { - metrics: Arc, - server: Arc, - latest_checkpoint_timestamp_receiver: Receiver, - reference_gas_price_receiver: Receiver, -} - -impl MyState { - fn check_full_node_is_fresh(&self) -> Result<(), InternalError> { - // Compute the staleness of the latest checkpoint timestamp. - let staleness = - saturating_duration_since(*self.latest_checkpoint_timestamp_receiver.borrow()); - if staleness > self.server.options.allowed_staleness { - return Err(InternalError::Failure(format!( - "Full node is stale. Latest checkpoint is {} ms old.", - staleness.as_millis() - ))); - } - Ok(()) - } - - fn reference_gas_price(&self) -> u64 { - *self.reference_gas_price_receiver.borrow() - } - - fn validate_sdk_version(&self, version_string: &str) -> Result<(), InternalError> { - let version = Version::parse(version_string).map_err(|_| InvalidSDKVersion)?; - if !self - .server - .options - .sdk_version_requirement - .matches(&version) - { - return Err(DeprecatedSDKVersion); - } - Ok(()) - } -} - -/// Middleware to validate the SDK version. -async fn handle_request_headers( - state: State, - request: Request, - next: Next, -) -> Result { - // Log the request id and SDK version - let version = request.headers().get("Client-Sdk-Version"); - - info!( - "Request id: {:?}, SDK version: {:?}, SDK type: {:?}, Target API version: {:?}", - request - .headers() - .get("Request-Id") - .map(|v| v.to_str().unwrap_or_default()), - version, - request.headers().get("Client-Sdk-Type"), - request.headers().get("Client-Target-Api-Version") - ); - - version - .ok_or(MissingRequiredHeader("Client-Sdk-Version".to_string())) - .and_then(|v| v.to_str().map_err(|_| InvalidSDKVersion)) - .and_then(|v| state.validate_sdk_version(v)) - .tap_err(|e| { - debug!("Invalid SDK version: {:?}", e); - state.metrics.observe_error(e.as_str()); - })?; - Ok(next.run(request).await) -} - -/// Middleware to add headers to all responses. -async fn add_response_headers(mut response: Response) -> Response { - let headers = response.headers_mut(); - headers.insert( - "X-KeyServer-Version", - HeaderValue::from_static(package_version!()), - ); - headers.insert( - "X-KeyServer-GitVersion", - HeaderValue::from_static(GIT_VERSION), - ); - response -} - -/// Creates a [prometheus::core::Collector] that tracks the uptime of the server. -fn uptime_metric(version: &str) -> Box { - let opts = prometheus::opts!("uptime", "uptime of the key server in seconds") - .variable_label("version"); - - let start_time = std::time::Instant::now(); - let uptime = move || start_time.elapsed().as_secs(); - let metric = prometheus_closure_metric::ClosureMetric::new( - opts, - prometheus_closure_metric::ValueType::Counter, - uptime, - &[version], - ) - .unwrap(); - - Box::new(metric) -} - -/// Spawn server's background tasks: -/// - background checkpoint downloader -/// - reference gas price updater. -/// - optional metrics pusher (if configured). -/// -/// The returned JoinHandle can be used to catch any tasks error or panic. -async fn start_server_background_tasks( - server: Arc, - metrics: Arc, - registry: prometheus::Registry, -) -> ( - Receiver, - Receiver, - JoinHandle>, -) { - // Spawn background checkpoint timestamp updater. - let (latest_checkpoint_timestamp_receiver, latest_checkpoint_timestamp_handle) = server - .spawn_latest_checkpoint_timestamp_updater(Some(&metrics)) - .await; - - // Spawn background reference gas price updater. - let (reference_gas_price_receiver, reference_gas_price_handle) = server - .spawn_reference_gas_price_updater(Some(&metrics)) - .await; - - // Spawn metrics push task - let metrics_push_handle = server.spawn_metrics_push_job(registry); - - // Spawn a monitor task that will exit the program if any updater task panics - let handle: JoinHandle> = tokio::spawn(async move { - tokio::select! { - result = latest_checkpoint_timestamp_handle => { - if let Err(e) = result { - error!("Latest checkpoint timestamp updater panicked: {:?}", e); - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } - return Err(e.into()); - } - } - result = reference_gas_price_handle => { - if let Err(e) = result { - error!("Reference gas price updater panicked: {:?}", e); - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } - return Err(e.into()); - } - } - result = metrics_push_handle => { - if let Err(e) = result { - error!("Metrics push task panicked: {:?}", e); - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } - return Err(e.into()); - } - } - } - - unreachable!("One of the background tasks should have returned an error"); - }); - - ( - latest_checkpoint_timestamp_receiver, - reference_gas_price_receiver, - handle, - ) -} +use anyhow::Result; +use key_server::master_keys::MasterKeys; +use key_server::{app, get_server_options_from_env}; +use mysten_service::serve; +use sui_sdk::SuiClient; +use tracing::error; #[tokio::main] async fn main() -> Result<()> { let _guard = mysten_service::logging::init(); - let (monitor_handle, app) = app().await?; + + let options = get_server_options_from_env()?; + let master_keys = MasterKeys::load_from_env(&options.server_mode)?; + + let (monitor_handle, app) = app::(options, master_keys).await?; tokio::select! { server_result = serve(app) => { @@ -783,117 +28,3 @@ async fn main() -> Result<()> { } } } - -pub(crate) async fn app() -> Result<(JoinHandle>, Router)> { - // If CONFIG_PATH is set, read the configuration from the file. - // Otherwise, use the local environment variables. - let options = match env::var("CONFIG_PATH") { - Ok(config_path) => { - info!("Loading config file: {}", config_path); - let mut opts: KeyServerOptions = serde_yaml::from_reader( - std::fs::File::open(&config_path) - .context(format!("Cannot open configuration file {config_path}"))?, - ) - .expect("Failed to parse configuration file"); - - // Handle Custom network NODE_URL configuration - if let Network::Custom { - ref mut node_url, .. - } = opts.network - { - let env_node_url = env::var("NODE_URL").ok(); - - match (node_url.as_ref(), env_node_url.as_ref()) { - (Some(_), Some(_)) => { - panic!("NODE_URL cannot be provided in both config file and environment variable. Please use only one source."); - } - (None, Some(url)) => { - info!("Using NODE_URL from environment variable: {}", url); - *node_url = Some(url.clone()); - } - (Some(url), None) => { - info!("Using NODE_URL from config file: {}", url); - } - (None, None) => { - panic!("Custom network requires NODE_URL to be set either in config file or as environment variable"); - } - } - } - - opts - } - Err(_) => { - info!("Using local environment variables for configuration, should only be used for testing"); - let network = env::var("NETWORK") - .map(|n| Network::from_str(&n)) - .unwrap_or(Network::Testnet); - KeyServerOptions::new_open_server_with_default_values( - network, - utils::decode_object_id("KEY_SERVER_OBJECT_ID")?, - ) - } - }; - - info!("Setting up metrics"); - let registry = start_prometheus_server(SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), - options.metrics_host_port, - )) - .default_registry(); - - // Tracks the uptime of the server. - let registry_clone = registry.clone(); - tokio::task::spawn(async move { - registry_clone - .register(uptime_metric( - format!("{}-{}", package_version!(), GIT_VERSION).as_str(), - )) - .expect("metrics defined at compile time must be valid"); - }); - - // hook up custom application metrics - let metrics = Arc::new(Metrics::new(®istry)); - - info!( - "Starting server, version {}", - format!("{}-{}", package_version!(), GIT_VERSION).as_str() - ); - options.validate()?; - let server = Arc::new(Server::new(options, Some(metrics.clone())).await); - - let (latest_checkpoint_timestamp_receiver, reference_gas_price_receiver, monitor_handle) = - start_server_background_tasks(server.clone(), metrics.clone(), registry.clone()).await; - - let state = MyState { - metrics, - server, - latest_checkpoint_timestamp_receiver, - reference_gas_price_receiver, - }; - - let cors = CorsLayer::new() - .allow_methods(Any) - .allow_origin(Any) - .allow_headers(Any) - .expose_headers(Any); - - let app = get_mysten_service::(package_name!(), package_version!()) - .merge( - axum::Router::new() - .route("/v1/fetch_key", post(handle_fetch_key)) - .route("/v1/service", get(handle_get_service)) - .layer(from_fn_with_state(state.clone(), handle_request_headers)) - .layer(map_response(add_response_headers)) - // Outside most middlewares that tracks metrics for HTTP requests and response - // status. - .layer(from_fn_with_state( - state.metrics.clone(), - metrics_middleware, - )), - ) - .with_state(state) - // Global body size limit - .layer(RequestBodyLimitLayer::new(MAX_REQUEST_SIZE)) - .layer(cors); - Ok((monitor_handle, app)) -} diff --git a/crates/key-server/src/sui_rpc_client.rs b/crates/key-server/src/sui_rpc_client.rs index 48e17caec..2119c8b26 100644 --- a/crates/key-server/src/sui_rpc_client.rs +++ b/crates/key-server/src/sui_rpc_client.rs @@ -1,8 +1,14 @@ // Copyright (c), Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 +use crate::{key_server_options::RetryConfig, metrics::Metrics}; +use async_trait::async_trait; +use fastcrypto::encoding::{Base64, Encoding}; +use fastcrypto::traits::ToFromBytes; +use shared_crypto::intent::{Intent, IntentMessage, PersonalMessage}; use std::sync::Arc; - +use sui_sdk::error::Error; +use sui_sdk::rpc_types::{ZkLoginIntentScope, ZkLoginVerifyResult}; use sui_sdk::{ error::SuiRpcResult, rpc_types::{ @@ -11,11 +17,13 @@ use sui_sdk::{ }, SuiClient, }; -use sui_types::base_types::ObjectID; +use sui_types::base_types::{ObjectID, SuiAddress}; +use sui_types::messages_checkpoint::CheckpointSequenceNumber; +use sui_types::signature::AuthenticatorTrait; +use sui_types::signature::{GenericSignature, VerifyParams}; +use sui_types::signature_verification::VerifiedDigestCache; use sui_types::{dynamic_field::DynamicFieldName, transaction::TransactionData}; -use crate::{key_server_options::RetryConfig, metrics::Metrics}; - /// Trait for determining if an error is retriable pub trait RetriableError { /// Returns true if the error is transient and the operation should be retried @@ -115,17 +123,121 @@ where } } +#[async_trait] +pub trait RpcClient: Clone + Send + Sync + 'static { + async fn new_from_builder(build: Fut) -> SuiRpcResult + where + Fut: Future> + Send; + + async fn dry_run_transaction_block( + &self, + tx: TransactionData, + ) -> SuiRpcResult; + + async fn get_object_with_options( + &self, + object_id: ObjectID, + options: SuiObjectDataOptions, + ) -> SuiRpcResult; + + async fn get_latest_checkpoint_sequence_number(&self) + -> SuiRpcResult; + + async fn get_checkpoint(&self, id: CheckpointId) -> SuiRpcResult; + + async fn get_dynamic_field_object( + &self, + parent_object_id: ObjectID, + name: DynamicFieldName, + ) -> SuiRpcResult; + + async fn get_reference_gas_price(&self) -> SuiRpcResult; + + async fn verify_zklogin_signature( + &self, + bytes: String, + signature: String, + intent_scope: ZkLoginIntentScope, + address: SuiAddress, + ) -> SuiRpcResult; +} + +#[async_trait] +impl RpcClient for SuiClient { + async fn new_from_builder(build: Fut) -> SuiRpcResult + where + Fut: Future> + Send, + { + build.await + } + + async fn dry_run_transaction_block( + &self, + tx: TransactionData, + ) -> SuiRpcResult { + self.read_api().dry_run_transaction_block(tx).await + } + + async fn get_object_with_options( + &self, + object_id: ObjectID, + options: SuiObjectDataOptions, + ) -> SuiRpcResult { + self.read_api() + .get_object_with_options(object_id, options) + .await + } + + async fn get_latest_checkpoint_sequence_number( + &self, + ) -> SuiRpcResult { + self.read_api() + .get_latest_checkpoint_sequence_number() + .await + } + + async fn get_checkpoint(&self, id: CheckpointId) -> SuiRpcResult { + self.read_api().get_checkpoint(id).await + } + + async fn get_dynamic_field_object( + &self, + parent_object_id: ObjectID, + name: DynamicFieldName, + ) -> SuiRpcResult { + self.read_api() + .get_dynamic_field_object(parent_object_id, name) + .await + } + + async fn get_reference_gas_price(&self) -> SuiRpcResult { + self.read_api().get_reference_gas_price().await + } + + async fn verify_zklogin_signature( + &self, + bytes: String, + signature: String, + intent_scope: ZkLoginIntentScope, + address: SuiAddress, + ) -> SuiRpcResult { + self.read_api() + .verify_zklogin_signature(bytes, signature, intent_scope, address) + .await + } +} + /// Client for interacting with the Sui RPC API. #[derive(Clone)] -pub struct SuiRpcClient { - sui_client: SuiClient, +pub struct SuiRpcClient { + sui_client: T, rpc_retry_config: RetryConfig, metrics: Option>, } -impl SuiRpcClient { +impl SuiRpcClient { pub fn new( - sui_client: SuiClient, + sui_client: T, rpc_retry_config: RetryConfig, metrics: Option>, ) -> Self { @@ -137,7 +249,7 @@ impl SuiRpcClient { } /// Returns a reference to the underlying SuiClient. - pub fn sui_client(&self) -> &SuiClient { + pub fn sui_client(&self) -> &T { &self.sui_client } @@ -157,7 +269,6 @@ impl SuiRpcClient { self.metrics.clone(), || async { self.sui_client - .read_api() .dry_run_transaction_block(tx_data.clone()) .await }, @@ -177,7 +288,6 @@ impl SuiRpcClient { self.metrics.clone(), || async { self.sui_client - .read_api() .get_object_with_options(object_id, options.clone()) .await }, @@ -193,7 +303,6 @@ impl SuiRpcClient { self.metrics.clone(), || async { self.sui_client - .read_api() .get_latest_checkpoint_sequence_number() .await }, @@ -207,12 +316,7 @@ impl SuiRpcClient { &self.rpc_retry_config, "get_checkpoint", self.metrics.clone(), - || async { - self.sui_client - .read_api() - .get_checkpoint(checkpoint_id) - .await - }, + || async { self.sui_client.get_checkpoint(checkpoint_id).await }, ) .await } @@ -223,7 +327,7 @@ impl SuiRpcClient { &self.rpc_retry_config, "get_reference_gas_price", self.metrics.clone(), - || async { self.sui_client.read_api().get_reference_gas_price().await }, + || async { self.sui_client.get_reference_gas_price().await }, ) .await } @@ -240,7 +344,6 @@ impl SuiRpcClient { self.metrics.clone(), || async { self.sui_client - .read_api() .get_dynamic_field_object(object_id, dynamic_field_name.clone()) .await }, @@ -249,6 +352,51 @@ impl SuiRpcClient { } } +pub(crate) async fn verify_personal_message_signature( + signature: GenericSignature, + message: &[u8], + address: SuiAddress, + client: Option, +) -> Result<(), Error> { + let intent_msg = IntentMessage::new( + Intent::personal_message(), + PersonalMessage { + message: message.to_vec(), + }, + ); + match signature { + GenericSignature::ZkLoginAuthenticator(ref _sig) => { + if let Some(client) = client { + let bytes = Base64::encode(message); + let sig_string = Base64::encode(signature.as_bytes()); + let res = client + .verify_zklogin_signature( + bytes, + sig_string, + ZkLoginIntentScope::PersonalMessage, + address, + ) + .await?; + if res.success { + Ok(()) + } else { + Err(Error::InvalidSignature) + } + } else { + Err(Error::InvalidSignature) + } + } + _ => signature + .verify_claims::( + &intent_msg, + address, + &VerifyParams::default(), + Arc::new(VerifiedDigestCache::new_empty()), + ) + .map_err(|_| Error::InvalidSignature), + } +} + #[cfg(test)] mod tests { use crate::key_server_options::RetryConfig; diff --git a/crates/key-server/src/tests/e2e.rs b/crates/key-server/src/tests/e2e.rs index dc673a565..573a20b12 100644 --- a/crates/key-server/src/tests/e2e.rs +++ b/crates/key-server/src/tests/e2e.rs @@ -510,7 +510,7 @@ async fn create_server( sui_client: SuiClient, client_configs: Vec, vars: impl AsRef<[(&str, &[u8])]>, -) -> Server { +) -> Server { let options = KeyServerOptions { network: Network::TestCluster, server_mode: ServerMode::Permissioned { client_configs }, @@ -532,7 +532,8 @@ async fn create_server( Server { sui_rpc_client: SuiRpcClient::new(sui_client, RetryConfig::default(), None), - master_keys: temp_env::with_vars(vars, || MasterKeys::load(&options)).unwrap(), + master_keys: temp_env::with_vars(vars, || MasterKeys::load_from_env(&options.server_mode)) + .unwrap(), key_server_oid_to_pop: HashMap::new(), options, } diff --git a/crates/key-server/src/tests/externals.rs b/crates/key-server/src/tests/externals.rs index a5d4e09db..9f96c9c19 100644 --- a/crates/key-server/src/tests/externals.rs +++ b/crates/key-server/src/tests/externals.rs @@ -14,6 +14,7 @@ use rand::thread_rng; use seal_sdk::signed_message; use seal_sdk::types::{ElGamalPublicKey, ElgamalVerificationKey}; use shared_crypto::intent::{Intent, IntentMessage, PersonalMessage}; +use sui_sdk::SuiClient; use sui_types::{ base_types::ObjectID, crypto::Signature, signature::GenericSignature, transaction::ProgrammableTransaction, @@ -57,7 +58,7 @@ pub(super) fn sign( } pub(crate) async fn get_key( - server: &Server, + server: &Server, pkg_id: &ObjectID, ptb: ProgrammableTransaction, kp: &Ed25519KeyPair, diff --git a/crates/key-server/src/tests/mod.rs b/crates/key-server/src/tests/mod.rs index eaee8b4c4..fea2a438e 100644 --- a/crates/key-server/src/tests/mod.rs +++ b/crates/key-server/src/tests/mod.rs @@ -25,6 +25,7 @@ use std::time::Duration; use sui_move_build::BuildConfig; use sui_sdk::json::SuiJsonValue; use sui_sdk::rpc_types::{ObjectChange, SuiData, SuiObjectDataOptions}; +use sui_sdk::SuiClient; use sui_types::base_types::{ObjectID, SuiAddress}; use sui_types::crypto::get_key_pair_from_rng; use sui_types::move_package::UpgradePolicy; @@ -43,7 +44,7 @@ pub(crate) struct SealTestCluster { cluster: TestCluster, #[allow(dead_code)] pub(crate) registry: (ObjectID, ObjectID), - pub(crate) servers: Vec<(ObjectID, Server)>, + pub(crate) servers: Vec<(ObjectID, Server)>, pub(crate) users: Vec, } @@ -142,7 +143,7 @@ impl SealTestCluster { }; } - pub fn server(&self) -> &Server { + pub fn server(&self) -> &Server { &self.servers[0].1 } diff --git a/crates/key-server/src/tests/server.rs b/crates/key-server/src/tests/server.rs index c4cfb81d0..209bc09e5 100644 --- a/crates/key-server/src/tests/server.rs +++ b/crates/key-server/src/tests/server.rs @@ -9,10 +9,11 @@ use tracing_test::traced_test; use crate::externals::get_latest_checkpoint_timestamp; use crate::key_server_options::RetryConfig; use crate::metrics::Metrics; -use crate::start_server_background_tasks; use crate::sui_rpc_client::SuiRpcClient; use crate::tests::SealTestCluster; +use crate::{get_server_options_from_env, start_server_background_tasks}; +use crate::master_keys::MasterKeys; use crate::signed_message::signed_request; use crate::{app, time, Certificate, DefaultEncoding, FetchKeyRequest}; use axum::body::Body; @@ -38,6 +39,7 @@ use serde_json::Value; use shared_crypto::intent::Intent; use shared_crypto::intent::IntentMessage; use std::str::FromStr; +use sui_sdk::SuiClient; use sui_types::base_types::{ObjectID, SuiAddress}; use sui_types::crypto::Signature; use sui_types::signature::GenericSignature; @@ -156,7 +158,9 @@ async fn test_service() { ), ]; temp_env::async_with_vars(vars, async { - let (_, app) = app().await.unwrap(); + let options = get_server_options_from_env().unwrap(); + let master_keys = MasterKeys::load_from_env(&options.server_mode).unwrap(); + let (_, app) = app::(options, master_keys).await.unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); @@ -318,7 +322,9 @@ async fn test_fetch_key() { // Run test temp_env::async_with_vars(vars, async { - let (_, app) = app().await.unwrap(); + let options = get_server_options_from_env().unwrap(); + let master_keys = MasterKeys::load_from_env(&options.server_mode).unwrap(); + let (_, app) = app::(options, master_keys).await.unwrap(); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); diff --git a/crates/key-server/src/types.rs b/crates/key-server/src/types.rs index 9a8132c62..3ecc62c51 100644 --- a/crates/key-server/src/types.rs +++ b/crates/key-server/src/types.rs @@ -38,7 +38,7 @@ impl Network { } } - pub fn from_str(str: &str) -> Self { + pub fn from_str_unchecked(str: &str) -> Self { match str.to_ascii_lowercase().as_str() { "devnet" => Network::Devnet, "testnet" => Network::Testnet, diff --git a/crates/key-server/src/utils.rs b/crates/key-server/src/utils.rs index 7eb92d6ed..861251f52 100644 --- a/crates/key-server/src/utils.rs +++ b/crates/key-server/src/utils.rs @@ -36,23 +36,17 @@ use std::env; use sui_types::base_types::ObjectID; /// Read a byte array from an environment variable and decode it using the specified encoding. -pub fn decode_byte_array(env_name: &str) -> anyhow::Result<[u8; N]> { - let hex_string = - env::var(env_name).map_err(|_| anyhow!("Environment variable {} must be set", env_name))?; - let bytes = E::decode(&hex_string) - .map_err(|_| anyhow!("Environment variable {} should be hex encoded", env_name))?; - bytes.try_into().map_err(|_| { - anyhow!( - "Invalid byte array length for environment variable {env_name}. Must be {N} bytes long" - ) - }) +pub fn decode_byte_array(hex_string: &str) -> anyhow::Result<[u8; N]> { + let bytes = E::decode(hex_string).map_err(|_| anyhow!("Variable should be hex encoded"))?; // We don't print the master in the error message for security reasons + bytes + .try_into() + .map_err(|_| anyhow!("Invalid byte array length for variable. Must be {N} bytes long")) } /// Read a master key from an environment variable. -pub fn decode_master_key(env_name: &str) -> anyhow::Result { - let bytes = decode_byte_array::(env_name)?; - IbeMasterKey::from_byte_array(&bytes) - .map_err(|_| anyhow!("Invalid master key for environment variable {env_name}")) +pub fn decode_master_key(hex_string: &str) -> anyhow::Result { + let bytes = decode_byte_array::(hex_string)?; + IbeMasterKey::from_byte_array(&bytes).map_err(|_| anyhow!("Invalid master key for variable.")) } /// Read an ObjectID from an environment variable. diff --git a/crates/key-server/tests/server.rs b/crates/key-server/tests/server.rs new file mode 100644 index 000000000..18600b77b --- /dev/null +++ b/crates/key-server/tests/server.rs @@ -0,0 +1,664 @@ +use anyhow::bail; +use async_trait::async_trait; +use crypto::{elgamal, seal_encrypt, EncryptionInput, IBEPublicKeys}; +use fastcrypto::ed25519::{Ed25519KeyPair, Ed25519PrivateKey}; +use fastcrypto::encoding::{Base64, Encoding}; +use fastcrypto::groups::bls12381::G2Element; +use fastcrypto::serde_helpers::ToFromByteArray; +use fastcrypto::traits::{KeyPair, Signer, ToFromBytes}; +use key_server::errors::InternalError; +use key_server::key_server_options::KeyServerOptions; +use key_server::master_keys::MasterKeys; +use key_server::signed_message::signed_request; +use key_server::sui_rpc_client::RpcClient; +use key_server::types::Network; +use key_server::valid_ptb::ValidPtb; +use key_server::{fetch_key, get_server, Certificate, FetchKeyRequest}; +use rand::prelude::StdRng; +use rand::{thread_rng, SeedableRng}; +use seal_sdk::{seal_decrypt_all_objects, signed_message}; +use shared_crypto::intent::{Intent, IntentMessage}; +use std::str::FromStr; +use sui_sdk::error::SuiRpcResult; +use sui_sdk::rpc_types::{ + Checkpoint, CheckpointId, DryRunTransactionBlockResponse, OwnedObjectRef, SuiExecutionStatus, + SuiGasData, SuiObjectData, SuiObjectDataOptions, SuiObjectRef, SuiObjectResponse, + SuiProgrammableTransactionBlock, SuiRawData, SuiRawMovePackage, SuiTransactionBlockData, + SuiTransactionBlockDataV1, SuiTransactionBlockEffects, SuiTransactionBlockEffectsV1, + SuiTransactionBlockKind, ZkLoginIntentScope, ZkLoginVerifyResult, +}; +use sui_sdk::SuiClient; +use sui_types::base_types::{ObjectID, ObjectType, SuiAddress}; +use sui_types::crypto::Signature; +use sui_types::digests::ObjectDigest; +use sui_types::dynamic_field::DynamicFieldName; +use sui_types::messages_checkpoint::CheckpointSequenceNumber; +use sui_types::object::{Owner, OBJECT_START_VERSION}; +use sui_types::programmable_transaction_builder::ProgrammableTransactionBuilder; +use sui_types::signature::GenericSignature; +use sui_types::transaction::TransactionData; +use sui_types::SUI_CLOCK_OBJECT_ID; + +const KEY_SERVER_OBJECT_ID: &str = "0x1"; +const MASTER_KEY: &str = "0x403a839967eb6b81beac300dc7feab8eab18c4cfcd5f68126d4954c9370855b2"; +const PUBLIC_KEY: &str = "0x8557fc1c2507a1b3898ab1f65654b7b79990bdcfa8caa6ef787418a1ac7657741b36a1aa7830364cd5af4856b0eb45a5118986a08046263048d6b4b4420af54700309c884d01b9d01f41779f9e0f5507f4cca1763a0765d136876e23940d1ec5"; + +const FIRST_PACKAGE_ID: &str = "0x123456"; +const SECOND_PACKAGE_ID: &str = "0xabcdef"; + +#[derive(Clone)] +struct MockSuiClient; + +#[async_trait] +impl RpcClient for MockSuiClient { + async fn new_from_builder(_build: Fut) -> SuiRpcResult + where + Fut: Future> + Send, + { + SuiRpcResult::Ok(MockSuiClient) + } + + async fn dry_run_transaction_block( + &self, + _tx: TransactionData, + ) -> SuiRpcResult { + SuiRpcResult::Ok(DryRunTransactionBlockResponse { + effects: SuiTransactionBlockEffects::V1(SuiTransactionBlockEffectsV1 { + status: SuiExecutionStatus::Success, + executed_epoch: 0, + gas_used: Default::default(), + modified_at_versions: vec![], + shared_objects: vec![], + transaction_digest: Default::default(), + created: vec![], + mutated: vec![], + unwrapped: vec![], + deleted: vec![], + unwrapped_then_deleted: vec![], + wrapped: vec![], + accumulator_events: vec![], + gas_object: OwnedObjectRef { + owner: Owner::Immutable, + reference: SuiObjectRef { + object_id: SUI_CLOCK_OBJECT_ID, + version: Default::default(), + digest: ObjectDigest::new(Default::default()), + }, + }, + events_digest: None, + dependencies: vec![], + abort_error: None, + }), + events: Default::default(), + object_changes: vec![], + balance_changes: vec![], + input: SuiTransactionBlockData::V1(SuiTransactionBlockDataV1 { + transaction: SuiTransactionBlockKind::ProgrammableTransaction( + SuiProgrammableTransactionBlock { + inputs: vec![], + commands: vec![], + }, + ), + sender: Default::default(), + gas_data: SuiGasData { + payment: vec![], + owner: Default::default(), + price: 0, + budget: 0, + }, + }), + execution_error_source: None, + suggested_gas_price: None, + }) + } + + async fn get_object_with_options( + &self, + object_id: ObjectID, + _options: SuiObjectDataOptions, + ) -> SuiRpcResult { + let response = if object_id == ObjectID::from_hex_literal(FIRST_PACKAGE_ID).unwrap() + || object_id == ObjectID::from_hex_literal(SECOND_PACKAGE_ID).unwrap() + { + SuiObjectResponse::new( + Some(SuiObjectData { + object_id, + version: OBJECT_START_VERSION, + digest: ObjectDigest::new(Default::default()), + type_: Some(ObjectType::Package), + owner: None, + previous_transaction: None, + storage_rebate: None, + display: None, + content: None, + bcs: Some(SuiRawData::Package(SuiRawMovePackage { + id: object_id, + version: OBJECT_START_VERSION, + module_map: Default::default(), + type_origin_table: vec![], + linkage_table: Default::default(), + })), + }), + None, + ) + } else { + todo!() + }; + + SuiRpcResult::Ok(response) + } + + async fn get_latest_checkpoint_sequence_number( + &self, + ) -> SuiRpcResult { + todo!() + } + + async fn get_checkpoint(&self, _id: CheckpointId) -> SuiRpcResult { + todo!() + } + + async fn get_dynamic_field_object( + &self, + _parent_object_id: ObjectID, + _name: DynamicFieldName, + ) -> SuiRpcResult { + todo!() + } + + async fn get_reference_gas_price(&self) -> SuiRpcResult { + todo!() + } + + async fn verify_zklogin_signature( + &self, + _bytes: String, + _signature: String, + _intent_scope: ZkLoginIntentScope, + _address: SuiAddress, + ) -> SuiRpcResult { + todo!() + } +} + +#[tokio::test] +async fn encrypt_and_decrypt_with_mock_server() -> Result<(), anyhow::Error> { + let key_server_object_id = ObjectID::from_str(KEY_SERVER_OBJECT_ID).unwrap(); + + let options = KeyServerOptions::new_open_server_with_default_values( + Network::Devnet, + key_server_object_id, + ); + let master_keys = MasterKeys::load(&options.server_mode, MASTER_KEY)?; + + let (server, _, _) = get_server::(options, master_keys) + .await + .unwrap(); + + let server_public_key_bytes = hex::decode(PUBLIC_KEY.strip_prefix("0x").unwrap()).unwrap(); + let server_public_key_g2 = + G2Element::from_byte_array(&server_public_key_bytes.try_into().unwrap()).unwrap(); + + let server_public_keys = IBEPublicKeys::BonehFranklinBLS12381(vec![server_public_key_g2]); + let package_id = ObjectID::from_hex_literal(FIRST_PACKAGE_ID).unwrap(); + let id = vec![1, 2, 3]; + let data: Vec = vec![0, 0, 0, 1]; // 1u16 + + let (encrypted_object, _) = seal_encrypt( + sui_sdk_types::ObjectId::from(package_id.into_bytes()), + id.clone(), + vec![sui_sdk_types::ObjectId::from( + key_server_object_id.into_bytes(), + )], + &server_public_keys, + 1, + EncryptionInput::Aes256Gcm { + data: data.clone(), + aad: None, + }, + ) + .unwrap(); + + let user_secret_key = Ed25519PrivateKey::from_bytes(&[ + 16, 38, 58, 130, 194, 133, 180, 117, 252, 32, 106, 49, 97, 22, 170, 130, 33, 59, 81, 63, + 132, 11, 246, 227, 58, 130, 18, 208, 130, 124, 49, 12, + ]) + .unwrap(); + let keypair = Ed25519KeyPair::from(user_secret_key); + let user = + SuiAddress::from_str("0xb743cafeb5da4914cef0cf0a32400c9adfedc5cdb64209f9e740e56d23065100") + .unwrap(); + + // Generate session key and encryption key + let (enc_secret, enc_key, enc_verification_key) = elgamal::genkey(&mut thread_rng()); + let session = Ed25519KeyPair::generate(&mut StdRng::from_seed([1; 32])); + + // Create certificate + let creation_time = chrono::Utc::now().timestamp_millis() as u64; + let ttl_min = 10; + let message = signed_message( + package_id.to_hex_uncompressed(), + session.public(), + creation_time, + ttl_min, + ); + let msg_with_intent = IntentMessage::new(Intent::personal_message(), message.clone()); + let signature = Signature::new_secure(&msg_with_intent, &keypair); + + let certificate = Certificate { + user, + session_vk: session.public().clone(), + creation_time, + ttl_min, + signature: GenericSignature::Signature(signature), + mvr_name: None, + }; + + let mut ptb_builder = ProgrammableTransactionBuilder::new(); + let id_arg = ptb_builder.pure(id).unwrap(); + + ptb_builder.programmable_move_call( + package_id, + "my_module".parse().unwrap(), + "seal_approve".parse().unwrap(), + vec![], + vec![id_arg], + ); + + let ptb = ptb_builder.finish(); + + let request_message = signed_request(&ptb, &enc_key, &enc_verification_key); + let request_signature = session.sign(&request_message); + + // Create the FetchKeyRequest + let request = FetchKeyRequest { + ptb: Base64::encode(bcs::to_bytes(&ptb).unwrap()), + enc_key, + enc_verification_key, + request_signature, + certificate, + }; + + let fetch_keys_response = fetch_key( + server, + &request, + ValidPtb::try_from(ptb).unwrap(), + None, + "1", + 1, + None, + ) + .await + .unwrap(); + + let sui_types_sdk_key_server_object_id = + sui_sdk_types::ObjectId::from(key_server_object_id.into_bytes()); + + let decrypted = seal_decrypt_all_objects( + &enc_secret, + &[(sui_types_sdk_key_server_object_id, fetch_keys_response)], + &[encrypted_object], + &[(sui_types_sdk_key_server_object_id, server_public_key_g2)] + .into_iter() + .collect(), + ) + .unwrap() + .into_iter() + .next() + .unwrap(); + + assert_eq!(decrypted, data); + + Ok(()) +} + +#[tokio::test] +async fn encrypt_and_decrypt_wrong_id_with_mock_server() -> Result<(), anyhow::Error> { + let key_server_object_id = ObjectID::from_str(KEY_SERVER_OBJECT_ID).unwrap(); + + let options = KeyServerOptions::new_open_server_with_default_values( + Network::Devnet, + key_server_object_id, + ); + let master_keys = MasterKeys::load(&options.server_mode, MASTER_KEY)?; + + let (server, _, _) = get_server::(options, master_keys) + .await + .unwrap(); + + let server_public_key_bytes = hex::decode(PUBLIC_KEY.strip_prefix("0x").unwrap()).unwrap(); + let server_public_key_g2 = + G2Element::from_byte_array(&server_public_key_bytes.try_into().unwrap()).unwrap(); + + let server_public_keys = IBEPublicKeys::BonehFranklinBLS12381(vec![server_public_key_g2]); + let package_id = ObjectID::from_hex_literal(FIRST_PACKAGE_ID).unwrap(); + let id = vec![1, 2, 3]; + let data: Vec = vec![0, 0, 0, 1]; // 1u16 + + let (encrypted_object, _) = seal_encrypt( + sui_sdk_types::ObjectId::from(package_id.into_bytes()), + id.clone(), + vec![sui_sdk_types::ObjectId::from( + key_server_object_id.into_bytes(), + )], + &server_public_keys, + 1, + EncryptionInput::Aes256Gcm { + data: data.clone(), + aad: None, + }, + ) + .unwrap(); + + let user_secret_key = Ed25519PrivateKey::from_bytes(&[ + 16, 38, 58, 130, 194, 133, 180, 117, 252, 32, 106, 49, 97, 22, 170, 130, 33, 59, 81, 63, + 132, 11, 246, 227, 58, 130, 18, 208, 130, 124, 49, 12, + ]) + .unwrap(); + let keypair = Ed25519KeyPair::from(user_secret_key); + let user = + SuiAddress::from_str("0xb743cafeb5da4914cef0cf0a32400c9adfedc5cdb64209f9e740e56d23065100") + .unwrap(); + + // Generate session key and encryption key + let (enc_secret, enc_key, enc_verification_key) = elgamal::genkey(&mut thread_rng()); + let session = Ed25519KeyPair::generate(&mut StdRng::from_seed([1; 32])); + + // Create certificate + let creation_time = chrono::Utc::now().timestamp_millis() as u64; + let ttl_min = 10; + let message = signed_message( + package_id.to_hex_uncompressed(), + session.public(), + creation_time, + ttl_min, + ); + let msg_with_intent = IntentMessage::new(Intent::personal_message(), message.clone()); + let signature = Signature::new_secure(&msg_with_intent, &keypair); + + let certificate = Certificate { + user, + session_vk: session.public().clone(), + creation_time, + ttl_min, + signature: GenericSignature::Signature(signature), + mvr_name: None, + }; + + let mut ptb_builder = ProgrammableTransactionBuilder::new(); + let id_arg = ptb_builder.pure(vec![0u8]).unwrap(); // This should make the later decrypt process to fail + + ptb_builder.programmable_move_call( + package_id, + "my_module".parse().unwrap(), + "seal_approve".parse().unwrap(), + vec![], + vec![id_arg], + ); + + let ptb = ptb_builder.finish(); + + let request_message = signed_request(&ptb, &enc_key, &enc_verification_key); + let request_signature = session.sign(&request_message); + + // Create the FetchKeyRequest + let request = FetchKeyRequest { + ptb: Base64::encode(bcs::to_bytes(&ptb).unwrap()), + enc_key, + enc_verification_key, + request_signature, + certificate, + }; + + let fetch_keys_response = fetch_key( + server, + &request, + ValidPtb::try_from(ptb).unwrap(), + None, + "1", + 1, + None, + ) + .await + .unwrap(); + + let sui_types_sdk_key_server_object_id = + sui_sdk_types::ObjectId::from(key_server_object_id.into_bytes()); + + let decrypted = seal_decrypt_all_objects( + &enc_secret, + &[(sui_types_sdk_key_server_object_id, fetch_keys_response)], + &[encrypted_object], + &[(sui_types_sdk_key_server_object_id, server_public_key_g2)] + .into_iter() + .collect(), + ); + + if decrypted.is_ok() { + bail!("Should not succeed") + } + + Ok(()) +} + +#[tokio::test] +async fn encrypt_and_decrypt_wrong_package_id_with_mock_server() -> Result<(), anyhow::Error> { + let key_server_object_id = ObjectID::from_str(KEY_SERVER_OBJECT_ID).unwrap(); + + let options = KeyServerOptions::new_open_server_with_default_values( + Network::Devnet, + key_server_object_id, + ); + let master_keys = MasterKeys::load(&options.server_mode, MASTER_KEY)?; + + let (server, _, _) = get_server::(options, master_keys) + .await + .unwrap(); + + let server_public_key_bytes = hex::decode(PUBLIC_KEY.strip_prefix("0x").unwrap()).unwrap(); + let server_public_key_g2 = + G2Element::from_byte_array(&server_public_key_bytes.try_into().unwrap()).unwrap(); + + let server_public_keys = IBEPublicKeys::BonehFranklinBLS12381(vec![server_public_key_g2]); + let package_id = ObjectID::from_hex_literal(FIRST_PACKAGE_ID).unwrap(); + let id = vec![1, 2, 3]; + let data: Vec = vec![0, 0, 0, 1]; // 1u16 + + let (encrypted_object, _) = seal_encrypt( + sui_sdk_types::ObjectId::from(package_id.into_bytes()), + id.clone(), + vec![sui_sdk_types::ObjectId::from( + key_server_object_id.into_bytes(), + )], + &server_public_keys, + 1, + EncryptionInput::Aes256Gcm { + data: data.clone(), + aad: None, + }, + ) + .unwrap(); + + let user_secret_key = Ed25519PrivateKey::from_bytes(&[ + 16, 38, 58, 130, 194, 133, 180, 117, 252, 32, 106, 49, 97, 22, 170, 130, 33, 59, 81, 63, + 132, 11, 246, 227, 58, 130, 18, 208, 130, 124, 49, 12, + ]) + .unwrap(); + let keypair = Ed25519KeyPair::from(user_secret_key); + let user = + SuiAddress::from_str("0xb743cafeb5da4914cef0cf0a32400c9adfedc5cdb64209f9e740e56d23065100") + .unwrap(); + + // Generate session key and encryption key + let (enc_secret, enc_key, enc_verification_key) = elgamal::genkey(&mut thread_rng()); + let session = Ed25519KeyPair::generate(&mut StdRng::from_seed([1; 32])); + + // Create certificate + let creation_time = chrono::Utc::now().timestamp_millis() as u64; + let ttl_min = 10; + + let wrong_package_id = ObjectID::from_hex_literal(SECOND_PACKAGE_ID).unwrap(); + + let message = signed_message( + wrong_package_id.to_hex_uncompressed(), + session.public(), + creation_time, + ttl_min, + ); + let msg_with_intent = IntentMessage::new(Intent::personal_message(), message.clone()); + let signature = Signature::new_secure(&msg_with_intent, &keypair); + + let certificate = Certificate { + user, + session_vk: session.public().clone(), + creation_time, + ttl_min, + signature: GenericSignature::Signature(signature), + mvr_name: None, + }; + + let mut ptb_builder = ProgrammableTransactionBuilder::new(); + let id_arg = ptb_builder.pure(id).unwrap(); + + ptb_builder.programmable_move_call( + wrong_package_id, // This should make the later decrypt process to fail + "my_module".parse().unwrap(), + "seal_approve".parse().unwrap(), + vec![], + vec![id_arg], + ); + + let ptb = ptb_builder.finish(); + + let request_message = signed_request(&ptb, &enc_key, &enc_verification_key); + let request_signature = session.sign(&request_message); + + // Create the FetchKeyRequest + let request = FetchKeyRequest { + ptb: Base64::encode(bcs::to_bytes(&ptb).unwrap()), + enc_key, + enc_verification_key, + request_signature, + certificate, + }; + + let fetch_keys_response = fetch_key( + server, + &request, + ValidPtb::try_from(ptb).unwrap(), + None, + "1", + 1, + None, + ) + .await + .unwrap(); + + let sui_types_sdk_key_server_object_id = + sui_sdk_types::ObjectId::from(key_server_object_id.into_bytes()); + + let decrypted = seal_decrypt_all_objects( + &enc_secret, + &[(sui_types_sdk_key_server_object_id, fetch_keys_response)], + &[encrypted_object], + &[(sui_types_sdk_key_server_object_id, server_public_key_g2)] + .into_iter() + .collect(), + ); + + if decrypted.is_ok() { + bail!("Should not succeed") + } + + Ok(()) +} + +#[tokio::test] +async fn encrypt_and_decrypt_invalid_signature_with_mock_server() -> Result<(), anyhow::Error> { + let key_server_object_id = ObjectID::from_str(KEY_SERVER_OBJECT_ID).unwrap(); + + let options = KeyServerOptions::new_open_server_with_default_values( + Network::Devnet, + key_server_object_id, + ); + let master_keys = MasterKeys::load(&options.server_mode, MASTER_KEY)?; + + let (server, _, _) = get_server::(options, master_keys) + .await + .unwrap(); + + let package_id = ObjectID::from_hex_literal(FIRST_PACKAGE_ID).unwrap(); + let id: Vec = vec![1, 2, 3]; + + let user_secret_key = Ed25519PrivateKey::from_bytes(&[ + 16, 38, 58, 130, 194, 133, 180, 117, 252, 32, 106, 49, 97, 22, 170, 130, 33, 59, 81, 63, + 132, 11, 246, 227, 58, 130, 18, 208, 130, 124, 49, 12, + ]) + .unwrap(); + let keypair = Ed25519KeyPair::from(user_secret_key); + let user = + SuiAddress::from_str("0xb743cafeb5da4914cef0cf0a32400c9adfedc5cdb64209f9e740e56d23065100") + .unwrap(); + + let (_, enc_key, enc_verification_key) = elgamal::genkey(&mut thread_rng()); + let session = Ed25519KeyPair::generate(&mut StdRng::from_seed([1; 32])); + + // Create certificate + let creation_time = chrono::Utc::now().timestamp_millis() as u64; + let ttl_min = 10; + let message = "This is an invalid message, causing an invalid signature!"; + let msg_with_intent = IntentMessage::new(Intent::personal_message(), message); + let signature = Signature::new_secure(&msg_with_intent, &keypair); + + let certificate = Certificate { + user, + session_vk: session.public().clone(), + creation_time, + ttl_min, + signature: GenericSignature::Signature(signature), + mvr_name: None, + }; + + let mut ptb_builder = ProgrammableTransactionBuilder::new(); + let id_arg = ptb_builder.pure(id).unwrap(); + + ptb_builder.programmable_move_call( + package_id, + "my_module".parse().unwrap(), + "seal_approve".parse().unwrap(), + vec![], + vec![id_arg], + ); + + let ptb = ptb_builder.finish(); + + let request_message = signed_request(&ptb, &enc_key, &enc_verification_key); + let request_signature = session.sign(&request_message); + + let request = FetchKeyRequest { + ptb: Base64::encode(bcs::to_bytes(&ptb).unwrap()), + enc_key, + enc_verification_key, + request_signature, + certificate, + }; + + let fetch_keys_response_result = fetch_key( + server, + &request, + ValidPtb::try_from(ptb).unwrap(), + None, + "1", + 1, + None, + ) + .await; + + match fetch_keys_response_result { + Ok(_) => bail!("Should not succeed"), + Err(InternalError::InvalidSignature) => {} + Err(error) => bail!("Invalid error: {:?}", error), + } + + Ok(()) +}