Add `require` function to Luau

This commit is contained in:
Alex Orlenko 2022-03-20 20:01:04 +00:00
parent 405cff5d49
commit 4e0ba6559e
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
3 changed files with 85 additions and 2 deletions

View File

@ -72,6 +72,7 @@ tokio = { version = "1.0", features = ["full"] }
futures-timer = "3.0"
serde_json = "1.0"
maplit = "1.0"
tempfile = "3"
[[bench]]
name = "benchmark"

View File

@ -2606,7 +2606,6 @@ impl Lua {
use std::ffi::CStr;
// Since Luau has some missing standard function, we re-implement them here
// They are: collectgarbage, loadstring, require
unsafe extern "C" fn lua_collectgarbage(state: *mut ffi::lua_State) -> c_int {
let option = ffi::luaL_optstring(state, 1, cstr!("collect"));
@ -2629,10 +2628,62 @@ impl Lua {
}
}
self.globals().raw_set(
fn lua_require(lua: &Lua, name: Option<std::string::String>) -> Result<Value> {
let name = name.ok_or(Error::RuntimeError("name is nil".into()))?;
// Find module in the cache
let loaded = unsafe {
let _sg = StackGuard::new(lua.state);
check_stack(lua.state, 2)?;
protect_lua!(lua.state, 0, 1, fn(state) {
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED"));
})?;
Table(lua.pop_ref())
};
if let Some(v) = loaded.raw_get(name.clone())? {
return Ok(v);
}
// Load file from filesystem
let mut search_path = std::env::var("LUAU_PATH").unwrap_or_default();
if search_path.is_empty() {
search_path = "?.luau;?.lua".into();
}
let mut source = None;
for path in search_path.split(";") {
if let Ok(buf) = std::fs::read(path.replacen("?", &name, 1)) {
source = Some(buf);
break;
}
}
let source =
source.ok_or_else(|| Error::RuntimeError(format!("cannot find '{}'", name)))?;
let value = lua
.load(&source)
.set_name(&format!("={}", name))?
.set_mode(ChunkMode::Text)
.call::<_, Value>(())?;
// Save in the cache
loaded.raw_set(
name,
match value.clone() {
Value::Nil => Value::Boolean(true),
v => v,
},
)?;
Ok(value)
}
let globals = self.globals();
globals.raw_set(
"collectgarbage",
self.create_c_function(lua_collectgarbage)?,
)?;
globals.raw_set("require", self.create_function(lua_require)?)?;
Ok(())
}

31
tests/luau.rs Normal file
View File

@ -0,0 +1,31 @@
#![cfg(feature = "luau")]
use std::env;
use std::fs;
use mlua::{Lua, Result};
#[test]
fn test_require() -> Result<()> {
let lua = Lua::new();
let temp_dir = tempfile::tempdir().unwrap();
fs::write(
temp_dir.path().join("module.luau"),
r#"
counter = counter or 0
return counter + 1
"#,
)?;
env::set_var("LUAU_PATH", temp_dir.path().join("?.luau"));
lua.load(
r#"
local module = require("module")
assert(module == 1)
module = require("module")
assert(module == 1)
"#,
)
.exec()
}