import math

import gymnasium as gym
import torch
import torch.nn as nn
from random import random
from torch import optim

# Hyperparameters
episodes = 7000
gamma = 0.99 # discount factor: how much rewards down the line are worth
lr = 0.0001 # learning rate
train = True

if train:
	env = gym.make('CartPole-v1')  # pole without visualization
else:
	env = gym.make('CartPole-v1', render_mode="human")  # pole with visualization

policy_net = nn.Sequential(
		nn.Linear(4, 128), # (pole state: position, velocity, angle, angular-velocity)
		nn.ReLU(),
		nn.Dropout(.5), # the higher, the more chaos!
		nn.Linear(128, 2),
		nn.Softmax(dim=0) 	# 2 action-probabilities  wird bei pytorch cross-entropy loss nicht benötigt
	)

save_file = "pole_pg.pth"
try: # load pre-trained model weights if available
	policy_net.load_state_dict(torch.load(save_file, weights_only=False))
	print("pre-trained model loaded")
except:
	print("No pre-trained model found, starting from scratch.")

optimizer = optim.Adam(policy_net.parameters(), lr=lr)

def policy_random(state):
	return env.action_space.sample()  # Random action

def policy_direct(state):
	probability=policy_net(state) # vector of q-values
	if random() < probability[0] :
		action=0
	else:
		action=1
	return action

def policy(state, epsilon):
	if random() < epsilon:
		action = policy_random(state)
	else:
		action = policy_direct(state)
	return action

gain_list = []
for step in range(episodes+1):
	if train:
		epsilon = max(0., 1 - 2 * step / episodes) # linear decay of epsilon from 1 to 0
		# epsilon =  1. - step / episodes # linear decay of epsilon from 1 to 0
	else:
		epsilon = 0 # no exploration, only exploitation, deterministic policy
	state = torch.tensor(env.reset()[0])
	done = False
	episode = []
	total_gain = 0
	while not done:
		action = policy(state, epsilon)
		next_state, reward, terminated, truncated, info = env.step(action)
		next_state = torch.tensor(next_state)
		episode.append((state, action, reward))
		state = next_state
		done = terminated or truncated
		total_gain += reward
	gain_list.append(total_gain)

# Policy Gradient REINFORCE update learning
	Gain = 0
	# loss = 0
	for state, action, reward in reversed(episode):
		Gain = reward + gamma * Gain
		# .detach() um zu verhindern, dass der Gradient durch die gesamte Episode zurück propagiert wird / doppelt genutzt wird
		prop = policy_net(state.detach())[action] # get action probability (again;)
		log_prob = torch.log(prop)
		log_prob = log_prob.detach()
		loss = log_prob * Gain # -∑ log π(aₜ|sₜ; θ) * G
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

	if step % 100 == 0:
		print(f"Episode {step}  gain: {total_gain}")
		torch.save(policy_net.state_dict(), save_file)

# plot
import matplotlib.pyplot as plt
plt.plot(gain_list)
plt.show()