import math

import gymnasium as gym
import torch
import torch.nn as nn
from random import random
from torch import optim

# Hyperparameters
episodes = 50000
gamma = 0.95 # discount factor: how much rewards down the line are worth
lr = 0.0001 # learning rate
# train = True
train = False

if train:
	env = gym.make('FrozenLake-v1')  # pole without visualization
else:
	env = gym.make('FrozenLake-v1', render_mode="human")  # pole with visualization


policy_net = nn.Sequential(
		nn.Linear(env.observation_space.n, 128), # 16 positionen der Figur
		nn.ReLU(),
		# nn.Dropout(.5), # the higher, the more chaos!
		nn.Linear(128, env.action_space.n),
		nn.Softmax(dim=-1) 	# 2 action-probabilities  wird bei pytorch cross-entropy loss nicht benötigt
	)

save_file = "frozen4.pth"
# try: # load pre-trained model weights if available
# 	policy_net.load_state_dict(torch.load(save_file, weights_only=False))
# 	print("pre-trained model loaded")
# except:
# 	print("No pre-trained model found, starting from scratch.")

optimizer = optim.Adam(policy_net.parameters(), lr=lr)

def policy_random(state):
	return env.action_space.sample()  # Random action

def policy_direct(state):
	probs=policy_net(state) # vector of q-values
	action = torch.multinomial(probs, num_samples=1).item()
	# action = action.item()  # convert to integer
	return action

def policy(state, epsilon):
	if random() < epsilon:
		action = policy_random(state)
	else:
		action = policy_direct(state)
	return action

def own_reward(reward, next_state_nr):
	# own reward function for FrozenLake
	if next_state_nr == 15: # goal state
		return 10
	elif next_state_nr == 0: # start state
		return 0
	elif next_state_nr in [5, 7, 11, 12]: # hole states
		return -1
	else:
		return 0.01 + reward # small positive reward for all other states

gain_list = []
for step in range(episodes+1):
	if train:
		epsilon = max(0., 1 - 2 * step / episodes) # linear decay of epsilon from 1 to 0
		# epsilon =  1. - step / episodes # linear decay of epsilon from 1 to 0
	else:
		epsilon = 0 # no exploration, only exploitation, deterministic policy
	state_number = env.reset()[0] # 1…16 # 0 is the start state, 15 is the goal state
	# convert state_number to one-hot encoded tensor
	state = torch.nn.functional.one_hot(torch.tensor(state_number), num_classes=16).float()
	done = False
	trajectory = []
	total_gain = 0
	while not done:
		action = policy(state, epsilon)
		next_state_nr, reward, terminated, truncated, info = env.step(action)
		reward = own_reward(reward, next_state_nr)
		# state bei frozen lake ist nicht als (x,y) kodiert sondern als eine Zahl von 0 bis 15
		# netze arbeiten besser mit "buckets" also one-hot encoding
		next_state = torch.nn.functional.one_hot(torch.tensor(next_state_nr), num_classes=16).float()
		trajectory.append((state, action, reward))
		state = next_state
		done = terminated or truncated
		total_gain += reward
	gain_list.append(total_gain)

# Policy Gradient REINFORCE update learning
	Gain = 0
	loss = 0
	for state, action, reward in reversed(trajectory):
		Gain = reward + gamma * Gain
		# .detach() um zu verhindern, dass der Gradient durch die gesamte Episode zurück propagiert wird / doppelt genutzt wird
		props = policy_net(state.detach())
		prop = props[action] # get action probability (again;)
		log_prob = torch.log(prop)
		loss -= log_prob * Gain # -∑ log π(aₜ|sₜ; θ) * G
		# clip
	loss = torch.clamp(loss, -10.0, 10.0) # clip loss to avoid exploding gradients
	optimizer.zero_grad()
	loss.backward()
	optimizer.step()

	if step % 100 == 0:
		print(f"Episode {step}  gain: {total_gain} loss : {loss.item()}")
		torch.save(policy_net.state_dict(), save_file)

# plot
import matplotlib.pyplot as plt
plt.plot(gain_list)
plt.show()