import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from random import random

from gymnasium.wrappers import TimeLimit
from torch import optim

# Hyperparameters
episodes = 5000
gamma = 0.99 # discount factor: how much rewards down the line are worth
lr = 0.0001 # learning rate
train = True
# train = False

task = 'Pendulum-v1'
save_file = "Softmax" + task + '.pt'
if train:
	render_mode = None  # without visualization
else:
	render_mode = "human"  # with visualization
env = TimeLimit(gym.make(task,render_mode=render_mode),max_episode_steps=200)  # without visualization

# There is no specified termination. Adding a maximum number of steps might be a good idea.
observation_dimensions = env.observation_space.shape[0]
# action_dimensions = env.action_space.n  or discretized:
discrete_actions = np.linspace(env.action_space.low[0], env.action_space.high[0], num=11, dtype=np.float32)
action_dimensions = len(discrete_actions)

q_values=nn.Sequential(
		nn.Linear(observation_dimensions, 128),
		nn.ReLU(),
		nn.Linear(128, action_dimensions),
	)

optimizer = optim.Adam(q_values.parameters(), lr=lr)

try:
		checkpoint = torch.load(save_file)
		q_values.load_state_dict(checkpoint["model_state"])
		optimizer.load_state_dict(checkpoint["optimizer_state"])
		print("pre-trained model loaded")
except:
		print("No pre-trained model found, starting from scratch.")

def policy_random(state):
		return np.random.choice(range(action_dimensions))

def policy_q(state):
		with torch.no_grad():
				action_values = q_values(state)
				return torch.argmax(action_values).item()

def policy(state, epsilon):
		if random() < epsilon:
				return policy_random(state)
		else:
				return policy_q(state)

gain_list = []
for step in range(episodes):
	if train:
		epsilon = 1 - step / episodes # linear decay of epsilon from 1 to 0
	else:
		epsilon = 0 # no exploration, only exploitation, deterministic policy
	state = torch.tensor(env.reset()[0], dtype=torch.float32)
	done = False
	episode = []
	total_gain = 0
	while not done:
		action_idx = policy(state, epsilon)
		probs = torch.softmax(q_values(state), dim=-1)
		action = [torch.sum(probs * torch.tensor(discrete_actions, dtype=torch.float32)).detach()]
		next_state, reward, terminated, truncated, info = env.step(action)
		# reward = custom_reward(reward, next_state)
		next_state = torch.tensor(next_state, dtype=torch.float32)
		episode.append((state, action_idx, reward))
		state = next_state
		done = terminated or truncated
		total_gain += reward
	gain_list.append(total_gain)

# Monte Carlo update learning
	Gain = 0
	for state, action, reward in reversed(episode):
		Gain = reward + gamma * Gain
		prediction = q_values(state)[action]
		# Train the model
		if train:
			loss = (prediction - Gain)**2
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

	if step % 100 == 0:
		print(f"Episode {step}  gain: {total_gain}")
		if train:
			torch.save({
				"model_state": q_values.state_dict(),
				"optimizer_state": optimizer.state_dict()
			}, save_file)

# plot
import matplotlib.pyplot as plt
plt.plot(gain_list)
plt.show()