Refactor Waker handling in async code.
Instead of storing `Option<Waker>` in the Lua registry, store it on the reference thread. It gives approx +10% performance gain when calling async function.
This commit is contained in:
parent
c62b17a5c8
commit
d098c9ccf6
72
src/lua.rs
72
src/lua.rs
|
@ -76,9 +76,13 @@ struct ExtraData {
|
|||
ref_stack_top: c_int,
|
||||
ref_free: Vec<c_int>,
|
||||
|
||||
// Pool of preallocated `WrappedFailure` enums
|
||||
// Pool of preallocated `WrappedFailure` enums on the ref thread
|
||||
wrapped_failures_pool: Vec<c_int>,
|
||||
|
||||
// Index of `Option<Waker>` userdata on the ref thread
|
||||
#[cfg(feature = "async")]
|
||||
ref_waker_idx: c_int,
|
||||
|
||||
hook_callback: Option<HookCallback>,
|
||||
}
|
||||
|
||||
|
@ -148,8 +152,6 @@ impl LuaOptions {
|
|||
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) static ASYNC_POLL_PENDING: u8 = 0;
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) static WAKER_REGISTRY_KEY: u8 = 0;
|
||||
pub(crate) static EXTRA_REGISTRY_KEY: u8 = 0;
|
||||
|
||||
const WRAPPED_FAILURES_POOL_SIZE: usize = 16;
|
||||
|
@ -169,6 +171,13 @@ impl Drop for Lua {
|
|||
ffi::lua_replace(extra.ref_thread, index);
|
||||
extra.ref_free.push(index);
|
||||
}
|
||||
#[cfg(feature = "async")]
|
||||
{
|
||||
// Destroy Waker slot
|
||||
ffi::lua_pushnil(extra.ref_thread);
|
||||
ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx);
|
||||
extra.ref_free.push(extra.ref_waker_idx);
|
||||
}
|
||||
mlua_debug_assert!(
|
||||
ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top
|
||||
&& extra.ref_stack_top as usize == extra.ref_free.len(),
|
||||
|
@ -411,13 +420,6 @@ impl Lua {
|
|||
init_gc_metatable::<AsyncCallbackUpvalue>(state, None)?;
|
||||
init_gc_metatable::<AsyncPollUpvalue>(state, None)?;
|
||||
init_gc_metatable::<Option<Waker>>(state, None)?;
|
||||
|
||||
// Create empty Waker slot
|
||||
push_gc_userdata::<Option<Waker>>(state, None)?;
|
||||
protect_lua!(state, 1, 0, fn(state) {
|
||||
let waker_key = &WAKER_REGISTRY_KEY as *const u8 as *const c_void;
|
||||
ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, waker_key);
|
||||
})?;
|
||||
}
|
||||
|
||||
// Init serde metatables
|
||||
|
@ -440,6 +442,17 @@ impl Lua {
|
|||
"Error while creating ref thread",
|
||||
);
|
||||
|
||||
// Create empty Waker slot on the ref thread
|
||||
#[cfg(feature = "async")]
|
||||
let ref_waker_idx = {
|
||||
mlua_expect!(
|
||||
push_gc_userdata::<Option<Waker>>(ref_thread, None),
|
||||
"Error while creating Waker slot"
|
||||
);
|
||||
ffi::lua_gettop(ref_thread)
|
||||
};
|
||||
let ref_stack_top = ffi::lua_gettop(ref_thread);
|
||||
|
||||
// Create ExtraData
|
||||
|
||||
let extra = Arc::new(UnsafeCell::new(ExtraData {
|
||||
|
@ -452,9 +465,11 @@ impl Lua {
|
|||
safe: false,
|
||||
// We need 1 extra stack space to move values in and out of the ref stack.
|
||||
ref_stack_size: ffi::LUA_MINSTACK - 1,
|
||||
ref_stack_top: 0,
|
||||
ref_stack_top,
|
||||
ref_free: Vec::new(),
|
||||
wrapped_failures_pool: Vec::new(),
|
||||
#[cfg(feature = "async")]
|
||||
ref_waker_idx,
|
||||
hook_callback: None,
|
||||
}));
|
||||
|
||||
|
@ -1720,10 +1735,14 @@ impl Lua {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serialize")]
|
||||
/// Executes the function provided on the ref thread
|
||||
#[inline]
|
||||
pub(crate) unsafe fn get_ref_ptr(&self, lref: &LuaRef) -> *const c_void {
|
||||
ffi::lua_topointer((*self.extra.get()).ref_thread, lref.index)
|
||||
pub(crate) unsafe fn ref_thread_exec<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(*mut ffi::lua_State) -> R,
|
||||
{
|
||||
let ref_thread = (*self.extra.get()).ref_thread;
|
||||
f(ref_thread)
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn push_userdata_metatable<T: 'static + UserData>(&self) -> Result<()> {
|
||||
|
@ -2008,14 +2027,7 @@ impl Lua {
|
|||
lua.state = state;
|
||||
|
||||
// Try to get an outer poll waker
|
||||
let waker_key = &WAKER_REGISTRY_KEY as *const u8 as *const c_void;
|
||||
ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, waker_key);
|
||||
let waker = match get_gc_userdata::<Option<Waker>>(state, -1).as_ref() {
|
||||
Some(Some(waker)) => waker.clone(),
|
||||
_ => noop_waker(),
|
||||
};
|
||||
ffi::lua_pop(state, 1);
|
||||
|
||||
let waker = lua.waker().unwrap_or_else(noop_waker);
|
||||
let mut ctx = Context::from_waker(&waker);
|
||||
|
||||
let fut = &mut (*upvalue).fut;
|
||||
|
@ -2090,6 +2102,22 @@ impl Lua {
|
|||
.into_function()
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) unsafe fn waker(&self) -> Option<Waker> {
|
||||
let extra = &*self.extra.get();
|
||||
(*get_userdata::<Option<Waker>>(extra.ref_thread, extra.ref_waker_idx)).clone()
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) unsafe fn set_waker(&self, waker: Option<Waker>) -> Option<Waker> {
|
||||
let extra = &*self.extra.get();
|
||||
let waker_slot = &mut *get_userdata::<Option<Waker>>(extra.ref_thread, extra.ref_waker_idx);
|
||||
match waker {
|
||||
Some(waker) => waker_slot.replace(waker),
|
||||
None => waker_slot.take(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn make_userdata<T>(&self, data: UserDataCell<T>) -> Result<AnyUserData>
|
||||
where
|
||||
T: 'static + UserData,
|
||||
|
|
|
@ -7,6 +7,7 @@ use std::string::String as StdString;
|
|||
use serde::de::{self, IntoDeserializer};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::ffi;
|
||||
use crate::table::{Table, TablePairs, TableSequence};
|
||||
use crate::value::Value;
|
||||
|
||||
|
@ -500,7 +501,9 @@ impl RecursionGuard {
|
|||
#[inline]
|
||||
fn new(table: &Table, visited: &Rc<RefCell<HashSet<*const c_void>>>) -> Self {
|
||||
let visited = Rc::clone(visited);
|
||||
let ptr = unsafe { table.0.lua.get_ref_ptr(&table.0) };
|
||||
let lua = table.0.lua;
|
||||
let ptr =
|
||||
unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, table.0.index)) };
|
||||
visited.borrow_mut().insert(ptr);
|
||||
RecursionGuard { ptr, visited }
|
||||
}
|
||||
|
@ -521,7 +524,8 @@ fn check_value_if_skip(
|
|||
match value {
|
||||
Value::Table(table) => {
|
||||
let lua = table.0.lua;
|
||||
let ptr = unsafe { lua.get_ref_ptr(&table.0) };
|
||||
let ptr =
|
||||
unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, table.0.index)) };
|
||||
if visited.borrow().contains(&ptr) {
|
||||
if options.deny_recursive_tables {
|
||||
return Err(de::Error::custom("recursive table detected"));
|
||||
|
|
|
@ -4,7 +4,7 @@ use std::os::raw::c_int;
|
|||
use crate::error::{Error, Result};
|
||||
use crate::ffi;
|
||||
use crate::types::LuaRef;
|
||||
use crate::util::{assert_stack, check_stack, error_traceback, pop_error, StackGuard};
|
||||
use crate::util::{check_stack, error_traceback, pop_error, StackGuard};
|
||||
use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti};
|
||||
|
||||
#[cfg(any(feature = "lua54", all(feature = "luajit", feature = "vendored"), doc))]
|
||||
|
@ -13,15 +13,13 @@ use crate::function::Function;
|
|||
#[cfg(feature = "async")]
|
||||
use {
|
||||
crate::{
|
||||
lua::{ASYNC_POLL_PENDING, WAKER_REGISTRY_KEY},
|
||||
util::get_gc_userdata,
|
||||
lua::{Lua, ASYNC_POLL_PENDING},
|
||||
value::Value,
|
||||
},
|
||||
futures_core::{future::Future, stream::Stream},
|
||||
std::{
|
||||
cell::RefCell,
|
||||
marker::PhantomData,
|
||||
mem,
|
||||
os::raw::c_void,
|
||||
pin::Pin,
|
||||
task::{Context, Poll, Waker},
|
||||
|
@ -114,11 +112,10 @@ impl<'lua> Thread<'lua> {
|
|||
let nargs = args.len() as c_int;
|
||||
let results = unsafe {
|
||||
let _sg = StackGuard::new(lua.state);
|
||||
check_stack(lua.state, cmp::min(nargs + 1, 3))?;
|
||||
check_stack(lua.state, cmp::max(nargs + 1, 3))?;
|
||||
|
||||
lua.push_ref(&self.0);
|
||||
let thread_state = ffi::lua_tothread(lua.state, -1);
|
||||
ffi::lua_pop(lua.state, 1);
|
||||
let thread_state =
|
||||
lua.ref_thread_exec(|ref_thread| ffi::lua_tothread(ref_thread, self.0.index));
|
||||
|
||||
let status = ffi::lua_status(thread_state);
|
||||
if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 {
|
||||
|
@ -155,12 +152,8 @@ impl<'lua> Thread<'lua> {
|
|||
pub fn status(&self) -> ThreadStatus {
|
||||
let lua = self.0.lua;
|
||||
unsafe {
|
||||
let _sg = StackGuard::new(lua.state);
|
||||
assert_stack(lua.state, 1);
|
||||
|
||||
lua.push_ref(&self.0);
|
||||
let thread_state = ffi::lua_tothread(lua.state, -1);
|
||||
ffi::lua_pop(lua.state, 1);
|
||||
let thread_state =
|
||||
lua.ref_thread_exec(|ref_thread| ffi::lua_tothread(ref_thread, self.0.index));
|
||||
|
||||
let status = ffi::lua_status(thread_state);
|
||||
if status != ffi::LUA_OK && status != ffi::LUA_YIELD {
|
||||
|
@ -288,7 +281,7 @@ where
|
|||
_ => return Poll::Ready(None),
|
||||
};
|
||||
|
||||
let _wg = WakerGuard::new(lua.state, cx.waker().clone());
|
||||
let _wg = WakerGuard::new(lua, cx.waker().clone());
|
||||
let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() {
|
||||
self.thread.resume(args?)?
|
||||
} else {
|
||||
|
@ -319,7 +312,7 @@ where
|
|||
_ => return Poll::Ready(Err(Error::CoroutineInactive)),
|
||||
};
|
||||
|
||||
let _wg = WakerGuard::new(lua.state, cx.waker().clone());
|
||||
let _wg = WakerGuard::new(lua, cx.waker().clone());
|
||||
let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() {
|
||||
self.thread.resume(args?)?
|
||||
} else {
|
||||
|
@ -352,37 +345,27 @@ fn is_poll_pending(val: &MultiValue) -> bool {
|
|||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
struct WakerGuard(*mut ffi::lua_State, Option<Waker>);
|
||||
struct WakerGuard<'lua> {
|
||||
lua: &'lua Lua,
|
||||
prev: Option<Waker>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
impl WakerGuard {
|
||||
pub fn new(state: *mut ffi::lua_State, waker: Waker) -> Result<WakerGuard> {
|
||||
impl<'lua> WakerGuard<'lua> {
|
||||
#[inline]
|
||||
pub fn new(lua: &Lua, waker: Waker) -> Result<WakerGuard> {
|
||||
unsafe {
|
||||
let _sg = StackGuard::new(state);
|
||||
check_stack(state, 3)?;
|
||||
|
||||
let waker_key = &WAKER_REGISTRY_KEY as *const u8 as *const c_void;
|
||||
ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, waker_key);
|
||||
let waker_slot = get_gc_userdata::<Option<Waker>>(state, -1).as_mut();
|
||||
let old = mlua_expect!(waker_slot, "Waker is destroyed").replace(waker);
|
||||
|
||||
Ok(WakerGuard(state, old))
|
||||
let prev = lua.set_waker(Some(waker));
|
||||
Ok(WakerGuard { lua, prev })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
impl Drop for WakerGuard {
|
||||
impl<'lua> Drop for WakerGuard<'lua> {
|
||||
fn drop(&mut self) {
|
||||
let state = self.0;
|
||||
unsafe {
|
||||
let _sg = StackGuard::new(state);
|
||||
assert_stack(state, 3);
|
||||
|
||||
let waker_key = &WAKER_REGISTRY_KEY as *const u8 as *const c_void;
|
||||
ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, waker_key);
|
||||
let waker_slot = get_gc_userdata::<Option<Waker>>(state, -1).as_mut();
|
||||
mem::swap(mlua_expect!(waker_slot, "Waker is destroyed"), &mut self.1);
|
||||
self.lua.set_waker(self.prev.take());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue