diff --git a/sailfish/src/runtime/escape/avx2.rs b/sailfish/src/runtime/escape/avx2.rs index 3ccb0df..b7ee822 100644 --- a/sailfish/src/runtime/escape/avx2.rs +++ b/sailfish/src/runtime/escape/avx2.rs @@ -7,7 +7,6 @@ use std::arch::x86_64::*; use std::slice; use super::super::Buffer; -use super::sse2; use super::{ESCAPED, ESCAPED_LEN, ESCAPE_LUT}; const VECTOR_BYTES: usize = std::mem::size_of::<__m256i>(); @@ -15,10 +14,12 @@ const VECTOR_ALIGN: usize = VECTOR_BYTES - 1; #[target_feature(enable = "avx2")] pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { + debug_assert!(feed.len() >= 16); + let len = feed.len(); if len < VECTOR_BYTES { - sse2::escape(feed, buffer); + escape_small(feed, buffer); return; } @@ -30,11 +31,11 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { let v_key1 = _mm256_set1_epi8(0x27); let v_key2 = _mm256_set1_epi8(0x3e); - let maskgen = |x: __m256i| -> i32 { + let maskgen = |x: __m256i| -> u32 { _mm256_movemask_epi8(_mm256_or_si256( _mm256_cmpeq_epi8(_mm256_or_si256(x, v_independent1), v_key1), _mm256_cmpeq_epi8(_mm256_or_si256(x, v_independent2), v_key2), - )) + )) as u32 }; let mut ptr = start_ptr; @@ -92,5 +93,106 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { next_ptr = next_ptr.add(VECTOR_BYTES); } - sse2::escape_aligned(buffer, start_ptr, ptr, end_ptr); + debug_assert!(next_ptr > end_ptr); + + if ptr < end_ptr { + debug_assert!((end_ptr as usize - ptr as usize) < VECTOR_BYTES); + let backs = VECTOR_BYTES - (end_ptr as usize - ptr as usize); + + let mut mask = + maskgen(_mm256_loadu_si256(ptr.sub(backs) as *const __m256i)) >> backs; + while mask != 0 { + let trailing_zeros = mask.trailing_zeros() as usize; + let ptr2 = ptr.add(trailing_zeros); + let c = ESCAPE_LUT[*ptr2 as usize] as usize; + if c < ESCAPED_LEN { + if start_ptr < ptr2 { + let slc = slice::from_raw_parts( + start_ptr, + ptr2 as usize - start_ptr as usize, + ); + buffer.push_str(std::str::from_utf8_unchecked(slc)); + } + buffer.push_str(*ESCAPED.get_unchecked(c)); + start_ptr = ptr2.add(1); + } + mask ^= 1 << trailing_zeros; + } + } + + if end_ptr > start_ptr { + let slc = slice::from_raw_parts(start_ptr, end_ptr as usize - start_ptr as usize); + buffer.push_str(std::str::from_utf8_unchecked(slc)); + } +} + +#[inline] +#[target_feature(enable = "avx2")] +unsafe fn escape_small(feed: &str, buffer: &mut Buffer) { + debug_assert!(feed.len() >= 16); + debug_assert!(feed.len() < VECTOR_BYTES); + + let len = feed.len(); + let mut start_ptr = feed.as_ptr(); + let mut ptr = start_ptr; + let end_ptr = start_ptr.add(len); + + let v_independent1 = _mm_set1_epi8(5); + let v_independent2 = _mm_set1_epi8(2); + let v_key1 = _mm_set1_epi8(0x27); + let v_key2 = _mm_set1_epi8(0x3e); + + let maskgen = |x: __m128i| -> u32 { + _mm_movemask_epi8(_mm_or_si128( + _mm_cmpeq_epi8(_mm_or_si128(x, v_independent1), v_key1), + _mm_cmpeq_epi8(_mm_or_si128(x, v_independent2), v_key2), + )) as u32 + }; + + let mut mask = maskgen(_mm_loadu_si128(ptr as *const __m128i)); + while mask != 0 { + let trailing_zeros = mask.trailing_zeros() as usize; + let ptr2 = ptr.add(trailing_zeros); + let c = ESCAPE_LUT[*ptr2 as usize] as usize; + if c < ESCAPED_LEN { + if start_ptr < ptr2 { + let slc = + slice::from_raw_parts(start_ptr, ptr2 as usize - start_ptr as usize); + buffer.push_str(std::str::from_utf8_unchecked(slc)); + } + buffer.push_str(*ESCAPED.get_unchecked(c)); + start_ptr = ptr2.add(1); + } + mask ^= 1 << trailing_zeros; + } + + if len != 16 { + ptr = ptr.add(16); + let read_ptr = end_ptr.sub(16); + let backs = 32 - len; + let mut mask = maskgen(_mm_loadu_si128(read_ptr as *const __m128i)) >> backs; + + while mask != 0 { + let trailing_zeros = mask.trailing_zeros() as usize; + let ptr2 = ptr.add(trailing_zeros); + let c = ESCAPE_LUT[*ptr2 as usize] as usize; + if c < ESCAPED_LEN { + if start_ptr < ptr2 { + let slc = slice::from_raw_parts( + start_ptr, + ptr2 as usize - start_ptr as usize, + ); + buffer.push_str(std::str::from_utf8_unchecked(slc)); + } + buffer.push_str(*ESCAPED.get_unchecked(c)); + start_ptr = ptr2.add(1); + } + mask ^= 1 << trailing_zeros; + } + } + + if end_ptr > start_ptr { + let slc = slice::from_raw_parts(start_ptr, end_ptr as usize - start_ptr as usize); + buffer.push_str(std::str::from_utf8_unchecked(slc)); + } } diff --git a/sailfish/src/runtime/escape/fallback.rs b/sailfish/src/runtime/escape/fallback.rs index 231795b..e6f5346 100644 --- a/sailfish/src/runtime/escape/fallback.rs +++ b/sailfish/src/runtime/escape/fallback.rs @@ -49,7 +49,7 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { return; } - let ptr = start_ptr; + let mut ptr = start_ptr; let aligned_ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & USIZE_ALIGN)); debug_assert_eq!(aligned_ptr as usize % USIZE_BYTES, 0); debug_assert!(aligned_ptr <= end_ptr); @@ -59,15 +59,8 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { start_ptr = naive::proceed(buffer, start_ptr, ptr, aligned_ptr); } - escape_aligned(buffer, start_ptr, aligned_ptr, end_ptr); -} + ptr = aligned_ptr; -pub unsafe fn escape_aligned( - buffer: &mut Buffer, - mut start_ptr: *const u8, - mut ptr: *const u8, - end_ptr: *const u8, -) { while ptr.add(USIZE_BYTES) <= end_ptr { debug_assert_eq!((ptr as usize) % USIZE_BYTES, 0); diff --git a/sailfish/src/runtime/escape/mod.rs b/sailfish/src/runtime/escape/mod.rs index 50ef757..1c4662a 100644 --- a/sailfish/src/runtime/escape/mod.rs +++ b/sailfish/src/runtime/escape/mod.rs @@ -2,9 +2,11 @@ //! //! By default sailfish replaces the characters `&"'<>` with the equivalent html. +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] mod avx2; mod fallback; mod naive; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] mod sse2; use std::mem; @@ -30,16 +32,10 @@ static ESCAPE_LUT: [u8; 256] = [ const ESCAPED: [&str; 5] = [""", "&", "'", "<", ">"]; const ESCAPED_LEN: usize = 5; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] static FN: AtomicPtr<()> = AtomicPtr::new(escape as FnRaw); -#[cfg(target_feature = "avx2")] -#[inline] -fn escape(feed: &str, buf: &mut Buffer) { - debug_assert!(feed.len() >= 16); - unsafe { avx2::escape(feed, buf) } -} - -#[cfg(not(target_feature = "avx2"))] +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] fn escape(feed: &str, buf: &mut Buffer) { debug_assert!(feed.len() >= 16); let fun = if is_x86_feature_detected!("avx2") { @@ -66,8 +62,24 @@ pub(crate) fn escape_to_buf(feed: &str, buf: &mut Buffer) { let l = naive::escape_small(feed, buf.as_mut_ptr().add(buf.len())); buf.set_len(buf.len() + l); } else { - let fun = FN.load(Ordering::Relaxed); - mem::transmute::(fun)(feed, buf); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + #[cfg(target_feature = "avx2")] + { + avx2::escape(feed, buf); + } + + #[cfg(not(target_feature = "avx2"))] + { + let fun = FN.load(Ordering::Relaxed); + mem::transmute::(fun)(feed, buf); + } + } + + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + { + fallback::escape(feed, buf); + } } } } @@ -151,31 +163,38 @@ mod tests { let mut buf3 = Buffer::new(); for len in 0..100 { - data.clear(); - for _ in 0..len { - // xorshift - state ^= state << 13; - state ^= state >> 7; - state ^= state << 17; + for _ in 0..10 { + data.clear(); + for _ in 0..len { + // xorshift + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; - let idx = state as usize % ASCII_CHARS.len(); - data.push(ASCII_CHARS[idx]); + let idx = state as usize % ASCII_CHARS.len(); + data.push(ASCII_CHARS[idx]); + } + + let s = unsafe { std::str::from_utf8_unchecked(&*data) }; + + buf1.clear(); + buf2.clear(); + buf3.clear(); + + unsafe { + escape_to_buf(s, &mut buf1); + fallback::escape(s, &mut buf2); + naive::escape( + &mut buf3, + s.as_ptr(), + s.as_ptr(), + s.as_ptr().add(s.len()), + ); + } + + assert_eq!(buf1.as_str(), buf3.as_str()); + assert_eq!(buf2.as_str(), buf3.as_str()); } - - let s = unsafe { std::str::from_utf8_unchecked(&*data) }; - - buf1.clear(); - buf2.clear(); - buf3.clear(); - - unsafe { - escape_to_buf(s, &mut buf1); - fallback::escape(s, &mut buf2); - naive::escape(&mut buf3, s.as_ptr(), s.as_ptr(), s.as_ptr().add(s.len())); - } - - assert_eq!(buf1.as_str(), buf3.as_str()); - assert_eq!(buf2.as_str(), buf3.as_str()); } } } diff --git a/sailfish/src/runtime/escape/sse2.rs b/sailfish/src/runtime/escape/sse2.rs index 927f763..7287692 100644 --- a/sailfish/src/runtime/escape/sse2.rs +++ b/sailfish/src/runtime/escape/sse2.rs @@ -7,14 +7,12 @@ use std::arch::x86_64::*; use std::slice; use super::super::Buffer; -use super::naive; use super::{ESCAPED, ESCAPED_LEN, ESCAPE_LUT}; const VECTOR_BYTES: usize = std::mem::size_of::<__m128i>(); const VECTOR_ALIGN: usize = VECTOR_BYTES - 1; #[target_feature(enable = "sse2")] -#[inline] pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { let len = feed.len(); let mut start_ptr = feed.as_ptr(); @@ -25,11 +23,11 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { let v_key1 = _mm_set1_epi8(0x27); let v_key2 = _mm_set1_epi8(0x3e); - let maskgen = |x: __m128i| -> i32 { + let maskgen = |x: __m128i| -> u32 { _mm_movemask_epi8(_mm_or_si128( _mm_cmpeq_epi8(_mm_or_si128(x, v_independent1), v_key1), _mm_cmpeq_epi8(_mm_or_si128(x, v_independent2), v_key2), - )) + )) as u32 }; let mut ptr = start_ptr; @@ -61,29 +59,7 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { } ptr = aligned_ptr; - escape_aligned(buffer, start_ptr, ptr, end_ptr); -} - -#[target_feature(enable = "sse2")] -#[cfg_attr(feature = "perf-inline", inline)] -pub unsafe fn escape_aligned( - buffer: &mut Buffer, - mut start_ptr: *const u8, - mut ptr: *const u8, - end_ptr: *const u8, -) { let mut next_ptr = ptr.add(VECTOR_BYTES); - let v_independent1 = _mm_set1_epi8(5); - let v_independent2 = _mm_set1_epi8(2); - let v_key1 = _mm_set1_epi8(0x27); - let v_key2 = _mm_set1_epi8(0x3e); - - let maskgen = |x: __m128i| -> i32 { - _mm_movemask_epi8(_mm_or_si128( - _mm_cmpeq_epi8(_mm_or_si128(x, v_independent1), v_key1), - _mm_cmpeq_epi8(_mm_or_si128(x, v_independent2), v_key2), - )) - }; while next_ptr <= end_ptr { debug_assert_eq!((ptr as usize) % VECTOR_BYTES, 0); @@ -110,10 +86,14 @@ pub unsafe fn escape_aligned( next_ptr = next_ptr.add(VECTOR_BYTES); } - next_ptr = ptr.add(8); - if next_ptr <= end_ptr { - debug_assert_eq!((ptr as usize) % VECTOR_BYTES, 0); - let mut mask = maskgen(_mm_loadl_epi64(ptr as *const __m128i)); + debug_assert!(next_ptr > end_ptr); + + if ptr < end_ptr { + debug_assert!((end_ptr as usize - ptr as usize) < VECTOR_BYTES); + let backs = VECTOR_BYTES - (end_ptr as usize - ptr as usize); + let read_ptr = ptr.sub(backs); + + let mut mask = maskgen(_mm_loadu_si128(read_ptr as *const __m128i)) >> backs; while mask != 0 { let trailing_zeros = mask.trailing_zeros() as usize; let ptr2 = ptr.add(trailing_zeros); @@ -131,11 +111,10 @@ pub unsafe fn escape_aligned( } mask ^= 1 << trailing_zeros; } - - ptr = next_ptr; } - debug_assert!(ptr <= end_ptr); - debug_assert!(start_ptr <= ptr); - naive::escape(buffer, start_ptr, ptr, end_ptr); + if end_ptr > start_ptr { + let slc = slice::from_raw_parts(start_ptr, end_ptr as usize - start_ptr as usize); + buffer.push_str(std::str::from_utf8_unchecked(slc)); + } }