diff --git a/mlua_derive/src/lib.rs b/mlua_derive/src/lib.rs index 76fb415..7931dc0 100644 --- a/mlua_derive/src/lib.rs +++ b/mlua_derive/src/lib.rs @@ -1,7 +1,7 @@ use proc_macro::TokenStream; use proc_macro2::{Ident, Span}; use quote::quote; -use syn::{parse_macro_input, AttributeArgs, ItemFn, Lit, Meta, NestedMeta}; +use syn::{parse_macro_input, AttributeArgs, Error, ItemFn, Lit, Meta, NestedMeta, Result}; #[cfg(feature = "macros")] use { @@ -13,41 +13,49 @@ use { struct ModuleArgs { name: Option, } + impl ModuleArgs { - fn parse(attr: AttributeArgs) -> Self { + fn parse(args: AttributeArgs) -> Result { let mut ret = Self::default(); - for arg in attr { + for arg in args { match arg { NestedMeta::Meta(Meta::NameValue(meta)) => { - if meta.path.segments.last().unwrap().ident == "name" { - if let Lit::Str(val) = meta.lit { - if let Ok(val) = val.parse() { - ret.name = Some(val); + if meta.path.is_ident("name") { + match meta.lit { + Lit::Str(val) => { + ret.name = Some(val.parse()?); + } + _ => { + return Err(Error::new_spanned(meta.lit, "expected string literal")) } } + } else { + return Err(Error::new_spanned(meta.path, "expected `name`")); } } - _ => {} + _ => { + return Err(Error::new_spanned(arg, "invalid argument")); + } } } - ret + + Ok(ret) } } #[proc_macro_attribute] pub fn lua_module(attr: TokenStream, item: TokenStream) -> TokenStream { let args = parse_macro_input!(attr as AttributeArgs); - let args = ModuleArgs::parse(args); + let args = match ModuleArgs::parse(args) { + Ok(args) => args, + Err(err) => return err.to_compile_error().into(), + }; let func = parse_macro_input!(item as ItemFn); let func_name = func.sig.ident.clone(); - let module_name = if let Some(name) = args.name { - name - } else { - func_name.clone() - }; - let ext_entrypoint_name = Ident::new(&format!("luaopen_{}", module_name), Span::call_site()); + let module_name = args.name.unwrap_or_else(|| func_name.clone()); + let ext_entrypoint_name = Ident::new(&format!("luaopen_{module_name}"), Span::call_site()); let wrapped = quote! { ::mlua::require_module_feature!(); diff --git a/tests/module/src/lib.rs b/tests/module/src/lib.rs index c624078..4e5aae5 100644 --- a/tests/module/src/lib.rs +++ b/tests/module/src/lib.rs @@ -12,7 +12,7 @@ fn check_userdata(_: &Lua, ud: MyUserData) -> LuaResult { Ok(ud.0) } -#[mlua::lua_module(name = "rust_module_first")] +#[mlua::lua_module] fn rust_module(lua: &Lua) -> LuaResult { let exports = lua.create_table()?; exports.set("sum", lua.create_function(sum)?)?; @@ -26,8 +26,8 @@ struct MyUserData(i32); impl LuaUserData for MyUserData {} -#[mlua::lua_module] -fn rust_module_second(lua: &Lua) -> LuaResult { +#[mlua::lua_module(name = "rust_module_second")] +fn rust_module2(lua: &Lua) -> LuaResult { let exports = lua.create_table()?; exports.set("userdata", lua.create_userdata(MyUserData(123))?)?; Ok(exports)