import matplotlib.pyplot as plt
import numpy as np
n = 2**12
x = np.zeros([n])
for i in range(4,30): x[i]=1
c0 = (1+np.sqrt(3))/(4*np.sqrt(2))
c1 = (3+np.sqrt(3))/(4*np.sqrt(2))
c2 = (3-np.sqrt(3))/(4*np.sqrt(2))
c3 = (1-np.sqrt(3))/(4*np.sqrt(2))
wav_transform = np.zeros([n,n])
for i in range(int(n/2)-1):
wav_transform[2*i][2*i] = c0
wav_transform[2*i][2*i+1] = c1
wav_transform[2*i][2*i+2] = c2
wav_transform[2*i][2*i+3] = c3
wav_transform[2*i+1][2*i] = c3
wav_transform[2*i+1][2*i+1] = -c2
wav_transform[2*i+1][2*i+2] = c1
wav_transform[2*i+1][2*i+3] = -c0
x = wavelet(x)
plt.plot(x)
def wavelet(x_):
n_ = len(x_)
xw = np.dot(wav_transform,x_)
x_filtered = np.empty([int(n_/2)])
for i in range(int(n_/2)):
x_filtered[i] = xw[2*i]
return x_filtered