Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions litebox/src/platform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,33 @@ pub trait SystemInfoProvider {
/// execution context and transfer control to the syscall handler.
fn get_syscall_entry_point(&self) -> usize;
}

/// A provider for thread-local storage.
pub trait ThreadLocalStorageProvider {
type ThreadLocalStorage;

/// Set a thread-local storage value for the current thread.
///
/// # Panics
///
/// Panics if TLS is set already.
fn set_thread_local_storage(&self, value: Self::ThreadLocalStorage);

/// Invokes the provided callback function with the thread-local storage value for the current thread.
///
/// # Panics
///
/// Panics if TLS is not set yet.
/// Panics if TLS is borrowed already (e.g., recursive call).
fn with_thread_local_storage_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut Self::ThreadLocalStorage) -> R;

/// Release the thread-local storage value for the current thread
///
/// # Panics
///
/// Panics if TLS is not set yet.
/// Panics if TLS is being used by [`Self::with_thread_local_storage_mut`].
fn release_thread_local_storage(&self) -> Self::ThreadLocalStorage;
}
1 change: 1 addition & 0 deletions litebox_common_linux/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2024"

[dependencies]
bitfield = "0.19.1"
bitflags = "2.9.0"
cfg-if = "1.0.0"
litebox = { path = "../litebox/", version = "0.1.0" }
Expand Down
99 changes: 95 additions & 4 deletions litebox_common_linux/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use syscalls::Sysno;

pub mod errno;

extern crate alloc;

// TODO(jayb): Should errno::Errno be publicly re-exported?

bitflags::bitflags! {
Expand Down Expand Up @@ -681,13 +683,102 @@ pub unsafe fn wrfsbase(fs_base: usize) {
}
}

/// Reads the GS segment base address
///
/// ## Safety
///
/// If `CR4.FSGSBASE` is not set, this instruction will throw an `#UD`.
#[cfg(target_arch = "x86_64")]
pub unsafe fn rdgsbase() -> usize {
let ret: usize;
unsafe {
core::arch::asm!(
"rdgsbase {}",
out(reg) ret,
options(nostack, nomem)
);
}
ret
}

/// Writes the GS segment base address
///
/// ## Safety
///
/// If `CR4.FSGSBASE` is not set, this instruction will throw an `#UD`.
///
/// The caller must ensure that this write operation has no unsafe side
/// effects, as the GS segment base address might be in use.
#[cfg(target_arch = "x86_64")]
pub unsafe fn wrgsbase(gs_base: usize) {
unsafe {
core::arch::asm!(
"wrgsbase {}",
in(reg) gs_base,
options(nostack, nomem)
);
}
}

/// Linux's `user_desc` struct used by the `set_thread_area` syscall.
#[repr(C, packed)]
#[derive(Debug, Clone)]
pub struct UserDesc {
pub entry_number: i32,
pub base_addr: i32,
pub limit: i32,
pub flags: i32,
pub entry_number: u32,
pub base_addr: u32,
pub limit: u32,
pub flags: UserDescFlags,
}

bitfield::bitfield! {
/// Flags for the `user_desc` struct.
#[derive(Clone, Copy)]
pub struct UserDescFlags(u32);
impl Debug;
/// 1 if the segment is 32-bit
pub seg_32bit, set_seg_32bit: 0;
/// Contents of the segment
pub contents, set_contents: 1, 2;
/// Read-exec only
pub read_exec_only, set_read_exec_only: 3;
/// Limit in pages
pub limit_in_pages, set_limit_in_pages: 4;
/// Segment not present
pub seg_not_present, set_seg_not_present: 5;
/// Usable by userland
pub useable, set_useable: 6;
/// 1 if the segment is 64-bit (x86_64 only)
pub lm, set_lm: 7;
}

/// Struct for thread-local storage.
pub struct ThreadLocalStorage<Platform: litebox::platform::RawPointerProvider> {
/// Indicates whether the TLS is being borrowed.
pub borrowed: bool,

#[cfg(target_arch = "x86")]
pub self_ptr: *mut ThreadLocalStorage<Platform>,
pub current_task: alloc::boxed::Box<Task>,

pub __phantom: core::marker::PhantomData<Platform>,
}

impl<Platform: litebox::platform::RawPointerProvider> ThreadLocalStorage<Platform> {
pub const fn new(task: alloc::boxed::Box<Task>) -> Self {
Self {
borrowed: false,
#[cfg(target_arch = "x86")]
self_ptr: core::ptr::null_mut(),
current_task: task,
__phantom: core::marker::PhantomData,
}
}
}

