wk10¶

Let's look at various control strategies for the cart pole system.

See https://deepnote.com/workspace/miana-smith-49fc76d3-915a-4104-b9b0-e7fbf3cac38e/project/nmmwk10-1b09b3b4-e541-498e-9217-19067f09621f/notebook/Notebook%201-e7dc8d2a0e164c74b908fcb9b791602b for hosted notebook.

Imports¶

In [1]:
from copy import deepcopy
from IPython.display import display, HTML, clear_output
import matplotlib.pyplot as plt, mpld3
import numpy as np
import os
import random
from scipy.signal import lfilter

import torch
from torch import nn
from torch.nn import functional as F
from torch.distributions.normal import Normal
from torch.distributions.independent import Independent

from pydrake.all import (
    AddMultibodyPlantSceneGraph,
    ConnectPlanarSceneGraphVisualizer,
    DiagramBuilder,
    FixedOffsetFrame,
    InverseDynamicsController,
    LeafSystem,
    MultibodyPlant,
    Parser,
    PlanarJoint,
    PrismaticJoint,
    RevoluteJoint,
    RigidTransform,
    RotationMatrix,
    Simulator,
    SpatialInertia,
    UnitInertia,
    VectorSystem,
    ControllabilityMatrix,
    Linearize, 
    LinearQuadraticRegulator,
    MeshcatVisualizer,
    Saturation, 
    SceneGraph,
    StartMeshcat, 
    WrapToSystem, 
    ConstantVectorSource, 
    TrajectorySource, 
    Trajectory, 
    DirectCollocation, 
    LogVectorOutput, 
    PiecewisePolynomial, 
    PlanarSceneGraphVisualizer,
    MultibodyPositionToGeometryPose
)
from manipulation.utils import ConfigureParser
from manipulation import running_as_notebook



if running_as_notebook:
    mpld3.enable_notebook()

Start visualizer¶

Only run once, will open in a new tab.

In [2]:
# Start the visualizer (run this cell only once, each instance consumes a port)
meshcat = StartMeshcat()
INFO:drake:Meshcat listening for connections at https://1b09b3b4-e541-498e-9217-19067f09621f.deepnoteproject.com/7000/
Installing NginX server for MeshCat on Deepnote...
Meshcat URL: https://1b09b3b4-e541-498e-9217-19067f09621f.deepnoteproject.com/7000/

Interactive cart pole¶

Slider UI to apply pushing force to the cart pole base. If you want to get the pendulum even close to vertical, you will need to fairly quickly swing between the maximal u values.

In [3]:
class Utraj(VectorSystem):

        def __init__(self):
            # 4 inputs: state: x, theta, xdot, thetadot
            # 1 output: pushing force
            VectorSystem.__init__(self, 4,  1)
            
            self.u = meshcat.GetSliderValue('u')

        # note that this function is called at each time step
        def DoCalcVectorOutput(self, context, state_delta, unused, spring_force):
            u = meshcat.GetSliderValue('u')
            spring_force[:] = u# [- k * (l + delta)]
            
def cartpole_sliders():    
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
    parser = Parser(plant)
    ConfigureParser(parser)
    file_name = 'cartpole.urdf'
    parser.AddModelFromFile(file_name)
    plant.Finalize()
    # reset meshcat visualization
    meshcat.Delete()
    meshcat.DeleteAddedControls()
    # configure for 2D
    meshcat.Set2dRenderMode(xmin=-10.5, xmax=5.5, ymin=-5.5, ymax=5.5)
    MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

    # Setup slider input
    meshcat.AddSlider('u', min=-45.0, max=45.0, step=1, value=0.0)

    # Add the slider input to the base
    force_system = builder.AddSystem(Utraj())
    builder.Connect(plant.get_state_output_port(), force_system.get_input_port())
    builder.Connect(force_system.get_output_port(),
                    plant.get_actuation_input_port())

    diagram = builder.Build()

    # Set up a simulator to run this diagram
    simulator = Simulator(diagram)
    context = simulator.get_mutable_context()

    # Set the initial conditions (x, theta, xdot, thetadot)
    context.SetContinuousState([0, 0, 0, 0]) # at rest

    simulator.set_target_realtime_rate(1.0)

    print('Use the slider in the MeshCat controls apply pushing force u')
    print("Press 'Stop Simulation' in MeshCat to continue.")
    meshcat.AddButton('Stop Simulation')
    while meshcat.GetButtonClicks('Stop Simulation') < 1:
        simulator.AdvanceTo(simulator.get_context().get_time() + 1.0)

