Skip to content

Commit cc8b7a9

Browse files
committed
[wpimath] Add Sleipnir Java bindings
1 parent 7a2a982 commit cc8b7a9

File tree

116 files changed

+13070
-1153
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+13070
-1153
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright (c) FIRST and other WPILib contributors.
2+
// Open Source Software; you can modify and/or share it under the terms of
3+
// the WPILib BSD license file in the root directory of this project.
4+
5+
package frc.robot;
6+
7+
import static edu.wpi.first.math.autodiff.NumericalIntegration.rk4;
8+
import static edu.wpi.first.math.autodiff.Variable.cos;
9+
import static edu.wpi.first.math.autodiff.Variable.sin;
10+
import static edu.wpi.first.math.autodiff.VariableMatrix.solve;
11+
import static edu.wpi.first.math.optimization.Constraints.eq;
12+
import static edu.wpi.first.math.optimization.Constraints.ge;
13+
import static edu.wpi.first.math.optimization.Constraints.le;
14+
15+
import edu.wpi.first.math.MathUtil;
16+
import edu.wpi.first.math.autodiff.Variable;
17+
import edu.wpi.first.math.autodiff.VariableMatrix;
18+
import edu.wpi.first.math.optimization.Problem;
19+
import edu.wpi.first.math.optimization.solver.Options;
20+
import org.ejml.simple.SimpleMatrix;
21+
22+
public final class CartPoleBenchmark {
23+
private CartPoleBenchmark() {
24+
// Utility class.
25+
}
26+
27+
@SuppressWarnings("LocalVariableName")
28+
private static VariableMatrix cartPoleDynamics(VariableMatrix x, VariableMatrix u) {
29+
final double m_c = 5.0; // Cart mass (kg)
30+
final double m_p = 0.5; // Pole mass (kg)
31+
final double l = 0.5; // Pole length (m)
32+
final double g = 9.806; // Acceleration due to gravity (m/s²)
33+
34+
var q = x.segment(0, 2);
35+
var qdot = x.segment(2, 2);
36+
var theta = q.get(1);
37+
var thetadot = qdot.get(1);
38+
39+
// [ m_c + m_p m_p l cosθ]
40+
// M(q) = [m_p l cosθ m_p l² ]
41+
var M =
42+
new VariableMatrix(
43+
new Variable[][] {
44+
{new Variable(m_c + m_p), cos(theta).times(m_p * l)},
45+
{cos(theta).times(m_p * l), new Variable(m_p * Math.pow(l, 2))}
46+
});
47+
48+
// [0 −m_p lθ̇ sinθ]
49+
// C(q, q̇) = [0 0 ]
50+
var C =
51+
new VariableMatrix(
52+
new Variable[][] {
53+
{new Variable(0), thetadot.times(-m_p * l).times(sin(theta))},
54+
{new Variable(0), new Variable(0)}
55+
});
56+
57+
// [ 0 ]
58+
// τ_g(q) = [-m_p gl sinθ]
59+
var tau_g =
60+
new VariableMatrix(new Variable[][] {{new Variable(0)}, {sin(theta).times(-m_p * g * l)}});
61+
62+
// [1]
63+
// B = [0]
64+
var B = new VariableMatrix(new double[][] {{1}, {0}});
65+
66+
// q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu)
67+
var qddot = new VariableMatrix(4);
68+
qddot.segment(0, 2).set(qdot);
69+
qddot.segment(2, 2).set(solve(M, tau_g.minus(C.times(qdot)).plus(B.times(u))));
70+
return qddot;
71+
}
72+
73+
/** Cart-pole benchmark. */
74+
public static void cartPole() {
75+
final double T = 5.0; // s
76+
final double dt = 0.05; // s
77+
final int N = (int) (T / dt);
78+
79+
final double u_max = 20.0; // N
80+
final double d_max = 2.0; // m
81+
82+
final var x_initial = new SimpleMatrix(new double[][] {{0.0}, {0.0}, {0.0}, {0.0}});
83+
final var x_final = new SimpleMatrix(new double[][] {{1.0}, {Math.PI}, {0.0}, {0.0}});
84+
85+
var problem = new Problem();
86+
87+
// x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ
88+
var X = problem.decisionVariable(4, N + 1);
89+
90+
// Initial guess
91+
for (int k = 0; k < N + 1; ++k) {
92+
X.get(0, k).setValue(MathUtil.lerp(x_initial.get(0), x_final.get(0), (double) k / N));
93+
X.get(1, k).setValue(MathUtil.lerp(x_initial.get(1), x_final.get(1), (double) k / N));
94+
}
95+
96+
// u = f_x
97+
var U = problem.decisionVariable(1, N);
98+
99+
// Initial conditions
100+
problem.subjectTo(eq(X.col(0), x_initial));
101+
102+
// Final conditions
103+
problem.subjectTo(eq(X.col(N), x_final));
104+
105+
// Cart position constraints
106+
problem.subjectTo(ge(X.row(0), 0.0));
107+
problem.subjectTo(le(X.row(0), d_max));
108+
109+
// Input constraints
110+
problem.subjectTo(ge(U, -u_max));
111+
problem.subjectTo(le(U, u_max));
112+
113+
// Dynamics constraints - RK4 integration
114+
for (int k = 0; k < N; ++k) {
115+
problem.subjectTo(
116+
eq(X.col(k + 1), rk4(CartPoleBenchmark::cartPoleDynamics, X.col(k), U.col(k), dt)));
117+
}
118+
119+
// Minimize sum squared inputs
120+
var J = new Variable(0.0);
121+
for (int k = 0; k < N; ++k) {
122+
J = J.plus(U.col(k).T().times(U.col(k)).get(0));
123+
}
124+
problem.minimize(J);
125+
126+
problem.solve(new Options().withDiagnostics(true));
127+
}
128+
}

