Skip to content
129 changes: 118 additions & 11 deletions nexus/db-queries/src/db/datastore/silo_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::DataStore;
use crate::authz;
use crate::context::OpContext;
use crate::db::IncompleteOnConflictExt;
use crate::db::datastore::DbConnection;
use crate::db::datastore::RunnableQueryNoReturn;
use crate::db::model;
use crate::db::model::Silo;
Expand All @@ -32,6 +33,7 @@ use omicron_common::api::external::InternalContext;
use omicron_common::api::external::ListResultVec;
use omicron_common::api::external::LookupResult;
use omicron_common::api::external::UpdateResult;
use omicron_uuid_kinds::GenericUuid;
use omicron_uuid_kinds::SiloGroupUuid;
use omicron_uuid_kinds::SiloUserUuid;
use uuid::Uuid;
Expand Down Expand Up @@ -68,6 +70,15 @@ impl SiloGroup {
SiloGroup::Scim(u) => u.silo_id,
}
}

/// Set the member count for this group
pub fn set_member_count(&mut self, count: i64) {
match self {
SiloGroup::ApiOnly(g) => g.member_count = count,
SiloGroup::Jit(g) => g.member_count = count,
SiloGroup::Scim(g) => g.member_count = count,
}
}
}

impl From<model::SiloGroup> for SiloGroup {
Expand All @@ -85,6 +96,7 @@ impl From<model::SiloGroup> for SiloGroup {
group with provision type 'api_only' from having a \
null external_id",
),
member_count: 0,
})
}

Expand All @@ -99,6 +111,7 @@ impl From<model::SiloGroup> for SiloGroup {
group with provision type 'jit' from having a null \
external_id",
),
member_count: 0,
}),

UserProvisionType::Scim => SiloGroup::Scim(SiloGroupScim {
Expand All @@ -113,6 +126,7 @@ impl From<model::SiloGroup> for SiloGroup {
display_name",
),
external_id: record.external_id,
member_count: 0,
}),
}
}
Expand Down Expand Up @@ -148,6 +162,9 @@ pub struct SiloGroupApiOnly {

/// The identity provider's ID for this group.
pub external_id: String,

/// The number of members in this group
pub member_count: i64,
}

impl SiloGroupApiOnly {
Expand All @@ -159,6 +176,7 @@ impl SiloGroupApiOnly {
time_deleted: None,
silo_id,
external_id,
member_count: 0,
}
}
}
Expand Down Expand Up @@ -193,6 +211,7 @@ impl From<SiloGroupApiOnly> for views::Group {
// TODO the use of external_id as display_name is temporary
display_name: u.external_id,
silo_id: u.silo_id,
member_count: u.member_count,
}
}
}
Expand All @@ -207,6 +226,9 @@ pub struct SiloGroupJit {

/// The identity provider's ID for this user.
pub external_id: String,

/// The number of members in this group
pub member_count: i64,
}

impl SiloGroupJit {
Expand All @@ -218,6 +240,7 @@ impl SiloGroupJit {
time_deleted: None,
silo_id,
external_id,
member_count: 0,
}
}
}
Expand Down Expand Up @@ -252,6 +275,7 @@ impl From<SiloGroupJit> for views::Group {
// TODO the use of external_id as display_name is temporary
display_name: u.external_id,
silo_id: u.silo_id,
member_count: u.member_count,
}
}
}
Expand All @@ -268,6 +292,9 @@ pub struct SiloGroupScim {
pub display_name: String,

pub external_id: Option<String>,

/// The number of members in this group
pub member_count: i64,
}

impl SiloGroupScim {
Expand All @@ -285,6 +312,7 @@ impl SiloGroupScim {
silo_id,
display_name,
external_id,
member_count: 0,
}
}
}
Expand Down Expand Up @@ -319,6 +347,7 @@ impl From<SiloGroupScim> for views::Group {
// TODO the use of display name as display_name is temporary
display_name: u.display_name,
silo_id: u.silo_id,
member_count: u.member_count,
}
}
}
Expand All @@ -345,6 +374,23 @@ impl<'a> SiloGroupLookup<'a> {
}

