import numpy as np
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import matplotlib.pyplot as plt
import base64
from IPython.display import HTML, display
import imageio
from tqdm import tqdm
from functools import partial
class HarmonicOscillator:
def __init__(self, damping_coefficient, natural_frequency):
self.d = damping_coefficient
self.w0 = natural_frequency
assert self.d < self.w0, "Damping coefficient must be less than natural frequency for underdamped oscillator"
self.w = np.sqrt(self.w0**2 - self.d**2)
self.phi = np.arctan(-self.d / self.w)
self.A = 1 / (2 * np.cos(self.phi))
def analytical_solution(self, t):
cos_part = np.cos(self.phi + self.w * t)
exp_part = np.exp(-self.d * t)
u = exp_part * 2 * self.A * cos_part
return u.reshape(-1, 1)
def generate_data(self):
t = np.linspace(0, 1, 500).reshape(-1, 1)
u = self.analytical_solution(t)
t_data = t[0:200:20]
u_data = u[0:200:20]
return t, u, t_data, u_data
class NeuralNetwork(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.layer1 = nnx.Linear(1, 32, rngs=rngs)
self.layer2 = nnx.Linear(32, 32, rngs=rngs)
self.layer3 = nnx.Linear(32, 32, rngs=rngs)
self.layer4 = nnx.Linear(32, 1, rngs=rngs)
def __call__(self, x):
x = jnp.tanh(self.layer1(x))
x = jnp.tanh(self.layer2(x))
x = jnp.tanh(self.layer3(x))
return self.layer4(x)
@nnx.jit
def compute_physics_loss(model, t_physics, mu, k_param):
def u_t(t_single):
return model(t_single.reshape(1, -1))[0, 0]
du_dt = jax.grad(u_t)
d2u_dt2 = jax.grad(du_dt)
# Compute derivatives at physics points
u = jax.vmap(lambda t: model(t.reshape(1, -1))[0, 0])(t_physics)
du_dt_vals = jax.vmap(du_dt)(t_physics[:, 0])
d2u_dt2_vals = jax.vmap(d2u_dt2)(t_physics[:, 0])
# Compute residual of differential equation
residual = d2u_dt2_vals + mu * du_dt_vals + k_param * u
return jnp.mean(residual ** 2)
@nnx.jit
def train_step_pinn(model, optimizer, t_data, u_data, t_physics, mu, k_param):
def loss_fn(model):
u_pred = model(t_data)
data_loss = jnp.mean((u_pred - u_data) ** 2)
physics_loss = 1e-4 * compute_physics_loss(model, t_physics, mu, k_param)
return data_loss + physics_loss
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
@nnx.jit
def train_step_standard(model, optimizer, t_data, u_data):
def loss_fn(model):
u_pred = model(t_data)
return jnp.mean((u_pred - u_data) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
def train_standard_nn(ho, visualizer):
t, u_exact, t_data, u_data = ho.generate_data()
# Convert to JAX arrays
t = jnp.array(t)
u_exact = jnp.array(u_exact)
t_data = jnp.array(t_data)
u_data = jnp.array(u_data)
# Create model and optimizer
model = NeuralNetwork(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
frames = []
for step in tqdm(range(1000)):
loss = train_step_standard(model, optimizer, t_data, u_data)
if (step + 1) % 100 == 0:
u_pred = model(t)
image = visualizer.plot_result(t, u_exact, t_data, u_data, u_pred, iteration=step)
frames.append(image)
visualizer.create_animation(frames, filename='oscillator_nn.gif')
visualizer.display_animation('oscillator_nn.gif')
def train_pinn(ho, visualizer):
t, u_exact, t_data, u_data = ho.generate_data()
t_physics = np.linspace(0, 1, 30).reshape(-1, 1)
# Convert to JAX arrays
t = jnp.array(t)
u_exact = jnp.array(u_exact)
t_data = jnp.array(t_data)
u_data = jnp.array(u_data)
t_physics = jnp.array(t_physics)
mu = jnp.array(2 * ho.d)
k_param = jnp.array(ho.w0 ** 2)
# Create model and optimizer
model = NeuralNetwork(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
frames = []
for step in tqdm(range(20000)):
loss = train_step_pinn(model, optimizer, t_data, u_data, t_physics, mu, k_param)
if (step + 1) % 1500 == 0:
u_pred = model(t)
image = visualizer.plot_result(
t, u_exact, t_data, u_data, u_pred,
t_physics=t_physics, iteration=step
)
frames.append(image)
visualizer.create_animation(frames, filename='oscillator_pinn.gif')
visualizer.display_animation('oscillator_pinn.gif')
if __name__ == '__main__':
ho = HarmonicOscillator(damping_coefficient=2, natural_frequency=20)
visualizer = Visualizer()
train_standard_nn(ho, visualizer)
train_pinn(ho, visualizer)