Let's look at various control strategies for the cart pole system.
from copy import deepcopy
from IPython.display import display, HTML, clear_output
import matplotlib.pyplot as plt, mpld3
import numpy as np
import os
import random
from scipy.signal import lfilter
import torch
from torch import nn
from torch.nn import functional as F
from torch.distributions.normal import Normal
from torch.distributions.independent import Independent
from pydrake.all import (
AddMultibodyPlantSceneGraph,
ConnectPlanarSceneGraphVisualizer,
DiagramBuilder,
FixedOffsetFrame,
InverseDynamicsController,
LeafSystem,
MultibodyPlant,
Parser,
PlanarJoint,
PrismaticJoint,
RevoluteJoint,
RigidTransform,
RotationMatrix,
Simulator,
SpatialInertia,
UnitInertia,
VectorSystem,
ControllabilityMatrix,
Linearize,
LinearQuadraticRegulator,
MeshcatVisualizer,
Saturation,
SceneGraph,
StartMeshcat,
WrapToSystem,
ConstantVectorSource,
TrajectorySource,
Trajectory,
DirectCollocation,
LogVectorOutput,
PiecewisePolynomial,
PlanarSceneGraphVisualizer,
MultibodyPositionToGeometryPose
)
from manipulation.utils import ConfigureParser
from manipulation import running_as_notebook
if running_as_notebook:
mpld3.enable_notebook()
Only run once, will open in a new tab.
# Start the visualizer (run this cell only once, each instance consumes a port)
meshcat = StartMeshcat()
INFO:drake:Meshcat listening for connections at https://1b09b3b4-e541-498e-9217-19067f09621f.deepnoteproject.com/7000/ Installing NginX server for MeshCat on Deepnote...
Slider UI to apply pushing force to the cart pole base. If you want to get the pendulum even close to vertical, you will need to fairly quickly swing between the maximal u values.
class Utraj(VectorSystem):
def __init__(self):
# 4 inputs: state: x, theta, xdot, thetadot
# 1 output: pushing force
VectorSystem.__init__(self, 4, 1)
self.u = meshcat.GetSliderValue('u')
# note that this function is called at each time step
def DoCalcVectorOutput(self, context, state_delta, unused, spring_force):
u = meshcat.GetSliderValue('u')
spring_force[:] = u# [- k * (l + delta)]
def cartpole_sliders():
builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
parser = Parser(plant)
ConfigureParser(parser)
file_name = 'cartpole.urdf'
parser.AddModelFromFile(file_name)
plant.Finalize()
# reset meshcat visualization
meshcat.Delete()
meshcat.DeleteAddedControls()
# configure for 2D
meshcat.Set2dRenderMode(xmin=-10.5, xmax=5.5, ymin=-5.5, ymax=5.5)
MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)
# Setup slider input
meshcat.AddSlider('u', min=-45.0, max=45.0, step=1, value=0.0)
# Add the slider input to the base
force_system = builder.AddSystem(Utraj())
builder.Connect(plant.get_state_output_port(), force_system.get_input_port())
builder.Connect(force_system.get_output_port(),
plant.get_actuation_input_port())
diagram = builder.Build()
# Set up a simulator to run this diagram
simulator = Simulator(diagram)
context = simulator.get_mutable_context()
# Set the initial conditions (x, theta, xdot, thetadot)
context.SetContinuousState([0, 0, 0, 0]) # at rest
simulator.set_target_realtime_rate(1.0)
print('Use the slider in the MeshCat controls apply pushing force u')
print("Press 'Stop Simulation' in MeshCat to continue.")
meshcat.AddButton('Stop Simulation')
while meshcat.GetButtonClicks('Stop Simulation') < 1:
simulator.AdvanceTo(simulator.get_context().get_time() + 1.0)
cartpole_sliders()
Use the slider in the MeshCat controls apply pushing force u Press 'Stop Simulation' in MeshCat to continue.
Below is pretty directly from Underactuated's example, with a modification to step through increasingly perturbed starting conditions (starting close to to the upright state and moving farther away). The cart pole dynamics are linearized about the unstable equilibrium point and LQR is used to stabilize the cart pole around this point — so this won't work well for the very out there initial conditions.
Start the simulation by pressing the 'start simulation' button. Each section lasts 5 seconds. When it is done simulating that section, the 'done' button will be set to 1. At this point, press 'start simulation' again to advance through to the next initial condition.
# This is from the provided LQR cartpole example in: http://underactuated.mit.edu/acrobot.html
# reset meshcat visualization
meshcat.Delete()
meshcat.DeleteAddedControls()
def cartpole_LQR():
def UprightState():
state = (0, np.pi, 0, 0)
return state
def Controllability(plant):
context = plant.CreateDefaultContext()
plant.get_actuation_input_port().FixValue(context, [0])
plant.SetPositionsAndVelocities(context, UprightState())
linearized_plant = Linearize(
plant,
context,
input_port_index=plant.get_actuation_input_port().get_index(),
output_port_index=plant.get_state_output_port().get_index(),
)
print(linearized_plant.A())
print(linearized_plant.B())
print(
f"The singular values of the controllability matrix are: {np.linalg.svd(ControllabilityMatrix(linearized_plant), compute_uv=False)}"
)
def BalancingLQR(plant):
# Design an LQR controller for stabilizing the CartPole around the upright.
# Returns a (static) AffineSystem that implements the controller (in
# the original CartPole coordinates).
context = plant.CreateDefaultContext()
plant.get_actuation_input_port().FixValue(context, [0])
plant.SetPositionsAndVelocities(context, UprightState())
Q = np.diag((10.0, 10.0, 1.0, 1.0))
R = np.array([1])
# MultibodyPlant has many (optional) input ports, so we must pass the
# input_port_index to LQR.
return LinearQuadraticRegulator(
plant,
context,
Q,
R,
input_port_index=plant.get_actuation_input_port().get_index(),
)
builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.1)
parser = Parser(plant)
ConfigureParser(parser)
file_name = 'cartpole.urdf'
parser.AddModelFromFile(file_name)
plant.Finalize()
controller = builder.AddSystem(BalancingLQR(plant))
builder.Connect(
plant.get_state_output_port(), controller.get_input_port(0)
)
builder.Connect(
controller.get_output_port(0), plant.get_actuation_input_port()
)
# Setup visualization
meshcat.Delete()
meshcat.Set2dRenderMode(xmin=-2.5, xmax=2.5, ymin=-1.0, ymax=2.5)
MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)
# add a slider to publish values
meshcat.AddSlider('done', min=0, max=1, step=1, value=0.0)
diagram = builder.Build()
# Set up a simulator to run this diagram
simulator = Simulator(diagram)
context = simulator.get_mutable_context()
plant_context = plant.GetMyMutableContextFromRoot(context)
# Simulate
simulator.set_target_realtime_rate(1.0 if running_as_notebook else 0.0)
duration = 5.0 if running_as_notebook else 0.1
meshcat.AddButton('Start Simulation')
# while meshcat.GetButtonClicks('Start Simulation') < 1:
# #pause until button click
# pass
print("Press 'Start Simulation' in MeshCat to start simulation, and press 'Start Simulation' again to advance through.")
print(f"simulation runs for {duration} seconds and resets 5 times to increasingly worse initial conditions. Done button turns to 1 when ready to advance.")
for i in range(5):
while meshcat.GetButtonClicks('Start Simulation') < (i+1):
#pause until button click
pass
meshcat.SetSliderValue('done', 0)
context.SetTime(0.0)
perturb = 0.05*(i+1)*np.ones((4,)) # step through more perturbed initial states
perturb[1] += 0.22*i # want to mess with initial theta a bit more
plant.SetPositionsAndVelocities(
plant_context,
UprightState() + perturb,
)
simulator.Initialize()
simulator.AdvanceTo(duration)
meshcat.SetSliderValue('done', 1)
cartpole_LQR()
Press 'Start Simulation' in MeshCat to start simulation. simulation runs for 5.0 seconds and resets 5 times to a new slightly perturbed position
energy controller - this will get us to the stable point but it won't stabilize us near there! this simulates us forward for 5 seconds to show this. the performance could likely be "better" however you define that if you actually spent some time tuning the gains.
def get_cartpole_parameters():
# parse urdf
plant = MultibodyPlant(time_step=0)
parser = Parser(plant)
ConfigureParser(parser)
file_name = 'cartpole.urdf'
parser.AddModelFromFile(file_name)
plant.Finalize()
# retrieve physical parameters
m_cart = plant.GetBodyByName("cart").default_mass()
m_pole = plant.GetBodyByName("pole").default_mass()
g = -plant.gravity_field().gravity_vector()[-1]
l = 0.5 # don't know how to get this out the urdf
return m_cart, m_pole, l, g
class SwingUpController(VectorSystem):
def __init__(self,K):
# 4 inputs: state: x, theta, xdot, thetadot
# 1 output: pushing force
VectorSystem.__init__(self, 4, 1)
self.m_cart, self.m_pole, self.l, self.g = get_cartpole_parameters()
# controller gains
self.ke = K[0]
self.kp = K[1]
self.kd = K[2]
# note that this function is called at each time step
def DoCalcVectorOutput(self, context, state, unused, u):
x = state[0]
theta = state[1]
xdot = state[2]
thetadot = state[3]
# i don't like typing but perhaps this wasn't time saving
m = self.m_pole
l = self.l
g = self.g
ke = self.ke
kp = self.kp
kd = self.kd
# energy of the inverted pendulum
E = 1/2*m*l**2*thetadot**2 - m*g*l*(1-np.cos(theta))
# pendulum + pd control for the base
u[:] = ke*thetadot*np.cos(theta)*E - kp*x - kd*xdot
# print(u)
# u[:] = 1.0
def cartpole_swing():
builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
parser = Parser(plant)
ConfigureParser(parser)
file_name = 'cartpole.urdf'
parser.AddModelFromFile(file_name)
plant.Finalize()
# reset meshcat visualization
meshcat.Delete()
meshcat.DeleteAddedControls()
# configure for 2D
meshcat.Set2dRenderMode(xmin=-10.5, xmax=5.5, ymin=-5.5, ymax=5.5)
MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)
# Add controller
K = [1.0,1.0,1.0]
force_system = builder.AddSystem(SwingUpController(K))
builder.Connect(plant.get_state_output_port(), force_system.get_input_port())
builder.Connect(force_system.get_output_port(),
plant.get_actuation_input_port())
diagram = builder.Build()
# Set up a simulator to run this diagram
simulator = Simulator(diagram)
context = simulator.get_mutable_context()
# Set the initial conditions (x, theta, xdot, thetadot)
context.SetContinuousState([0.1, 2.0, 0.1, 0]) #
simulator.set_target_realtime_rate(1.5)
meshcat.AddButton('Start Simulation')
while meshcat.GetButtonClicks('Start Simulation') < 1:
pass
simulator.AdvanceTo(5.0)
# reset meshcat visualization
meshcat.Delete()
meshcat.DeleteAddedControls()
# simulate (swing up only)
cartpole_swing()
we will combine these two control strategies -> so we start in energy control, and once we get into to vicinity of the LQR controller's attraction basin, we'll switch over.
def UprightState():
state = (0, np.pi, 0, 0)
return state
def BalancingLQR(plant):
# Design an LQR controller for stabilizing the CartPole around the upright.
# Returns a (static) AffineSystem that implements the controller (in
# the original CartPole coordinates).
context = plant.CreateDefaultContext()
plant.get_actuation_input_port().FixValue(context, [0])
plant.SetPositionsAndVelocities(context, UprightState())
Q = np.diag((10.0, 10.0, 1.0, 1.0))
R = np.array([1])
# MultibodyPlant has many (optional) input ports, so we must pass the
# input_port_index to LQR.
return LinearQuadraticRegulator(
plant,
context,
Q,
R,
input_port_index=plant.get_actuation_input_port().get_index(),
)
class StableSwingUpController(VectorSystem):
def __init__(self, K, plant, SwingUpController, BalancingLQR):
VectorSystem.__init__(self, 4, 1)
self.m_cart, self.m_pole, self.l, self.g = get_cartpole_parameters()
# controller gains
self.ke = K[0]
self.kp = K[1]
self.kd = K[2]
self.SwingUpController = SwingUpController
self.plant = plant
def DoCalcVectorOutput(self, context, state, unused, u):
x = state[0]
theta = state[1]
xdot = state[2]
thetadot = state[3]
if abs(theta) > 0.25:
# do energy control
u[:] = self.SwingUpController.DoCalcVectorOutput(self, context, state, unused, u)
print(u)
else:
# do LQR
# u[:] = BalancingLQR(self.plant).DoCalcVectorOutput()
def cartpole_swingup():
builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
parser = Parser(plant)
ConfigureParser(parser)
file_name = 'cartpole.urdf'
parser.AddModelFromFile(file_name)
plant.Finalize()
# reset meshcat visualization
meshcat.Delete()
meshcat.DeleteAddedControls()
# configure for 2D
meshcat.Set2dRenderMode(xmin=-10.5, xmax=5.5, ymin=-5.5, ymax=5.5)
MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)
# Add controller
K = [1.0,1.0,1.0]
force_system = builder.AddSystem(StableSwingUpController(K,plant,SwingUpController, BalancingLQR))
builder.Connect(plant.get_state_output_port(), force_system.get_input_port())
builder.Connect(force_system.get_output_port(),
plant.get_actuation_input_port())
diagram = builder.Build()
# Set up a simulator to run this diagram
simulator = Simulator(diagram)
context = simulator.get_mutable_context()
# Set the initial conditions (x, theta, xdot, thetadot)
context.SetContinuousState([0.1, 0.1, 0.1, 0]) #
simulator.set_target_realtime_rate(1.0)
meshcat.AddButton('Start Simulation')
# while meshcat.GetButtonClicks('Start Simulation') < 1:
# pass
simulator.AdvanceTo(2.0)
# reset meshcat visualization
meshcat.Delete()
meshcat.DeleteAddedControls()
# simulate (swing up only)
cartpole_swingup()
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In [73], line 99 96 meshcat.DeleteAddedControls() 97 # simulate (swing up only) ---> 99 cartpole_swingup() Cell In [73], line 92, in cartpole_swingup() 89 meshcat.AddButton('Start Simulation') 90 # while meshcat.GetButtonClicks('Start Simulation') < 1: 91 # pass ---> 92 simulator.AdvanceTo(2.0) Cell In [73], line 53, in StableSwingUpController.DoCalcVectorOutput(self, context, state, unused, u) 50 print(u) 51 else: 52 # do LQR ---> 53 u[:] = BalancingLQR(self.plant).DoCalcVectorOutput() AttributeError: 'pydrake.systems.primitives.AffineSystem' object has no attribute 'DoCalcVectorOutput'
print('hi')
hi