cartpole_sliders()
Use the slider in the MeshCat controls apply pushing force u
Press 'Stop Simulation' in MeshCat to continue.

Optimization based trajectories¶

LQR - balancing¶

Below is pretty directly from Underactuated's example, with a modification to step through increasingly perturbed starting conditions (starting close to to the upright state and moving farther away). The cart pole dynamics are linearized about the unstable equilibrium point and LQR is used to stabilize the cart pole around this point — so this won't work well for the very out there initial conditions.

Start the simulation by pressing the 'start simulation' button. Each section lasts 5 seconds. When it is done simulating that section, the 'done' button will be set to 1. At this point, press 'start simulation' again to advance through to the next initial condition.

In [15]:
# This is from the provided LQR cartpole example in: http://underactuated.mit.edu/acrobot.html

# reset meshcat visualization
meshcat.Delete()
meshcat.DeleteAddedControls()

def cartpole_LQR():
    def UprightState():
        state = (0, np.pi, 0, 0)
        return state

    def Controllability(plant):
        context = plant.CreateDefaultContext()
        plant.get_actuation_input_port().FixValue(context, [0])
        plant.SetPositionsAndVelocities(context, UprightState())

        linearized_plant = Linearize(
            plant,
            context,
            input_port_index=plant.get_actuation_input_port().get_index(),
            output_port_index=plant.get_state_output_port().get_index(),
        )
        print(linearized_plant.A())
        print(linearized_plant.B())
        print(
            f"The singular values of the controllability matrix are: {np.linalg.svd(ControllabilityMatrix(linearized_plant), compute_uv=False)}"
        )

    def BalancingLQR(plant):
        # Design an LQR controller for stabilizing the CartPole around the upright.
        # Returns a (static) AffineSystem that implements the controller (in
        # the original CartPole coordinates).

        context = plant.CreateDefaultContext()
        plant.get_actuation_input_port().FixValue(context, [0])

        plant.SetPositionsAndVelocities(context, UprightState())

        Q = np.diag((10.0, 10.0, 1.0, 1.0))
        R = np.array([1])

        # MultibodyPlant has many (optional) input ports, so we must pass the
        # input_port_index to LQR.
        return LinearQuadraticRegulator(
            plant,
            context,
            Q,
            R,
            input_port_index=plant.get_actuation_input_port().get_index(),
        )

    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.1)
    parser = Parser(plant)
    ConfigureParser(parser)
    file_name = 'cartpole.urdf'
    parser.AddModelFromFile(file_name)
    plant.Finalize()

    controller = builder.AddSystem(BalancingLQR(plant))
    builder.Connect(
        plant.get_state_output_port(), controller.get_input_port(0)
    )
    builder.Connect(
        controller.get_output_port(0), plant.get_actuation_input_port()
    )

    # Setup visualization
    meshcat.Delete()
    meshcat.Set2dRenderMode(xmin=-2.5, xmax=2.5, ymin=-1.0, ymax=2.5)
    MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)
    
    # add a slider to publish values
    meshcat.AddSlider('done', min=0, max=1, step=1, value=0.0)

    diagram = builder.Build()

    # Set up a simulator to run this diagram
    simulator = Simulator(diagram)
    context = simulator.get_mutable_context()
    plant_context = plant.GetMyMutableContextFromRoot(context)

    # Simulate
    simulator.set_target_realtime_rate(1.0 if running_as_notebook else 0.0)
    duration = 5.0 if running_as_notebook else 0.1
    meshcat.AddButton('Start Simulation')
    # while meshcat.GetButtonClicks('Start Simulation') < 1:
    #     #pause until button click
    #     pass
    print("Press 'Start Simulation' in MeshCat to start simulation, and press 'Start Simulation' again to advance through.")
    print(f"simulation runs for {duration} seconds and resets 5 times to increasingly worse initial conditions. Done button turns to 1 when ready to advance.")
    for i in range(5):
        while meshcat.GetButtonClicks('Start Simulation') < (i+1):
            #pause until button click
            pass
        meshcat.SetSliderValue('done', 0)
        context.SetTime(0.0)
        perturb = 0.05*(i+1)*np.ones((4,)) # step through more perturbed initial states
        perturb[1] += 0.22*i # want to mess with initial theta a bit more
        plant.SetPositionsAndVelocities(
            plant_context,
            UprightState() + perturb,
        )
        simulator.Initialize()
        simulator.AdvanceTo(duration)
        meshcat.SetSliderValue('done', 1)
