From 9fdba541e983ddee34c602863df64fd3fe9f26bf Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Thu, 15 Jun 2023 00:34:41 +0100 Subject: [PATCH] Update `UserDataMethods::add_async_method()` functions to take `&T` as second argument instead of cloning `T`. New functions: `UserDataMethods::add_async_method_mut()`, `UserDataMethods::add_async_meta_method_mut()`. --- examples/async_http_client.rs | 5 +- examples/async_tcp_server.rs | 35 +- src/lua.rs | 25 +- src/scope.rs | 48 +- src/userdata.rs | 65 ++- src/userdata_impl.rs | 500 +++++++++++------- tests/async.rs | 27 +- tests/compile.rs | 6 +- tests/compile/async_any_userdata_method.rs | 14 + .../compile/async_any_userdata_method.stderr | 81 +++ tests/compile/async_nonstatic_userdata.stderr | 5 +- tests/compile/async_userdata_method.rs | 14 + tests/compile/async_userdata_method.stderr | 17 + 13 files changed, 572 insertions(+), 270 deletions(-) create mode 100644 tests/compile/async_any_userdata_method.rs create mode 100644 tests/compile/async_any_userdata_method.stderr create mode 100644 tests/compile/async_userdata_method.rs create mode 100644 tests/compile/async_userdata_method.stderr diff --git a/examples/async_http_client.rs b/examples/async_http_client.rs index dde997e..bdd9dbc 100644 --- a/examples/async_http_client.rs +++ b/examples/async_http_client.rs @@ -3,14 +3,13 @@ use std::collections::HashMap; use hyper::body::{Body as HyperBody, HttpBody as _}; use hyper::Client as HyperClient; -use mlua::{chunk, AnyUserData, ExternalResult, Lua, Result, UserData, UserDataMethods}; +use mlua::{chunk, ExternalResult, Lua, Result, UserData, UserDataMethods}; struct BodyReader(HyperBody); impl UserData for BodyReader { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_async_function("read", |lua, reader: AnyUserData| async move { - let mut reader = reader.borrow_mut::()?; + methods.add_async_method_mut("read", |lua, reader, ()| async move { if let Some(bytes) = reader.0.data().await { let bytes = bytes.into_lua_err()?; return Some(lua.create_string(&bytes)).transpose(); diff --git a/examples/async_tcp_server.rs b/examples/async_tcp_server.rs index edfc114..edc4149 100644 --- a/examples/async_tcp_server.rs +++ b/examples/async_tcp_server.rs @@ -6,9 +6,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::task; -use mlua::{ - chunk, AnyUserData, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods, -}; +use mlua::{chunk, Function, Lua, RegistryKey, String as LuaString, UserData, UserDataMethods}; struct LuaTcpStream(TcpStream); @@ -18,28 +16,19 @@ impl UserData for LuaTcpStream { Ok(this.0.peer_addr()?.to_string()) }); - methods.add_async_function( - "read", - |lua, (this, size): (AnyUserData, usize)| async move { - let mut this = this.borrow_mut::()?; - let mut buf = vec![0; size]; - let n = this.0.read(&mut buf).await?; - buf.truncate(n); - lua.create_string(&buf) - }, - ); + methods.add_async_method_mut("read", |lua, this, size| async move { + let mut buf = vec![0; size]; + let n = this.0.read(&mut buf).await?; + buf.truncate(n); + lua.create_string(&buf) + }); - methods.add_async_function( - "write", - |_, (this, data): (AnyUserData, LuaString)| async move { - let mut this = this.borrow_mut::()?; - let n = this.0.write(&data.as_bytes()).await?; - Ok(n) - }, - ); + methods.add_async_method_mut("write", |_, this, data: LuaString| async move { + let n = this.0.write(&data.as_bytes()).await?; + Ok(n) + }); - methods.add_async_function("close", |_, this: AnyUserData| async move { - let mut this = this.borrow_mut::()?; + methods.add_async_method_mut("close", |_, this, ()| async move { this.0.shutdown().await?; Ok(()) }); diff --git a/src/lua.rs b/src/lua.rs index ebd1260..17d041c 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -2658,18 +2658,17 @@ impl Lua { } } - // Pushes a LuaRef value onto the stack, checking that it's a registered + // Returns `TypeId` for the LuaRef, checking that it's a registered // and not destructed UserData. - // Uses 2 stack spaces, does not call checkstack. - pub(crate) unsafe fn push_userdata_ref(&self, lref: &LuaRef) -> Result> { - let state = self.state(); - self.push_ref(lref); - if ffi::lua_getmetatable(state, -1) == 0 { - ffi::lua_pop(state, 1); + // + // Returns `None` if the userdata is registered but non-static. + pub(crate) unsafe fn get_userdata_type_id(&self, lref: &LuaRef) -> Result> { + let ref_thread = self.ref_thread(); + if ffi::lua_getmetatable(ref_thread, lref.index) == 0 { return Err(Error::UserDataTypeMismatch); } - let mt_ptr = ffi::lua_topointer(state, -1); - ffi::lua_pop(state, 1); + let mt_ptr = ffi::lua_topointer(ref_thread, -1); + ffi::lua_pop(ref_thread, 1); // Fast path to skip looking up the metatable in the map let (last_mt, last_type_id) = (*self.extra.get()).last_checked_userdata_mt; @@ -2689,6 +2688,14 @@ impl Lua { } } + // Pushes a LuaRef (userdata) value onto the stack, returning their `TypeId`. + // Uses 1 stack space, does not call checkstack. + pub(crate) unsafe fn push_userdata_ref(&self, lref: &LuaRef) -> Result> { + let type_id = self.get_userdata_type_id(lref)?; + self.push_ref(lref); + Ok(type_id) + } + // Creates a Function out of a Callback containing a 'static Fn. This is safe ONLY because the // Fn is 'static, otherwise it could capture 'lua arguments improperly. Without ATCs, we // cannot easily deal with the "correct" callback type of: diff --git a/src/scope.rs b/src/scope.rs index 3d4cb94..a6f8b48 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -611,12 +611,28 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> for NonStaticUserDataMethods<'l } #[cfg(feature = "async")] - fn add_async_method(&mut self, _name: impl AsRef, _method: M) + fn add_async_method<'s, M, A, MR, R>(&mut self, _name: impl AsRef, _method: M) where - T: Clone, - M: Fn(&'lua Lua, T, A) -> MR + MaybeSend + 'static, + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s T, A) -> MR + MaybeSend + 'static, A: FromLuaMulti<'lua>, - MR: Future> + 'lua, + MR: Future> + 's, + R: IntoLuaMulti<'lua>, + { + // The panic should never happen as async non-static code wouldn't compile + // Non-static lifetime must be bounded to 'lua lifetime + panic!("asynchronous methods are not supported for non-static userdata") + } + + #[cfg(feature = "async")] + fn add_async_method_mut<'s, M, A, MR, R>(&mut self, _name: impl AsRef, _method: M) + where + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s mut T, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti<'lua>, + MR: Future> + 's, R: IntoLuaMulti<'lua>, { // The panic should never happen as async non-static code wouldn't compile @@ -686,12 +702,28 @@ impl<'lua, T: UserData> UserDataMethods<'lua, T> for NonStaticUserDataMethods<'l } #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] - fn add_async_meta_method(&mut self, _name: impl AsRef, _method: M) + fn add_async_meta_method<'s, M, A, MR, R>(&mut self, _name: impl AsRef, _method: M) where - T: Clone, - M: Fn(&'lua Lua, T, A) -> MR + MaybeSend + 'static, + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s T, A) -> MR + MaybeSend + 'static, A: FromLuaMulti<'lua>, - MR: Future> + 'lua, + MR: Future> + 's, + R: IntoLuaMulti<'lua>, + { + // The panic should never happen as async non-static code wouldn't compile + // Non-static lifetime must be bounded to 'lua lifetime + panic!("asynchronous meta methods are not supported for non-static userdata") + } + + #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] + fn add_async_meta_method_mut<'s, M, A, MR, R>(&mut self, _name: impl AsRef, _method: M) + where + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s mut T, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti<'lua>, + MR: Future> + 's, R: IntoLuaMulti<'lua>, { // The panic should never happen as async non-static code wouldn't compile diff --git a/src/userdata.rs b/src/userdata.rs index 8215083..d2209ef 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -259,8 +259,7 @@ pub trait UserDataMethods<'lua, T> { A: FromLuaMulti<'lua>, R: IntoLuaMulti<'lua>; - /// Add an async method which accepts a `T` as the first parameter and returns Future. - /// The passed `T` is cloned from the original value. + /// Add an async method which accepts a `&T` as the first parameter and returns Future. /// /// Refer to [`add_method`] for more information about the implementation. /// @@ -269,12 +268,31 @@ pub trait UserDataMethods<'lua, T> { /// [`add_method`]: #method.add_method #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - fn add_async_method(&mut self, name: impl AsRef, method: M) + fn add_async_method<'s, M, A, MR, R>(&mut self, name: impl AsRef, method: M) where - T: Clone, - M: Fn(&'lua Lua, T, A) -> MR + MaybeSend + 'static, + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s T, A) -> MR + MaybeSend + 'static, A: FromLuaMulti<'lua>, - MR: Future> + 'lua, + MR: Future> + 's, + R: IntoLuaMulti<'lua>; + + /// Add an async method which accepts a `&mut T` as the first parameter and returns Future. + /// + /// Refer to [`add_method`] for more information about the implementation. + /// + /// Requires `feature = "async"` + /// + /// [`add_method`]: #method.add_method + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + fn add_async_method_mut<'s, M, A, MR, R>(&mut self, name: impl AsRef, method: M) + where + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s mut T, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti<'lua>, + MR: Future> + 's, R: IntoLuaMulti<'lua>; /// Add a regular method as a function which accepts generic arguments, the first argument will @@ -349,8 +367,7 @@ pub trait UserDataMethods<'lua, T> { A: FromLuaMulti<'lua>, R: IntoLuaMulti<'lua>; - /// Add an async metamethod which accepts a `T` as the first parameter and returns Future. - /// The passed `T` is cloned from the original value. + /// Add an async metamethod which accepts a `&T` as the first parameter and returns Future. /// /// This is an async version of [`add_meta_method`]. /// @@ -359,12 +376,31 @@ pub trait UserDataMethods<'lua, T> { /// [`add_meta_method`]: #method.add_meta_method #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - fn add_async_meta_method(&mut self, name: impl AsRef, method: M) + fn add_async_meta_method<'s, M, A, MR, R>(&mut self, name: impl AsRef, method: M) where - T: Clone, - M: Fn(&'lua Lua, T, A) -> MR + MaybeSend + 'static, + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s T, A) -> MR + MaybeSend + 'static, A: FromLuaMulti<'lua>, - MR: Future> + 'lua, + MR: Future> + 's, + R: IntoLuaMulti<'lua>; + + /// Add an async metamethod which accepts a `&mut T` as the first parameter and returns Future. + /// + /// This is an async version of [`add_meta_method_mut`]. + /// + /// Requires `feature = "async"` + /// + /// [`add_meta_method_mut`]: #method.add_meta_method_mut + #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + fn add_async_meta_method_mut<'s, M, A, MR, R>(&mut self, name: impl AsRef, method: M) + where + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s mut T, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti<'lua>, + MR: Future> + 's, R: IntoLuaMulti<'lua>; /// Add a metamethod which accepts generic arguments. @@ -1055,6 +1091,11 @@ impl<'lua> AnyUserData<'lua> { OwnedAnyUserData(self.0.into_owned()) } + #[inline(always)] + pub(crate) fn type_id(&self) -> Result> { + unsafe { self.0.lua.get_userdata_type_id(&self.0) } + } + /// Returns a type name of this `UserData` (from `__name` metatable field). pub(crate) fn type_name(&self) -> Result> { let lua = self.0.lua; diff --git a/src/userdata_impl.rs b/src/userdata_impl.rs index 363ef63..0d63fd7 100644 --- a/src/userdata_impl.rs +++ b/src/userdata_impl.rs @@ -1,6 +1,9 @@ +#![allow(clippy::await_holding_refcell_ref, clippy::await_holding_lock)] + use std::any::TypeId; use std::cell::{Ref, RefCell, RefMut}; use std::marker::PhantomData; +use std::os::raw::c_int; use std::string::String as StdString; use std::sync::{Arc, Mutex, RwLock}; @@ -10,7 +13,7 @@ use crate::types::{Callback, MaybeSend}; use crate::userdata::{ AnyUserData, MetaMethod, UserData, UserDataCell, UserDataFields, UserDataMethods, }; -use crate::util::{check_stack, get_userdata, short_type_name, StackGuard}; +use crate::util::{get_userdata, short_type_name}; use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, MultiValue, Value}; #[cfg(not(feature = "send"))] @@ -75,62 +78,54 @@ impl<'lua, T: 'static> UserDataRegistrar<'lua, T> { } Box::new(move |lua, mut args| { - let front = args.pop_front(); + let front = args + .pop_front() + .ok_or_else(|| Error::from_lua_conversion("missing argument", "userdata", None)); + let front = try_self_arg!(front); let call = |ud| { // Self was at index 1, so we pass 2 here let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; method(lua, ud, args)?.into_lua_multi(lua) }; - if let Some(front) = front { - let state = lua.state(); - let userdata = try_self_arg!(AnyUserData::from_lua(front, lua)); - unsafe { - let _sg = StackGuard::new(state); - check_stack(state, 2)?; - - let type_id = try_self_arg!(lua.push_userdata_ref(&userdata.0)); - match type_id { - Some(id) if id == TypeId::of::() => { - let ud = try_self_arg!(get_userdata_ref::(state)); - call(&ud) - } - #[cfg(not(feature = "send"))] - Some(id) if id == TypeId::of::>>() => { - let ud = try_self_arg!(get_userdata_ref::>>(state)); - let ud = try_self_arg!(ud.try_borrow(), Error::UserDataBorrowError); - call(&ud) - } - Some(id) if id == TypeId::of::>>() => { - let ud = try_self_arg!(get_userdata_ref::>>(state)); - let ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowError); - call(&ud) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(state); - let ud = try_self_arg!(ud); - let ud = try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowError)); - call(&ud) - } - Some(id) if id == TypeId::of::>>() => { - let ud = try_self_arg!(get_userdata_ref::>>(state)); - let ud = try_self_arg!(ud.try_read(), Error::UserDataBorrowError); - call(&ud) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(state); - let ud = try_self_arg!(ud); - let ud = try_self_arg!(ud.try_read().ok_or(Error::UserDataBorrowError)); - call(&ud) - } - _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), - } - } - } else { - let err = Error::from_lua_conversion("missing argument", "userdata", None); - Err(Error::bad_self_argument(&name, err)) + let userdata = try_self_arg!(AnyUserData::from_lua(front, lua)); + let (ref_thread, index) = (lua.ref_thread(), userdata.0.index); + match try_self_arg!(userdata.type_id()) { + Some(id) if id == TypeId::of::() => unsafe { + let ud = try_self_arg!(get_userdata_ref::(ref_thread, index)); + call(&ud) + }, + #[cfg(not(feature = "send"))] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = try_self_arg!(get_userdata_ref::>>(ref_thread, index)); + let ud = try_self_arg!(ud.try_borrow(), Error::UserDataBorrowError); + call(&ud) + }, + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = try_self_arg!(get_userdata_ref::>>(ref_thread, index)); + let ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowError); + call(&ud) + }, + #[cfg(feature = "parking_lot")] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = get_userdata_ref::>>(ref_thread, index); + let ud = try_self_arg!(ud); + let ud = try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowError)); + call(&ud) + }, + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = try_self_arg!(get_userdata_ref::>>(ref_thread, index)); + let ud = try_self_arg!(ud.try_read(), Error::UserDataBorrowError); + call(&ud) + }, + #[cfg(feature = "parking_lot")] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = get_userdata_ref::>>(ref_thread, index); + let ud = try_self_arg!(ud); + let ud = try_self_arg!(ud.try_read().ok_or(Error::UserDataBorrowError)); + call(&ud) + }, + _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), } }) } @@ -156,158 +151,236 @@ impl<'lua, T: 'static> UserDataRegistrar<'lua, T> { let mut method = method .try_borrow_mut() .map_err(|_| Error::RecursiveMutCallback)?; - let front = args.pop_front(); + let front = args + .pop_front() + .ok_or_else(|| Error::from_lua_conversion("missing argument", "userdata", None)); + let front = try_self_arg!(front); let call = |ud| { // Self was at index 1, so we pass 2 here let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; method(lua, ud, args)?.into_lua_multi(lua) }; - if let Some(front) = front { - let state = lua.state(); - let userdata = try_self_arg!(AnyUserData::from_lua(front, lua)); - unsafe { - let _sg = StackGuard::new(state); - check_stack(state, 2)?; - - let type_id = try_self_arg!(lua.push_userdata_ref(&userdata.0)); - match type_id { - Some(id) if id == TypeId::of::() => { - let mut ud = try_self_arg!(get_userdata_mut::(state)); - call(&mut ud) - } - #[cfg(not(feature = "send"))] - Some(id) if id == TypeId::of::>>() => { - let ud = try_self_arg!(get_userdata_mut::>>(state)); - let mut ud = - try_self_arg!(ud.try_borrow_mut(), Error::UserDataBorrowMutError); - call(&mut ud) - } - Some(id) if id == TypeId::of::>>() => { - let ud = try_self_arg!(get_userdata_mut::>>(state)); - let mut ud = - try_self_arg!(ud.try_lock(), Error::UserDataBorrowMutError); - call(&mut ud) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_mut::>>(state); - let ud = try_self_arg!(ud); - let mut ud = - try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowMutError)); - call(&mut ud) - } - Some(id) if id == TypeId::of::>>() => { - let ud = try_self_arg!(get_userdata_mut::>>(state)); - let mut ud = - try_self_arg!(ud.try_write(), Error::UserDataBorrowMutError); - call(&mut ud) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_mut::>>(state); - let ud = try_self_arg!(ud); - let mut ud = - try_self_arg!(ud.try_write().ok_or(Error::UserDataBorrowMutError)); - call(&mut ud) - } - _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), - } - } - } else { - let err = Error::from_lua_conversion("missing argument", "userdata", None); - Err(Error::bad_self_argument(&name, err)) + let userdata = try_self_arg!(AnyUserData::from_lua(front, lua)); + let (ref_thread, index) = (lua.ref_thread(), userdata.0.index); + match try_self_arg!(userdata.type_id()) { + Some(id) if id == TypeId::of::() => unsafe { + let mut ud = try_self_arg!(get_userdata_mut::(ref_thread, index)); + call(&mut ud) + }, + #[cfg(not(feature = "send"))] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = try_self_arg!(get_userdata_mut::>>(ref_thread, index)); + let mut ud = try_self_arg!(ud.try_borrow_mut(), Error::UserDataBorrowMutError); + call(&mut ud) + }, + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = try_self_arg!(get_userdata_mut::>>(ref_thread, index)); + let mut ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowMutError); + call(&mut ud) + }, + #[cfg(feature = "parking_lot")] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = get_userdata_mut::>>(ref_thread, index); + let ud = try_self_arg!(ud); + let mut ud = try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowMutError)); + call(&mut ud) + }, + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = try_self_arg!(get_userdata_mut::>>(ref_thread, index)); + let mut ud = try_self_arg!(ud.try_write(), Error::UserDataBorrowMutError); + call(&mut ud) + }, + #[cfg(feature = "parking_lot")] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = get_userdata_mut::>>(ref_thread, index); + let ud = try_self_arg!(ud); + let mut ud = try_self_arg!(ud.try_write().ok_or(Error::UserDataBorrowMutError)); + call(&mut ud) + }, + _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), } }) } #[cfg(feature = "async")] - fn box_async_method(name: &str, method: M) -> AsyncCallback<'lua, 'static> + fn box_async_method<'s, M, A, MR, R>(name: &str, method: M) -> AsyncCallback<'lua, 'static> where - T: Clone, - M: Fn(&'lua Lua, T, A) -> MR + MaybeSend + 'static, + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s T, A) -> MR + MaybeSend + 'static, A: FromLuaMulti<'lua>, - MR: Future> + 'lua, + MR: Future> + 's, R: IntoLuaMulti<'lua>, { let name = get_function_name::(name); - macro_rules! try_self_arg { - ($res:expr) => { - $res.map_err(|err| Error::bad_self_argument(&name, err))? - }; - ($res:expr, $err:expr) => { - $res.map_err(|_| Error::bad_self_argument(&name, $err))? - }; - } + let method = Arc::new(method); Box::new(move |lua, mut args| { - let front = args.pop_front(); - let call = |ud| { - // Self was at index 1, so we pass 2 here - let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; - Ok(method(lua, ud, args)) - }; - - let fut_res = || { - if let Some(front) = front { - let state = lua.state(); - let userdata = AnyUserData::from_lua(front, lua)?; - unsafe { - let _sg = StackGuard::new(state); - check_stack(state, 2)?; - - let type_id = try_self_arg!(lua.push_userdata_ref(&userdata.0)); - match type_id { - Some(id) if id == TypeId::of::() => { - let ud = get_userdata_ref::(state)?; - call(ud.clone()) - } - #[cfg(not(feature = "send"))] - Some(id) if id == TypeId::of::>>() => { - let ud = try_self_arg!(get_userdata_ref::>>(state)); - let ud = try_self_arg!(ud.try_borrow(), Error::UserDataBorrowError); - call(ud.clone()) - } - Some(id) if id == TypeId::of::>>() => { - let ud = try_self_arg!(get_userdata_ref::>>(state)); - let ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowError); - call(ud.clone()) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(state); - let ud = try_self_arg!(ud); - let ud = - try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowError)); - call(ud.clone()) - } - Some(id) if id == TypeId::of::>>() => { - let ud = try_self_arg!(get_userdata_ref::>>(state)); - let ud = try_self_arg!(ud.try_read(), Error::UserDataBorrowError); - call(ud.clone()) - } - #[cfg(feature = "parking_lot")] - Some(id) if id == TypeId::of::>>() => { - let ud = get_userdata_ref::>>(state); - let ud = try_self_arg!(ud); - let ud = - try_self_arg!(ud.try_read().ok_or(Error::UserDataBorrowError)); - call(ud.clone()) - } - _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), - } - } - } else { - let err = Error::from_lua_conversion("missing argument", "userdata", None); - Err(Error::bad_self_argument(&name, err)) - } - }; - match fut_res() { - Ok(fut) => { - Box::pin(fut.and_then(move |ret| future::ready(ret.into_lua_multi(lua)))) - } - Err(e) => Box::pin(future::err(e)), + let name = name.clone(); + let method = method.clone(); + macro_rules! try_self_arg { + ($res:expr) => { + $res.map_err(|err| Error::bad_self_argument(&name, err))? + }; + ($res:expr, $err:expr) => { + $res.map_err(|_| Error::bad_self_argument(&name, $err))? + }; } + + Box::pin(async move { + let front = args.pop_front().ok_or_else(|| { + Error::from_lua_conversion("missing argument", "userdata", None) + }); + let front = try_self_arg!(front); + let userdata: AnyUserData = try_self_arg!(AnyUserData::from_lua(front, lua)); + let (ref_thread, index) = (lua.ref_thread(), userdata.0.index); + match try_self_arg!(userdata.type_id()) { + Some(id) if id == TypeId::of::() => unsafe { + let ud = try_self_arg!(get_userdata_ref::(ref_thread, index)); + let ud = std::mem::transmute::<&T, &T>(&ud); + // Self was at index 1, so we pass 2 here + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + #[cfg(not(feature = "send"))] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = + try_self_arg!(get_userdata_ref::>>(ref_thread, index)); + let ud = try_self_arg!(ud.try_borrow(), Error::UserDataBorrowError); + let ud = std::mem::transmute::<&T, &T>(&ud); + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = + try_self_arg!(get_userdata_ref::>>(ref_thread, index)); + let ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowError); + let ud = std::mem::transmute::<&T, &T>(&ud); + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + #[cfg(feature = "parking_lot")] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = get_userdata_ref::>>(ref_thread, index); + let ud = try_self_arg!(ud); + let ud = try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowError)); + let ud = std::mem::transmute::<&T, &T>(&ud); + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = + try_self_arg!(get_userdata_ref::>>(ref_thread, index)); + let ud = try_self_arg!(ud.try_read(), Error::UserDataBorrowError); + let ud = std::mem::transmute::<&T, &T>(&ud); + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + #[cfg(feature = "parking_lot")] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = get_userdata_ref::>>(ref_thread, index); + let ud = try_self_arg!(ud); + let ud = try_self_arg!(ud.try_read().ok_or(Error::UserDataBorrowError)); + let ud = std::mem::transmute::<&T, &T>(&ud); + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), + } + }) + }) + } + + #[cfg(feature = "async")] + fn box_async_method_mut<'s, M, A, MR, R>(name: &str, method: M) -> AsyncCallback<'lua, 'static> + where + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s mut T, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti<'lua>, + MR: Future> + 's, + R: IntoLuaMulti<'lua>, + { + let name = get_function_name::(name); + let method = Arc::new(method); + + Box::new(move |lua, mut args| { + let name = name.clone(); + let method = method.clone(); + macro_rules! try_self_arg { + ($res:expr) => { + $res.map_err(|err| Error::bad_self_argument(&name, err))? + }; + ($res:expr, $err:expr) => { + $res.map_err(|_| Error::bad_self_argument(&name, $err))? + }; + } + + Box::pin(async move { + let front = args.pop_front().ok_or_else(|| { + Error::from_lua_conversion("missing argument", "userdata", None) + }); + let front = try_self_arg!(front); + let userdata: AnyUserData = try_self_arg!(AnyUserData::from_lua(front, lua)); + let (ref_thread, index) = (lua.ref_thread(), userdata.0.index); + match try_self_arg!(userdata.type_id()) { + Some(id) if id == TypeId::of::() => unsafe { + let mut ud = try_self_arg!(get_userdata_mut::(ref_thread, index)); + let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); + // Self was at index 1, so we pass 2 here + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + #[cfg(not(feature = "send"))] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = + try_self_arg!(get_userdata_mut::>>(ref_thread, index)); + let mut ud = + try_self_arg!(ud.try_borrow_mut(), Error::UserDataBorrowMutError); + let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = + try_self_arg!(get_userdata_mut::>>(ref_thread, index)); + let mut ud = try_self_arg!(ud.try_lock(), Error::UserDataBorrowMutError); + let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + #[cfg(feature = "parking_lot")] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = get_userdata_mut::>>(ref_thread, index); + let ud = try_self_arg!(ud); + let mut ud = + try_self_arg!(ud.try_lock().ok_or(Error::UserDataBorrowMutError)); + let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = + try_self_arg!(get_userdata_mut::>>(ref_thread, index)); + let mut ud = try_self_arg!(ud.try_write(), Error::UserDataBorrowMutError); + let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + #[cfg(feature = "parking_lot")] + Some(id) if id == TypeId::of::>>() => unsafe { + let ud = get_userdata_mut::>>(ref_thread, index); + let ud = try_self_arg!(ud); + let mut ud = + try_self_arg!(ud.try_write().ok_or(Error::UserDataBorrowMutError)); + let ud = std::mem::transmute::<&mut T, &mut T>(&mut ud); + let args = A::from_lua_multi_args(args, 2, Some(&name), lua)?; + method(lua, ud, args).await?.into_lua_multi(lua) + }, + _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), + } + }) }) } @@ -500,12 +573,13 @@ impl<'lua, T: 'static> UserDataMethods<'lua, T> for UserDataRegistrar<'lua, T> { } #[cfg(feature = "async")] - fn add_async_method(&mut self, name: impl AsRef, method: M) + fn add_async_method<'s, M, A, MR, R>(&mut self, name: impl AsRef, method: M) where - T: Clone, - M: Fn(&'lua Lua, T, A) -> MR + MaybeSend + 'static, + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s T, A) -> MR + MaybeSend + 'static, A: FromLuaMulti<'lua>, - MR: Future> + 'lua, + MR: Future> + 's, R: IntoLuaMulti<'lua>, { let name = name.as_ref(); @@ -513,6 +587,21 @@ impl<'lua, T: 'static> UserDataMethods<'lua, T> for UserDataRegistrar<'lua, T> { .push((name.into(), Self::box_async_method(name, method))); } + #[cfg(feature = "async")] + fn add_async_method_mut<'s, M, A, MR, R>(&mut self, name: impl AsRef, method: M) + where + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s mut T, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti<'lua>, + MR: Future> + 's, + R: IntoLuaMulti<'lua>, + { + let name = name.as_ref(); + self.async_methods + .push((name.into(), Self::box_async_method_mut(name, method))); + } + fn add_function(&mut self, name: impl AsRef, function: F) where F: Fn(&'lua Lua, A) -> Result + MaybeSend + 'static, @@ -571,12 +660,13 @@ impl<'lua, T: 'static> UserDataMethods<'lua, T> for UserDataRegistrar<'lua, T> { } #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] - fn add_async_meta_method(&mut self, name: impl AsRef, method: M) + fn add_async_meta_method<'s, M, A, MR, R>(&mut self, name: impl AsRef, method: M) where - T: Clone, - M: Fn(&'lua Lua, T, A) -> MR + MaybeSend + 'static, + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s T, A) -> MR + MaybeSend + 'static, A: FromLuaMulti<'lua>, - MR: Future> + 'lua, + MR: Future> + 's, R: IntoLuaMulti<'lua>, { let name = name.as_ref(); @@ -584,6 +674,21 @@ impl<'lua, T: 'static> UserDataMethods<'lua, T> for UserDataRegistrar<'lua, T> { .push((name.into(), Self::box_async_method(name, method))); } + #[cfg(all(feature = "async", not(any(feature = "lua51", feature = "luau"))))] + fn add_async_meta_method_mut<'s, M, A, MR, R>(&mut self, name: impl AsRef, method: M) + where + 'lua: 's, + T: 'static, + M: Fn(&'lua Lua, &'s mut T, A) -> MR + MaybeSend + 'static, + A: FromLuaMulti<'lua>, + MR: Future> + 's, + R: IntoLuaMulti<'lua>, + { + let name = name.as_ref(); + self.async_meta_methods + .push((name.into(), Self::box_async_method_mut(name, method))); + } + fn add_meta_function(&mut self, name: impl AsRef, function: F) where F: Fn(&'lua Lua, A) -> Result + MaybeSend + 'static, @@ -632,13 +737,16 @@ impl<'lua, T: 'static> UserDataMethods<'lua, T> for UserDataRegistrar<'lua, T> { } #[inline] -unsafe fn get_userdata_ref<'a, T>(state: *mut ffi::lua_State) -> Result> { - (*get_userdata::>(state, -1)).try_borrow() +unsafe fn get_userdata_ref<'a, T>(state: *mut ffi::lua_State, index: c_int) -> Result> { + (*get_userdata::>(state, index)).try_borrow() } #[inline] -unsafe fn get_userdata_mut<'a, T>(state: *mut ffi::lua_State) -> Result> { - (*get_userdata::>(state, -1)).try_borrow_mut() +unsafe fn get_userdata_mut<'a, T>( + state: *mut ffi::lua_State, + index: c_int, +) -> Result> { + (*get_userdata::>(state, index)).try_borrow_mut() } macro_rules! lua_userdata_impl { diff --git a/tests/async.rs b/tests/async.rs index e49c8c2..cb91e07 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -1,6 +1,5 @@ #![cfg(feature = "async")] -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -361,19 +360,18 @@ async fn test_async_thread_pool() -> Result<()> { #[tokio::test] async fn test_async_userdata() -> Result<()> { - #[derive(Clone)] - struct MyUserData(Arc); + struct MyUserData(u64); impl UserData for MyUserData { fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_async_method("get_value", |_, data, ()| async move { Delay::new(Duration::from_millis(10)).await; - Ok(data.0.load(Ordering::Relaxed)) + Ok(data.0) }); - methods.add_async_method("set_value", |_, data, n| async move { + methods.add_async_method_mut("set_value", |_, data, n| async move { Delay::new(Duration::from_millis(10)).await; - data.0.store(n, Ordering::Relaxed); + data.0 = n; Ok(()) }); @@ -384,7 +382,7 @@ async fn test_async_userdata() -> Result<()> { #[cfg(not(any(feature = "lua51", feature = "luau")))] methods.add_async_meta_method(mlua::MetaMethod::Call, |_, data, ()| async move { - let n = data.0.load(Ordering::Relaxed); + let n = data.0; Delay::new(Duration::from_millis(n)).await; Ok(format!("elapsed:{}ms", n)) }); @@ -395,23 +393,24 @@ async fn test_async_userdata() -> Result<()> { |_, data, key: String| async move { Delay::new(Duration::from_millis(10)).await; match key.as_str() { - "ms" => Ok(Some(data.0.load(Ordering::Relaxed) as f64)), - "s" => Ok(Some((data.0.load(Ordering::Relaxed) as f64) / 1000.0)), + "ms" => Ok(Some(data.0 as f64)), + "s" => Ok(Some((data.0 as f64) / 1000.0)), _ => Ok(None), } }, ); #[cfg(not(any(feature = "lua51", feature = "luau")))] - methods.add_async_meta_method( + methods.add_async_meta_method_mut( mlua::MetaMethod::NewIndex, |_, data, (key, value): (String, f64)| async move { Delay::new(Duration::from_millis(10)).await; match key.as_str() { - "ms" => Ok(data.0.store(value as u64, Ordering::Relaxed)), - "s" => Ok(data.0.store((value * 1000.0) as u64, Ordering::Relaxed)), - _ => Err(Error::external(format!("key '{}' not found", key))), + "ms" => data.0 = value as u64, + "s" => data.0 = (value * 1000.0) as u64, + _ => return Err(Error::external(format!("key '{}' not found", key))), } + Ok(()) }, ); } @@ -420,7 +419,7 @@ async fn test_async_userdata() -> Result<()> { let lua = Lua::new(); let globals = lua.globals(); - let userdata = lua.create_userdata(MyUserData(Arc::new(AtomicU64::new(11))))?; + let userdata = lua.create_userdata(MyUserData(11))?; globals.set("userdata", userdata.clone())?; lua.load( diff --git a/tests/compile.rs b/tests/compile.rs index e4d822a..5a1a626 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -15,7 +15,11 @@ fn test_compilation() { t.compile_fail("tests/compile/static_callback_args.rs"); #[cfg(feature = "async")] - t.compile_fail("tests/compile/async_nonstatic_userdata.rs"); + { + t.compile_fail("tests/compile/async_any_userdata_method.rs"); + t.compile_fail("tests/compile/async_nonstatic_userdata.rs"); + t.compile_fail("tests/compile/async_userdata_method.rs"); + } #[cfg(feature = "send")] t.compile_fail("tests/compile/non_send.rs"); diff --git a/tests/compile/async_any_userdata_method.rs b/tests/compile/async_any_userdata_method.rs new file mode 100644 index 0000000..a9f2f8d --- /dev/null +++ b/tests/compile/async_any_userdata_method.rs @@ -0,0 +1,14 @@ +use mlua::{UserDataMethods, Lua}; + +fn main() { + let lua = Lua::new(); + + lua.register_userdata_type::(|reg| { + let s = String::new(); + let mut s = &s; + reg.add_async_method("t", |_, this: &String, ()| async { + s = this; + Ok(()) + }); + }).unwrap(); +} diff --git a/tests/compile/async_any_userdata_method.stderr b/tests/compile/async_any_userdata_method.stderr new file mode 100644 index 0000000..6790fbc --- /dev/null +++ b/tests/compile/async_any_userdata_method.stderr @@ -0,0 +1,81 @@ +error: lifetime may not live long enough + --> tests/compile/async_any_userdata_method.rs:9:58 + | +9 | reg.add_async_method("t", |_, this: &String, ()| async { + | ___________________________________----------------------_^ + | | | | + | | | return type of closure `[async block@$DIR/tests/compile/async_any_userdata_method.rs:9:58: 12:10]` contains a lifetime `'2` + | | lifetime `'1` represents this closure's body +10 | | s = this; +11 | | Ok(()) +12 | | }); + | |_________^ returning this value requires that `'1` must outlive `'2` + | + = note: closure implements `Fn`, so references to captured variables can't escape the closure + +error[E0596]: cannot borrow `s` as mutable, as it is a captured variable in a `Fn` closure + --> tests/compile/async_any_userdata_method.rs:9:58 + | +9 | reg.add_async_method("t", |_, this: &String, ()| async { + | __________________________________________________________^ +10 | | s = this; + | | - mutable borrow occurs due to use of `s` in closure +11 | | Ok(()) +12 | | }); + | |_________^ cannot borrow as mutable + +error[E0597]: `s` does not live long enough + --> tests/compile/async_any_userdata_method.rs:8:21 + | +8 | let mut s = &s; + | ^^ borrowed value does not live long enough +9 | / reg.add_async_method("t", |_, this: &String, ()| async { +10 | | s = this; +11 | | Ok(()) +12 | | }); + | |__________- argument requires that `s` is borrowed for `'static` +13 | }).unwrap(); + | - `s` dropped here while still borrowed + +error[E0521]: borrowed data escapes outside of closure + --> tests/compile/async_any_userdata_method.rs:9:9 + | +6 | lua.register_userdata_type::(|reg| { + | --- + | | + | `reg` is a reference that is only valid in the closure body + | has type `&mut LuaUserDataRegistrar<'1, std::string::String>` +... +9 | / reg.add_async_method("t", |_, this: &String, ()| async { +10 | | s = this; +11 | | Ok(()) +12 | | }); + | | ^ + | | | + | |__________`reg` escapes the closure body here + | argument requires that `'1` must outlive `'static` + | + = note: requirement occurs because of a mutable reference to `LuaUserDataRegistrar<'_, std::string::String>` + = note: mutable references are invariant over their type parameter + = help: see for more information about variance + +error[E0373]: closure may outlive the current function, but it borrows `s`, which is owned by the current function + --> tests/compile/async_any_userdata_method.rs:9:35 + | +9 | reg.add_async_method("t", |_, this: &String, ()| async { + | ^^^^^^^^^^^^^^^^^^^^^^ may outlive borrowed value `s` +10 | s = this; + | - `s` is borrowed here + | +note: function requires argument type to outlive `'static` + --> tests/compile/async_any_userdata_method.rs:9:9 + | +9 | / reg.add_async_method("t", |_, this: &String, ()| async { +10 | | s = this; +11 | | Ok(()) +12 | | }); + | |__________^ +help: to force the closure to take ownership of `s` (and any other referenced variables), use the `move` keyword + | +9 | reg.add_async_method("t", move |_, this: &String, ()| async { + | ++++ diff --git a/tests/compile/async_nonstatic_userdata.stderr b/tests/compile/async_nonstatic_userdata.stderr index ac21155..316feb1 100644 --- a/tests/compile/async_nonstatic_userdata.stderr +++ b/tests/compile/async_nonstatic_userdata.stderr @@ -4,11 +4,8 @@ error: lifetime may not live long enough 7 | impl<'a> UserData for MyUserData<'a> { | -- lifetime `'a` defined here 8 | fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - | ---- lifetime `'lua` defined here 9 | / methods.add_async_method("print", |_, data, ()| async move { 10 | | println!("{}", data.0); 11 | | Ok(()) 12 | | }); - | |______________^ argument requires that `'a` must outlive `'lua` - | - = help: consider adding the following bound: `'a: 'lua` + | |______________^ requires that `'a` must outlive `'static` diff --git a/tests/compile/async_userdata_method.rs b/tests/compile/async_userdata_method.rs new file mode 100644 index 0000000..7b07dc1 --- /dev/null +++ b/tests/compile/async_userdata_method.rs @@ -0,0 +1,14 @@ +use mlua::{UserData, UserDataMethods}; + +struct MyUserData; + +impl UserData for MyUserData { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_async_method("method", |_, this: &'static Self, ()| async { + Ok(()) + }); + // ^ lifetime may not live long enough + } +} + +fn main() {} diff --git a/tests/compile/async_userdata_method.stderr b/tests/compile/async_userdata_method.stderr new file mode 100644 index 0000000..a7f25a2 --- /dev/null +++ b/tests/compile/async_userdata_method.stderr @@ -0,0 +1,17 @@ +warning: unused variable: `this` + --> tests/compile/async_userdata_method.rs:7:48 + | +7 | methods.add_async_method("method", |_, this: &'static Self, ()| async { + | ^^^^ help: if this is intentional, prefix it with an underscore: `_this` + | + = note: `#[warn(unused_variables)]` on by default + +error: lifetime may not live long enough + --> tests/compile/async_userdata_method.rs:7:9 + | +6 | fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + | ---- lifetime `'lua` defined here +7 | / methods.add_async_method("method", |_, this: &'static Self, ()| async { +8 | | Ok(()) +9 | | }); + | |__________^ argument requires that `'lua` must outlive `'static`