diff --git a/src/lua.rs b/src/lua.rs index 9c2d59a..9946eed 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -61,6 +61,7 @@ struct ExtraData { registered_userdata_mt: HashSet, registry_unref_list: Arc>>>, + libs: StdLib, mem_info: *mut MemoryInfo, ref_thread: *mut ffi::lua_State, @@ -172,7 +173,9 @@ impl Lua { let mut lua = unsafe { Self::unsafe_new_with(libs) }; - mlua_expect!(lua.disable_c_modules(), "Error during disabling C modules"); + if libs.contains(StdLib::PACKAGE) { + mlua_expect!(lua.disable_c_modules(), "Error during disabling C modules"); + } lua.safe = true; Ok(lua) @@ -262,7 +265,7 @@ impl Lua { lua.ephemeral = false; #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] { - lua.extra.lock().unwrap().mem_info = mem_info; + mlua_expect!(lua.extra.lock(), "extra is poisoned").mem_info = mem_info; } mlua_expect!( @@ -271,6 +274,7 @@ impl Lua { }), "Error during loading standard libraries" ); + mlua_expect!(lua.extra.lock(), "extra is poisoned").libs |= libs; lua } @@ -321,6 +325,7 @@ impl Lua { registered_userdata_mt: HashSet::new(), registry_unref_list: Arc::new(Mutex::new(Some(Vec::new()))), ref_thread, + libs: StdLib::NONE, mem_info: ptr::null_mut(), // We need 1 extra stack space to move values in and out of the ref stack. ref_stack_size: ffi::LUA_MINSTACK - 1, @@ -381,11 +386,20 @@ impl Lua { } let state = self.main_state.unwrap_or(self.state); - unsafe { + let res = unsafe { protect_lua_closure(state, 0, 0, |state| { load_from_std_lib(state, libs); }) + }; + + // If `package` library loaded into a safe lua state then disable C modules + let curr_libs = mlua_expect!(self.extra.lock(), "extra is poisoned").libs; + if self.safe && (curr_libs ^ (curr_libs | libs)).contains(StdLib::PACKAGE) { + mlua_expect!(self.disable_c_modules(), "Error during disabling C modules"); } + mlua_expect!(self.extra.lock(), "extra is poisoned").libs |= libs; + + res } /// Consumes and leaks `Lua` object, returning a static reference `&'static Lua`. @@ -1655,7 +1669,12 @@ impl Lua { 'lua: 'callback, { #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - self.load_from_std_lib(StdLib::COROUTINE)?; + { + let libs = mlua_expect!(self.extra.lock(), "extra is poisoned").libs; + if !libs.contains(StdLib::COROUTINE) { + self.load_from_std_lib(StdLib::COROUTINE)?; + } + } unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int { callback_error(state, |nargs| { diff --git a/src/stdlib.rs b/src/stdlib.rs index 02f7e3d..9b20f4c 100644 --- a/src/stdlib.rs +++ b/src/stdlib.rs @@ -47,6 +47,8 @@ impl StdLib { /// (unsafe) [`debug`](https://www.lua.org/manual/5.3/manual.html#6.10) library pub const DEBUG: StdLib = StdLib(1 << 31); + /// No libraries + pub const NONE: StdLib = StdLib(0); /// (unsafe) All standard libraries pub const ALL: StdLib = StdLib(u32::MAX); /// The safe subset of the standard libraries diff --git a/tests/tests.rs b/tests/tests.rs index da8b739..813ea1d 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -72,6 +72,20 @@ fn test_safety() -> Result<()> { Err(e) => panic!("expected SafetyError, got {:?}", e), Ok(_) => panic!("expected SafetyError, got no error"), } + drop(lua); + + // Test safety rules after dynamically loading `package` library + let lua = Lua::new_with(StdLib::NONE)?; + assert!(lua.globals().get::<_, Option>("require")?.is_none()); + lua.load_from_std_lib(StdLib::PACKAGE)?; + match lua.load(r#"package.loadlib()"#).exec() { + Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { + Error::SafetyError(_) => {} + e => panic!("expected SafetyError cause, got {:?}", e), + }, + Err(e) => panic!("expected CallbackError, got {:?}", e), + Ok(_) => panic!("expected CallbackError, got no error"), + }; Ok(()) }