import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from random import random
from torch import optim

# Hyperparameters
episodes = 5000
gamma = 0.99 # discount factor: how much rewards down the line are worth
lr = 0.0001 # learning rate
train = True

if train:
	env = gym.make('Acrobot-v1')
else:
	env = gym.make('Acrobot-v1', render_mode="human")  #  with visualization
save_file = "acrobot_mc.pth"

observation_dimensions = env.observation_space.shape[0]
action_dimensions = env.action_space.n

q_values=nn.Sequential(
		nn.Linear(observation_dimensions, 128),
		nn.ReLU(),
		nn.Linear(128, action_dimensions),
	)

optimizer = optim.Adam(q_values.parameters(), lr=lr)

try:
	checkpoint = torch.load(save_file)
	q_values.load_state_dict(checkpoint["model_state"])
	optimizer.load_state_dict(checkpoint["optimizer_state"])
	print("💡 pre-trained model loaded")
except:
	print("⚠️ No pre-trained model found, starting from scratch.")

def policy_random(state):
	return env.action_space.sample()  # Random action

def policy_q(state):
	action_values=q_values(state) # vector of q-values
	action = torch.argmax(action_values).item()  # Get the index of the max value
	return action

def policy(state, epsilon):
	if random() < epsilon:
		action = policy_random(state)
	else:
		action = policy_q(state)
	return action

def custom_reward(reward, state):
	# Custom reward function for Acrobot
	# The state is a tuple of (cos(theta1), sin(theta1), cos(theta 2), sin(theta2))
	cos_theta1, sin_theta1, cos_theta2, sin_theta2, v1, v2 = state
	theta1 = np.arctan2(sin_theta1, cos_theta1)
	theta2 = np.arctan2(sin_theta2, cos_theta2)
	height = - (np.cos(theta1) + np.cos(theta1 + theta2))
	# Custom reward function for Acrobot
	return reward + height*10  # Add height to the reward

gain_list = []
for step in range(episodes):
	if train:
		epsilon = 1 - 2*step / episodes # linear decay of epsilon from 1 to 0
	else:
		epsilon = 0 # no exploration, only exploitation, deterministic policy
	state = torch.tensor(env.reset()[0])
	done = False
	episode = []
	total_gain = 0
	while not done:
		action = policy(state, epsilon)
		next_state, reward, terminated, truncated, info = env.step(action)
		reward = custom_reward(reward, next_state)
		next_state = torch.tensor(next_state)
		episode.append((state, action, reward))
		state = next_state
		done = terminated or truncated
		total_gain += reward
	gain_list.append(total_gain)

# Monte Carlo update learning
	Gain = 0
	for state, action, reward in reversed(episode):
		Gain = reward + gamma * Gain
		prediction = q_values(state)[action]
		# Train the model
		loss = (prediction - Gain)**2
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

	if step % 100 == 0:
		print(f"Episode {step}  gain: {total_gain}")
		torch.save({
			"model_state": q_values.state_dict(),
			"optimizer_state": optimizer.state_dict()
		}, save_file)

# plot
import matplotlib.pyplot as plt
plt.plot(gain_list)
plt.show()