diff --git a/Cargo.toml b/Cargo.toml index 9da21c57c..54626ea62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ license = "BSD-2-Clause" half = { version = "2.0", default-features = false, optional = true } libc = "0.2" nalgebra = { version = ">=0.30, <0.34", default-features = false, optional = true } +faer = { version = "0.21.9", optional = true } num-complex = ">= 0.2, < 0.5" num-integer = "0.1" num-traits = "0.2" @@ -25,6 +26,9 @@ ndarray = ">= 0.15, < 0.17" pyo3 = { version = "0.24", default-features = false, features = ["macros"] } rustc-hash = "2.0" +[features] +faer = ["dep:faer"] + [dev-dependencies] pyo3 = { version = "0.24", default-features = false, features = ["auto-initialize"] } nalgebra = { version = ">=0.30, <0.34", default-features = false, features = ["std"] } diff --git a/src/convert.rs b/src/convert.rs index 66c557b1b..16ca0255e 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -2,7 +2,7 @@ use std::{mem, os::raw::c_int, ptr}; -use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix1, OwnedRepr}; +use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix1, Ix2, OwnedRepr}; use pyo3::{Bound, Python}; use crate::array::{PyArray, PyArrayMethods}; @@ -90,6 +90,32 @@ impl IntoPyArray for Vec { } } +#[cfg(feature = "faer")] +impl IntoPyArray for faer::Mat { + type Item = T; + type Dim = Ix2; + + fn into_pyarray<'py>(mut self, py: Python<'py>) -> Bound<'py, PyArray> { + let dims = Dim([self.nrows(), self.ncols()]); + let rstride = self.row_stride(); + let cstride = self.col_stride(); + let strides = [ + rstride * mem::size_of::() as npy_intp, + cstride * mem::size_of::() as npy_intp, + ]; + let data_ptr = self.as_ptr_mut(); + unsafe { + PyArray::from_raw_parts( + py, + dims, + strides.as_ptr(), + data_ptr, + PySliceContainer::from(self), + ) + } + } +} + impl IntoPyArray for ArrayBase, D> where A: Element, diff --git a/src/slice_container.rs b/src/slice_container.rs index 0c29eae61..8df0eb32a 100644 --- a/src/slice_container.rs +++ b/src/slice_container.rs @@ -71,6 +71,35 @@ impl From> for PySliceContainer { } } +#[cfg(feature = "faer")] +impl From> for PySliceContainer { + fn from(data: faer::Mat) -> Self { + unsafe fn drop_faer_mat(ptr: *mut u8, len_nrows: usize, cap_ncols: usize) { + let _ = faer::mat::MatMut::from_raw_parts_mut( + ptr as *mut T, + len_nrows, + cap_ncols, + 1, + cap_ncols as isize, + ); + } + + let mut data = mem::ManuallyDrop::new(data); + + let ptr = data.as_ptr_mut() as *mut u8; + let len = data.nrows(); + let cap = data.ncols(); + let drop = drop_faer_mat::; + + Self { + ptr, + len, + cap, + drop, + } + } +} + impl From, D>> for PySliceContainer where A: Send + Sync, diff --git a/tests/to_py.rs b/tests/to_py.rs index c18d2d6db..aece50b25 100644 --- a/tests/to_py.rs +++ b/tests/to_py.rs @@ -288,6 +288,31 @@ fn slice_container_type_confusion() { }); } +#[cfg(feature = "faer")] +#[test] +fn faer_mat_to_numpy() { + let faer_mat: faer::Mat = faer::Scale(2.0) * faer::mat::Mat::::identity(2, 2); + let faer_mat_wide: faer::Mat = faer::mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + let faer_mat_tall: faer::Mat = faer_mat_wide.transpose().to_owned(); + Python::with_gil(|py| { + let mat_pyarray = faer_mat.into_pyarray(py); + let mat_wide_pyarray = faer_mat_wide.into_pyarray(py); + let mat_tall_pyarray = faer_mat_tall.into_pyarray(py); + assert_eq!( + mat_pyarray.readonly().as_array(), + array![[2.0f64, 0.0f64], [0.0f64, 2.0f64]] + ); + assert_eq!( + mat_wide_pyarray.readonly().as_array(), + array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]] + ); + assert_eq!( + mat_tall_pyarray.readonly().as_array(), + array![[1.0f64, 4.0], [2.0, 5.0], [3.0, 6.0]] + ); + }); +} + #[cfg(feature = "nalgebra")] #[test] fn matrix_to_numpy() {