impl DataStore {
/// Helper function to fetch member counts for a list of groups
async fn silo_group_member_counts(
conn: &async_bb8_diesel::Connection<DbConnection>,
group_ids: Vec<Uuid>,
) -> Result<std::collections::HashMap<Uuid, i64>, Error> {
use nexus_db_schema::schema::silo_group_membership::dsl;

dsl::silo_group_membership
.filter(dsl::silo_group_id.eq_any(group_ids))
.group_by(dsl::silo_group_id)
.select((dsl::silo_group_id, diesel::dsl::count(dsl::silo_user_id)))
.load_async::<(Uuid, i64)>(conn)
.await
.map_err(|e| public_error_from_diesel(e, ErrorHandler::Server))
.map(|counts| counts.into_iter().collect())
}

pub(super) async fn silo_group_ensure_query(
opctx: &OpContext,
authz_silo: &authz::Silo,
Expand Down Expand Up @@ -591,16 +637,38 @@ impl DataStore {
silo_group as sg, silo_group_membership as sgm,
};

let page = paginated(sg::dsl::silo_group, sg::id, pagparams)
let conn = self.pool_connection_authorized(opctx).await?;

// First get the groups this user belongs to, in the correct paginated order
let groups = paginated(sg::dsl::silo_group, sg::id, pagparams)
.inner_join(sgm::table.on(sgm::silo_group_id.eq(sg::id)))
.filter(sgm::silo_user_id.eq(to_db_typed_uuid(silo_user_id)))
.filter(sg::time_deleted.is_null())
.select(model::SiloGroup::as_returning())
.get_results_async(&*self.pool_connection_authorized(opctx).await?)
.select(model::SiloGroup::as_select())
.load_async::<model::SiloGroup>(&*conn)
.await
.map_err(|e| public_error_from_diesel(e, ErrorHandler::Server))?
.map_err(|e| public_error_from_diesel(e, ErrorHandler::Server))?;

if groups.is_empty() {
return Ok(Vec::new());
}

let group_ids: Vec<Uuid> =
groups.iter().map(|g| *g.id().as_untyped_uuid()).collect();

let member_counts =
DataStore::silo_group_member_counts(&conn, group_ids).await?;

let page = groups
.into_iter()
.map(|group: model::SiloGroup| group.into())
.map(|group| {
let group_id = *group.id().as_untyped_uuid();
let member_count =
member_counts.get(&group_id).copied().unwrap_or(0);
let mut silo_group: SiloGroup = group.into();
silo_group.set_member_count(member_count);
silo_group
})
.collect::<Vec<SiloGroup>>();

Ok(page)
Expand Down Expand Up @@ -742,9 +810,9 @@ impl DataStore {
let conn = self.pool_connection_authorized(opctx).await?;

let silo = {
use nexus_db_schema::schema::silo::dsl;
dsl::silo
.filter(dsl::id.eq(authz_silo.id()))
use nexus_db_schema::schema::silo::dsl as silo_dsl;
silo_dsl::silo
.filter(silo_dsl::id.eq(authz_silo.id()))
.select(model::Silo::as_select())
.get_result_async::<model::Silo>(&*conn)
.await
Expand All @@ -756,18 +824,57 @@ impl DataStore {
})?
};

let page = paginated(dsl::silo_group, dsl::id, pagparams)
let groups = paginated(dsl::silo_group, dsl::id, pagparams)
.filter(dsl::silo_id.eq(authz_silo.id()))
.filter(dsl::time_deleted.is_null())
.filter(dsl::user_provision_type.eq(silo.user_provision_type))
.select(model::SiloGroup::as_select())
.load_async::<model::SiloGroup>(&*conn)
.await
.map_err(|e| public_error_from_diesel(e, ErrorHandler::Server))?
.map_err(|e| public_error_from_diesel(e, ErrorHandler::Server))?;

if groups.is_empty() {
return Ok(Vec::new());
}

let group_ids: Vec<Uuid> =
groups.iter().map(|g| *g.id().as_untyped_uuid()).collect();

let member_counts =
DataStore::silo_group_member_counts(&conn, group_ids).await?;

let page = groups
.into_iter()
.map(|group: model::SiloGroup| group.into())
.map(|group| {
let group_id = *group.id().as_untyped_uuid();
let member_count =
member_counts.get(&group_id).copied().unwrap_or(0);
let mut silo_group: SiloGroup = group.into();
silo_group.set_member_count(member_count);
silo_group
})
.collect::<Vec<SiloGroup>>();

Ok(page)
}

/// Fetch the member count for a single silo group
pub async fn silo_group_member_count(
&self,
opctx: &OpContext,
group_id: SiloGroupUuid,
) -> Result<i64, Error> {
use nexus_db_schema::schema::silo_group_membership::dsl;

let conn = self.pool_connection_authorized(opctx).await?;

let count = dsl::silo_group_membership
.filter(dsl::silo_group_id.eq(to_db_typed_uuid(group_id)))
.count()
.get_result_async::<i64>(&*conn)
.await
.map_err(|e| public_error_from_diesel(e, ErrorHandler::Server))?;

Ok(count)
}
}
Loading
Loading