# import gym
import gymnasium as gym
import torch
import torch.nn as nn # neural networks
import torch.optim as optim
import matplotlib.pyplot as plt

nr_training_steps = 4000
gamma = 0.99
lr = 1e-3

render_mode="human"
# render_mode=None
env = gym.make("LunarLander-v3", render_mode=render_mode, max_episode_steps=5000)

# Policy Network
observation_size=8 # lunar lander state: (position x, position y, velocity x, velocity y, angle, angular velocity, left engine, right engine)
hidden_size = 128
n_actions = 4  # do nothing, fire left engine, fire main engine, fire right engine.

policy = nn.Sequential(  # batch wird bei pytorch IMPLIZIT mir durchgeschliffen!
	nn.Linear(observation_size, hidden_size),  # keine batch size mitgeben
	nn.ReLU(),
	# nn.Linear(hidden_size, hidden_size),
	# nn.ReLU(),
	nn.Linear(hidden_size, n_actions),
	nn.Softmax(dim=-1)
)

optimizer = optim.Adam(policy.parameters(), lr=lr)

value_function = nn.Sequential(
	nn.Linear(observation_size, hidden_size),
	nn.ReLU(),
	nn.Linear(hidden_size, 1)
)
value_optimizer = optim.Adam(value_function.parameters(), lr=lr)

# load weights if available
try:
	policy.load_state_dict(torch.load("lunar_ac_policy.pth", weights_only=True))
	value_function.load_state_dict(torch.load("lunar_ac_value.pth", weights_only=True))
except:
	print("No weights available, starting from scratch.")

run_lengths = []
for training_step in range(nr_training_steps):
	state = env.reset()[0]
	log_probs = []
	values = []
	rewards = []

	done = False
	steps = 0
	while not done:
		steps += 1
		state_tensor = torch.tensor(state, dtype=torch.float32)
		probs = policy(state_tensor)
		value = value_function(state_tensor)

		action = torch.multinomial(probs, num_samples=1).item()
		log_prob = torch.log(probs[action])

		log_probs.append(log_prob)
		values.append(value)

		state, reward, terminated, truncated, info = env.step(action)
		rewards.append(reward)
		done = terminated or truncated

	run_lengths.append(steps)

	# Compute returns and advantages
	# 15 Zeilen Code für Monte Carlo Returns, immer noch überschaubar!
	returns = []
	G = 0  # zusammen-zählen der Belohnungen als Gewinn
	for r in reversed(rewards):
		G = r + gamma * G
		returns.insert(0, G)
	returns = torch.tensor(returns, dtype=torch.float32)  # Gesamtgewinn
	values = torch.stack(values).squeeze()
	advantages = returns - values.detach()  # Vorteil = Gewinn - geschätzter Wert des Zustands

	# Normalize advantages (optional)
	advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

	log_probs = torch.stack(log_probs)
	policy_loss = -(log_probs * advantages).sum()
	value_loss = nn.functional.mse_loss(values, returns)

	optimizer.zero_grad()
	policy_loss.backward()
	optimizer.step()

	value_optimizer.zero_grad()
	value_loss.backward()
	value_optimizer.step()

	if training_step % 50 == 0:
		print("step", training_step, "policy loss", policy_loss.item(), "value loss", value_loss.item(), "steps", steps)
		print("mean run length", sum(run_lengths) / len(run_lengths))
		plt.plot(run_lengths)
		plt.savefig("run_lengths.png")
		plt.show()
		plt.pause(0.001)

# save the policy and valuefunction
torch.save(policy.state_dict(), "lunar_ac_policy.pth")
torch.save(value_function.state_dict(), "lunar_ac_value.pth")
# close the environment
env.close()
