Extended Kalman Filter

In [1]:
import numpy as np, matplotlib.pyplot as plt
In [2]:
samples = 1000

x = np.linspace(0, 1000, samples)
y = np.sin(0.1*x + 4*np.sin(0.01*x)) + np.random.normal(scale=0.1, size=samples)
In [3]:
plt.figure()
plt.plot(x, y)
plt.show()
In [6]:
def Kalman(y, sd=1e-1):
    A = np.array([[2., -1.], [1., 0.]])
    N_y = np.array([[1e-1]])
    N_x = np.diag([sd**2,sd**2])

    est_x_t = np.array([[0.],[0.]])
    E = np.identity(2)

    est_x = np.zeros((2, len(y)))
    est_y = np.zeros((len(y)))
    
    for i in range(len(y)):
        B = np.cos(est_x_t).T
        B[0,1] = 0
        est_y_t = np.sin(est_x_t)[0,0]
        y_t = y[i]
        K = E.dot(B.T).dot(np.linalg.pinv(B.dot(E).dot(B.T)  + N_y))
        x_t = est_x_t + K.dot(y_t - est_y_t)
        
        est_x[:,i:i+1] = x_t
        est_y[i] = est_y_t
        
        E = (np.identity(2)-K.dot(B)).dot(E)
        est_x_t  = A.dot(x_t)
        E = A.dot(E).dot(A.T) + N_x
        
    return est_x, est_y
In [7]:
for sd in [1e-1,1e-2,1e-3,1e-4,1e-5]:
    plt.figure()
    plt.title("Kalman Filter tracking for state sd.="+str(sd))
    plt.plot(x,y, label='real_y')
    plt.plot(x, Kalman(y, sd)[1], label='est_y')
    #plt.plot(x, y_2, label='est_y')
    plt.ylim([-2,2])

    #plt.plot(x, y, label='y')
    plt.legend()
    plt.show()

tracking of the measured output breaks down as we assume an ever lower standard deviation in the state

Predicting HMM States

In [78]:
pTrans = [0.5, 0.1]

pHead  = [0.5, 0.4]

def observe(nums, pT=pTrans, pH=pHead):
    state = 0
    states = []
    observations = np.zeros(nums, dtype=int)
    
    for i in range(nums):
        states.append(state)
        if pH[state] > np.random.rand():
            observations[i]=1
        else:
            observations[i]=0
            
        if pT[state] > np.random.rand():
            state = 1-state
    
    return states, observations
In [79]:
s, obs = observe(100)
In [80]:
plt.figure(figsize=(10,7))

plt.plot(np.arange(len(s)), s, label='state [0=fair]')
plt.plot(np.arange(len(s)), obs, label='observation [1=head]')
plt.ylim(-0.2,1.2)
plt.legend()

plt.show()
In [88]:
def viterbi(y, pT=pTrans, pH=pHead):
    v = np.zeros((len(y), 2))
    b = np.zeros((len(y), 2), dtype=int)
    vals = np.array([0,0])
    v[0] = [1,0]
    
    for i in range(1,len(y)):
        for j in range(2):
            p = np.log(pH[j]) if y[i] else np.log(1-pH[j])
            vals = [v[i-1,0], v[i-1][1]]
            vals[1-j] += np.log(pT[1-j])
            b[i,j] = np.argmax(vals)
            v[i,j] = vals[b[i,j]] + p
    
    guess = np.zeros(len(y), dtype=int)
    state = np.argmax(vals)
    
    for i in range(1, len(y)+1):
        guess[-i] = state
        state = b[-i, state]
    
    return guess
    
In [89]:
guess = viterbi(obs)
accuracy = sum(1-np.abs(s-guess))/float(len(s))
In [90]:
print accuracy
0.82
In [ ]: