# import gym
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

nr_training_steps = 1500
gamma = 0.99
lr = 1e-3

# Gym Environment
# env = gym.make('CartPole-v1', render_mode="human")  # pole with visualization
env = gym.make('CartPole-v1')  # pole without visualization

# Policy Network
observation_size = 4  # pole state: (position, velocity, angle, angular-velocity)
hidden_size = 128
n_actions = 2  # left / right

actor = nn.Sequential(  # policy network
	nn.Linear(observation_size, hidden_size),  # keine batch size mitgeben
	nn.ReLU(),
	nn.Linear(hidden_size, n_actions),
	nn.Softmax(dim=-1)  # ähnlich wie CrossEntropyLoss (Klassifikation) aber wir machen es selbst, daher softmax
)
optimizer = optim.Adam(actor.parameters(), lr=lr)

# Value Network = Critic
value_net = nn.Sequential(
	nn.Linear(observation_size, hidden_size),
	nn.ReLU(),
	nn.Linear(hidden_size, 1)  # output : value of the state
)
value_optimizer = optim.Adam(value_net.parameters(), lr=lr)

# load weights if available
save_file = "cartpole_ac_policy.pth"
save_file_critic = "cartpole_ac_value.pth"
try:
	actor.load_state_dict(torch.load(save_file, weights_only=True))
	value_net.load_state_dict(torch.load(save_file_critic, weights_only=True))
except:
	print("No weights available, starting from scratch.")

run_lengths = []
for training_step in range(nr_training_steps):
	state = env.reset()[0]
	log_probs = []
	values = []
	rewards = []
	run_length = 0
	done = False
	steps = 0
	while not done:
		steps += 1
		state_tensor = torch.tensor(state, dtype=torch.float32)
		probs = actor(state_tensor)
		value = value_net(state_tensor)

		action = torch.multinomial(probs, num_samples=1).item()
		log_prob = torch.log(probs[action])

		log_probs.append(log_prob)
		values.append(value)

		state, reward, terminated, truncated, info = env.step(action)
		rewards.append(reward)
		done = terminated or truncated

	run_lengths.append(steps)

	# Compute returns and advantages
	# 15 Zeilen Code für Monte Carlo Returns, immer noch überschaubar!
	gains = []
	G = 0  # zusammenzählen der Belohnungen als Gewinn
	for r in reversed(rewards):
		G = r + gamma * G
		gains.insert(0, G)
	gains = torch.tensor(gains, dtype=torch.float32)  # Gesamtgewinn
	values = torch.stack(values).squeeze()
	with torch.no_grad():
		deltas = gains - values  # .detach()  # Vorteile = Gewinne - geschätzter Werte des Zustands

	# Normalize advantages (optional)
	deltas = (deltas - deltas.mean()) / (deltas.std() + 1e-8)

	log_probs = torch.stack(log_probs)
	policy_loss = -(log_probs * deltas).sum()
	value_loss = nn.functional.mse_loss(values, gains)

	optimizer.zero_grad()
	policy_loss.backward()
	optimizer.step()

	value_optimizer.zero_grad()
	value_loss.backward()
	value_optimizer.step()

	if training_step % 50 == 0:
		print("step", training_step, "policy loss", policy_loss.item(), "value loss", value_loss.item(), "steps", steps)
		print("mean run length", sum(run_lengths) / len(run_lengths))
		plt.plot(run_lengths)
		plt.savefig("run_lengths.png")
		# plt.show()
		# plt.pause(0.001)

# save the policy and valuefunction
torch.save(actor.state_dict(), save_file)
torch.save(value_net.state_dict(), save_file_critic)
# close the environment
env.close()
