diff --git a/sailfish/src/runtime/escape/avx2.rs b/sailfish/src/runtime/escape/avx2.rs index b75043b..2639f42 100644 --- a/sailfish/src/runtime/escape/avx2.rs +++ b/sailfish/src/runtime/escape/avx2.rs @@ -14,20 +14,20 @@ const VECTOR_BYTES: usize = std::mem::size_of::<__m256i>(); const VECTOR_ALIGN: usize = VECTOR_BYTES - 1; #[target_feature(enable = "avx2")] -pub unsafe fn escape(buffer: &mut Buffer, bytes: &[u8]) { - let len = bytes.len(); +pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { + let len = feed.len(); if len < 8 { - let start_ptr = bytes.as_ptr(); + let start_ptr = feed.as_ptr(); let end_ptr = start_ptr.add(len); naive::escape(buffer, start_ptr, start_ptr, end_ptr); return; } else if len < VECTOR_BYTES { - sse2::escape(buffer, bytes); + sse2::escape(feed, buffer); return; } - let mut start_ptr = bytes.as_ptr(); + let mut start_ptr = feed.as_ptr(); let end_ptr = start_ptr.add(len); let v_independent1 = _mm256_set1_epi8(5); diff --git a/sailfish/src/runtime/escape/fallback.rs b/sailfish/src/runtime/escape/fallback.rs index 540b934..231795b 100644 --- a/sailfish/src/runtime/escape/fallback.rs +++ b/sailfish/src/runtime/escape/fallback.rs @@ -39,12 +39,12 @@ fn contains_key(x: usize) -> bool { } #[inline] -pub unsafe fn escape(buffer: &mut Buffer, bytes: &[u8]) { - let len = bytes.len(); - let mut start_ptr = bytes.as_ptr(); +pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { + let len = feed.len(); + let mut start_ptr = feed.as_ptr(); let end_ptr = start_ptr.add(len); - if bytes.len() < USIZE_BYTES { + if feed.len() < USIZE_BYTES { naive::escape(buffer, start_ptr, start_ptr, end_ptr); return; } diff --git a/sailfish/src/runtime/escape/mod.rs b/sailfish/src/runtime/escape/mod.rs index 606c42e..eccae83 100644 --- a/sailfish/src/runtime/escape/mod.rs +++ b/sailfish/src/runtime/escape/mod.rs @@ -7,8 +7,13 @@ mod fallback; mod naive; mod sse2; +use std::mem; +use std::sync::atomic::{AtomicPtr, Ordering}; + use super::buffer::Buffer; +type FnRaw = *mut (); + static ESCAPE_LUT: [u8; 256] = [ 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 9, 9, 9, 1, 2, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, @@ -25,38 +30,37 @@ static ESCAPE_LUT: [u8; 256] = [ const ESCAPED: [&str; 5] = [""", "&", "'", "<", ">"]; const ESCAPED_LEN: usize = 5; +static FN: AtomicPtr<()> = AtomicPtr::new(escape as FnRaw); + #[cfg(target_feature = "avx2")] -pub(crate) fn escape_to_buf(feed: &str, buf: &mut Buffer) { +pub fn escape(feed: &str, buf: &mut Buffer) { unsafe { avx2::escape(buf, feed.as_bytes()) } } +/// default escape function #[cfg(not(target_feature = "avx2"))] +pub fn escape(feed: &str, buf: &mut Buffer) { + let fun = if is_x86_feature_detected!("avx2") { + avx2::escape + } else if is_x86_feature_detected!("sse2") { + sse2::escape + } else { + fallback::escape + }; + + FN.store(fun as FnRaw, Ordering::Relaxed); + unsafe { fun(feed, buf) }; +} + +pub fn register_escape_fn(fun: fn(&str, &mut Buffer)) { + FN.store(fun as FnRaw, Ordering::Relaxed); +} + +#[inline] pub(crate) fn escape_to_buf(feed: &str, buf: &mut Buffer) { - use std::mem; - use std::sync::atomic::{AtomicPtr, Ordering}; - - type FnRaw = *mut (); - - static FN: AtomicPtr<()> = AtomicPtr::new(detect as FnRaw); - - fn detect(buffer: &mut Buffer, bytes: &[u8]) { - let fun = if is_x86_feature_detected!("avx2") { - avx2::escape as FnRaw - } else if is_x86_feature_detected!("sse2") { - sse2::escape as FnRaw - } else { - fallback::escape as FnRaw - }; - - FN.store(fun as FnRaw, Ordering::Relaxed); - unsafe { - mem::transmute::(fun)(buffer, bytes); - } - } - unsafe { let fun = FN.load(Ordering::Relaxed); - mem::transmute::(fun)(buf, feed.as_bytes()); + mem::transmute::(fun)(feed, buf); } } @@ -157,8 +161,8 @@ mod tests { buf3.clear(); unsafe { - escape_to_buf(&*s, &mut buf1); - fallback::escape(&mut buf2, s.as_bytes()); + escape_to_buf(s, &mut buf1); + fallback::escape(s, &mut buf2); naive::escape(&mut buf3, s.as_ptr(), s.as_ptr(), s.as_ptr().add(s.len())); } diff --git a/sailfish/src/runtime/escape/sse2.rs b/sailfish/src/runtime/escape/sse2.rs index 1b84ef7..3a650af 100644 --- a/sailfish/src/runtime/escape/sse2.rs +++ b/sailfish/src/runtime/escape/sse2.rs @@ -15,16 +15,16 @@ const VECTOR_ALIGN: usize = VECTOR_BYTES - 1; #[target_feature(enable = "sse2")] #[inline] -pub unsafe fn escape(buffer: &mut Buffer, bytes: &[u8]) { - let len = bytes.len(); +pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { + let len = feed.len(); if len < VECTOR_BYTES { - let start_ptr = bytes.as_ptr(); + let start_ptr = feed.as_ptr(); let end_ptr = start_ptr.add(len); naive::escape(buffer, start_ptr, start_ptr, end_ptr); return; } - let mut start_ptr = bytes.as_ptr(); + let mut start_ptr = feed.as_ptr(); let end_ptr = start_ptr.add(len); let v_independent1 = _mm_set1_epi8(5);