From 595dc3e95f38ee719010d6bc926239b7e40b7386 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Thu, 31 Mar 2022 19:31:37 +0100 Subject: [PATCH] Move some Luau functionality to a new module Immplement native "vector" function to construct vectors --- src/lib.rs | 2 + src/lua.rs | 87 ----------------------------------------- src/luau.rs | 106 ++++++++++++++++++++++++++++++++++++++++++++++++++ tests/luau.rs | 33 +++++++++++----- 4 files changed, 132 insertions(+), 96 deletions(-) create mode 100644 src/luau.rs diff --git a/src/lib.rs b/src/lib.rs index a591fd6..fdedfe2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -88,6 +88,8 @@ mod ffi; mod function; mod hook; mod lua; +#[cfg(feature = "luau")] +mod luau; mod multi; mod scope; mod stdlib; diff --git a/src/lua.rs b/src/lua.rs index 888dac0..d7bcb8b 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -2781,93 +2781,6 @@ impl Lua { Ok(()) } - #[cfg(feature = "luau")] - unsafe fn prepare_luau_state(&self) -> Result<()> { - use std::ffi::CStr; - - // Since Luau has some missing standard function, we re-implement them here - - unsafe extern "C" fn lua_collectgarbage(state: *mut ffi::lua_State) -> c_int { - let option = ffi::luaL_optstring(state, 1, cstr!("collect")); - let option = CStr::from_ptr(option); - match option.to_str() { - Ok("collect") => { - ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0); - 0 - } - Ok("count") => { - let n = ffi::lua_gc(state, ffi::LUA_GCCOUNT, 0); - ffi::lua_pushnumber(state, n as ffi::lua_Number); - 1 - } - // TODO: More variants - _ => ffi::luaL_error( - state, - cstr!("collectgarbage must be called with 'count' or 'collect'"), - ), - } - } - - fn lua_require(lua: &Lua, name: Option) -> Result { - let name = name.ok_or_else(|| Error::RuntimeError("name is nil".into()))?; - - // Find module in the cache - let loaded = unsafe { - let _sg = StackGuard::new(lua.state); - check_stack(lua.state, 2)?; - protect_lua!(lua.state, 0, 1, fn(state) { - ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); - })?; - Table(lua.pop_ref()) - }; - if let Some(v) = loaded.raw_get(name.clone())? { - return Ok(v); - } - - // Load file from filesystem - let mut search_path = std::env::var("LUAU_PATH").unwrap_or_default(); - if search_path.is_empty() { - search_path = "?.luau;?.lua".into(); - } - - let mut source = None; - for path in search_path.split(';') { - if let Ok(buf) = std::fs::read(path.replacen('?', &name, 1)) { - source = Some(buf); - break; - } - } - let source = - source.ok_or_else(|| Error::RuntimeError(format!("cannot find '{}'", name)))?; - - let value = lua - .load(&source) - .set_name(&format!("={}", name))? - .set_mode(ChunkMode::Text) - .call::<_, Value>(())?; - - // Save in the cache - loaded.raw_set( - name, - match value.clone() { - Value::Nil => Value::Boolean(true), - v => v, - }, - )?; - - Ok(value) - } - - let globals = self.globals(); - globals.raw_set( - "collectgarbage", - self.create_c_function(lua_collectgarbage)?, - )?; - globals.raw_set("require", self.create_function(lua_require)?)?; - - Ok(()) - } - pub(crate) unsafe fn make_from_ptr(state: *mut ffi::lua_State) -> Option { let _sg = StackGuard::new(state); assert_stack(state, 1); diff --git a/src/luau.rs b/src/luau.rs new file mode 100644 index 0000000..f45290c --- /dev/null +++ b/src/luau.rs @@ -0,0 +1,106 @@ +use std::ffi::CStr; +use std::os::raw::{c_float, c_int}; + +use crate::chunk::ChunkMode; +use crate::error::{Error, Result}; +use crate::ffi; +use crate::lua::Lua; +use crate::table::Table; +use crate::util::{check_stack, StackGuard}; +use crate::value::Value; + +// Since Luau has some missing standard function, we re-implement them here + +impl Lua { + pub(crate) unsafe fn prepare_luau_state(&self) -> Result<()> { + let globals = self.globals(); + + globals.raw_set( + "collectgarbage", + self.create_c_function(lua_collectgarbage)?, + )?; + globals.raw_set("require", self.create_function(lua_require)?)?; + globals.raw_set("vector", self.create_c_function(lua_vector)?)?; + + Ok(()) + } +} + +unsafe extern "C" fn lua_collectgarbage(state: *mut ffi::lua_State) -> c_int { + let option = ffi::luaL_optstring(state, 1, cstr!("collect")); + let option = CStr::from_ptr(option); + match option.to_str() { + Ok("collect") => { + ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0); + 0 + } + Ok("count") => { + let n = ffi::lua_gc(state, ffi::LUA_GCCOUNT, 0); + ffi::lua_pushnumber(state, n as ffi::lua_Number); + 1 + } + // TODO: More variants + _ => ffi::luaL_error( + state, + cstr!("collectgarbage must be called with 'count' or 'collect'"), + ), + } +} + +fn lua_require(lua: &Lua, name: Option) -> Result { + let name = name.ok_or_else(|| Error::RuntimeError("name is nil".into()))?; + + // Find module in the cache + let loaded = unsafe { + let _sg = StackGuard::new(lua.state); + check_stack(lua.state, 2)?; + protect_lua!(lua.state, 0, 1, fn(state) { + ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); + })?; + Table(lua.pop_ref()) + }; + if let Some(v) = loaded.raw_get(name.clone())? { + return Ok(v); + } + + // Load file from filesystem + let mut search_path = std::env::var("LUAU_PATH").unwrap_or_default(); + if search_path.is_empty() { + search_path = "?.luau;?.lua".into(); + } + + let mut source = None; + for path in search_path.split(';') { + if let Ok(buf) = std::fs::read(path.replacen('?', &name, 1)) { + source = Some(buf); + break; + } + } + let source = source.ok_or_else(|| Error::RuntimeError(format!("cannot find '{}'", name)))?; + + let value = lua + .load(&source) + .set_name(&format!("={}", name))? + .set_mode(ChunkMode::Text) + .call::<_, Value>(())?; + + // Save in the cache + loaded.raw_set( + name, + match value.clone() { + Value::Nil => Value::Boolean(true), + v => v, + }, + )?; + + Ok(value) +} + +// Luau vector datatype constructor +unsafe extern "C" fn lua_vector(state: *mut ffi::lua_State) -> c_int { + let x = ffi::luaL_checknumber(state, 1) as c_float; + let y = ffi::luaL_checknumber(state, 2) as c_float; + let z = ffi::luaL_checknumber(state, 3) as c_float; + ffi::lua_pushvector(state, x, y, z); + 1 +} diff --git a/tests/luau.rs b/tests/luau.rs index 608ea98..044c883 100644 --- a/tests/luau.rs +++ b/tests/luau.rs @@ -36,17 +36,32 @@ fn test_require() -> Result<()> { fn test_vectors() -> Result<()> { let lua = Lua::new(); - let globals = lua.globals(); - globals.set( - "vector", - lua.create_function(|_, (x, y, z)| Ok(Value::Vector(x, y, z)))?, - )?; - - let v: [f32; 3] = lua - .load("return vector(1, 2, 3) + vector(3, 2, 1)") - .eval()?; + let v: [f32; 3] = lua.load("vector(1, 2, 3) + vector(3, 2, 1)").eval()?; assert_eq!(v, [4.0, 4.0, 4.0]); + // Test vector methods + lua.load( + r#" + local v = vector(1, 2, 3) + assert(v.x == 1) + assert(v.y == 2) + assert(v.z == 3) + "#, + ) + .exec()?; + + // Test vector methods (fastcall) + lua.load( + r#" + local v = vector(1, 2, 3) + assert(v.x == 1) + assert(v.y == 2) + assert(v.z == 3) + "#, + ) + .set_vector_ctor(Some("vector".to_string())) + .exec()?; + Ok(()) }