/// A task associated with a thread in LiteBox.
pub struct Task {
/// Thread identifier
pub tid: u32,
}

#[repr(C)]
Expand Down
194 changes: 180 additions & 14 deletions litebox_platform_linux_userland/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,23 @@ fn set_fs_base_arch_prctl(fs_base: usize) -> Result<usize, litebox_common_linux:
})
}

#[cfg(target_arch = "x86")]
fn set_thread_area(
user_desc: litebox::platform::trivial_providers::TransparentMutPtr<
litebox_common_linux::UserDesc,
>,
) -> Result<usize, litebox_common_linux::errno::Errno> {
unsafe { syscalls::syscall1(syscalls::Sysno::set_thread_area, user_desc.as_usize()) }.map_err(
|err| match err {
syscalls::Errno::EFAULT => litebox_common_linux::errno::Errno::EFAULT,
syscalls::Errno::EINVAL => litebox_common_linux::errno::Errno::EINVAL,
syscalls::Errno::ENOSYS => litebox_common_linux::errno::Errno::ENOSYS,
syscalls::Errno::ESRCH => litebox_common_linux::errno::Errno::ESRCH,
_ => panic!("unexpected error {err}"),
},
)
}

pub struct PunchthroughToken {
punchthrough: PunchthroughSyscall<LinuxUserland>,
}
Expand Down Expand Up @@ -674,19 +691,7 @@ impl litebox::platform::PunchthroughToken for PunchthroughToken {
}
#[cfg(target_arch = "x86")]
PunchthroughSyscall::SetThreadArea { user_desc } => {
use litebox::platform::RawConstPointer as _;
unsafe {
syscalls::syscall1(syscalls::Sysno::set_thread_area, user_desc.as_usize())
}
.map_err(|err| {
litebox::platform::PunchthroughError::Failure(match err {
syscalls::Errno::EFAULT => litebox_common_linux::errno::Errno::EFAULT,
syscalls::Errno::EINVAL => litebox_common_linux::errno::Errno::EINVAL,
syscalls::Errno::ENOSYS => litebox_common_linux::errno::Errno::ENOSYS,
syscalls::Errno::ESRCH => litebox_common_linux::errno::Errno::ESRCH,
_ => panic!("unexpected error {err}"),
})
})
set_thread_area(user_desc).map_err(litebox::platform::PunchthroughError::Failure)
}
}
}
Expand Down Expand Up @@ -1207,12 +1212,148 @@ impl litebox::platform::SystemInfoProvider for LinuxUserland {
}
}

impl LinuxUserland {
#[cfg(target_arch = "x86_64")]
fn get_thread_local_storage() -> *mut litebox_common_linux::ThreadLocalStorage<LinuxUserland> {
let tls = unsafe { litebox_common_linux::rdgsbase() };
if tls == 0 {
return core::ptr::null_mut();
}
tls as *mut litebox_common_linux::ThreadLocalStorage<LinuxUserland>
}

#[cfg(target_arch = "x86")]
fn get_thread_local_storage() -> *mut litebox_common_linux::ThreadLocalStorage<LinuxUserland> {
let mut fs_selector: u16;
unsafe {
core::arch::asm!(
"mov {0:x}, fs",
out(reg) fs_selector,
options(nostack, preserves_flags)
);
}
if fs_selector == 0 {
return core::ptr::null_mut();
}

let mut addr: usize;
unsafe {
core::arch::asm!(
"mov {0}, fs:{offset}",
out(reg) addr,
offset = const core::mem::offset_of!(litebox_common_linux::ThreadLocalStorage<LinuxUserland>, self_ptr),
options(nostack, preserves_flags)
);
}
addr as *mut litebox_common_linux::ThreadLocalStorage<LinuxUserland>
}
}