benchmark/src/main/java/frc/robot/Main.java

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
package frc.robot;
66

77
import edu.wpi.first.math.geometry.Pose2d;
8-
import edu.wpi.first.math.geometry.Rotation2d;
9-
import edu.wpi.first.math.path.TravelingSalesman;
108
import java.util.concurrent.TimeUnit;
119
import org.openjdk.jmh.annotations.Benchmark;
1210
import org.openjdk.jmh.annotations.BenchmarkMode;
@@ -15,41 +13,17 @@
1513
import org.openjdk.jmh.profile.GCProfiler;
1614
import org.openjdk.jmh.runner.Runner;
1715
import org.openjdk.jmh.runner.RunnerException;
18-
import org.openjdk.jmh.runner.options.Options;
1916
import org.openjdk.jmh.runner.options.OptionsBuilder;
2017
import org.openjdk.jmh.runner.options.TimeValue;
2118

2219
public class Main {
23-
private static final Pose2d[] poses = {
24-
new Pose2d(-1, 1, Rotation2d.kCW_90deg),
25-
new Pose2d(-1, 2, Rotation2d.kCCW_90deg),
26-
new Pose2d(0, 0, Rotation2d.kZero),
27-
new Pose2d(0, 3, Rotation2d.kCW_90deg),
28-
new Pose2d(1, 1, Rotation2d.kCCW_90deg),
29-
new Pose2d(1, 2, Rotation2d.kCCW_90deg),
30-
};
31-
private static final int iterations = 100;
32-
33-
private static final TravelingSalesman transformTraveler =
34-
new TravelingSalesman(
35-
(pose1, pose2) -> {
36-
var transform = pose2.minus(pose1);
37-
return Math.hypot(transform.getX(), transform.getY());
38-
});
39-
private static final TravelingSalesman twistTraveler =
40-
new TravelingSalesman(
41-
(pose1, pose2) -> {
42-
var twist = pose2.minus(pose1).log();
43-
return Math.hypot(twist.dx, twist.dy);
44-
});
45-
4620
/**
4721
* Main function.
4822
*
4923
* @param args The (unused) arguments to the program.
5024
*/
5125
public static void main(String... args) throws RunnerException {
52-
Options opt =
26+
var opt =
5327
new OptionsBuilder()
5428
.include(Main.class.getSimpleName())
5529
.addProfiler(GCProfiler.class)
@@ -66,14 +40,21 @@ public static void main(String... args) throws RunnerException {
6640
@Benchmark
6741
@BenchmarkMode(Mode.AverageTime)
6842
@OutputTimeUnit(TimeUnit.MICROSECONDS)
69-
public Pose2d[] transform() {
70-
return transformTraveler.solve(poses, iterations);
43+
public void cartPole() {
44+
CartPoleBenchmark.cartPole();
45+
}
46+
47+
@Benchmark
48+
@BenchmarkMode(Mode.AverageTime)
49+
@OutputTimeUnit(TimeUnit.MICROSECONDS)
50+
public Pose2d[] travelingSalesmanTransform() {
51+
return TravelingSalesmanBenchmark.transform();
7152
}
7253

7354
@Benchmark
7455
@BenchmarkMode(Mode.AverageTime)
7556
@OutputTimeUnit(TimeUnit.MICROSECONDS)
76-
public Pose2d[] twist() {
77-
return twistTraveler.solve(poses, iterations);
57+
public Pose2d[] travelingSalesmanTwist() {
58+
return TravelingSalesmanBenchmark.twist();
7859
}
7960
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) FIRST and other WPILib contributors.
2+
// Open Source Software; you can modify and/or share it under the terms of
3+
// the WPILib BSD license file in the root directory of this project.
4+
5+
package frc.robot;
6+
7+
import edu.wpi.first.math.geometry.Pose2d;
8+
import edu.wpi.first.math.geometry.Rotation2d;
9+
import edu.wpi.first.math.path.TravelingSalesman;
10+
11+
public final class TravelingSalesmanBenchmark {
12+
private TravelingSalesmanBenchmark() {
13+
// Utility class.
14+
}
15+
16+
private static final Pose2d[] poses = {
17+
new Pose2d(-1, 1, Rotation2d.kCW_90deg),
18+
new Pose2d(-1, 2, Rotation2d.kCCW_90deg),
19+
new Pose2d(0, 0, Rotation2d.kZero),
20+
new Pose2d(0, 3, Rotation2d.kCW_90deg),
21+
new Pose2d(1, 1, Rotation2d.kCCW_90deg),
22+
new Pose2d(1, 2, Rotation2d.kCCW_90deg),
23+
};
24+
private static final int iterations = 100;
25+
26+
private static final TravelingSalesman transformTraveler =
27+
new TravelingSalesman(
28+
(pose1, pose2) -> {
29+
var transform = pose2.minus(pose1);
30+
return Math.hypot(transform.getX(), transform.getY());
31+
});
32+
private static final TravelingSalesman twistTraveler =
33+
new TravelingSalesman(
34+
(pose1, pose2) -> {
35+
var twist = pose2.minus(pose1).log();
36+
return Math.hypot(twist.dx, twist.dy);
37+
});
38+
39+
/** TravelingSalesman transform benchmark. */
40+
public static Pose2d[] transform() {
41+
return transformTraveler.solve(poses, iterations);
42+
}
43+
44+
/** TravelingSalesman twist benchmark. */
45+
public static Pose2d[] twist() {
46+
return twistTraveler.solve(poses, iterations);
47+
}
48+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright (c) FIRST and other WPILib contributors.
2+
// Open Source Software; you can modify and/or share it under the terms of
3+
// the WPILib BSD license file in the root directory of this project.
4+
5+
#pragma once
6+
7+
#include <benchmark/benchmark.h>
8+
#include <frc/system/NumericalIntegration.h>
9+
10+
#include <chrono>
11+
12+
#include <sleipnir/optimization/problem.hpp>
13+
14+
inline slp::VariableMatrix CartPoleDynamics(const slp::VariableMatrix& x,
15+
const slp::VariableMatrix& u) {
16+
constexpr double m_c = 5.0; // Cart mass (kg)
17+
constexpr double m_p = 0.5; // Pole mass (kg)
18+
constexpr double l = 0.5; // Pole length (m)
19+
constexpr double g = 9.806; // Acceleration due to gravity (m/s²)
20+
21+
auto q = x.segment(0, 2);
22+
auto qdot = x.segment(2, 2);
23+
auto theta = q[1];
24+
auto thetadot = qdot[1];
25+
26+
// [ m_c + m_p m_p l cosθ]
27+
// M(q) = [m_p l cosθ m_p l² ]
28+
slp::VariableMatrix M{{m_c + m_p, m_p * l * slp::cos(theta)},
29+
{m_p * l * slp::cos(theta), m_p * std::pow(l, 2)}};
30+
31+
// [0 −m_p lθ̇ sinθ]
32+
// C(q, q̇) = [0 0 ]
33+
slp::VariableMatrix C{{0, -m_p * l * thetadot * slp::sin(theta)}, {0, 0}};
34+
35+
// [ 0 ]
36+
// τ_g(q) = [-m_p gl sinθ]
37+
slp::VariableMatrix tau_g{{0}, {-m_p * g * l * slp::sin(theta)}};
38+
39+
// [1]
40+
// B = [0]
41+
constexpr Eigen::Matrix<double, 2, 1> B{{1}, {0}};
42+
43+
// q̈ = M⁻¹(q)(τ_g(q) − C(q, q̇)q̇ + Bu)
44+
slp::VariableMatrix qddot{4};
45+
qddot.segment(0, 2) = qdot;
46+
qddot.segment(2, 2) = slp::solve(M, tau_g - C * qdot + B * u);
47+
return qddot;
48+
}
49+
50+
inline void BM_CartPole(benchmark::State& state) {
51+
using namespace std::chrono_literals;
52+
53+
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
54+
for (auto _ : state) {
55+
constexpr std::chrono::duration<double> T = 5s;
56+
constexpr std::chrono::duration<double> dt = 50ms;
57+
constexpr int N = T / dt;
58+
59+
constexpr double u_max = 20.0; // N
60+
constexpr double d_max = 2.0; // m
61+
62+
constexpr Eigen::Vector<double, 4> x_initial{{0.0, 0.0, 0.0, 0.0}};
63+
constexpr Eigen::Vector<double, 4> x_final{
64+
{1.0, std::numbers::pi, 0.0, 0.0}};
65+
66+
slp::Problem problem;
67+
68+
// x = [q, q̇]ᵀ = [x, θ, ẋ, θ̇]ᵀ
69+
auto X = problem.decision_variable(4, N + 1);
70+
71+
// Initial guess
72+
for (int k = 0; k < N + 1; ++k) {
73+
X(0, k).set_value(
74+
std::lerp(x_initial[0], x_final[0], static_cast<double>(k) / N));
75+
X(1, k).set_value(
76+
std::lerp(x_initial[1], x_final[1], static_cast<double>(k) / N));
77+
}
78+
79+
// u = f_x
80+
auto U = problem.decision_variable(1, N);
81+
82+
// Initial conditions
83+
problem.subject_to(X.col(0) == x_initial);
84+
85+
// Final conditions
86+
problem.subject_to(X.col(N) == x_final);
87+
88+
// Cart position constraints
89+
problem.subject_to(X.row(0) >= 0.0);
90+
problem.subject_to(X.row(0) <= d_max);
91+
92+
// Input constraints
93+
problem.subject_to(U >= -u_max);
94+
problem.subject_to(U <= u_max);
95+
96+
// Dynamics constraints - RK4 integration
97+
for (int k = 0; k < N; ++k) {
98+
problem.subject_to(X.col(k + 1) ==
99+
frc::RK4<decltype(CartPoleDynamics),
100+
slp::VariableMatrix, slp::VariableMatrix>(
101+
CartPoleDynamics, X.col(k), U.col(k), dt));
102+
}
103+
104+
// Minimize sum squared inputs
105+
slp::Variable J = 0.0;
106+
for (int k = 0; k < N; ++k) {
107+
J += U.col(k).T() * U.col(k);
108+
}
109+
problem.minimize(J);
110+
111+
problem.solve({.diagnostics = true});
112+
}
113+
}

0 commit comments

Comments
 (0)