diff --git a/src/lua.rs b/src/lua.rs index 81ac37d..15af0a7 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -5,6 +5,8 @@ use std::collections::HashMap; use std::ffi::CString; use std::fmt; use std::marker::PhantomData; +use std::mem::ManuallyDrop; +use std::ops::{Deref, DerefMut}; use std::os::raw::{c_char, c_int, c_void}; use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe, Location}; use std::sync::{Arc, Mutex}; @@ -66,19 +68,25 @@ use { #[cfg(feature = "serialize")] use serde::Serialize; -/// Top level Lua struct which holds the Lua state itself. -pub struct Lua { +/// Top level Lua struct which represents an instance of Lua VM. +#[repr(transparent)] +pub struct Lua(Arc>); + +/// An inner Lua struct which holds a raw Lua state. +pub struct LuaInner { pub(crate) state: *mut ffi::lua_State, main_state: *mut ffi::lua_State, extra: Arc>, - ephemeral: bool, safe: bool, // Lua has lots of interior mutability, should not be RefUnwindSafe _no_ref_unwind_safe: PhantomData>, } // Data associated with the Lua. -struct ExtraData { +pub(crate) struct ExtraData { + // Same layout as `Lua` + inner: Option>>>, + registered_userdata: FxHashMap, registered_userdata_mt: FxHashMap<*const c_void, Option>, registry_unref_list: Arc>>>, @@ -90,7 +98,6 @@ struct ExtraData { libs: StdLib, mem_info: Option>, - safe: bool, // Same as in the Lua struct ref_thread: *mut ffi::lua_State, ref_stack_size: c_int, @@ -221,48 +228,52 @@ const MULTIVALUE_CACHE_SIZE: usize = 32; /// Requires `feature = "send"` #[cfg(feature = "send")] #[cfg_attr(docsrs, doc(cfg(feature = "send")))] -unsafe impl Send for Lua {} +unsafe impl Send for LuaInner {} -impl Drop for Lua { +#[cfg(not(feature = "module"))] +impl Drop for LuaInner { fn drop(&mut self) { unsafe { - if !self.ephemeral { - let extra = &mut *self.extra.get(); - let drain_iter = extra.wrapped_failures_cache.drain(..); - #[cfg(feature = "async")] - let drain_iter = drain_iter.chain(extra.recycled_thread_cache.drain(..)); - for index in drain_iter { - ffi::lua_pushnil(extra.ref_thread); - ffi::lua_replace(extra.ref_thread, index); - extra.ref_free.push(index); - } - #[cfg(feature = "async")] - { - // Destroy Waker slot - ffi::lua_pushnil(extra.ref_thread); - ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx); - extra.ref_free.push(extra.ref_waker_idx); - } - #[cfg(feature = "luau")] - { - let callbacks = ffi::lua_callbacks(self.state); - let extra_ptr = (*callbacks).userdata as *mut Arc>; - drop(Box::from_raw(extra_ptr)); - (*callbacks).userdata = ptr::null_mut(); - } - mlua_debug_assert!( - ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top - && extra.ref_stack_top as usize == extra.ref_free.len(), - "reference leak detected" - ); - ffi::lua_close(self.main_state); + let extra = &mut *self.extra.get(); + let drain_iter = extra.wrapped_failures_cache.drain(..); + #[cfg(feature = "async")] + let drain_iter = drain_iter.chain(extra.recycled_thread_cache.drain(..)); + for index in drain_iter { + ffi::lua_pushnil(extra.ref_thread); + ffi::lua_replace(extra.ref_thread, index); + extra.ref_free.push(index); } + #[cfg(feature = "async")] + { + // Destroy Waker slot + ffi::lua_pushnil(extra.ref_thread); + ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx); + extra.ref_free.push(extra.ref_waker_idx); + } + #[cfg(feature = "luau")] + { + let callbacks = ffi::lua_callbacks(self.state); + let extra_ptr = (*callbacks).userdata as *mut Arc>; + drop(Box::from_raw(extra_ptr)); + (*callbacks).userdata = ptr::null_mut(); + } + mlua_debug_assert!( + ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top + && extra.ref_stack_top as usize == extra.ref_free.len(), + "reference leak detected" + ); + ffi::lua_close(self.main_state); } } } impl Drop for ExtraData { fn drop(&mut self) { + #[cfg(feature = "module")] + unsafe { + ManuallyDrop::drop(&mut self.inner.take().unwrap()) + }; + *mlua_expect!(self.registry_unref_list.lock(), "unref list poisoned") = None; if let Some(mem_info) = self.mem_info { drop(unsafe { Box::from_raw(mem_info.as_ptr()) }); @@ -276,6 +287,20 @@ impl fmt::Debug for Lua { } } +impl Deref for Lua { + type Target = LuaInner; + + fn deref(&self) -> &Self::Target { + unsafe { &*(*self.0).get() } + } +} + +impl DerefMut for Lua { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *(*self.0).get() } + } +} + impl Lua { /// Creates a new Lua state and loads the **safe** subset of the standard libraries. /// @@ -336,7 +361,6 @@ impl Lua { mlua_expect!(lua.disable_c_modules(), "Error during disabling C modules"); } lua.safe = true; - unsafe { (*lua.extra.get()).safe = true }; Ok(lua) } @@ -430,9 +454,7 @@ impl Lua { ffi::luaL_requiref(state, cstr!("_G"), ffi::luaopen_base, 1); ffi::lua_pop(state, 1); - let mut lua = Lua::init_from_ptr(state); - lua.ephemeral = false; - + let lua = Lua::init_from_ptr(state); let extra = &mut *lua.extra.get(); #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] @@ -544,6 +566,7 @@ impl Lua { // Create ExtraData let extra = Arc::new(UnsafeCell::new(ExtraData { + inner: None, registered_userdata: FxHashMap::default(), registered_userdata_mt: FxHashMap::default(), registry_unref_list: Arc::new(Mutex::new(Some(Vec::new()))), @@ -551,7 +574,6 @@ impl Lua { ref_thread, libs: StdLib::NONE, mem_info: None, - safe: false, // We need 1 extra stack space to move values in and out of the ref stack. ref_stack_size: ffi::LUA_MINSTACK - 1, ref_stack_top, @@ -606,14 +628,19 @@ impl Lua { (*ffi::lua_callbacks(main_state)).userdata = extra_raw as *mut c_void; } - Lua { + let inner = Arc::new(UnsafeCell::new(LuaInner { state, main_state, - extra, - ephemeral: true, + extra: Arc::clone(&extra), safe: false, _no_ref_unwind_safe: PhantomData, - } + })); + + (*extra.get()).inner = Some(ManuallyDrop::new(Arc::clone(&inner))); + #[cfg(not(feature = "module"))] + Arc::decrement_strong_count(Arc::as_ptr(&inner)); + + Lua(inner) } /// Loads the specified subset of the standard libraries into an existing Lua state. @@ -1476,12 +1503,11 @@ impl Lua { /// /// [`ToLua`]: crate::ToLua /// [`ToLuaMulti`]: crate::ToLuaMulti - pub fn create_function<'lua, 'callback, A, R, F>(&'lua self, func: F) -> Result> + pub fn create_function<'lua, A, R, F>(&'lua self, func: F) -> Result> where - 'lua: 'callback, - A: FromLuaMulti<'callback>, - R: ToLuaMulti<'callback>, - F: 'static + MaybeSend + Fn(&'callback Lua, A) -> Result, + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result, { self.create_callback(Box::new(move |lua, args| { func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua) @@ -1494,15 +1520,11 @@ impl Lua { /// [`create_function`] for more information about the implementation. /// /// [`create_function`]: #method.create_function - pub fn create_function_mut<'lua, 'callback, A, R, F>( - &'lua self, - func: F, - ) -> Result> + pub fn create_function_mut<'lua, A, R, F>(&'lua self, func: F) -> Result> where - 'lua: 'callback, - A: FromLuaMulti<'callback>, - R: ToLuaMulti<'callback>, - F: 'static + MaybeSend + FnMut(&'callback Lua, A) -> Result, + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result, { let func = RefCell::new(func); self.create_function(move |lua, args| { @@ -1564,15 +1586,11 @@ impl Lua { /// [`AsyncThread`]: crate::AsyncThread #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub fn create_async_function<'lua, 'callback, A, R, F, FR>( - &'lua self, - func: F, - ) -> Result> + pub fn create_async_function<'lua, A, R, F, FR>(&'lua self, func: F) -> Result> where - 'lua: 'callback, - A: FromLuaMulti<'callback>, - R: ToLuaMulti<'callback>, - F: 'static + MaybeSend + Fn(&'callback Lua, A) -> FR, + A: FromLuaMulti<'lua>, + R: ToLuaMulti<'lua>, + F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR, FR: 'lua + Future>, { self.create_async_callback(Box::new(move |lua, args| { @@ -2459,25 +2477,22 @@ impl Lua { } // Creates a Function out of a Callback containing a 'static Fn. This is safe ONLY because the - // Fn is 'static, otherwise it could capture 'callback arguments improperly. Without ATCs, we + // Fn is 'static, otherwise it could capture 'lua arguments improperly. Without ATCs, we // cannot easily deal with the "correct" callback type of: // // Box Fn(&'lua Lua, MultiValue<'lua>) -> Result>)> // // So we instead use a caller provided lifetime, which without the 'static requirement would be // unsafe. - pub(crate) fn create_callback<'lua, 'callback>( + pub(crate) fn create_callback<'lua>( &'lua self, - func: Callback<'callback, 'static>, - ) -> Result> - where - 'lua: 'callback, - { + func: Callback<'lua, 'static>, + ) -> Result> { unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int { let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) { ffi::LUA_TUSERDATA => { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - (*upvalue).lua.extra.get() + (*upvalue).extra.get() } _ => ptr::null_mut(), }; @@ -2492,10 +2507,10 @@ impl Lua { check_stack(state, ffi::LUA_MINSTACK - nargs)?; } - let mut lua = (*upvalue).lua.clone(); - lua.state = state; + let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap()); + let _guard = StateGuard::new(&mut *lua.0.get(), state); - let mut args = MultiValue::new_or_cached(&lua); + let mut args = MultiValue::new_or_cached(lua); args.reserve(nargs as usize); for _ in 0..nargs { args.push_front(lua.pop_value()); @@ -2518,9 +2533,9 @@ impl Lua { let _sg = StackGuard::new(self.state); check_stack(self.state, 4)?; - let lua = self.clone(); let func = mem::transmute(func); - push_gc_userdata(self.state, CallbackUpvalue { lua, func })?; + let extra = Arc::clone(&self.extra); + push_gc_userdata(self.state, CallbackUpvalue { extra, func })?; protect_lua!(self.state, 1, 1, fn(state) { ffi::lua_pushcclosure(state, call_callback, 1); })?; @@ -2530,13 +2545,10 @@ impl Lua { } #[cfg(feature = "async")] - pub(crate) fn create_async_callback<'lua, 'callback>( + pub(crate) fn create_async_callback<'lua>( &'lua self, - func: AsyncCallback<'callback, 'static>, - ) -> Result> - where - 'lua: 'callback, - { + func: AsyncCallback<'lua, 'static>, + ) -> Result> { #[cfg(any( feature = "lua54", feature = "lua53", @@ -2550,28 +2562,12 @@ impl Lua { } } - struct StateGuard(*mut Lua, *mut ffi::lua_State); - - impl StateGuard { - unsafe fn new(lua: *mut Lua, state: *mut ffi::lua_State) -> Self { - let orig_state = (*lua).state; - (*lua).state = state; - Self(lua, orig_state) - } - } - - impl Drop for StateGuard { - fn drop(&mut self) { - unsafe { (*self.0).state = self.1 } - } - } - unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int { let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) { ffi::LUA_TUSERDATA => { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - (*upvalue).lua.extra.get() + (*upvalue).extra.get() } _ => ptr::null_mut(), }; @@ -2586,8 +2582,8 @@ impl Lua { check_stack(state, ffi::LUA_MINSTACK - nargs)?; } - let lua = &mut (*upvalue).lua; - let _guard = StateGuard::new(lua, state); + let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap()); + let _guard = StateGuard::new(&mut *lua.0.get(), state); let mut args = MultiValue::new_or_cached(lua); args.reserve(nargs as usize); @@ -2596,8 +2592,8 @@ impl Lua { } let fut = ((*upvalue).func)(lua, args); - let lua = lua.clone(); - push_gc_userdata(state, AsyncPollUpvalue { lua, fut })?; + let extra = Arc::clone(&(*upvalue).extra); + push_gc_userdata(state, AsyncPollUpvalue { extra, fut })?; protect_lua!(state, 1, 1, fn(state) { ffi::lua_pushcclosure(state, poll_future, 1); })?; @@ -2610,7 +2606,7 @@ impl Lua { let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) { ffi::LUA_TUSERDATA => { let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - (*upvalue).lua.extra.get() + (*upvalue).extra.get() } _ => ptr::null_mut(), }; @@ -2625,8 +2621,8 @@ impl Lua { check_stack(state, ffi::LUA_MINSTACK - nargs)?; } - let lua = &mut (*upvalue).lua; - lua.state = state; + let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap()); + let _guard = StateGuard::new(&mut *lua.0.get(), state); // Try to get an outer poll waker let waker = lua.waker().unwrap_or_else(noop_waker); @@ -2657,9 +2653,9 @@ impl Lua { let _sg = StackGuard::new(self.state); check_stack(self.state, 4)?; - let lua = self.clone(); let func = mem::transmute(func); - push_gc_userdata(self.state, AsyncCallbackUpvalue { lua, func })?; + let extra = Arc::clone(&self.extra); + push_gc_userdata(self.state, AsyncCallbackUpvalue { extra, func })?; protect_lua!(self.state, 1, 1, fn(state) { ffi::lua_pushcclosure(state, call_callback, 1); })?; @@ -2752,18 +2748,6 @@ impl Lua { Ok(AnyUserData(self.pop_ref())) } - #[inline] - pub(crate) fn clone(&self) -> Self { - Lua { - state: self.state, - main_state: self.main_state, - extra: Arc::clone(&self.extra), - ephemeral: true, - safe: self.safe, - _no_ref_unwind_safe: PhantomData, - } - } - #[cfg(not(feature = "luau"))] fn disable_c_modules(&self) -> Result<()> { let package: Table = self.globals().get("package")?; @@ -2794,17 +2778,9 @@ impl Lua { pub(crate) unsafe fn make_from_ptr(state: *mut ffi::lua_State) -> Option { let _sg = StackGuard::new(state); assert_stack(state, 1); - let extra = extra_data(state)?; - let safe = (*extra.get()).safe; - Some(Lua { - state, - main_state: get_main_state(state).unwrap_or(state), - extra, - ephemeral: true, - safe, - _no_ref_unwind_safe: PhantomData, - }) + let inner = &*(*extra.get()).inner.as_ref().unwrap(); + Some(Lua(Arc::clone(inner))) } #[inline] @@ -2827,6 +2803,21 @@ impl Lua { } } +struct StateGuard<'a>(&'a mut LuaInner, *mut ffi::lua_State); + +impl<'a> StateGuard<'a> { + fn new(inner: &'a mut LuaInner, mut state: *mut ffi::lua_State) -> Self { + mem::swap(&mut (*inner).state, &mut state); + Self(inner, state) + } +} + +impl<'a> Drop for StateGuard<'a> { + fn drop(&mut self) { + mem::swap(&mut (*self.0).state, &mut self.1); + } +} + #[cfg(feature = "luau")] unsafe fn extra_data(state: *mut ffi::lua_State) -> Option>> { let extra_ptr = (*ffi::lua_callbacks(state)).userdata as *mut Arc>; diff --git a/src/types.rs b/src/types.rs index a14ff3f..2d88f72 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,3 +1,4 @@ +use std::cell::UnsafeCell; use std::hash::{Hash, Hasher}; use std::os::raw::{c_int, c_void}; use std::sync::{Arc, Mutex}; @@ -13,7 +14,7 @@ use crate::error::Result; use crate::ffi; #[cfg(not(feature = "luau"))] use crate::hook::Debug; -use crate::lua::Lua; +use crate::lua::{ExtraData, Lua}; use crate::util::{assert_stack, StackGuard}; use crate::value::MultiValue; @@ -30,7 +31,7 @@ pub(crate) type Callback<'lua, 'a> = Box) -> Result> + 'a>; pub(crate) struct CallbackUpvalue<'lua> { - pub(crate) lua: Lua, + pub(crate) extra: Arc>, pub(crate) func: Callback<'lua, 'static>, } @@ -40,13 +41,13 @@ pub(crate) type AsyncCallback<'lua, 'a> = #[cfg(feature = "async")] pub(crate) struct AsyncCallbackUpvalue<'lua> { - pub(crate) lua: Lua, + pub(crate) extra: Arc>, pub(crate) func: AsyncCallback<'lua, 'static>, } #[cfg(feature = "async")] pub(crate) struct AsyncPollUpvalue<'lua> { - pub(crate) lua: Lua, + pub(crate) extra: Arc>, pub(crate) fut: LocalBoxFuture<'lua, Result>>, }