/// Similar to libc, we use fs/gs registers to store thread-local storage (TLS).
/// To avoid conflicts with libc's TLS, we choose to use gs on x86_64 and fs on x86
/// as libc uses fs on x86_64 and gs on x86.
impl litebox::platform::ThreadLocalStorageProvider for LinuxUserland {
type ThreadLocalStorage = litebox_common_linux::ThreadLocalStorage<LinuxUserland>;

#[cfg(target_arch = "x86_64")]
fn set_thread_local_storage(&self, tls: Self::ThreadLocalStorage) {
let old_gs_base = unsafe { litebox_common_linux::rdgsbase() };
assert!(old_gs_base == 0, "TLS already set for this thread");
let tls = Box::new(tls);
unsafe { litebox_common_linux::wrgsbase(Box::into_raw(tls) as usize) };
}

#[cfg(target_arch = "x86")]
fn set_thread_local_storage(&self, tls: Self::ThreadLocalStorage) {
let mut old_fs_selector: u16;
unsafe {
core::arch::asm!(
"mov {0:x}, fs",
out(reg) old_fs_selector,
options(nostack, preserves_flags)
);
}
assert!(old_fs_selector == 0, "TLS already set for this thread");

let mut tls = Box::new(tls);
tls.self_ptr = tls.as_mut();

let mut flags = litebox_common_linux::UserDescFlags(0);
flags.set_seg_32bit(true);
flags.set_useable(true);
let mut user_desc = litebox_common_linux::UserDesc {
entry_number: u32::MAX,
base_addr: Box::into_raw(tls) as u32,
limit: u32::try_from(core::mem::size_of::<Self::ThreadLocalStorage>()).unwrap() - 1,
flags,
};
let user_desc_ptr = litebox::platform::trivial_providers::TransparentMutPtr {
inner: &raw mut user_desc,
};
set_thread_area(user_desc_ptr).expect("Failed to set thread area for TLS");

let new_fs_selector = ((user_desc.entry_number & 0xfff) << 3) | 0x3; // user mode
// set fs selector
unsafe {
core::arch::asm!(
"mov fs, {0:x}",
in(reg) new_fs_selector,
options(nostack, preserves_flags)
);
}
}

#[cfg(target_arch = "x86_64")]
fn release_thread_local_storage(&self) -> Self::ThreadLocalStorage {
let tls = Self::get_thread_local_storage();
assert!(!tls.is_null(), "TLS must be set before releasing it");
unsafe {
litebox_common_linux::wrgsbase(0);
}

let tls = unsafe { Box::from_raw(tls) };
assert!(!tls.borrowed, "TLS must not be borrowed when releasing it");
*tls
}

#[cfg(target_arch = "x86")]
fn release_thread_local_storage(&self) -> Self::ThreadLocalStorage {
let tls = Self::get_thread_local_storage();
assert!(!tls.is_null(), "TLS must be set before releasing it");
unsafe {
core::arch::asm!(
"mov fs, {0}",
in(reg) 0,
options(nostack, preserves_flags)
);
}

let tls = unsafe { Box::from_raw(tls) };
assert!(!tls.borrowed, "TLS must not be borrowed when releasing it");
*tls
}

fn with_thread_local_storage_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut Self::ThreadLocalStorage) -> R,
{
let tls = Self::get_thread_local_storage();
assert!(!tls.is_null(), "TLS must be set before accessing it");
let tls = unsafe { &mut *tls };
assert!(!tls.borrowed, "TLS is already borrowed");
tls.borrowed = true; // mark as borrowed
let ret = f(tls);
tls.borrowed = false; // mark as not borrowed anymore
ret
}
}

#[cfg(test)]
mod tests {
use core::sync::atomic::AtomicU32;
use std::thread::sleep;

use litebox::platform::RawMutex;
use litebox::platform::{RawMutex, ThreadLocalStorageProvider as _};

use crate::LinuxUserland;
use litebox::platform::PageManagementProvider;
Expand Down Expand Up @@ -1249,4 +1390,29 @@ mod tests {
prev = page.end;
}
}

#[test]
fn test_tls() {
let platform = LinuxUserland::new(None);
let tls = LinuxUserland::get_thread_local_storage();
assert!(tls.is_null(), "TLS should be null in the main thread");
platform.set_thread_local_storage(litebox_common_linux::ThreadLocalStorage::new(Box::new(
litebox_common_linux::Task { tid: 0xffff },
)));
platform.with_thread_local_storage_mut(|tls| {
assert_eq!(
tls.current_task.tid, 0xffff,
"TLS should have the correct task ID"
);
tls.current_task.tid = 0x1234; // Change the task ID
});
let tls = platform.release_thread_local_storage();
assert_eq!(
tls.current_task.tid, 0x1234,
"TLS should have the correct task ID"
);

let tls = LinuxUserland::get_thread_local_storage();
assert!(tls.is_null(), "TLS should be null after releasing it");
}
}