import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from random import random

from gymnasium.wrappers import TimeLimit
from torch import optim

# Hyperparameters
episodes = 5000
gamma = 0.99 # discount factor: how much rewards down the line are worth
lr = 0.0001 # learning rate
train = True
# train = False

task = 'MountainCarContinuous-v0'
save_file = task + '.pt'
if train:
	render_mode = None  # without visualization
else:
	render_mode = "human"  # with visualization
env = gym.make(task,render_mode=render_mode)

# There is no specified termination. Adding a maximum number of steps might be a good idea.
observation_dimensions = env.observation_space.shape[0]
action_dimensions = env.action_space.shape[0]

q_values=nn.Sequential(
		nn.Linear(observation_dimensions, 128),
		nn.ReLU(),
		nn.Linear(128, action_dimensions),
	)

optimizer = optim.Adam(q_values.parameters(), lr=lr)

try:
		checkpoint = torch.load(save_file, weights_only=True)
		q_values.load_state_dict(checkpoint["model_state"])
		optimizer.load_state_dict(checkpoint["optimizer_state"])
		print("pre-trained model loaded")
except:
		print("No pre-trained model found, starting from scratch.")

def policy_random(state):
		return np.random.uniform(env.action_space.low, env.action_space.high)

def policy_q(state):
		with torch.no_grad():
				action_values = q_values(state)
				return action_values.detach().numpy()

def policy(state, epsilon):
		if random() < epsilon:
				return policy_random(state)
		else:
				return policy_q(state)

def custom_reward(reward, next_state):
	position, velocity = next_state
	return reward + abs(velocity)*100

gain_list = []
for step in range(episodes):
	if train:
		epsilon = 1 - step / episodes # linear decay of epsilon from 1 to 0
	else:
		epsilon = 0 # no exploration, only exploitation, deterministic policy
	state = torch.tensor(env.reset()[0])
	done = False
	episode_replay = []
	total_gain = 0
	while not done: # Spiel Schleife
		action = policy(state, epsilon)
		next_state, reward, terminated, truncated, info = env.step(action)
		reward = custom_reward(reward, next_state)
		next_state = torch.tensor(next_state)
		episode_replay.append((state, action, reward))
		state = next_state
		done = terminated or truncated
		total_gain += reward
	gain_list.append(total_gain)

# Monte Carlo update learning
	Gain = 0
	for state, action, reward in reversed(episode_replay):
		Gain = reward + gamma * Gain
		prediction = q_values(state)
		# Train the model
		if train:
			loss = (prediction.squeeze() - Gain)**2
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

	if step % 100 == 0:
		print(f"Episode {step}  gain: {total_gain}")
		if train:
			torch.save({
				"model_state": q_values.state_dict(),
				"optimizer_state": optimizer.state_dict()
			}, save_file)

# plot
import matplotlib.pyplot as plt
plt.plot(gain_list)
plt.show()