import gymnasium as gym
import torch
import torch.nn as nn
from random import random
from torch import optim

# Hyperparameters
episodes = 5000
gamma = 0.9 # discount factor: how much rewards down the line are worth
lr = 0.001 # learning rate
train = True

if train:
	env = gym.make('Taxi-v3')  # pole without visualization
else:
	env = gym.make('Taxi-v3', render_mode="human")  # pole with visualization

observation_dimensions = env.observation_space.n
action_dimensions = env.action_space.n

q_values=nn.Sequential(
		nn.Linear(observation_dimensions, 128),
		nn.ReLU(),
		nn.Linear(128, 6),
	)

try: # load pre-trained model weights if available
	q_values.load_state_dict(torch.load("taxi_mc.pth",weights_only=False))
	print("pre-trained model loaded")
except:
	print("No pre-trained model found, starting from scratch.")

optimizer = optim.Adam(q_values.parameters(), lr=lr)

def policy_random(state):
	return env.action_space.sample()  # Random action

def policy_q(state):
	action_values=q_values(state) # vector of q-values
	# argmax returns the index of the maximum value in the tensor


def policy(state, epsilon):
	if random() < epsilon:
		action = policy_random(state)
	else:
		action = policy_q(state)
	return action

gain_list = []
for step in range(episodes):
	if train:
		epsilon = 1 - step / episodes # linear decay of epsilon from 1 to 0
	else:
		epsilon = 0 # no exploration, only exploitation, deterministic policy
	state = torch.tensor(env.reset()[0])
	done = False
	episode = []
	total_gain = 0
	while not done:
		# to one hot vector
		state = nn.functional.one_hot(state, num_classes=observation_dimensions).float().unsqueeze(0)
		# add batch dummy dimension
		action = policy(state, epsilon) or 0
		next_state, reward, terminated, truncated, info = env.step(action)
		next_state = torch.tensor(next_state)
		episode.append((state, action, reward))
		state = next_state
		done = terminated or truncated
		total_gain += reward
	gain_list.append(total_gain)

# Monte Carlo update learning
	Gain = 0
	for state, action, reward in reversed(episode):
		Gain = reward + gamma * Gain
		# add batch dummy dimension
		predictions = q_values(state)
		prediction = predictions[0][action]
		# Train the model
		loss = (prediction - Gain)**2
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

	if step % 100 == 0:
		print(f"Episode {step}  gain: {total_gain}")
		torch.save(q_values.state_dict(), "taxi_mc.pth")

# plot
import matplotlib.pyplot as plt
plt.plot(gain_list)
plt.show()