Skip to content

Store factor to file #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 90 additions & 5 deletions pydiso/mkl_solver.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ cimport numpy as np
from cython cimport numeric

import warnings
from time import time
import numpy as np
import scipy.sparse as sp
import os
from pathlib import Path

ctypedef long long MKL_INT64
ctypedef unsigned long long MKL_UINT64
Expand All @@ -30,6 +32,7 @@ cdef extern from 'mkl.h':
void mkl_get_version(MKLVersion* pv)

void mkl_set_num_threads(int nth)

int mkl_domain_set_num_threads(int nt, int domain)
int mkl_get_max_threads()
int mkl_domain_get_max_threads(int domain)
Expand All @@ -39,6 +42,10 @@ cdef extern from 'mkl.h':

ctypedef void * _MKL_DSS_HANDLE_t

void pardiso_handle_store(_MKL_DSS_HANDLE_t pt, char *dirname, int *err)

void pardiso_handle_restore(_MKL_DSS_HANDLE_t pt, char *dirname, int *err)

void pardiso(_MKL_DSS_HANDLE_t, const int*, const int*, const int*,
const int *, const int *, const void *, const int *,
const int *, int *, const int *, int *,
Expand Down Expand Up @@ -184,13 +191,16 @@ cdef class MKLPardisoSolver:
cdef int_t _factored
cdef size_t shape[2]
cdef int_t _initialized
cdef char* call_flag_dir
cdef char* _flag_dir
cdef int_t _store

cdef void * a

cdef object _data_type
cdef object _Adata #a reference to make sure the pointer "a" doesn't get destroyed

def __init__(self, A, matrix_type=None, factor=True, verbose=False):
def __init__(self, A, matrix_type=None, factor=True, verbose=False, store_factorization_dir=None):
'''ParidsoSolver(A, matrix_type=None, factor=True, verbose=False)
An interface to the intel MKL pardiso sparse matrix solver.

Expand Down Expand Up @@ -305,6 +315,27 @@ cdef class MKLPardisoSolver:
self._set_A(A.data)
self._analyze()
self._factored = False

# check if we want to store the factorization
if store_factorization_dir is not None:

# check if the flag files exist. If so delete them so factorization file get overwritten
check_file = Path(store_factorization_dir) / 'factorization_done.txt'

if os.path.exists(check_file):

second_file_to_remove = Path(store_factorization_dir) / "flagfile.txt"
os.remove(check_file)
os.remove(second_file_to_remove)

self._store = True
flag_dir_ = bytes(store_factorization_dir, 'utf-8')
self._flag_dir = flag_dir_

else:

self._store = False

if factor:
self._factor()

Expand Down Expand Up @@ -422,6 +453,11 @@ cdef class MKLPardisoSolver:
else:
self._par.iparm[i] = val

def store_factorization(self, directory=b'./'):

self._store = True
self._flag_dir = directory

@property
def nnz(self):
return self.iparm[17]
Expand Down Expand Up @@ -515,11 +551,47 @@ cdef class MKLPardisoSolver:
cdef _factor(self):
#phase = 22
self._factored = False

if self._store:
try:

err = self._run_pardiso(22)
if err!=0:
raise PardisoError("Factor step error, "+_err_messages[err])
self._factored = True
flag_file = self._flag_dir.decode("utf-8") + 'flagfile.txt'

self.call_flag_dir = self._flag_dir

with open(flag_file, 'x') as f:
f.write('inversion in progress')

err = self._run_pardiso(22)

self._pardiso_store(self.call_flag_dir)

done_file = self._flag_dir.decode("utf-8") + 'factorization_done.txt'

with open(done_file, 'w') as f2:
f2.write('done')

self._factored = True
return

except FileExistsError:

# flag file exists, wait for "done" file and read in factorization
done_file = self._flag_dir.decode("utf-8") + 'factorization_done.txt'

while not os.path.isfile(done_file):
time.sleep(1)

# now read in the factorization from the file
self.call_flag_dir = self._flag_dir
self._pardiso_restore(self.call_flag_dir)

else:

err = self._run_pardiso(22)
if err!=0:
raise PardisoError("Factor step error, "+_err_messages[err])
self._factored = True

cdef _solve(self, void* b, void* x, int_t nrhs_in):
#phase = 33
Expand All @@ -544,3 +616,16 @@ cdef class MKLPardisoSolver:
&phase64, &self._par64.n, self.a, &self._par64.ia[0], &self._par64.ja[0],
&self._par64.perm[0], &nrhs64, self._par64.iparm, &self._par64.msglvl, b, x, &error64)
return error64

cdef _pardiso_store(self, char *dir_name):

cdef int_t error=0

pardiso_handle_store(self.handle, dir_name, &error)

cdef _pardiso_restore(self, char *dir_name):

cdef int_t error=0

pardiso_handle_restore(self.handle, dir_name, &error)

38 changes: 38 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,44 @@ def test_multiple_RHS():
assert rel_err < 1E3*eps
return rel_err

def test_multiple_RHS_store_factorization():
A = A_real_dict["real_symmetric_positive_definite"]
x = np.c_[xr, xr]
b = A @ x

solver = Solver(A, "real_symmetric_positive_definite", store_factorization_dir='./')
x2 = solver.solve(b)

eps = np.finfo(np.float64).eps
rel_err = np.linalg.norm(x-x2)/np.linalg.norm(x)
assert rel_err < 1E3*eps
return rel_err

def test_multiple_RHS_store_factorization_clean_flag_files():
A = A_real_dict["real_symmetric_positive_definite"]
x = np.c_[xr, xr]
b = A @ x

solver = Solver(A, "real_symmetric_positive_definite", store_factorization_dir='./')
x2 = solver.solve(b)

eps = np.finfo(np.float64).eps
rel_err = np.linalg.norm(x-x2)/np.linalg.norm(x)

assert rel_err < 1E3*eps

# run again to make sure the created flag files are checked and removed and running again works
x3 = solver.solve(b)

eps3 = np.finfo(np.float64).eps
rel_err3 = np.linalg.norm(x-x2)/np.linalg.norm(x)

assert rel_err3 < 1E3*eps3

assert rel_err == rel_err3

return rel_err


def test_matrix_type_errors():
A = A_real_dict["real_symmetric_positive_definite"]
Expand Down