导入必备包
import numpy as np
import pandas as pd
import time
from numpy.testing import *
from scipy.linalg import *
import gc
from matplotlib import pyplot as plt
%matplotlib inline
from numba import jit, void, double
def det_by_lu(y, x):
y[0] = 1.
N = x.shape[0]
for k in range(N):
y[0] *= x[k,k]
for i in range(k+1, N):
x[i,k] /= x[k,k]
for j in range(k+1, N):
x[i,j] -= x[i,k] * x[k,j]
def run_python(A,B,y,N):
# check that result is correct
np.copyto(B,A)
det_by_lu(y, B)
L = np.tril(B, -1) + np.eye(N)
U = np.triu(B)
assert_almost_equal( L.dot(U), A)
gc.disable()
st = time.time()
loops = 1 + (100000 // (N*N))
for l in range(loops):
np.copyto(B,A)
det_by_lu(y, B)
et = time.time()
gc.enable()
return (et - st)/loops
def numpy_det_by_lu(y, x):
y[0] = 1.
N = x.shape[0]
with np.errstate(invalid='ignore'):
for k in range(N):
y[0] *= x[k,k]
xk = x[k]
for i in range(k+1, N):
xi = x[i]
xi[k] /= xk[k]
xi[k+1:] -= xi[k] * xk[k+1:]
def run_numpy(A,B,y,N):
# check that result is correct
np.copyto(B,A)
numpy_det_by_lu(y, B)
L = np.tril(B, -1) + np.eye(N)
U = np.triu(B)
assert_almost_equal( L.dot(U), A)
gc.disable()
st = time.time()
loops = 1 + (100000 // (N*N))
for l in range(loops):
np.copyto(B,A)
numpy_det_by_lu(y, B)
et = time.time()
gc.enable()
return (et - st)/loops
def numba_det_by_lu(y, x):
y[0] = 1.
N = x.shape[0]
for k in range(N):
y[0] *= x[k,k]
for i in range(k+1, N):
x[i,k] /= x[k,k]
for j in range(k+1, N):
x[i,j] -= x[i,k] * x[k,j]
fastdet_by_lu = jit(void(double[:], double[:,:]))(numba_det_by_lu)
def run_numba(A,B,y,N):
# check that result is correct
np.copyto(B,A)
fastdet_by_lu(y, B)
L = np.tril(B, -1) + np.eye(N)
U = np.triu(B)
assert_almost_equal( L.dot(U), A)
gc.disable()
st = time.time()
loops = 1 + min(1000000 // (N*N), 20000)
for l in range(loops):
np.copyto(B,A)
fastdet_by_lu(y, B)
et = time.time()
gc.enable()
return (et - st)/loops
%load_ext cython
%%cython
import cython
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef cython_det_by_lu(double[:] y, double[:,:] x):
y[0] = 1.
cdef int N = x.shape[0]
cdef int i,j,k
for k in range(N):
y[0] *= x[k,k]
for i in range(k+1, N):
x[i,k] /= x[k,k]
for j in range(k+1, N):
x[i,j] -= x[i,k] * x[k,j]
def run_cython(A,B,y,N):
# check that result is correct
np.copyto(B,A)
cython_det_by_lu(y, B)
L = np.tril(B, -1) + np.eye(N)
U = np.triu(B)
assert_almost_equal( L.dot(U), A)
gc.disable()
st = time.time()
loops = 1 + min(1000000 // (N*N), 20000)
for l in range(loops):
np.copyto(B,A)
cython_det_by_lu(y, B)
et = time.time()
gc.enable()
return (et - st)/loops
from cffi import FFI
ffi = FFI()
ffi.cdef('void det_by_lu(double *y, double *B, int N);')
C = ffi.dlopen(r"C:UsersIBM_ADMINlu.dll")
c_det_by_lu = C.det_by_lu
def run_c(A,B,y,N):
# run c code
#B = numpy.zeros((N,N), order='F')
#B[:,:] = A
np.copyto(B,A)
c_det_by_lu(ffi.cast("double *", y.ctypes.data),
ffi.cast("double *", B.ctypes.data),
ffi.cast("int", N))
# check that result is correct
L = np.tril(B, -1) + np.eye(N)
U = np.triu(B)
assert_almost_equal( L.dot(U), A)
gc.disable()
st = time.time()
loops = 1 + min(1000000 // (N*N), 20000)
for l in range(loops):
np.copyto(B,A)
c_det_by_lu(ffi.cast("double *", y.ctypes.data),
ffi.cast("double *", B.ctypes.data),
ffi.cast("int", N))
et = time.time()
gc.enable()
return (et - st)/loops
def run_scipy(A,B,y,N):
# check that result is correct
np.copyto(B,A)
(P,L,U) = lu(B,overwrite_a=True)
assert_almost_equal( P.dot(L.dot(U)), A)
gc.disable()
st = time.time()
loops = 1 + min(1000000 // (N*N), 20000)
for l in range(loops):
np.copyto(B,A)
lu(B,overwrite_a=True)
et = time.time()
gc.enable()
return (et - st)/loops
def run_lapack(A,B,y,N):
# check that result is correct
gc.disable()
st = time.time()
loops = 1 + min(1000000 // (N*N), 20000)
for l in range(loops):
np.copyto(B,A)
lu_factor(B,overwrite_a=True)
et = time.time()
gc.enable()
return (et - st)/loops
def timings(n=7,
series=['pure python', 'c', 'numba', 'numpy',
'cython', 'scipy', 'lapack', 'julia']):
Ns = np.array([5,10,30,100, 200, 300, 400, 600, 1000, 2000, 4000, 8000])
Fs = [run_python, run_c, run_numba, run_numpy,
run_cython, run_scipy, run_lapack]
times = pd.DataFrame(np.zeros((n, len(Fs)+1)), index = Ns[:n], columns = series)
for i,N in enumerate(Ns[:n]):
print ('N =', N, end=" ")
A = np.random.random((N,N))
B = np.empty(A.shape)
y = np.zeros(1)
for j,label in enumerate(series[:-1]):
if label != '':
print(j, end=" ")
times.loc[N,label] = Fs[j](A,B,y,N)
print('')
return times
times = timings(9)
N = 5 0 1 2 3 4 5 6
N = 10 0 1 2 3 4 5 6
N = 30 0 1 2 3 4 5 6
N = 100 0 1 2 3 4 5 6
N = 200 0 1 2 3 4 5 6
N = 300 0 1 2 3 4 5 6
N = 400 0 1 2 3 4 5 6
N = 600 0 1 2 3 4 5 6
N = 1000 0 1 2 3 4 5 6
times
python | c | numba | numpy | cython | scipy | lapack | julia | |
5 | 0.000051 | 0.000016 | 0.000002 | 0.000074 | 0.000006 | 0.000029 | 0.000031 | 6.091400e-07 |
10 | 0.000312 | 0.000016 | 0.000003 | 0.000234 | 0.000006 | 0.000030 | 0.000031 | 1.060710e-06 |
30 | 0.007800 | 0.000028 | 0.000014 | 0.001950 | 0.000014 | 0.000070 | 0.000056 | 9.082080e-06 |
100 | 0.289310 | 0.000154 | 0.000463 | 0.029782 | 0.000309 | 0.000309 | 0.000309 | 2.265530e-04 |
200 | 2.277604 | 0.001800 | 0.007200 | 0.119600 | 0.003600 | 0.001200 | 0.001200 | 1.740604e-03 |
300 | 7.636214 | 0.007800 | 0.019500 | 0.226200 | 0.007800 | 0.003900 | 0.001300 | 5.823171e-03 |
400 | 18.267632 | 0.017829 | 0.051257 | 0.514801 | 0.020057 | 0.008914 | 0.002229 | 1.372135e-02 |
600 | 62.197309 | 0.062400 | 0.124800 | 0.982802 | 0.088400 | 0.036400 | 0.010400 | 4.543215e-02 |
1000 | 290.472510 | 0.257401 | 0.569401 | 3.042005 | 0.288600 | 0.070200 | 0.039000 | 2.642414e-01 |
def plot_times(times,
cols = [],
name="runtimes.png"):
plt.figure(figsize=(7,5))
if cols == []:
cols = times.columns
for i,label in enumerate(cols):
if label != '':
plt.loglog(times.index, times[label], label=label)
plt.xlabel("N (matrix size)")
plt.ylabel("runtime [sec]")
plt.grid()
plt.legend(loc=2)
plt.savefig(name)
plt.show()
plot_times(times, cols=['pure python', 'c', 'numba'], name='runtimes_1')
plot_times(times, cols=['pure python', 'c', 'cython'], name='runtimes_2')
plot_times(times, cols=['pure python', 'c', 'numpy'], name='runtimes_3')
plot_times(times, cols=['pure python', 'c', 'scipy', 'lapack'], name='runtimes_4')
plot_times(times, cols=['pure python', 'c', 'julia'], name='runtimes_5')
页面更新:2024-05-12
本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828
© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号