diff --git a/Cargo.toml b/Cargo.toml index dd61bbf..15115be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,6 +72,7 @@ tokio = { version = "1.0", features = ["full"] } futures-timer = "3.0" serde_json = "1.0" maplit = "1.0" +tempfile = "3" [[bench]] name = "benchmark" diff --git a/src/lua.rs b/src/lua.rs index 8cfc967..c0782ef 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -2606,7 +2606,6 @@ impl Lua { use std::ffi::CStr; // Since Luau has some missing standard function, we re-implement them here - // They are: collectgarbage, loadstring, require unsafe extern "C" fn lua_collectgarbage(state: *mut ffi::lua_State) -> c_int { let option = ffi::luaL_optstring(state, 1, cstr!("collect")); @@ -2629,10 +2628,62 @@ impl Lua { } } - self.globals().raw_set( + fn lua_require(lua: &Lua, name: Option) -> Result { + let name = name.ok_or(Error::RuntimeError("name is nil".into()))?; + + // Find module in the cache + let loaded = unsafe { + let _sg = StackGuard::new(lua.state); + check_stack(lua.state, 2)?; + protect_lua!(lua.state, 0, 1, fn(state) { + ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); + })?; + Table(lua.pop_ref()) + }; + if let Some(v) = loaded.raw_get(name.clone())? { + return Ok(v); + } + + // Load file from filesystem + let mut search_path = std::env::var("LUAU_PATH").unwrap_or_default(); + if search_path.is_empty() { + search_path = "?.luau;?.lua".into(); + } + + let mut source = None; + for path in search_path.split(";") { + if let Ok(buf) = std::fs::read(path.replacen("?", &name, 1)) { + source = Some(buf); + break; + } + } + let source = + source.ok_or_else(|| Error::RuntimeError(format!("cannot find '{}'", name)))?; + + let value = lua + .load(&source) + .set_name(&format!("={}", name))? + .set_mode(ChunkMode::Text) + .call::<_, Value>(())?; + + // Save in the cache + loaded.raw_set( + name, + match value.clone() { + Value::Nil => Value::Boolean(true), + v => v, + }, + )?; + + Ok(value) + } + + let globals = self.globals(); + globals.raw_set( "collectgarbage", self.create_c_function(lua_collectgarbage)?, )?; + globals.raw_set("require", self.create_function(lua_require)?)?; Ok(()) } diff --git a/tests/luau.rs b/tests/luau.rs new file mode 100644 index 0000000..2524efb --- /dev/null +++ b/tests/luau.rs @@ -0,0 +1,31 @@ +#![cfg(feature = "luau")] + +use std::env; +use std::fs; + +use mlua::{Lua, Result}; + +#[test] +fn test_require() -> Result<()> { + let lua = Lua::new(); + + let temp_dir = tempfile::tempdir().unwrap(); + fs::write( + temp_dir.path().join("module.luau"), + r#" + counter = counter or 0 + return counter + 1 + "#, + )?; + + env::set_var("LUAU_PATH", temp_dir.path().join("?.luau")); + lua.load( + r#" + local module = require("module") + assert(module == 1) + module = require("module") + assert(module == 1) + "#, + ) + .exec() +}