cartpole_LQR()
Press 'Start Simulation' in MeshCat to start simulation.
simulation runs for 5.0 seconds and resets 5 times to a new slightly perturbed position

swing up¶

energy controller - this will get us to the stable point but it won't stabilize us near there! this simulates us forward for 5 seconds to show this. the performance could likely be "better" however you define that if you actually spent some time tuning the gains.

In [58]:
def get_cartpole_parameters():
    # parse urdf
    plant = MultibodyPlant(time_step=0)
    parser = Parser(plant)
    ConfigureParser(parser)
    file_name = 'cartpole.urdf'
    parser.AddModelFromFile(file_name)
    plant.Finalize()

    # retrieve physical parameters
    m_cart = plant.GetBodyByName("cart").default_mass()
    m_pole = plant.GetBodyByName("pole").default_mass()
    g = -plant.gravity_field().gravity_vector()[-1]
    l = 0.5 # don't know how to get this out the urdf

    return m_cart, m_pole, l, g

class SwingUpController(VectorSystem):

        def __init__(self,K):
            # 4 inputs: state: x, theta, xdot, thetadot
            # 1 output: pushing force
            VectorSystem.__init__(self, 4,  1)
            self.m_cart, self.m_pole, self.l, self.g = get_cartpole_parameters()
            # controller gains
            self.ke = K[0]
            self.kp = K[1]
            self.kd = K[2]

        # note that this function is called at each time step
        def DoCalcVectorOutput(self, context, state, unused, u):
            x = state[0]
            theta = state[1]
            xdot = state[2]
            thetadot = state[3]
            # i don't like typing but perhaps this wasn't time saving
            m = self.m_pole
            l = self.l
            g = self.g
            ke = self.ke
            kp = self.kp
            kd = self.kd
            # energy of the inverted pendulum
            E = 1/2*m*l**2*thetadot**2 - m*g*l*(1-np.cos(theta))
            # pendulum + pd control for the base
            u[:] = ke*thetadot*np.cos(theta)*E - kp*x - kd*xdot
            # print(u)
            # u[:] = 1.0


def cartpole_swing():
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
    parser = Parser(plant)
    ConfigureParser(parser)
    file_name = 'cartpole.urdf'
    parser.AddModelFromFile(file_name)
    plant.Finalize()
    # reset meshcat visualization
    meshcat.Delete()
    meshcat.DeleteAddedControls()
    # configure for 2D
    meshcat.Set2dRenderMode(xmin=-10.5, xmax=5.5, ymin=-5.5, ymax=5.5)
    MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

    # Add controller
    K = [1.0,1.0,1.0]
    force_system = builder.AddSystem(SwingUpController(K))
    builder.Connect(plant.get_state_output_port(), force_system.get_input_port())
    builder.Connect(force_system.get_output_port(),
                    plant.get_actuation_input_port())

    diagram = builder.Build()

    # Set up a simulator to run this diagram
    simulator = Simulator(diagram)
    context = simulator.get_mutable_context()

    # Set the initial conditions (x, theta, xdot, thetadot)
    context.SetContinuousState([0.1, 2.0, 0.1, 0]) #

    simulator.set_target_realtime_rate(1.5)
    meshcat.AddButton('Start Simulation')
    while meshcat.GetButtonClicks('Start Simulation') < 1:
            pass
    simulator.AdvanceTo(5.0)

# reset meshcat visualization
meshcat.Delete()
meshcat.DeleteAddedControls()
# simulate (swing up only)

cartpole_swing()

