#![allow(dead_code)]
extern crate unchecked_index;
extern crate memchr;
use std::cmp;
use std::mem;
use std::iter::Zip;
use self::unchecked_index::get_unchecked;
use TwoWaySearcher;
fn zip<I, J>(i: I, j: J) -> Zip<I::IntoIter, J::IntoIter>
where I: IntoIterator,
J: IntoIterator
{
i.into_iter().zip(j)
}
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[target_feature(enable = "sse4.2")]
unsafe fn pcmpestri_16_mask(text: *const u8, offset: usize, text_len: usize,
needle: __m128i, needle_len: usize) -> u32 {
let text = mask_load(text.offset(offset as _) as *const _, text_len);
_mm_cmpestri(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_ORDERED) as _
}
#[target_feature(enable = "sse4.2")]
unsafe fn pcmpestri_16_nomask(text: *const u8, offset: usize, text_len: usize,
needle: __m128i, needle_len: usize) -> u32 {
let text = _mm_loadu_si128(text.offset(offset as _) as *const _);
_mm_cmpestri(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_ORDERED) as _
}
#[target_feature(enable = "sse4.2")]
unsafe fn pcmpestrm_eq_each(text: *const u8, offset: usize, text_len: usize,
needle: *const u8, noffset: usize, needle_len: usize) -> u64 {
let needle = _mm_loadu_si128(needle.offset(noffset as _) as *const _);
let text = _mm_loadu_si128(text.offset(offset as _) as *const _);
let mask = _mm_cmpestrm(needle, needle_len as _, text, text_len as _, _SIDD_CMP_EQUAL_EACH);
#[cfg(target_arch = "x86")] {
_mm_extract_epi32(mask, 0) as u64 | (_mm_extract_epi32(mask, 1) as (u64) << 32)
}
#[cfg(target_arch = "x86_64")] {
_mm_extract_epi64(mask, 0) as _
}
}
#[cfg(test)]
fn first_start_of_match(text: &[u8], pat: &[u8]) -> Option<(usize, usize)> {
let patl = pat.len();
assert!(patl <= 16);
unsafe { first_start_of_match_mask(text, pat.len(), pat128(pat)) }
}
#[target_feature(enable = "sse4.2")]
unsafe fn first_start_of_match_mask(text: &[u8], pat_len: usize, p: __m128i) -> Option<(usize, usize)> {
let tp = text.as_ptr();
debug_assert!(pat_len <= 16);
let mut offset = 0;
while text.len() >= offset + pat_len {
let tlen = text.len() - offset;
let ret = pcmpestri_16_mask(tp, offset, tlen, p, pat_len) as usize;
if ret == 16 {
offset += 16;
} else {
let match_len = cmp::min(pat_len, 16 - ret);
return Some((offset + ret, match_len));
}
}
None
}
#[target_feature(enable = "sse4.2")]
unsafe fn first_start_of_match_nomask(text: &[u8], pat_len: usize, p: __m128i) -> Option<(usize, usize)> {
let tp = text.as_ptr();
debug_assert!(pat_len <= 16);
debug_assert!(pat_len <= text.len());
let mut offset = 0;
while text.len() - pat_len >= offset {
let tlen = text.len() - offset;
let ret = pcmpestri_16_nomask(tp, offset, tlen, p, pat_len) as usize;
if ret == 16 {
offset += 16;
} else {
let match_len = cmp::min(pat_len, 16 - ret);
return Some((offset + ret, match_len));
}
}
None
}
#[test]
fn test_first_start_of_match() {
let text = b"abc";
let longer = "longer text and so on";
assert_eq!(first_start_of_match(text, b"d"), None);
assert_eq!(first_start_of_match(text, b"c"), Some((2, 1)));
assert_eq!(first_start_of_match(text, b"abc"), Some((0, 3)));
assert_eq!(first_start_of_match(text, b"T"), None);
assert_eq!(first_start_of_match(text, b"\0text"), None);
assert_eq!(first_start_of_match(text, b"\0"), None);
for wsz in 1..17 {
for window in longer.as_bytes().windows(wsz) {
let str_find = longer.find(::std::str::from_utf8(window).unwrap());
assert!(str_find.is_some());
let first_start = first_start_of_match(longer.as_bytes(), window);
assert!(first_start.is_some());
let (pos, len) = first_start.unwrap();
assert!(len <= wsz);
assert!(len == wsz && Some(pos) == str_find
|| pos <= str_find.unwrap());
}
}
}
fn find_2byte_pat(text: &[u8], pat: &[u8]) -> Option<(usize, usize)> {
debug_assert!(text.len() >= pat.len());
debug_assert!(pat.len() == 2);
let mut off = 1;
while let Some(i) = memchr::memchr(pat[1], &text[off..]) {
match text.get(off + i - 1) {
None => break,
Some(&c) if c == pat[0] => return Some((off + i - 1, off + i + 1)),
_ => off += i + 1,
}
}
None
}
#[target_feature(enable = "sse4.2")]
unsafe fn find_short_pat(text: &[u8], pat: &[u8]) -> Option<usize> {
debug_assert!(pat.len() <= 8);
let r = pat128(pat);
let safetext = &text[..cmp::max(text.len(), 16) - 16];
let mut pos = 0;
'search: loop {
if pos + pat.len() > safetext.len() {
break;
}
match first_start_of_match_nomask(&safetext[pos..], pat.len(), r) {
None => {
pos = cmp::max(pos, safetext.len() - pat.len());
break
}
Some((mpos, mlen)) => {
pos += mpos;
if mlen < pat.len() {
if pos > text.len() - pat.len() {
return None;
}
for (&a, &b) in zip(&text[pos + mlen..], &pat[mlen..]) {
if a != b {
pos += 1;
continue 'search;
}
}
}
return Some(pos);
}
}
}
'tail: loop {
if pos > text.len() - pat.len() {
return None;
}
match first_start_of_match_mask(&text[pos..], pat.len(), r) {
None => return None,
Some((mpos, mlen)) => {
pos += mpos;
if mlen < pat.len() {
if pos > text.len() - pat.len() {
return None;
}
for (&a, &b) in zip(&text[pos + mlen..], &pat[mlen..]) {
if a != b {
pos += 1;
continue 'tail;
}
}
}
return Some(pos);
}
}
}
}
pub fn is_supported() -> bool {
#[cfg(feature = "use_std")]
return is_x86_feature_detected!("sse4.2");
#[cfg(not(feature = "use_std"))]
return cfg!(target_feature = "sse4.2");
}
pub fn find(text: &[u8], pattern: &[u8]) -> Option<usize> {
assert!(is_supported());
if pattern.is_empty() {
return Some(0);
} else if text.len() < pattern.len() {
return None;
} else if pattern.len() == 1 {
return memchr::memchr(pattern[0], text);
} else {
unsafe { find_inner(text, pattern) }
}
}
#[target_feature(enable = "sse4.2")]
pub(crate) unsafe fn find_inner(text: &[u8], pat: &[u8]) -> Option<usize> {
if pat.len() <= 6 {
return find_short_pat(text, pat);
}
let (crit_pos, mut period) = TwoWaySearcher::crit_params(pat);
let mut memory;
if &pat[..crit_pos] == &pat[period.. period + crit_pos] {
memory = 0;
} else {
memory = !0;
period = cmp::max(crit_pos, pat.len() - crit_pos) + 1;
}
let (left, right) = pat.split_at(crit_pos);
let (right16, _right17) = right.split_at(cmp::min(16, right.len()));
assert!(right.len() != 0);
let r = pat128(right);
let safetext = &text[..cmp::max(text.len(), 16) - 16];
let mut pos = 0;
if memory == !0 {
'search: loop {
if pos + pat.len() > safetext.len() {
break;
}
let start = crit_pos;
match first_start_of_match_nomask(&safetext[pos + start..], right16.len(), r) {
None => {
pos = cmp::max(pos, safetext.len() - pat.len());
break
}
Some((mpos, mlen)) => {
pos += mpos;
let mut pfxlen = mlen;
if pfxlen < right.len() {
pfxlen += shared_prefix_inner(&text[pos + start + mlen..], &right[mlen..]);
}
if pfxlen != right.len() {
pos += pfxlen + 1;
continue 'search;
} else {
}
}
}
if left != &text[pos..pos + left.len()] {
pos += period;
continue 'search;
}
return Some(pos);
}
} else {
'search_memory: loop {
if pos + pat.len() > safetext.len() {
break;
}
let mut pfxlen = if memory == 0 {
let start = crit_pos;
match first_start_of_match_nomask(&safetext[pos + start..], right16.len(), r) {
None => {
pos = cmp::max(pos, safetext.len() - pat.len());
break
}
Some((mpos, mlen)) => {
pos += mpos;
mlen
}
}
} else {
memory - crit_pos
};
if pfxlen < right.len() {
pfxlen += shared_prefix_inner(&text[pos + crit_pos + pfxlen..], &right[pfxlen..]);
}
if pfxlen != right.len() {
pos += pfxlen + 1;
memory = 0;
continue 'search_memory;
} else {
}
if memory <= left.len() && &left[memory..] != &text[pos + memory..pos + left.len()] {
pos += period;
memory = pat.len() - period;
continue 'search_memory;
}
return Some(pos);
}
}
'tail: loop {
if pos > text.len() - pat.len() {
return None;
}
let start = crit_pos;
match first_start_of_match_mask(&text[pos + start..], right16.len(), r) {
None => return None,
Some((mpos, mlen)) => {
pos += mpos;
let mut pfxlen = mlen;
if pfxlen < right.len() {
pfxlen += shared_prefix_inner(&text[pos + start + mlen..], &right[mlen..]);
}
if pfxlen != right.len() {
pos += pfxlen + 1;
continue 'tail;
} else {
}
}
}
if left != &text[pos..pos + left.len()] {
pos += period;
continue 'tail;
}
return Some(pos);
}
}
#[test]
fn test_find() {
let text = b"abc";
assert_eq!(find(text, b"d"), None);
assert_eq!(find(text, b"c"), Some(2));
let longer = "longer text and so on, a bit more";
for wsz in 1..longer.len() {
for window in longer.as_bytes().windows(wsz) {
let str_find = longer.find(::std::str::from_utf8(window).unwrap());
assert!(str_find.is_some());
assert_eq!(find(longer.as_bytes(), window), str_find, "{:?} {:?}",
longer, ::std::str::from_utf8(window));
}
}
let pat = b"ger text and so on";
assert!(pat.len() > 16);
assert_eq!(Some(3), find(longer.as_bytes(), pat));
let text = "cbabababcbabababab";
let n = "abababab";
assert_eq!(text.find(n), find(text.as_bytes(), n.as_bytes()));
let text = "cbababababababababababababababab";
let n = "abababab";
assert_eq!(text.find(n), find(text.as_bytes(), n.as_bytes()));
}
#[inline(always)]
fn pat128(pat: &[u8]) -> __m128i {
unsafe {
mask_load(pat.as_ptr() as *const _, pat.len())
}
}
#[inline(always)]
unsafe fn mask_load(ptr: *const u8, mut len: usize) -> __m128i {
let mut data: __m128i = _mm_setzero_si128();
len = cmp::min(len, mem::size_of_val(&data));
::std::ptr::copy_nonoverlapping(ptr, &mut data as *mut _ as _, len);
return data;
}
pub fn shared_prefix(text: &[u8], pat: &[u8]) -> usize {
assert!(is_supported());
unsafe { shared_prefix_inner(text, pat) }
}
#[target_feature(enable = "sse4.2")]
unsafe fn shared_prefix_inner(text: &[u8], pat: &[u8]) -> usize {
let tp = text.as_ptr();
let tlen = text.len();
let pp = pat.as_ptr();
let plen = pat.len();
let len = cmp::min(tlen, plen);
let initial_part = len.saturating_sub(16);
let mut prefix_len = 0;
let mut offset = 0;
while offset < initial_part {
let initial_tail = initial_part - offset;
let mask = pcmpestrm_eq_each(tp, offset, initial_tail, pp, offset, initial_tail);
if mask != 0xffff {
let first_bit_set = (mask ^ 0xffff).trailing_zeros() as usize;
prefix_len += first_bit_set;
return prefix_len;
} else {
prefix_len += cmp::min(initial_tail, 16);
}
offset += 16;
}
let text_suffix = get_unchecked(text, prefix_len..len);
let pat_suffix = get_unchecked(pat, prefix_len..len);
for (&a, &b) in zip(text_suffix, pat_suffix) {
if a != b {
break;
}
prefix_len += 1;
}
prefix_len
}
#[test]
fn test_prefixlen() {
let text_long = b"0123456789abcdefeffect";
let text_long2 = b"9123456789abcdefeffect";
let text_long3 = b"0123456789abcdefgffect";
let plen = shared_prefix(text_long, text_long);
assert_eq!(plen, text_long.len());
let plen = shared_prefix(b"abcd", b"abc");
assert_eq!(plen, 3);
let plen = shared_prefix(b"abcd", b"abcf");
assert_eq!(plen, 3);
assert_eq!(0, shared_prefix(text_long, text_long2));
assert_eq!(0, shared_prefix(text_long, &text_long[1..]));
assert_eq!(16, shared_prefix(text_long, text_long3));
for i in 0..text_long.len() + 1 {
assert_eq!(text_long.len() - i, shared_prefix(&text_long[i..], &text_long[i..]));
}
let l1 = [7u8; 1024];
let mut l2 = [7u8; 1024];
let off = 1000;
l2[off] = 0;
for i in 0..off {
let plen = shared_prefix(&l1[i..], &l2[i..]);
assert_eq!(plen, off - i);
}
}