Refactor Lua instance structure.

The idea is to keep same Lua instance across all calls and only change context inside callbacks.
This should solve #104.
This commit is contained in:
Alex Orlenko 2022-04-13 13:41:13 +01:00
parent 5cd82d0f6b
commit 0215c31a3a
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
2 changed files with 130 additions and 138 deletions

View File

@ -5,6 +5,8 @@ use std::collections::HashMap;
use std::ffi::CString;
use std::fmt;
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe, Location};
use std::sync::{Arc, Mutex};
@ -66,19 +68,25 @@ use {
#[cfg(feature = "serialize")]
use serde::Serialize;
/// Top level Lua struct which holds the Lua state itself.
pub struct Lua {
/// Top level Lua struct which represents an instance of Lua VM.
#[repr(transparent)]
pub struct Lua(Arc<UnsafeCell<LuaInner>>);
/// An inner Lua struct which holds a raw Lua state.
pub struct LuaInner {
pub(crate) state: *mut ffi::lua_State,
main_state: *mut ffi::lua_State,
extra: Arc<UnsafeCell<ExtraData>>,
ephemeral: bool,
safe: bool,
// Lua has lots of interior mutability, should not be RefUnwindSafe
_no_ref_unwind_safe: PhantomData<UnsafeCell<()>>,
}
// Data associated with the Lua.
struct ExtraData {
pub(crate) struct ExtraData {
// Same layout as `Lua`
inner: Option<ManuallyDrop<Arc<UnsafeCell<LuaInner>>>>,
registered_userdata: FxHashMap<TypeId, c_int>,
registered_userdata_mt: FxHashMap<*const c_void, Option<TypeId>>,
registry_unref_list: Arc<Mutex<Option<Vec<c_int>>>>,
@ -90,7 +98,6 @@ struct ExtraData {
libs: StdLib,
mem_info: Option<ptr::NonNull<MemoryInfo>>,
safe: bool, // Same as in the Lua struct
ref_thread: *mut ffi::lua_State,
ref_stack_size: c_int,
@ -221,48 +228,52 @@ const MULTIVALUE_CACHE_SIZE: usize = 32;
/// Requires `feature = "send"`
#[cfg(feature = "send")]
#[cfg_attr(docsrs, doc(cfg(feature = "send")))]
unsafe impl Send for Lua {}
unsafe impl Send for LuaInner {}
impl Drop for Lua {
#[cfg(not(feature = "module"))]
impl Drop for LuaInner {
fn drop(&mut self) {
unsafe {
if !self.ephemeral {
let extra = &mut *self.extra.get();
let drain_iter = extra.wrapped_failures_cache.drain(..);
#[cfg(feature = "async")]
let drain_iter = drain_iter.chain(extra.recycled_thread_cache.drain(..));
for index in drain_iter {
ffi::lua_pushnil(extra.ref_thread);
ffi::lua_replace(extra.ref_thread, index);
extra.ref_free.push(index);
}
#[cfg(feature = "async")]
{
// Destroy Waker slot
ffi::lua_pushnil(extra.ref_thread);
ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx);
extra.ref_free.push(extra.ref_waker_idx);
}
#[cfg(feature = "luau")]
{
let callbacks = ffi::lua_callbacks(self.state);
let extra_ptr = (*callbacks).userdata as *mut Arc<UnsafeCell<ExtraData>>;
drop(Box::from_raw(extra_ptr));
(*callbacks).userdata = ptr::null_mut();
}
mlua_debug_assert!(
ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top
&& extra.ref_stack_top as usize == extra.ref_free.len(),
"reference leak detected"
);
ffi::lua_close(self.main_state);
let extra = &mut *self.extra.get();
let drain_iter = extra.wrapped_failures_cache.drain(..);
#[cfg(feature = "async")]
let drain_iter = drain_iter.chain(extra.recycled_thread_cache.drain(..));
for index in drain_iter {
ffi::lua_pushnil(extra.ref_thread);
ffi::lua_replace(extra.ref_thread, index);
extra.ref_free.push(index);
}
#[cfg(feature = "async")]
{
// Destroy Waker slot
ffi::lua_pushnil(extra.ref_thread);
ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx);
extra.ref_free.push(extra.ref_waker_idx);
}
#[cfg(feature = "luau")]
{
let callbacks = ffi::lua_callbacks(self.state);
let extra_ptr = (*callbacks).userdata as *mut Arc<UnsafeCell<ExtraData>>;
drop(Box::from_raw(extra_ptr));
(*callbacks).userdata = ptr::null_mut();
}
mlua_debug_assert!(
ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top
&& extra.ref_stack_top as usize == extra.ref_free.len(),
"reference leak detected"
);
ffi::lua_close(self.main_state);
}
}
}
impl Drop for ExtraData {
fn drop(&mut self) {
#[cfg(feature = "module")]
unsafe {
ManuallyDrop::drop(&mut self.inner.take().unwrap())
};
*mlua_expect!(self.registry_unref_list.lock(), "unref list poisoned") = None;
if let Some(mem_info) = self.mem_info {
drop(unsafe { Box::from_raw(mem_info.as_ptr()) });
@ -276,6 +287,20 @@ impl fmt::Debug for Lua {
}
}
impl Deref for Lua {
type Target = LuaInner;
fn deref(&self) -> &Self::Target {
unsafe { &*(*self.0).get() }
}
}
impl DerefMut for Lua {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *(*self.0).get() }
}
}
impl Lua {
/// Creates a new Lua state and loads the **safe** subset of the standard libraries.
///
@ -336,7 +361,6 @@ impl Lua {
mlua_expect!(lua.disable_c_modules(), "Error during disabling C modules");
}
lua.safe = true;
unsafe { (*lua.extra.get()).safe = true };
Ok(lua)
}
@ -430,9 +454,7 @@ impl Lua {
ffi::luaL_requiref(state, cstr!("_G"), ffi::luaopen_base, 1);
ffi::lua_pop(state, 1);
let mut lua = Lua::init_from_ptr(state);
lua.ephemeral = false;
let lua = Lua::init_from_ptr(state);
let extra = &mut *lua.extra.get();
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
@ -544,6 +566,7 @@ impl Lua {
// Create ExtraData
let extra = Arc::new(UnsafeCell::new(ExtraData {
inner: None,
registered_userdata: FxHashMap::default(),
registered_userdata_mt: FxHashMap::default(),
registry_unref_list: Arc::new(Mutex::new(Some(Vec::new()))),
@ -551,7 +574,6 @@ impl Lua {
ref_thread,
libs: StdLib::NONE,
mem_info: None,
safe: false,
// We need 1 extra stack space to move values in and out of the ref stack.
ref_stack_size: ffi::LUA_MINSTACK - 1,
ref_stack_top,
@ -606,14 +628,19 @@ impl Lua {
(*ffi::lua_callbacks(main_state)).userdata = extra_raw as *mut c_void;
}
Lua {
let inner = Arc::new(UnsafeCell::new(LuaInner {
state,
main_state,
extra,
ephemeral: true,
extra: Arc::clone(&extra),
safe: false,
_no_ref_unwind_safe: PhantomData,
}
}));
(*extra.get()).inner = Some(ManuallyDrop::new(Arc::clone(&inner)));
#[cfg(not(feature = "module"))]
Arc::decrement_strong_count(Arc::as_ptr(&inner));
Lua(inner)
}
/// Loads the specified subset of the standard libraries into an existing Lua state.
@ -1476,12 +1503,11 @@ impl Lua {
///
/// [`ToLua`]: crate::ToLua
/// [`ToLuaMulti`]: crate::ToLuaMulti
pub fn create_function<'lua, 'callback, A, R, F>(&'lua self, func: F) -> Result<Function<'lua>>
pub fn create_function<'lua, A, R, F>(&'lua self, func: F) -> Result<Function<'lua>>
where
'lua: 'callback,
A: FromLuaMulti<'callback>,
R: ToLuaMulti<'callback>,
F: 'static + MaybeSend + Fn(&'callback Lua, A) -> Result<R>,
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
F: 'static + MaybeSend + Fn(&'lua Lua, A) -> Result<R>,
{
self.create_callback(Box::new(move |lua, args| {
func(lua, A::from_lua_multi(args, lua)?)?.to_lua_multi(lua)
@ -1494,15 +1520,11 @@ impl Lua {
/// [`create_function`] for more information about the implementation.
///
/// [`create_function`]: #method.create_function
pub fn create_function_mut<'lua, 'callback, A, R, F>(
&'lua self,
func: F,
) -> Result<Function<'lua>>
pub fn create_function_mut<'lua, A, R, F>(&'lua self, func: F) -> Result<Function<'lua>>
where
'lua: 'callback,
A: FromLuaMulti<'callback>,
R: ToLuaMulti<'callback>,
F: 'static + MaybeSend + FnMut(&'callback Lua, A) -> Result<R>,
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
F: 'static + MaybeSend + FnMut(&'lua Lua, A) -> Result<R>,
{
let func = RefCell::new(func);
self.create_function(move |lua, args| {
@ -1564,15 +1586,11 @@ impl Lua {
/// [`AsyncThread`]: crate::AsyncThread
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub fn create_async_function<'lua, 'callback, A, R, F, FR>(
&'lua self,
func: F,
) -> Result<Function<'lua>>
pub fn create_async_function<'lua, A, R, F, FR>(&'lua self, func: F) -> Result<Function<'lua>>
where
'lua: 'callback,
A: FromLuaMulti<'callback>,
R: ToLuaMulti<'callback>,
F: 'static + MaybeSend + Fn(&'callback Lua, A) -> FR,
A: FromLuaMulti<'lua>,
R: ToLuaMulti<'lua>,
F: 'static + MaybeSend + Fn(&'lua Lua, A) -> FR,
FR: 'lua + Future<Output = Result<R>>,
{
self.create_async_callback(Box::new(move |lua, args| {
@ -2459,25 +2477,22 @@ impl Lua {
}
// Creates a Function out of a Callback containing a 'static Fn. This is safe ONLY because the
// Fn is 'static, otherwise it could capture 'callback arguments improperly. Without ATCs, we
// Fn is 'static, otherwise it could capture 'lua arguments improperly. Without ATCs, we
// cannot easily deal with the "correct" callback type of:
//
// Box<for<'lua> Fn(&'lua Lua, MultiValue<'lua>) -> Result<MultiValue<'lua>>)>
//
// So we instead use a caller provided lifetime, which without the 'static requirement would be
// unsafe.
pub(crate) fn create_callback<'lua, 'callback>(
pub(crate) fn create_callback<'lua>(
&'lua self,
func: Callback<'callback, 'static>,
) -> Result<Function<'lua>>
where
'lua: 'callback,
{
func: Callback<'lua, 'static>,
) -> Result<Function<'lua>> {
unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int {
let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) {
ffi::LUA_TUSERDATA => {
let upvalue = get_userdata::<CallbackUpvalue>(state, ffi::lua_upvalueindex(1));
(*upvalue).lua.extra.get()
(*upvalue).extra.get()
}
_ => ptr::null_mut(),
};
@ -2492,10 +2507,10 @@ impl Lua {
check_stack(state, ffi::LUA_MINSTACK - nargs)?;
}
let mut lua = (*upvalue).lua.clone();
lua.state = state;
let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap());
let _guard = StateGuard::new(&mut *lua.0.get(), state);
let mut args = MultiValue::new_or_cached(&lua);
let mut args = MultiValue::new_or_cached(lua);
args.reserve(nargs as usize);
for _ in 0..nargs {
args.push_front(lua.pop_value());
@ -2518,9 +2533,9 @@ impl Lua {
let _sg = StackGuard::new(self.state);
check_stack(self.state, 4)?;
let lua = self.clone();
let func = mem::transmute(func);
push_gc_userdata(self.state, CallbackUpvalue { lua, func })?;
let extra = Arc::clone(&self.extra);
push_gc_userdata(self.state, CallbackUpvalue { extra, func })?;
protect_lua!(self.state, 1, 1, fn(state) {
ffi::lua_pushcclosure(state, call_callback, 1);
})?;
@ -2530,13 +2545,10 @@ impl Lua {
}
#[cfg(feature = "async")]
pub(crate) fn create_async_callback<'lua, 'callback>(
pub(crate) fn create_async_callback<'lua>(
&'lua self,
func: AsyncCallback<'callback, 'static>,
) -> Result<Function<'lua>>
where
'lua: 'callback,
{
func: AsyncCallback<'lua, 'static>,
) -> Result<Function<'lua>> {
#[cfg(any(
feature = "lua54",
feature = "lua53",
@ -2550,28 +2562,12 @@ impl Lua {
}
}
struct StateGuard(*mut Lua, *mut ffi::lua_State);
impl StateGuard {
unsafe fn new(lua: *mut Lua, state: *mut ffi::lua_State) -> Self {
let orig_state = (*lua).state;
(*lua).state = state;
Self(lua, orig_state)
}
}
impl Drop for StateGuard {
fn drop(&mut self) {
unsafe { (*self.0).state = self.1 }
}
}
unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int {
let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) {
ffi::LUA_TUSERDATA => {
let upvalue =
get_userdata::<AsyncCallbackUpvalue>(state, ffi::lua_upvalueindex(1));
(*upvalue).lua.extra.get()
(*upvalue).extra.get()
}
_ => ptr::null_mut(),
};
@ -2586,8 +2582,8 @@ impl Lua {
check_stack(state, ffi::LUA_MINSTACK - nargs)?;
}
let lua = &mut (*upvalue).lua;
let _guard = StateGuard::new(lua, state);
let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap());
let _guard = StateGuard::new(&mut *lua.0.get(), state);
let mut args = MultiValue::new_or_cached(lua);
args.reserve(nargs as usize);
@ -2596,8 +2592,8 @@ impl Lua {
}
let fut = ((*upvalue).func)(lua, args);
let lua = lua.clone();
push_gc_userdata(state, AsyncPollUpvalue { lua, fut })?;
let extra = Arc::clone(&(*upvalue).extra);
push_gc_userdata(state, AsyncPollUpvalue { extra, fut })?;
protect_lua!(state, 1, 1, fn(state) {
ffi::lua_pushcclosure(state, poll_future, 1);
})?;
@ -2610,7 +2606,7 @@ impl Lua {
let extra = match ffi::lua_type(state, ffi::lua_upvalueindex(1)) {
ffi::LUA_TUSERDATA => {
let upvalue = get_userdata::<AsyncPollUpvalue>(state, ffi::lua_upvalueindex(1));
(*upvalue).lua.extra.get()
(*upvalue).extra.get()
}
_ => ptr::null_mut(),
};
@ -2625,8 +2621,8 @@ impl Lua {
check_stack(state, ffi::LUA_MINSTACK - nargs)?;
}
let lua = &mut (*upvalue).lua;
lua.state = state;
let lua: &Lua = mem::transmute((*extra).inner.as_ref().unwrap());
let _guard = StateGuard::new(&mut *lua.0.get(), state);
// Try to get an outer poll waker
let waker = lua.waker().unwrap_or_else(noop_waker);
@ -2657,9 +2653,9 @@ impl Lua {
let _sg = StackGuard::new(self.state);
check_stack(self.state, 4)?;
let lua = self.clone();
let func = mem::transmute(func);
push_gc_userdata(self.state, AsyncCallbackUpvalue { lua, func })?;
let extra = Arc::clone(&self.extra);
push_gc_userdata(self.state, AsyncCallbackUpvalue { extra, func })?;
protect_lua!(self.state, 1, 1, fn(state) {
ffi::lua_pushcclosure(state, call_callback, 1);
})?;
@ -2752,18 +2748,6 @@ impl Lua {
Ok(AnyUserData(self.pop_ref()))
}
#[inline]
pub(crate) fn clone(&self) -> Self {
Lua {
state: self.state,
main_state: self.main_state,
extra: Arc::clone(&self.extra),
ephemeral: true,
safe: self.safe,
_no_ref_unwind_safe: PhantomData,
}
}
#[cfg(not(feature = "luau"))]
fn disable_c_modules(&self) -> Result<()> {
let package: Table = self.globals().get("package")?;
@ -2794,17 +2778,9 @@ impl Lua {
pub(crate) unsafe fn make_from_ptr(state: *mut ffi::lua_State) -> Option<Self> {
let _sg = StackGuard::new(state);
assert_stack(state, 1);
let extra = extra_data(state)?;
let safe = (*extra.get()).safe;
Some(Lua {
state,
main_state: get_main_state(state).unwrap_or(state),
extra,
ephemeral: true,
safe,
_no_ref_unwind_safe: PhantomData,
})
let inner = &*(*extra.get()).inner.as_ref().unwrap();
Some(Lua(Arc::clone(inner)))
}
#[inline]
@ -2827,6 +2803,21 @@ impl Lua {
}
}
struct StateGuard<'a>(&'a mut LuaInner, *mut ffi::lua_State);
impl<'a> StateGuard<'a> {
fn new(inner: &'a mut LuaInner, mut state: *mut ffi::lua_State) -> Self {
mem::swap(&mut (*inner).state, &mut state);
Self(inner, state)
}
}
impl<'a> Drop for StateGuard<'a> {
fn drop(&mut self) {
mem::swap(&mut (*self.0).state, &mut self.1);
}
}
#[cfg(feature = "luau")]
unsafe fn extra_data(state: *mut ffi::lua_State) -> Option<Arc<UnsafeCell<ExtraData>>> {
let extra_ptr = (*ffi::lua_callbacks(state)).userdata as *mut Arc<UnsafeCell<ExtraData>>;

View File

@ -1,3 +1,4 @@
use std::cell::UnsafeCell;
use std::hash::{Hash, Hasher};
use std::os::raw::{c_int, c_void};
use std::sync::{Arc, Mutex};
@ -13,7 +14,7 @@ use crate::error::Result;
use crate::ffi;
#[cfg(not(feature = "luau"))]
use crate::hook::Debug;
use crate::lua::Lua;
use crate::lua::{ExtraData, Lua};
use crate::util::{assert_stack, StackGuard};
use crate::value::MultiValue;
@ -30,7 +31,7 @@ pub(crate) type Callback<'lua, 'a> =
Box<dyn Fn(&'lua Lua, MultiValue<'lua>) -> Result<MultiValue<'lua>> + 'a>;
pub(crate) struct CallbackUpvalue<'lua> {
pub(crate) lua: Lua,
pub(crate) extra: Arc<UnsafeCell<ExtraData>>,
pub(crate) func: Callback<'lua, 'static>,
}
@ -40,13 +41,13 @@ pub(crate) type AsyncCallback<'lua, 'a> =
#[cfg(feature = "async")]
pub(crate) struct AsyncCallbackUpvalue<'lua> {
pub(crate) lua: Lua,
pub(crate) extra: Arc<UnsafeCell<ExtraData>>,
pub(crate) func: AsyncCallback<'lua, 'static>,
}
#[cfg(feature = "async")]
pub(crate) struct AsyncPollUpvalue<'lua> {
pub(crate) lua: Lua,
pub(crate) extra: Arc<UnsafeCell<ExtraData>>,
pub(crate) fut: LocalBoxFuture<'lua, Result<MultiValue<'lua>>>,
}