lqr + energy control¶

we will combine these two control strategies -> so we start in energy control, and once we get into to vicinity of the LQR controller's attraction basin, we'll switch over.

In [73]:
def UprightState():
    state = (0, np.pi, 0, 0)
    return state

def BalancingLQR(plant):
    # Design an LQR controller for stabilizing the CartPole around the upright.
    # Returns a (static) AffineSystem that implements the controller (in
    # the original CartPole coordinates).

    context = plant.CreateDefaultContext()
    plant.get_actuation_input_port().FixValue(context, [0])

    plant.SetPositionsAndVelocities(context, UprightState())

    Q = np.diag((10.0, 10.0, 1.0, 1.0))
    R = np.array([1])

    # MultibodyPlant has many (optional) input ports, so we must pass the
    # input_port_index to LQR.
    return LinearQuadraticRegulator(
        plant,
        context,
        Q,
        R,
        input_port_index=plant.get_actuation_input_port().get_index(),
    )



class StableSwingUpController(VectorSystem):
    def __init__(self, K, plant, SwingUpController, BalancingLQR):
        VectorSystem.__init__(self, 4,  1)
        self.m_cart, self.m_pole, self.l, self.g = get_cartpole_parameters()
        # controller gains
        self.ke = K[0]
        self.kp = K[1]
        self.kd = K[2]
        self.SwingUpController = SwingUpController
        self.plant = plant
    
    def DoCalcVectorOutput(self, context, state, unused, u):
            x = state[0]
            theta = state[1]
            xdot = state[2]
            thetadot = state[3]

            if abs(theta) > 0.25:
                # do energy control
                u[:] = self.SwingUpController.DoCalcVectorOutput(self, context, state, unused, u)
                print(u)
            else:
                # do LQR
                # u[:] = BalancingLQR(self.plant).DoCalcVectorOutput()



def cartpole_swingup():
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
    parser = Parser(plant)
    ConfigureParser(parser)
    file_name = 'cartpole.urdf'
    parser.AddModelFromFile(file_name)
    plant.Finalize()
    # reset meshcat visualization
    meshcat.Delete()
    meshcat.DeleteAddedControls()
    # configure for 2D
    meshcat.Set2dRenderMode(xmin=-10.5, xmax=5.5, ymin=-5.5, ymax=5.5)
    MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

    # Add controller
    K = [1.0,1.0,1.0]
    force_system = builder.AddSystem(StableSwingUpController(K,plant,SwingUpController, BalancingLQR))
    builder.Connect(plant.get_state_output_port(), force_system.get_input_port())
    builder.Connect(force_system.get_output_port(),
                    plant.get_actuation_input_port())

    diagram = builder.Build()

    # Set up a simulator to run this diagram
    simulator = Simulator(diagram)
    context = simulator.get_mutable_context()

    # Set the initial conditions (x, theta, xdot, thetadot)
    context.SetContinuousState([0.1, 0.1, 0.1, 0]) #

    simulator.set_target_realtime_rate(1.0)
    meshcat.AddButton('Start Simulation')
    # while meshcat.GetButtonClicks('Start Simulation') < 1:
    #         pass
    simulator.AdvanceTo(2.0)

# reset meshcat visualization
meshcat.Delete()
meshcat.DeleteAddedControls()
# simulate (swing up only)

cartpole_swingup()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In [73], line 99
     96 meshcat.DeleteAddedControls()
     97 # simulate (swing up only)
---> 99 cartpole_swingup()

Cell In [73], line 92, in cartpole_swingup()
     89 meshcat.AddButton('Start Simulation')
     90 # while meshcat.GetButtonClicks('Start Simulation') < 1:
     91 #         pass
---> 92 simulator.AdvanceTo(2.0)

Cell In [73], line 53, in StableSwingUpController.DoCalcVectorOutput(self, context, state, unused, u)
     50     print(u)
     51 else:
     52     # do LQR
---> 53     u[:] = BalancingLQR(self.plant).DoCalcVectorOutput()

AttributeError: 'pydrake.systems.primitives.AffineSystem' object has no attribute 'DoCalcVectorOutput'
In [63]:
print('hi')
hi
Created in deepnote.com Created in Deepnote