Refactor Waker handling in async code.

Instead of storing `Option<Waker>` in the Lua registry, store it on the reference thread.
It gives approx +10% performance gain when calling async function.
This commit is contained in:
Alex Orlenko 2021-10-03 21:21:53 +01:00
parent c62b17a5c8
commit d098c9ccf6
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
3 changed files with 76 additions and 61 deletions

View File

@ -76,9 +76,13 @@ struct ExtraData {
ref_stack_top: c_int,
ref_free: Vec<c_int>,
// Pool of preallocated `WrappedFailure` enums
// Pool of preallocated `WrappedFailure` enums on the ref thread
wrapped_failures_pool: Vec<c_int>,
// Index of `Option<Waker>` userdata on the ref thread
#[cfg(feature = "async")]
ref_waker_idx: c_int,
hook_callback: Option<HookCallback>,
}
@ -148,8 +152,6 @@ impl LuaOptions {
#[cfg(feature = "async")]
pub(crate) static ASYNC_POLL_PENDING: u8 = 0;
#[cfg(feature = "async")]
pub(crate) static WAKER_REGISTRY_KEY: u8 = 0;
pub(crate) static EXTRA_REGISTRY_KEY: u8 = 0;
const WRAPPED_FAILURES_POOL_SIZE: usize = 16;
@ -169,6 +171,13 @@ impl Drop for Lua {
ffi::lua_replace(extra.ref_thread, index);
extra.ref_free.push(index);
}
#[cfg(feature = "async")]
{
// Destroy Waker slot
ffi::lua_pushnil(extra.ref_thread);
ffi::lua_replace(extra.ref_thread, extra.ref_waker_idx);
extra.ref_free.push(extra.ref_waker_idx);
}
mlua_debug_assert!(
ffi::lua_gettop(extra.ref_thread) == extra.ref_stack_top
&& extra.ref_stack_top as usize == extra.ref_free.len(),
@ -411,13 +420,6 @@ impl Lua {
init_gc_metatable::<AsyncCallbackUpvalue>(state, None)?;
init_gc_metatable::<AsyncPollUpvalue>(state, None)?;
init_gc_metatable::<Option<Waker>>(state, None)?;
// Create empty Waker slot
push_gc_userdata::<Option<Waker>>(state, None)?;
protect_lua!(state, 1, 0, fn(state) {
let waker_key = &WAKER_REGISTRY_KEY as *const u8 as *const c_void;
ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, waker_key);
})?;
}
// Init serde metatables
@ -440,6 +442,17 @@ impl Lua {
"Error while creating ref thread",
);
// Create empty Waker slot on the ref thread
#[cfg(feature = "async")]
let ref_waker_idx = {
mlua_expect!(
push_gc_userdata::<Option<Waker>>(ref_thread, None),
"Error while creating Waker slot"
);
ffi::lua_gettop(ref_thread)
};
let ref_stack_top = ffi::lua_gettop(ref_thread);
// Create ExtraData
let extra = Arc::new(UnsafeCell::new(ExtraData {
@ -452,9 +465,11 @@ impl Lua {
safe: false,
// We need 1 extra stack space to move values in and out of the ref stack.
ref_stack_size: ffi::LUA_MINSTACK - 1,
ref_stack_top: 0,
ref_stack_top,
ref_free: Vec::new(),
wrapped_failures_pool: Vec::new(),
#[cfg(feature = "async")]
ref_waker_idx,
hook_callback: None,
}));
@ -1720,10 +1735,14 @@ impl Lua {
}
}
#[cfg(feature = "serialize")]
/// Executes the function provided on the ref thread
#[inline]
pub(crate) unsafe fn get_ref_ptr(&self, lref: &LuaRef) -> *const c_void {
ffi::lua_topointer((*self.extra.get()).ref_thread, lref.index)
pub(crate) unsafe fn ref_thread_exec<F, R>(&self, f: F) -> R
where
F: FnOnce(*mut ffi::lua_State) -> R,
{
let ref_thread = (*self.extra.get()).ref_thread;
f(ref_thread)
}
pub(crate) unsafe fn push_userdata_metatable<T: 'static + UserData>(&self) -> Result<()> {
@ -2008,14 +2027,7 @@ impl Lua {
lua.state = state;
// Try to get an outer poll waker
let waker_key = &WAKER_REGISTRY_KEY as *const u8 as *const c_void;
ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, waker_key);
let waker = match get_gc_userdata::<Option<Waker>>(state, -1).as_ref() {
Some(Some(waker)) => waker.clone(),
_ => noop_waker(),
};
ffi::lua_pop(state, 1);
let waker = lua.waker().unwrap_or_else(noop_waker);
let mut ctx = Context::from_waker(&waker);
let fut = &mut (*upvalue).fut;
@ -2090,6 +2102,22 @@ impl Lua {
.into_function()
}
#[cfg(feature = "async")]
pub(crate) unsafe fn waker(&self) -> Option<Waker> {
let extra = &*self.extra.get();
(*get_userdata::<Option<Waker>>(extra.ref_thread, extra.ref_waker_idx)).clone()
}
#[cfg(feature = "async")]
pub(crate) unsafe fn set_waker(&self, waker: Option<Waker>) -> Option<Waker> {
let extra = &*self.extra.get();
let waker_slot = &mut *get_userdata::<Option<Waker>>(extra.ref_thread, extra.ref_waker_idx);
match waker {
Some(waker) => waker_slot.replace(waker),
None => waker_slot.take(),
}
}
pub(crate) unsafe fn make_userdata<T>(&self, data: UserDataCell<T>) -> Result<AnyUserData>
where
T: 'static + UserData,

View File

@ -7,6 +7,7 @@ use std::string::String as StdString;
use serde::de::{self, IntoDeserializer};
use crate::error::{Error, Result};
use crate::ffi;
use crate::table::{Table, TablePairs, TableSequence};
use crate::value::Value;
@ -500,7 +501,9 @@ impl RecursionGuard {
#[inline]
fn new(table: &Table, visited: &Rc<RefCell<HashSet<*const c_void>>>) -> Self {
let visited = Rc::clone(visited);
let ptr = unsafe { table.0.lua.get_ref_ptr(&table.0) };
let lua = table.0.lua;
let ptr =
unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, table.0.index)) };
visited.borrow_mut().insert(ptr);
RecursionGuard { ptr, visited }
}
@ -521,7 +524,8 @@ fn check_value_if_skip(
match value {
Value::Table(table) => {
let lua = table.0.lua;
let ptr = unsafe { lua.get_ref_ptr(&table.0) };
let ptr =
unsafe { lua.ref_thread_exec(|refthr| ffi::lua_topointer(refthr, table.0.index)) };
if visited.borrow().contains(&ptr) {
if options.deny_recursive_tables {
return Err(de::Error::custom("recursive table detected"));

View File

@ -4,7 +4,7 @@ use std::os::raw::c_int;
use crate::error::{Error, Result};
use crate::ffi;
use crate::types::LuaRef;
use crate::util::{assert_stack, check_stack, error_traceback, pop_error, StackGuard};
use crate::util::{check_stack, error_traceback, pop_error, StackGuard};
use crate::value::{FromLuaMulti, MultiValue, ToLuaMulti};
#[cfg(any(feature = "lua54", all(feature = "luajit", feature = "vendored"), doc))]
@ -13,15 +13,13 @@ use crate::function::Function;
#[cfg(feature = "async")]
use {
crate::{
lua::{ASYNC_POLL_PENDING, WAKER_REGISTRY_KEY},
util::get_gc_userdata,
lua::{Lua, ASYNC_POLL_PENDING},
value::Value,
},
futures_core::{future::Future, stream::Stream},
std::{
cell::RefCell,
marker::PhantomData,
mem,
os::raw::c_void,
pin::Pin,
task::{Context, Poll, Waker},
@ -114,11 +112,10 @@ impl<'lua> Thread<'lua> {
let nargs = args.len() as c_int;
let results = unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, cmp::min(nargs + 1, 3))?;
check_stack(lua.state, cmp::max(nargs + 1, 3))?;
lua.push_ref(&self.0);
let thread_state = ffi::lua_tothread(lua.state, -1);
ffi::lua_pop(lua.state, 1);
let thread_state =
lua.ref_thread_exec(|ref_thread| ffi::lua_tothread(ref_thread, self.0.index));
let status = ffi::lua_status(thread_state);
if status != ffi::LUA_YIELD && ffi::lua_gettop(thread_state) == 0 {
@ -155,12 +152,8 @@ impl<'lua> Thread<'lua> {
pub fn status(&self) -> ThreadStatus {
let lua = self.0.lua;
unsafe {
let _sg = StackGuard::new(lua.state);
assert_stack(lua.state, 1);
lua.push_ref(&self.0);
let thread_state = ffi::lua_tothread(lua.state, -1);
ffi::lua_pop(lua.state, 1);
let thread_state =
lua.ref_thread_exec(|ref_thread| ffi::lua_tothread(ref_thread, self.0.index));
let status = ffi::lua_status(thread_state);
if status != ffi::LUA_OK && status != ffi::LUA_YIELD {
@ -288,7 +281,7 @@ where
_ => return Poll::Ready(None),
};
let _wg = WakerGuard::new(lua.state, cx.waker().clone());
let _wg = WakerGuard::new(lua, cx.waker().clone());
let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() {
self.thread.resume(args?)?
} else {
@ -319,7 +312,7 @@ where
_ => return Poll::Ready(Err(Error::CoroutineInactive)),
};
let _wg = WakerGuard::new(lua.state, cx.waker().clone());
let _wg = WakerGuard::new(lua, cx.waker().clone());
let ret: MultiValue = if let Some(args) = self.args0.borrow_mut().take() {
self.thread.resume(args?)?
} else {
@ -352,37 +345,27 @@ fn is_poll_pending(val: &MultiValue) -> bool {
}
#[cfg(feature = "async")]
struct WakerGuard(*mut ffi::lua_State, Option<Waker>);
struct WakerGuard<'lua> {
lua: &'lua Lua,
prev: Option<Waker>,
}
#[cfg(feature = "async")]
impl WakerGuard {
pub fn new(state: *mut ffi::lua_State, waker: Waker) -> Result<WakerGuard> {
impl<'lua> WakerGuard<'lua> {
#[inline]
pub fn new(lua: &Lua, waker: Waker) -> Result<WakerGuard> {
unsafe {
let _sg = StackGuard::new(state);
check_stack(state, 3)?;
let waker_key = &WAKER_REGISTRY_KEY as *const u8 as *const c_void;
ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, waker_key);
let waker_slot = get_gc_userdata::<Option<Waker>>(state, -1).as_mut();
let old = mlua_expect!(waker_slot, "Waker is destroyed").replace(waker);
Ok(WakerGuard(state, old))
let prev = lua.set_waker(Some(waker));
Ok(WakerGuard { lua, prev })
}
}
}
#[cfg(feature = "async")]
impl Drop for WakerGuard {
impl<'lua> Drop for WakerGuard<'lua> {
fn drop(&mut self) {
let state = self.0;
unsafe {
let _sg = StackGuard::new(state);
assert_stack(state, 3);
let waker_key = &WAKER_REGISTRY_KEY as *const u8 as *const c_void;
ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, waker_key);
let waker_slot = get_gc_userdata::<Option<Waker>>(state, -1).as_mut();
mem::swap(mlua_expect!(waker_slot, "Waker is destroyed"), &mut self.1);
self.lua.set_waker(self.prev.take());
}
}
}