From d407c3a538ee9a56d52c482c83218d0042ac7fb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Augusto=20C=C3=A9sar?= Date: Mon, 24 Mar 2025 15:51:28 +0100 Subject: [PATCH 1/5] feat!: remove certificate management from worker --- linkup-cli/src/commands/deploy/resources.rs | 8 - worker/src/lib.rs | 4 - worker/src/libdns.rs | 125 ---------- worker/src/routes/certificate_cache.rs | 158 ------------- worker/src/routes/certificate_cache_lock.rs | 198 ---------------- worker/src/routes/certificate_dns.rs | 239 -------------------- worker/src/routes/mod.rs | 4 - worker/wrangler.toml.sample | 7 - 8 files changed, 743 deletions(-) delete mode 100644 worker/src/libdns.rs delete mode 100644 worker/src/routes/certificate_cache.rs delete mode 100644 worker/src/routes/certificate_cache_lock.rs delete mode 100644 worker/src/routes/certificate_dns.rs delete mode 100644 worker/src/routes/mod.rs diff --git a/linkup-cli/src/commands/deploy/resources.rs b/linkup-cli/src/commands/deploy/resources.rs index adca0ff2..88369c63 100644 --- a/linkup-cli/src/commands/deploy/resources.rs +++ b/linkup-cli/src/commands/deploy/resources.rs @@ -1104,10 +1104,6 @@ pub fn cf_resources( name: "CLOUDLFLARE_ALL_ZONE_IDS".to_string(), text: all_zone_ids.join(","), }, - cloudflare::endpoints::workers::WorkersBinding::DurableObjectNamespace { - name: "CERTIFICATE_LOCKS".to_string(), - class_name: "CertificateStoreLock".to_string(), - }, ], worker_script_schedules: vec![cloudflare::endpoints::workers::WorkersSchedule { cron: Some("0 12 * * 2-6".to_string()), @@ -1122,10 +1118,6 @@ pub fn cf_resources( name: format!("linkup-tunnels-kv-{joined_zone_names}"), binding: "LINKUP_TUNNELS".to_string(), }, - KvNamespace { - name: format!("linkup-certificate-cache-kv-{joined_zone_names}"), - binding: "LINKUP_CERTIFICATE_CACHE".to_string(), - }, ], tunnel_zone_cache_rules: TargetCacheRules { name: "default".to_string(), diff --git a/worker/src/lib.rs b/worker/src/lib.rs index b59edfdb..80b5d3ba 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -23,8 +23,6 @@ use ws::handle_ws_resp; mod http_error; mod kv_store; -mod libdns; -mod routes; mod tunnel; mod ws; @@ -99,8 +97,6 @@ pub fn linkup_router(state: LinkupState) -> Router { .route("/linkup/check", get(always_ok)) .route("/linkup/no-tunnel", get(no_tunnel)) .route("/linkup", any(deprecated_linkup_session_handler)) - .merge(routes::certificate_dns::router()) - .merge(routes::certificate_cache::router()) .route_layer(from_fn_with_state(state.clone(), authenticate)) // Fallback for all other requests .fallback(any(linkup_request_handler)) diff --git a/worker/src/libdns.rs b/worker/src/libdns.rs deleted file mode 100644 index 2de36bd9..00000000 --- a/worker/src/libdns.rs +++ /dev/null @@ -1,125 +0,0 @@ -// This module behaves as the glue between [libdns](https://github.com/libdns/libdns) (used by Caddy) and Cloudflare -// TODO: Maybe upstream this as a provider under a repo like https://github.com/lus/libdns-rs - -use cloudflare::endpoints::dns::{ - CreateDnsRecordParams as CfCreateDnsRecordParams, DnsContent as CfDnsContent, - DnsRecord as CfDnsRecord, PatchDnsRecordParams as CfPatchDnsRecordParams, -}; -use serde::{Deserialize, Serialize}; - -/// This represents the record that is used in Caddy for working with libdns. -/// -/// Reference: https://github.com/libdns/libdns/blob/8b75c024f21e77c1ee32273ad24c579d1379b2b0/libdns.go#L114-L127 -#[derive(Debug, Serialize, Deserialize)] -#[cfg_attr(test, derive(Default))] -pub struct LibDnsRecord { - #[serde(rename = "ID")] - pub id: String, - #[serde(rename = "Type")] - pub record_type: String, - #[serde(rename = "Name")] - pub name: String, - #[serde(rename = "Value")] - pub value: String, - #[serde(rename = "TTL")] - pub ttl: u32, - #[serde(rename = "Priority")] - pub priority: u16, - #[serde(rename = "Weight")] - pub weight: u32, -} - -impl From for LibDnsRecord { - fn from(value: CfDnsRecord) -> Self { - let (ty, content_value, priority) = match value.content { - CfDnsContent::A { content } => ("A", content.to_string(), None), - CfDnsContent::AAAA { content } => ("AAAA", content.to_string(), None), - CfDnsContent::CNAME { content } => ("CNAME", content, None), - CfDnsContent::NS { content } => ("NS", content, None), - CfDnsContent::MX { content, priority } => ("MX", content, Some(priority)), - CfDnsContent::TXT { content } => ("TXT", content, None), - CfDnsContent::SRV { content } => ("SRV", content, None), - }; - - Self { - id: value.id, - record_type: ty.to_string(), - name: value.name, - value: content_value, - ttl: value.ttl, - priority: priority.unwrap_or(0), - weight: 0, - } - } -} - -impl<'a> From<&'a LibDnsRecord> for CfCreateDnsRecordParams<'a> { - fn from(val: &'a LibDnsRecord) -> Self { - CfCreateDnsRecordParams { - ttl: Some(val.ttl), - priority: Some(val.priority), - proxied: Some(false), - name: &val.name, - content: match val.record_type.as_str() { - "A" => cloudflare::endpoints::dns::DnsContent::A { - content: val.value.parse().unwrap(), - }, - "AAAA" => cloudflare::endpoints::dns::DnsContent::AAAA { - content: val.value.parse().unwrap(), - }, - "CNAME" => cloudflare::endpoints::dns::DnsContent::CNAME { - content: val.value.clone(), - }, - "NS" => cloudflare::endpoints::dns::DnsContent::NS { - content: val.value.clone(), - }, - "MX" => cloudflare::endpoints::dns::DnsContent::MX { - content: val.value.clone(), - priority: val.priority, - }, - "TXT" => cloudflare::endpoints::dns::DnsContent::TXT { - content: val.value.clone(), - }, - "SRV" => cloudflare::endpoints::dns::DnsContent::SRV { - content: val.value.clone(), - }, - _ => unreachable!(), - }, - } - } -} - -impl<'a> From<&'a LibDnsRecord> for CfPatchDnsRecordParams<'a> { - fn from(val: &'a LibDnsRecord) -> Self { - CfPatchDnsRecordParams { - ttl: Some(val.ttl), - proxied: Some(false), - name: &val.name, - content: match val.record_type.as_str() { - "A" => cloudflare::endpoints::dns::DnsContent::A { - content: val.value.parse().unwrap(), - }, - "AAAA" => cloudflare::endpoints::dns::DnsContent::AAAA { - content: val.value.parse().unwrap(), - }, - "CNAME" => cloudflare::endpoints::dns::DnsContent::CNAME { - content: val.value.clone(), - }, - "NS" => cloudflare::endpoints::dns::DnsContent::NS { - content: val.value.clone(), - }, - "MX" => cloudflare::endpoints::dns::DnsContent::MX { - content: val.value.clone(), - priority: val.priority, - }, - "TXT" => cloudflare::endpoints::dns::DnsContent::TXT { - content: val.value.clone(), - }, - "SRV" => cloudflare::endpoints::dns::DnsContent::SRV { - content: val.value.clone(), - }, - _ => unreachable!(), - }, - } - } -} diff --git a/worker/src/routes/certificate_cache.rs b/worker/src/routes/certificate_cache.rs deleted file mode 100644 index 139d0044..00000000 --- a/worker/src/routes/certificate_cache.rs +++ /dev/null @@ -1,158 +0,0 @@ -// TODO(augustoccesar)[2025-02-14]: Handle errors instead of using .unwrap() - -use axum::{ - extract::{self, Query, State}, - response::IntoResponse, - routing::{get, put}, - Json, Router, -}; -use base64::prelude::*; -use http::StatusCode; -use serde::{Deserialize, Serialize}; -use worker::console_log; - -use crate::LinkupState; - -use super::certificate_cache_lock; - -pub fn router() -> Router { - Router::new() - .merge(certificate_cache_lock::router()) - .route( - "/linkup/certificate-cache/keys", - get(list_certificate_cache_keys_handler), - ) - .route( - "/linkup/certificate-cache/{key}", - put(upsert_certificate_cache_handler) - .get(get_certificate_cache_handler) - .delete(delete_certificate_cache_handler), - ) -} - -#[derive(Debug, Deserialize)] -struct ListCertificateCacheQuery { - path: String, - recursive: bool, -} - -#[worker::send] -async fn list_certificate_cache_keys_handler( - State(state): State, - Query(query): Query, -) -> impl IntoResponse { - // TODO(augustoccesar)[2025-02-17]: Add pagination here. We should be fine with 1000 for now, but might be a problem in the future. - let keys = state - .certs_kv - .list() - .prefix(query.path) - .limit(1000) - .execute() - .await - .unwrap() - .keys - .iter() - .map(|key| key.name.clone()) - .collect::>(); - - if query.recursive { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Recursive listing is not supported.", - ) - .into_response(); - } - - Json(keys).into_response() -} - -#[derive(Debug, Serialize)] -struct CertificateCacheResponse { - data_base64: String, - size: usize, - last_modified: u64, -} - -#[derive(Serialize, Deserialize)] -struct CertificateMetadata { - last_modified: u64, -} - -#[worker::send] -async fn get_certificate_cache_handler( - State(state): State, - extract::Path(key): extract::Path, -) -> impl IntoResponse { - let (data, metadata) = state - .certs_kv - .get(&key) - .bytes_with_metadata::() - .await - .unwrap(); - - match data { - Some(data) => { - let data_base64 = BASE64_STANDARD.encode(&data); - let last_modified = metadata.map_or_else( - || worker::Date::now().as_millis(), - |m| { - serde_json::from_str::<'_, CertificateMetadata>(&m) - .unwrap() - .last_modified - }, - ); - - Json(CertificateCacheResponse { - data_base64, - size: data.len(), - last_modified, - }) - .into_response() - } - None => StatusCode::NOT_FOUND.into_response(), - } -} - -#[derive(Debug, Deserialize)] -struct UpsertCertificateCachePayload { - data_base64: String, -} - -#[worker::send] -async fn upsert_certificate_cache_handler( - State(state): State, - extract::Path(key): extract::Path, - Json(payload): Json, -) -> impl IntoResponse { - let data = BASE64_STANDARD.decode(&payload.data_base64).unwrap(); - let metadata = CertificateMetadata { - last_modified: worker::Date::now().as_millis(), - }; - - let req = state - .certs_kv - .put_bytes(&key, &data) - .unwrap() - .metadata(serde_json::to_string(&metadata).unwrap()) - .unwrap(); - - console_log!("Payload: {}", serde_json::to_string(&req).unwrap()); - - req.execute().await.unwrap(); - - Json(CertificateCacheResponse { - data_base64: payload.data_base64, - size: data.len(), - last_modified: metadata.last_modified, - }) -} - -#[worker::send] -async fn delete_certificate_cache_handler( - State(state): State, - extract::Path(key): extract::Path, -) -> impl IntoResponse { - state.certs_kv.delete(&key).await.unwrap(); - - StatusCode::NO_CONTENT.into_response() -} diff --git a/worker/src/routes/certificate_cache_lock.rs b/worker/src/routes/certificate_cache_lock.rs deleted file mode 100644 index 6eab9765..00000000 --- a/worker/src/routes/certificate_cache_lock.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::time::Duration; - -use axum::{ - extract::{self, State}, - response::IntoResponse, - routing::{get, put}, - Router, -}; -use http::StatusCode; -use worker::{console_error, console_log, durable_object, Env}; - -use crate::LinkupState; - -pub fn router() -> Router { - Router::new() - .route( - "/linkup/certificate-cache/locks/{key}", - get(get_lock_handler).delete(delete_lock_handler), - ) - .route( - "/linkup/certificate-cache/locks/{key}/touch", - put(touch_lock_handler), - ) -} - -#[worker::send] -async fn get_lock_handler( - State(state): State, - extract::Path(key): extract::Path, -) -> impl IntoResponse { - console_log!("Trying to acquire lock on key: {}", key); - - match fetch_lock(&state.env, &key).await { - Ok(res) if res.status_code() == 200 => (StatusCode::OK).into_response(), - Ok(res) if res.status_code() == 423 => (StatusCode::LOCKED).into_response(), - Ok(res) => { - console_error!( - "Durable object responded with unsupported HTTP status while ACQUIRING lock with key '{}': {}", - &key, - res.status_code() - ); - - (StatusCode::INTERNAL_SERVER_ERROR).into_response() - } - Err(error) => { - console_error!("Error ACQUIRING lock with key '{}': {}", &key, error); - - (StatusCode::INTERNAL_SERVER_ERROR).into_response() - } - } -} - -#[worker::send] -async fn touch_lock_handler( - State(state): State, - extract::Path(key): extract::Path, -) -> impl IntoResponse { - console_log!("Touching lock with key: {}", key); - - match fetch_touch(&state.env, &key).await { - Ok(res) if res.status_code() == 200 => (StatusCode::OK).into_response(), - Ok(res) => { - console_error!( - "Durable object responded with unsupported HTTP status while TOUCHING lock with key '{}': {}", - &key, - res.status_code() - ); - - (StatusCode::INTERNAL_SERVER_ERROR).into_response() - } - Err(error) => { - console_error!("Error TOUCHING lock with key '{}': {}", &key, error); - - (StatusCode::INTERNAL_SERVER_ERROR).into_response() - } - } -} - -#[worker::send] -async fn delete_lock_handler( - State(state): State, - extract::Path(key): extract::Path, -) -> impl IntoResponse { - console_log!("Unlocking lock with key: {}", key); - - match fetch_unlock(&state.env, &key).await { - Ok(res) if res.status_code() == 200 => (StatusCode::OK).into_response(), - Ok(res) => { - console_error!( - "Durable object responded with unsupported HTTP status while UNLOCKING lock with key '{}': {}", - &key, - res.status_code() - ); - - (StatusCode::INTERNAL_SERVER_ERROR).into_response() - } - Err(error) => { - console_error!("Error UNLOCKING lock with key '{}': {}", &key, error); - - (StatusCode::INTERNAL_SERVER_ERROR).into_response() - } - } -} - -fn get_stub(env: &Env, key: &str) -> worker::Result { - let namespace = env.durable_object("CERTIFICATE_LOCKS")?; - namespace.id_from_name(key)?.get_stub() -} - -async fn fetch_lock(env: &Env, key: &str) -> worker::Result { - let stub = get_stub(env, key)?; - stub.fetch_with_str("http://fake_url.com/lock").await -} - -async fn fetch_touch(env: &Env, key: &str) -> worker::Result { - let stub = get_stub(env, key)?; - stub.fetch_with_str("http://fake_url.com/touch").await -} - -async fn fetch_unlock(env: &Env, key: &str) -> worker::Result { - let stub = get_stub(env, key)?; - stub.fetch_with_str("http://fake_url.com/unlock").await -} - -#[durable_object] -pub struct CertificateStoreLock { - state: worker::State, - locked: bool, - last_touched: worker::Date, -} - -impl CertificateStoreLock { - pub async fn lock(&mut self) -> worker::Result { - if self.locked { - Ok(worker::Response::builder().with_status(423).empty()) - } else { - self.state - .storage() - .set_alarm(Duration::from_secs(3)) - .await?; - - self.locked = true; - self.last_touched = worker::Date::now(); - - worker::Response::empty() - } - } - - pub async fn touch(&mut self) -> worker::Result { - self.last_touched = worker::Date::now(); - - worker::Response::empty() - } - - pub async fn unlock(&mut self) -> worker::Result { - self.locked = false; - - if let Err(error) = self.state.storage().delete_alarm().await { - console_log!("Error deleting alarm on unlock: {}", error); - } - - worker::Response::empty() - } -} - -#[durable_object] -impl DurableObject for CertificateStoreLock { - fn new(state: worker::State, _env: Env) -> Self { - Self { - state, - locked: false, - last_touched: worker::Date::now(), - } - } - - async fn fetch(&mut self, req: worker::Request) -> worker::Result { - match req.path().as_str() { - "/lock" => self.lock().await, - "/touch" => self.touch().await, - "/unlock" => self.unlock().await, - _ => Ok(worker::Response::builder().with_status(404).empty()), - } - } - - async fn alarm(&mut self) -> worker::Result { - if worker::Date::now().as_millis() - self.last_touched.as_millis() > 5000 { - self.locked = false; - } else if let Err(error) = self.state.storage().set_alarm(Duration::from_secs(3)).await { - console_log!("Error setting alarm: {}", error); - - // NOTE(augustoccesar)[2025-02-25]: If we fail to set the next alarm, instantly unlock the - // lock to avoid ending on a deadlock. - self.locked = false; - } - - worker::Response::empty() - } -} diff --git a/worker/src/routes/certificate_dns.rs b/worker/src/routes/certificate_dns.rs deleted file mode 100644 index 4d5bbacc..00000000 --- a/worker/src/routes/certificate_dns.rs +++ /dev/null @@ -1,239 +0,0 @@ -use axum::{ - extract::{Query, State}, - response::IntoResponse, - routing::get, - Json, Router, -}; -use serde::Deserialize; - -use crate::{cloudflare_client, libdns::LibDnsRecord, LinkupState}; - -pub fn router() -> Router { - Router::new().route( - "/linkup/certificate-dns", - get(get_certificate_dns_handler) - .post(create_certificate_dns_handler) - .put(update_certificate_dns_handler) - .delete(delete_certificate_dns_handler), - ) -} - -#[derive(Deserialize)] -struct GetCertificateDns { - zone: String, -} - -#[worker::send] -async fn get_certificate_dns_handler( - State(state): State, - Query(query): Query, -) -> impl IntoResponse { - let client = cloudflare_client(&state.cloudflare.api_token); - - let zone = get_zone(&client, &query.zone).await; - - let req = cloudflare::endpoints::dns::ListDnsRecords { - zone_identifier: &zone.id, - params: cloudflare::endpoints::dns::ListDnsRecordsParams::default(), - }; - - let records = client.request(&req).await.unwrap().result; - let mut libdns_records: Vec = Vec::with_capacity(records.len()); - for record in records { - libdns_records.push(record.into()); - } - - format_records_names(&mut libdns_records, &zone.name); - - Json(libdns_records) -} - -#[derive(Debug, Deserialize)] -struct CreateDnsRecords { - zone: String, - records: Vec, -} - -#[worker::send] -async fn create_certificate_dns_handler( - State(state): State, - Json(payload): Json, -) -> impl IntoResponse { - let client = cloudflare_client(&state.cloudflare.api_token); - - let zone = get_zone(&client, &payload.zone).await; - - let mut records: Vec = Vec::with_capacity(payload.records.len()); - - for record in payload.records { - let create_record = cloudflare::endpoints::dns::CreateDnsRecord { - zone_identifier: &zone.id, - params: (&record).into(), - }; - - let response = client.request(&create_record).await.unwrap().result; - - records.push(response.into()); - } - - format_records_names(&mut records, &zone.name); - - Json(records) -} - -#[derive(Debug, Deserialize)] -struct UpdateDnsRecords { - zone: String, - records: Vec, -} - -#[worker::send] -async fn update_certificate_dns_handler( - State(state): State, - Json(payload): Json, -) -> impl IntoResponse { - let client = cloudflare_client(&state.cloudflare.api_token); - - let zone = get_zone(&client, &payload.zone).await; - - let mut updated_records: Vec = Vec::with_capacity(payload.records.len()); - for record in payload.records { - if record.id.is_empty() { - // TODO: Check if we need to implement this for our use case. - unimplemented!("Needs to implement lookup DNS by name and type"); - } - - let req = cloudflare::endpoints::dns::PatchDnsRecord { - zone_identifier: &zone.id, - identifier: &record.id, - params: (&record).into(), - }; - - let res = client.request(&req).await.unwrap().result; - updated_records.push(res.into()); - } - - format_records_names(&mut updated_records, &zone.name); - - Json(updated_records) -} - -#[derive(Debug, Deserialize)] -struct DeleteDnsRecords { - zone: String, - records: Vec, -} - -#[worker::send] -async fn delete_certificate_dns_handler( - State(state): State, - Json(payload): Json, -) -> impl IntoResponse { - let client = cloudflare_client(&state.cloudflare.api_token); - - let zone = get_zone(&client, &payload.zone).await; - - let mut deleted_records = Vec::with_capacity(payload.records.len()); - for record in payload.records { - if record.id.is_empty() { - // TODO: Check if we need to implement this for our use case. - unimplemented!("Needs to implement lookup DNS by name and type"); - } - - let req = cloudflare::endpoints::dns::DeleteDnsRecord { - zone_identifier: &zone.id, - identifier: &record.id, - }; - - client.request(&req).await.unwrap(); - - deleted_records.push(record); - } - - format_records_names(&mut deleted_records, &zone.name); - - Json(deleted_records) -} - -async fn get_zone( - client: &cloudflare::framework::async_api::Client, - zone: &str, -) -> cloudflare::endpoints::zone::Zone { - let req = cloudflare::endpoints::zone::ListZones { - params: cloudflare::endpoints::zone::ListZonesParams { - name: Some(zone.to_string()), - ..Default::default() - }, - }; - - let mut res = client.request(&req).await.unwrap().result; - if res.is_empty() { - panic!("Zone not found"); - } - - if res.len() > 1 { - panic!("Found more than one zone for name"); - } - - res.pop().unwrap() -} - -fn format_records_names(records: &mut [LibDnsRecord], zone: &str) { - for record in records.iter_mut() { - record.name = name_relative_to_zone(&record.name, zone); - } -} - -fn name_relative_to_zone(fqdm: &str, zone: &str) -> String { - let trimmed_fqdm = fqdm.trim_end_matches('.'); - let trimmed_zone = zone.trim_end_matches('.'); - - let fqdm_relative_to_zone = trimmed_fqdm.replace(trimmed_zone, ""); - - fqdm_relative_to_zone.trim_end_matches('.').to_string() -} - -#[cfg(test)] -mod test { - use crate::{ - libdns::LibDnsRecord, - routes::certificate_dns::{format_records_names, name_relative_to_zone}, - }; - - #[test] - fn test_name_relative_to_zone() { - let fqdm = "api.mentimeter.com."; - let zone = "mentimeter.com."; - - assert_eq!("api", name_relative_to_zone(fqdm, zone)); - } - - #[test] - fn test_name_relative_to_zone_subdomain() { - let fqdm = "v2.api.mentimeter.com."; - let zone = "mentimeter.com."; - - assert_eq!("v2.api", name_relative_to_zone(fqdm, zone)); - } - - #[test] - fn test_name_relative_to_zone_not_matching_zone() { - let fqdm = "api.mentimeter.com."; - let zone = "menti.meter."; - - assert_eq!("api.mentimeter.com", name_relative_to_zone(fqdm, zone)); - } - - #[test] - fn test_format_records_names() { - let mut records = vec![LibDnsRecord { - name: "api.mentimeter.com".to_string(), - ..Default::default() - }]; - let zone = "mentimeter.com"; - - format_records_names(&mut records, zone); - - assert_eq!(records[0].name, "api") - } -} diff --git a/worker/src/routes/mod.rs b/worker/src/routes/mod.rs deleted file mode 100644 index 6b6755ed..00000000 --- a/worker/src/routes/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod certificate_cache; -pub mod certificate_dns; - -mod certificate_cache_lock; diff --git a/worker/wrangler.toml.sample b/worker/wrangler.toml.sample index d3a08639..9f140db0 100644 --- a/worker/wrangler.toml.sample +++ b/worker/wrangler.toml.sample @@ -18,12 +18,5 @@ WORKER_TOKEN = "token123" [triggers] crons = [ "0 12 * * 2-6" ] -[durable_objects] -bindings = [{ name = "CERTIFICATE_LOCKS", class_name = "CertificateStoreLock" }] - -[[migrations]] -tag = "v1" -new_classes = ["CertificateStoreLock"] - [build] command = "cargo install -q worker-build && worker-build --release" From 384f4aa56c394d029ddf38ffd1e067a4261e71ff Mon Sep 17 00:00:00 2001 From: Oliver Stenbom Date: Tue, 29 Apr 2025 14:12:11 +0200 Subject: [PATCH 2/5] Fix: remove last cert and print worker token --- linkup-cli/src/commands/deploy/resources.rs | 17 +++++++++++++++++ worker/src/lib.rs | 3 --- worker/wrangler.toml.sample | 1 - 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/linkup-cli/src/commands/deploy/resources.rs b/linkup-cli/src/commands/deploy/resources.rs index 88369c63..813be330 100644 --- a/linkup-cli/src/commands/deploy/resources.rs +++ b/linkup-cli/src/commands/deploy/resources.rs @@ -665,6 +665,23 @@ impl TargetCfResources { } } + let worker_token = final_metadata + .bindings + .iter() + .find(|b| matches!(b, cloudflare::endpoints::workers::WorkersBinding::PlainText { name, .. } if name == "WORKER_TOKEN")); + + if let Some(cloudflare::endpoints::workers::WorkersBinding::PlainText { + text, .. + }) = worker_token + { + println!("@@@@@@@"); + println!( + "The worker_token to add to your linkup config is: {:?}", + text + ); + println!("@@@@@@@"); + } + notifier.notify("Uploading worker script..."); api.create_worker_script(script_name.clone(), final_metadata, parts.clone()) .await?; diff --git a/worker/src/lib.rs b/worker/src/lib.rs index 80b5d3ba..f959a3b3 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -44,7 +44,6 @@ pub struct LinkupState { pub min_supported_client_version: Version, pub sessions_kv: KvStore, pub tunnels_kv: KvStore, - pub certs_kv: KvStore, pub cloudflare: CloudflareEnvironemnt, pub env: Env, } @@ -58,7 +57,6 @@ impl TryFrom for LinkupState { let sessions_kv = value.kv("LINKUP_SESSIONS")?; let tunnels_kv = value.kv("LINKUP_TUNNELS")?; - let certs_kv = value.kv("LINKUP_CERTIFICATE_CACHE")?; let cf_account_id = value.var("CLOUDFLARE_ACCOUNT_ID")?; let cf_tunnel_zone_id = value.var("CLOUDFLARE_TUNNEL_ZONE_ID")?; let cf_all_zone_ids: Vec = value @@ -74,7 +72,6 @@ impl TryFrom for LinkupState { min_supported_client_version, sessions_kv, tunnels_kv, - certs_kv, cloudflare: CloudflareEnvironemnt { account_id: cf_account_id.to_string(), tunnel_zone_id: cf_tunnel_zone_id.to_string(), diff --git a/worker/wrangler.toml.sample b/worker/wrangler.toml.sample index 9f140db0..426d63f4 100644 --- a/worker/wrangler.toml.sample +++ b/worker/wrangler.toml.sample @@ -5,7 +5,6 @@ compatibility_date = "2024-05-30" kv_namespaces = [ { binding = "LINKUP_SESSIONS", id = "xxx_sessions" }, { binding = "LINKUP_TUNNELS", id = "xxx_tunnels" }, - { binding = "LINKUP_CERTIFICATE_CACHE", id = "xxx_cache" }, ] [vars] From d471a81b19688507f8c4687168735121e18fe6f3 Mon Sep 17 00:00:00 2001 From: Oliver Stenbom Date: Tue, 29 Apr 2025 14:28:18 +0200 Subject: [PATCH 3/5] Delete one tunnel dns at a time --- linkup-cli/src/commands/deploy/resources.rs | 27 ++++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/linkup-cli/src/commands/deploy/resources.rs b/linkup-cli/src/commands/deploy/resources.rs index 813be330..33bf33a7 100644 --- a/linkup-cli/src/commands/deploy/resources.rs +++ b/linkup-cli/src/commands/deploy/resources.rs @@ -1002,19 +1002,22 @@ impl TargetCfResources { let dns_records_to_delete: Vec = dns_records.iter().map(|record| record.id.clone()).collect(); - let batch_delete_dns_req = cloudflare::endpoints::dns::BatchDnsRecords { - zone_identifier: &self.tunnel_zone_id, - params: cloudflare::endpoints::dns::BatchDnsRecordsParams { - deletes: Some(dns_records_to_delete), - }, - }; + for record in dns_records_to_delete { + let delete_req = cloudflare::endpoints::dns::DeleteDnsRecord { + zone_identifier: &self.tunnel_zone_id, + identifier: &record, + }; - match cloudflare_client.request(&batch_delete_dns_req).await { - Ok(_) => { - notifier.notify("DNS records deleted"); - } - Err(error) => { - notifier.notify(&format!("Failed to delete DNS records: {}", error)); + match cloudflare_client.request(&delete_req).await { + Ok(_) => { + notifier.notify(&format!("DNS record '{}' deleted", record)); + } + Err(error) => { + notifier.notify(&format!( + "Failed to delete DNS record '{}': {}", + record, error + )); + } } } } From e70fdfe737544b7590563402875ab983cc5912e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Augusto=20C=C3=A9sar?= Date: Tue, 13 May 2025 14:21:05 +0200 Subject: [PATCH 4/5] refactor: restructure websocket implementation (#235) Use [tungstenite](https://docs.rs/tungstenite/latest/tungstenite/) for handling websocket on local-server. With this implementation we should have a more stable handling of the events and specially of the closing. --- Cargo.lock | 1 + local-server/Cargo.toml | 3 +- local-server/src/lib.rs | 179 +++++++++++----------------------- local-server/src/ws.rs | 158 ++++++++++++++++++++++++++++++ server-tests/tests/ws_test.rs | 41 ++++++-- worker/src/lib.rs | 18 ++-- worker/src/ws.rs | 76 ++++++++++----- 7 files changed, 311 insertions(+), 165 deletions(-) create mode 100644 local-server/src/ws.rs diff --git a/Cargo.lock b/Cargo.lock index ea3f0137..ee01bb98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1777,6 +1777,7 @@ dependencies = [ "rustls-pemfile", "thiserror 2.0.11", "tokio", + "tokio-tungstenite", "tower 0.5.2", "tower-http", ] diff --git a/local-server/Cargo.toml b/local-server/Cargo.toml index a2890748..8c57691c 100644 --- a/local-server/Cargo.toml +++ b/local-server/Cargo.toml @@ -8,7 +8,7 @@ name = "linkup_local_server" path = "src/lib.rs" [dependencies] -axum = { version = "0.8.1", features = ["http2", "json"] } +axum = { version = "0.8.1", features = ["http2", "json", "ws"] } axum-server = { version = "0.7", features = ["tls-rustls"] } http = "1.2.0" hickory-server = { version = "0.25.1", features = ["resolver"] } @@ -28,6 +28,7 @@ tokio = { version = "1.43.0", features = [ "signal", "rt-multi-thread", ] } +tokio-tungstenite = "0.26.1" tower-http = { version = "0.6.2", features = ["trace"] } tower = "0.5.2" rcgen = { version = "0.13", features = ["x509-parser"] } diff --git a/local-server/src/lib.rs b/local-server/src/lib.rs index af483cbd..d3dee5aa 100644 --- a/local-server/src/lib.rs +++ b/local-server/src/lib.rs @@ -27,7 +27,7 @@ use http::{header::HeaderMap, Uri}; use hyper_rustls::HttpsConnector; use hyper_util::{ client::legacy::{connect::HttpConnector, Client}, - rt::{TokioExecutor, TokioIo}, + rt::TokioExecutor, }; use linkup::{ allow_all_cors, get_additional_headers, get_target_service, MemoryStringStore, NameKind, @@ -40,10 +40,12 @@ use std::{ }; use std::{path::Path, sync::Arc}; use tokio::{net::UdpSocket, signal}; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tower::ServiceBuilder; use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}; pub mod certificates; +mod ws; type HttpsClient = Client, Body>; @@ -182,6 +184,7 @@ pub async fn start_dns_server(linkup_session_name: String, domains: Vec) async fn linkup_request_handler( Extension(store): Extension, Extension(client): Extension, + ws: ws::ExtractOptionalWebSocketUpgrade, req: Request, ) -> Response { let sessions = SessionAllocator::new(&store); @@ -224,15 +227,58 @@ async fn linkup_request_handler( let extra_headers = get_additional_headers(&url, &headers, &session_name, &target_service); - if req - .headers() - .get("upgrade") - .map(|v| v == "websocket") - .unwrap_or(false) - { - handle_ws_req(req, target_service, extra_headers, client).await - } else { - handle_http_req(req, target_service, extra_headers, client).await + match ws.0 { + Some(downstream_upgrade) => { + let mut url = target_service.url; + if url.starts_with("http://") { + url = url.replace("http://", "ws://"); + } else if url.starts_with("https://") { + url = url.replace("https://", "wss://"); + } + + let uri = url.parse::().unwrap(); + let mut upstream_request = uri.into_client_request().unwrap(); + + let extra_http_headers: HeaderMap = extra_headers.into(); + for (key, value) in extra_http_headers.iter() { + upstream_request.headers_mut().insert(key, value.clone()); + } + + let (upstream_ws_stream, upstream_response) = + match tokio_tungstenite::connect_async(upstream_request).await { + Ok(connection) => connection, + Err(error) => match error { + tokio_tungstenite::tungstenite::Error::Http(response) => { + let (parts, body) = response.into_parts(); + let body = body.unwrap_or_default(); + + return Response::from_parts(parts, Body::from(body)); + } + error => { + return Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(Body::from(error.to_string())) + .unwrap() + } + }, + }; + + let mut upstream_upgrade_response = + downstream_upgrade.on_upgrade(ws::context_handle_socket(upstream_ws_stream)); + + let websocket_upgrade_response_headers = upstream_upgrade_response.headers_mut(); + for upstream_header in upstream_response.headers() { + if !websocket_upgrade_response_headers.contains_key(upstream_header.0) { + websocket_upgrade_response_headers + .append(upstream_header.0, upstream_header.1.clone()); + } + } + + websocket_upgrade_response_headers.extend(allow_all_cors()); + + upstream_upgrade_response + } + None => handle_http_req(req, target_service, extra_headers, client).await, } } @@ -272,119 +318,6 @@ async fn handle_http_req( resp.into_response() } -async fn handle_ws_req( - req: Request, - target_service: TargetService, - extra_headers: linkup::HeaderMap, - client: HttpsClient, -) -> Response { - let extra_http_headers: HeaderMap = extra_headers.into(); - - let target_ws_req_result = Request::builder() - .uri(target_service.url) - .method(req.method().clone()) - .body(Body::empty()); - - let mut target_ws_req = match target_ws_req_result { - Ok(request) => request, - Err(e) => { - return ApiError::new( - format!("Failed to build request: {}", e), - StatusCode::INTERNAL_SERVER_ERROR, - ) - .into_response(); - } - }; - - target_ws_req.headers_mut().extend(req.headers().clone()); - target_ws_req.headers_mut().extend(extra_http_headers); - target_ws_req.headers_mut().remove(http::header::HOST); - - // Send the modified request to the target service. - let target_ws_resp = match client.request(target_ws_req).await { - Ok(resp) => resp, - Err(e) => { - return ApiError::new( - format!("Failed to proxy request: {}", e), - StatusCode::BAD_GATEWAY, - ) - .into_response() - } - }; - - let status = target_ws_resp.status(); - if status != 101 { - return ApiError::new( - format!( - "Failed to proxy request: expected 101 Switching Protocols, got {}", - status - ), - StatusCode::BAD_GATEWAY, - ) - .into_response(); - } - - let target_ws_resp_headers = target_ws_resp.headers().clone(); - - let upgraded_target = match hyper::upgrade::on(target_ws_resp).await { - Ok(upgraded) => upgraded, - Err(e) => { - return ApiError::new( - format!("Failed to upgrade connection: {}", e), - StatusCode::BAD_GATEWAY, - ) - .into_response() - } - }; - - tokio::spawn(async move { - // We won't get passed this until the 101 response returns to the client - let upgraded_incoming = match hyper::upgrade::on(req).await { - Ok(upgraded) => upgraded, - Err(e) => { - println!("Failed to upgrade incoming connection: {}", e); - return; - } - }; - - let mut incoming_stream = TokioIo::new(upgraded_incoming); - let mut target_stream = TokioIo::new(upgraded_target); - - let res = tokio::io::copy_bidirectional(&mut incoming_stream, &mut target_stream).await; - - match res { - Ok((incoming_to_target, target_to_incoming)) => { - println!( - "Copied {} bytes from incoming to target and {} bytes from target to incoming", - incoming_to_target, target_to_incoming - ); - } - Err(e) => { - eprintln!("Error copying between incoming and target: {}", e); - } - } - }); - - let mut resp_builder = Response::builder().status(101); - let resp_headers_result = resp_builder.headers_mut(); - if let Some(resp_headers) = resp_headers_result { - for (header, value) in target_ws_resp_headers { - if let Some(header_name) = header { - resp_headers.append(header_name, value); - } - } - } - - match resp_builder.body(Body::empty()) { - Ok(response) => response, - Err(e) => ApiError::new( - format!("Failed to build response: {}", e), - StatusCode::INTERNAL_SERVER_ERROR, - ) - .into_response(), - } -} - async fn linkup_config_handler( Extension(store): Extension, Json(update_req): Json, diff --git a/local-server/src/ws.rs b/local-server/src/ws.rs new file mode 100644 index 00000000..fe770ffe --- /dev/null +++ b/local-server/src/ws.rs @@ -0,0 +1,158 @@ +use std::{future::Future, pin::Pin}; + +use axum::extract::{ws::WebSocket, FromRequestParts, WebSocketUpgrade}; +use futures::{SinkExt, StreamExt}; +use http::{request::Parts, StatusCode}; +use tokio::net::TcpStream; +use tokio_tungstenite::{ + tungstenite::{self, Message}, + MaybeTlsStream, WebSocketStream, +}; + +pub struct ExtractOptionalWebSocketUpgrade(pub Option); + +impl FromRequestParts for ExtractOptionalWebSocketUpgrade +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let upgrade = WebSocketUpgrade::from_request_parts(parts, state).await; + + match upgrade { + Ok(upgrade) => Ok(ExtractOptionalWebSocketUpgrade(Some(upgrade))), + Err(_) => { + // TODO: Maybe log? + Ok(ExtractOptionalWebSocketUpgrade(None)) + } + } + } +} + +fn tungstenite_to_axum(message: tungstenite::Message) -> axum::extract::ws::Message { + match message { + Message::Text(utf8_bytes) => axum::extract::ws::Message::Text( + axum::extract::ws::Utf8Bytes::from(utf8_bytes.as_str()), + ), + Message::Binary(bytes) => axum::extract::ws::Message::Binary(bytes), + Message::Ping(bytes) => axum::extract::ws::Message::Ping(bytes), + Message::Pong(bytes) => axum::extract::ws::Message::Pong(bytes), + Message::Close(close_frame) => match close_frame { + Some(frame) => axum::extract::ws::Message::Close(Some(axum::extract::ws::CloseFrame { + code: frame.code.into(), + reason: axum::extract::ws::Utf8Bytes::from(frame.reason.as_str()), + })), + None => axum::extract::ws::Message::Close(None), + }, + Message::Frame(_frame) => unreachable!(), + } +} + +fn axum_to_tungstenite(message: axum::extract::ws::Message) -> tungstenite::Message { + match message { + axum::extract::ws::Message::Text(utf8_bytes) => { + tungstenite::Message::Text(tungstenite::Utf8Bytes::from(utf8_bytes.as_str())) + } + axum::extract::ws::Message::Binary(bytes) => tungstenite::Message::Binary(bytes), + axum::extract::ws::Message::Ping(bytes) => tungstenite::Message::Ping(bytes), + axum::extract::ws::Message::Pong(bytes) => tungstenite::Message::Pong(bytes), + axum::extract::ws::Message::Close(close_frame) => { + tungstenite::Message::Close(close_frame.map(|frame| { + tungstenite::protocol::frame::CloseFrame { + code: frame.code.into(), + reason: tungstenite::Utf8Bytes::from(frame.reason.as_str()), + } + })) + } + } +} + +type WrappedSocketHandler = + Box Pin + Send>> + Send>; + +pub fn context_handle_socket( + upstream_ws: WebSocketStream>, +) -> WrappedSocketHandler { + Box::new(move |downstream: WebSocket| { + Box::pin(async move { + use futures::future::{select, Either}; + + let (mut upstream_write, mut upstream_read) = upstream_ws.split(); + let (mut downstream_write, mut downstream_read) = downstream.split(); + + let mut is_closed = false; + + loop { + match select(downstream_read.next(), upstream_read.next()).await { + Either::Left((Some(downstream_message), _)) => match downstream_message { + Ok(message) => { + let tungstenite_message = axum_to_tungstenite(message); + + match &tungstenite_message { + Message::Close(_) => { + let _ = upstream_write.send(tungstenite_message).await; + + if is_closed { + break; + } else { + is_closed = true; + } + } + _ => { + if let Err(e) = upstream_write.send(tungstenite_message).await { + eprintln!("Error sending message to upstream: {}", e); + break; + } + } + } + } + Err(e) => { + eprint!("Got error on reading message from downstream: {}", e); + break; + } + }, + Either::Right((Some(upstream_message), _)) => match upstream_message { + Ok(message) => { + let axum_message = tungstenite_to_axum(message); + + match &axum_message { + axum::extract::ws::Message::Close(_) => { + let _ = downstream_write.send(axum_message).await; + + if is_closed { + break; + } else { + is_closed = true; + } + } + _ => { + if let Err(e) = downstream_write.send(axum_message).await { + eprintln!("Error sending message to upstream: {}", e); + break; + } + } + } + } + Err(e) => { + eprint!("Got error on reading message from upstream: {}", e); + break; + } + }, + other => { + // TODO: On the select! macro, if nothing is matched, it panics. I guess + // this might be better than panicking? Or do we want to "fail loudly" here? + // + // https://docs.rs/tokio/latest/tokio/macro.select.html#panics + eprint!("Received unexpected message: {other:?}"); + + break; + } + } + } + + let _ = upstream_write.close().await; + let _ = downstream_write.close().await; + }) + }) +} diff --git a/server-tests/tests/ws_test.rs b/server-tests/tests/ws_test.rs index 75589182..51469ed6 100644 --- a/server-tests/tests/ws_test.rs +++ b/server-tests/tests/ws_test.rs @@ -5,7 +5,7 @@ use axum::response::IntoResponse; use axum::Router; use futures::{SinkExt, StreamExt}; use helpers::ServerKind; -use http::Uri; +use http::{HeaderName, HeaderValue}; use rstest::rstest; use tokio::net::TcpListener; @@ -27,7 +27,7 @@ async fn can_request_underlying_websocket_server( assert_eq!(session_resp.text().await.unwrap(), "ws-session"); // Connect to the WebSocket server through the proxy - let uri = Uri::from_str(url.as_str()).unwrap(); + let uri = http::Uri::from_str(url.as_str()).unwrap(); let req = http::Request::builder() .uri(format!("ws://{}/ws", uri.authority().unwrap())) .header("referer", "example.com") @@ -46,6 +46,10 @@ async fn can_request_underlying_websocket_server( .expect("Failed to connect to WebSocket server"); assert_eq!(ws_resp.status(), 101); + assert_eq!( + ws_resp.headers().get("my-special-header"), + Some(&HeaderValue::from_str("special-value").unwrap()) + ); // Send a message let msg = "Hello, WebSocket!"; @@ -58,9 +62,9 @@ async fn can_request_underlying_websocket_server( Some(Ok(tokio_tungstenite::tungstenite::Message::Text(text))) => { assert_eq!(text, msg); } - anythingelse => { - println!("{:?}", anythingelse); - panic!("Failed to receive message") + anything_else => { + println!("{:?}", anything_else); + panic!("Failed to receive echoed message") } } @@ -68,10 +72,31 @@ async fn can_request_underlying_websocket_server( .close(None) .await .expect("Failed to close WebSocket"); + + match ws_stream.next().await { + Some(Ok(tokio_tungstenite::tungstenite::Message::Close(frame))) => { + println!("Received close frame from server: {:?}", frame); + } + None => { + println!("Connection closed without explicit close frame from server"); + } + other => { + panic!( + "Expected a close frame or stream termination, but got: {:?}", + other + ); + } + } } async fn websocket_echo(ws: WebSocketUpgrade) -> impl IntoResponse { - ws.on_upgrade(handle_websocket) + let mut response = ws.on_upgrade(handle_websocket); + response.headers_mut().append( + HeaderName::from_str("my-special-header").unwrap(), + HeaderValue::from_str("special-value").unwrap(), + ); + + response } async fn handle_websocket(mut socket: WebSocket) { @@ -86,6 +111,10 @@ async fn handle_websocket(mut socket: WebSocket) { break; } } else if let Message::Close(_) = msg { + println!("Received close on server, sending close back"); + if let Err(e) = socket.send(Message::Close(None)).await { + println!("Failed to send message: {:?}", e); + } if let Err(e) = socket.close().await { println!("Failed to close: {:?}", e); } diff --git a/worker/src/lib.rs b/worker/src/lib.rs index b59edfdb..3b095db9 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -336,7 +336,7 @@ async fn linkup_request_handler( req.headers_mut().extend(extra_http_headers); req.headers_mut().remove(http::header::HOST); - let worker_req: worker::Request = match req.try_into() { + let upstream_request: worker::Request = match req.try_into() { Ok(req) => req, Err(e) => { return HttpError::new( @@ -347,11 +347,11 @@ async fn linkup_request_handler( } }; - let cacheable_req = is_cacheable_request(&worker_req, &config); - let cache_key = get_cache_key(&worker_req, &session_name).unwrap_or_default(); + let cacheable_req = is_cacheable_request(&upstream_request, &config); + let cache_key = get_cache_key(&upstream_request, &session_name).unwrap_or_default(); if cacheable_req { - if let Some(worker_resp) = get_cached_req(cache_key.clone()).await { - let resp: HttpResponse = match worker_resp.try_into() { + if let Some(upstream_response) = get_cached_req(cache_key.clone()).await { + let resp: HttpResponse = match upstream_response.try_into() { Ok(resp) => resp, Err(e) => { return HttpError::new( @@ -365,7 +365,7 @@ async fn linkup_request_handler( } } - let mut worker_resp = match Fetch::Request(worker_req).send().await { + let mut upstream_response = match Fetch::Request(upstream_request).send().await { Ok(resp) => resp, Err(e) => { return HttpError::new( @@ -377,10 +377,10 @@ async fn linkup_request_handler( }; if is_websocket { - handle_ws_resp(worker_resp).await.into_response() + handle_ws_resp(upstream_response).await.into_response() } else { if cacheable_req { - let cache_clone = match worker_resp.cloned() { + let cache_clone = match upstream_response.cloned() { Ok(resp) => resp, Err(e) => { return HttpError::new( @@ -398,7 +398,7 @@ async fn linkup_request_handler( .into_response(); } } - handle_http_resp(worker_resp).await.into_response() + handle_http_resp(upstream_response).await.into_response() } } diff --git a/worker/src/ws.rs b/worker/src/ws.rs index e145a6aa..cfec1f1a 100644 --- a/worker/src/ws.rs +++ b/worker/src/ws.rs @@ -1,4 +1,7 @@ +use std::str::FromStr; + use axum::{http::StatusCode, response::IntoResponse}; +use http::{HeaderName, HeaderValue}; use linkup::allow_all_cors; use worker::{console_log, Error, HttpResponse, WebSocket, WebSocketPair, WebsocketEvent}; @@ -9,12 +12,13 @@ use futures::{ use crate::http_error::HttpError; -pub async fn handle_ws_resp(worker_resp: worker::Response) -> impl IntoResponse { - let dest_ws_res = match worker_resp.websocket() { +pub async fn handle_ws_resp(upstream_response: worker::Response) -> impl IntoResponse { + let upstream_response_headers = upstream_response.headers().clone(); + let upstream_ws_result = match upstream_response.websocket() { Some(ws) => Ok(ws), None => Err(Error::RustError("server did not accept".into())), }; - let dest_ws = match dest_ws_res { + let upstream_ws = match upstream_ws_result { Ok(ws) => ws, Err(e) => { return HttpError::new( @@ -25,7 +29,7 @@ pub async fn handle_ws_resp(worker_resp: worker::Response) -> impl IntoResponse } }; - let source_ws = match WebSocketPair::new() { + let downstream_ws = match WebSocketPair::new() { Ok(ws) => ws, Err(e) => { return HttpError::new( @@ -35,38 +39,44 @@ pub async fn handle_ws_resp(worker_resp: worker::Response) -> impl IntoResponse .into_response() } }; - let source_ws_server = source_ws.server; + let downstream_ws_server = downstream_ws.server; worker::wasm_bindgen_futures::spawn_local(async move { - let mut dest_events = dest_ws.events().expect("could not open dest event stream"); - let mut source_events = source_ws_server + let mut upstream_events = upstream_ws + .events() + .expect("could not open dest event stream"); + let mut downstream_events = downstream_ws_server .events() .expect("could not open source event stream"); - dest_ws.accept().expect("could not accept dest ws"); - source_ws_server + upstream_ws.accept().expect("could not accept dest ws"); + downstream_ws_server .accept() .expect("could not accept source ws"); + let mut is_closed = false; + loop { - match future::select(source_events.next(), dest_events.next()).await { - Either::Left((Some(source_event), _)) => { + match future::select(downstream_events.next(), upstream_events.next()).await { + Either::Left((Some(downstream_event), _)) => { if let Err(e) = forward_ws_event( - source_event, - &source_ws_server, - &dest_ws, + downstream_event, + &downstream_ws_server, + &upstream_ws, "to destination".into(), + &mut is_closed, ) { console_log!("Error forwarding source event: {:?}", e); break; } } - Either::Right((Some(dest_event), _)) => { + Either::Right((Some(upstream_event), _)) => { if let Err(e) = forward_ws_event( - dest_event, - &dest_ws, - &source_ws_server, + upstream_event, + &upstream_ws, + &downstream_ws_server, "to source".into(), + &mut is_closed, ) { console_log!("Error forwarding dest event: {:?}", e); break; @@ -76,8 +86,8 @@ pub async fn handle_ws_resp(worker_resp: worker::Response) -> impl IntoResponse console_log!("No event received, error"); close_with_internal_error( "Received something other than event from streams".to_string(), - &source_ws_server, - &dest_ws, + &downstream_ws_server, + &upstream_ws, ); break; } @@ -85,7 +95,7 @@ pub async fn handle_ws_resp(worker_resp: worker::Response) -> impl IntoResponse } }); - let worker_resp = match worker::Response::from_websocket(source_ws.client) { + let downstream_resp = match worker::Response::from_websocket(downstream_ws.client) { Ok(res) => res, Err(e) => { return HttpError::new( @@ -95,7 +105,8 @@ pub async fn handle_ws_resp(worker_resp: worker::Response) -> impl IntoResponse .into_response() } }; - let mut resp: HttpResponse = match worker_resp.try_into() { + + let mut resp: HttpResponse = match downstream_resp.try_into() { Ok(resp) => resp, Err(e) => { return HttpError::new( @@ -106,6 +117,15 @@ pub async fn handle_ws_resp(worker_resp: worker::Response) -> impl IntoResponse } }; + for upstream_header in upstream_response_headers.entries() { + if !resp.headers().contains_key(&upstream_header.0) { + resp.headers_mut().append( + HeaderName::from_str(&upstream_header.0).unwrap(), + HeaderValue::from_str(&upstream_header.1).unwrap(), + ); + } + } + resp.headers_mut().extend(allow_all_cors()); resp.into_response() @@ -116,6 +136,7 @@ fn forward_ws_event( from: &WebSocket, to: &WebSocket, description: String, + is_closed: &mut bool, ) -> Result<(), Error> { match event { Ok(WebsocketEvent::Message(msg)) => { @@ -144,11 +165,14 @@ fn forward_ws_event( } } Ok(WebsocketEvent::Close(close)) => { - let close_res = to.close(Some(1000), Some(close.reason())); - if let Err(e) = close_res { - console_log!("Error closing {} with close event: {:?}", description, e); + let _ = to.close(Some(1000), Some(close.reason())); + if *is_closed { + return Err(Error::RustError("Closed!".into())); + } else { + *is_closed = true; } - Err(Error::RustError(format!("Close event: {}", close.reason()))) + + Ok(()) } Err(e) => { let err_msg = format!("Other {} error: {:?}", description, e); From affc9138a001ac33249d1031f4d14e166fa4a0f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Augusto=20C=C3=A9sar?= Date: Tue, 13 May 2025 15:13:35 +0200 Subject: [PATCH 5/5] fix: update documentation link on --help --- linkup-cli/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linkup-cli/src/main.rs b/linkup-cli/src/main.rs index dadedf97..a7df5783 100644 --- a/linkup-cli/src/main.rs +++ b/linkup-cli/src/main.rs @@ -185,7 +185,7 @@ pub enum CheckErr { #[derive(Parser)] #[command( name = "linkup", - about = "Connect remote and local dev/preview environments\n\nIf you need help running linkup, start here:\nhttps://github.com/mentimeter/linkup/blob/main/docs/using-linkup.md", + about = "Connect remote and local dev/preview environments\n\nIf you need help running linkup, start here:\nhttps://mentimeter.github.io/linkup", version = env!("CARGO_PKG_VERSION"), )] struct Cli {