import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt

nr_training_steps=7000
gamma=0.9 # discount factor
lr=1e-3

# Gym MountainCar Environment
gameType = "discrete" # fahre links/rechts
# gameType = "continuous"  # beschleunige links/rechts [-1.0, 1.0] # MountainCarContinuous-v0
# render_mode = "human" #
render_mode = None

save_file = "mountainCar_policy_"+ gameType + ".pth"
if gameType == "continuous":
		env = gym.make("MountainCarContinuous-v0", render_mode=render_mode, max_episode_steps=400)
else:
		env = gym.make("MountainCar-v0", render_mode=render_mode) #  , max_episode_steps=200

#######################################################
### Step 1: Build the neural network for the policy ###
#######################################################
# Policy Network
observation_size=2 # car state: (position, velocity)
hidden_size = 128

if gameType == "continuous":
		n_actions = 1 # directional force applied on the car. Continuous version
else:
		n_actions = 3 # left, right, no action. Discrete version

policy = nn.Sequential( # batch wird bei pytorch IMPLIZIT mir durchgeschliffen!
						nn.Linear(observation_size, hidden_size), # keine batch size mitgeben
						nn.ReLU(),
						nn.Linear(hidden_size, n_actions), # Aktionswahrscheinlichkeiten statt Quality-Werte
	# Softmax on demand:
)
# Softmax only for discrete variant
if gameType == "discrete":
		policy.add_module("softmax", nn.Softmax(dim=-1)) # -1 = ignoriere batch
if gameType == "continuous":
		policy.add_module("tanh", nn.Tanh()) # output is in range [-1.0, 1.0] Aktionen (fast) direkt über Mittelwert

# load weights if available
try:
		policy.load_state_dict(torch.load(save_file, weights_only=True))
except Exception as e:
	print(e)
	print("No weights available, starting from scratch.")



optimizer = optim.Adam(policy.parameters(), lr=lr)

#############################
### Step 2: Training loop ###
#############################

def own_reward(state, reward):
		pos = state[0] + 0.5
		speed = state[1]
		# reward += float(abs(pos * 100.))**2 # reward for distance to target
		# if pos > -.8:
		reward += float(abs(speed * 10000.)) # reward for speed ;)
		return reward

max_reach=[] # 0 = win
min_reach=[]
victories=[]
victorie=0
losses = [] # for plotting

for training_step in range(nr_training_steps):
		state = env.reset()[0]
		log_probs = [] # action probabilities (log'ed)
		rewards = [] # rewards for each action taken
		done = False
		best_reach = -1.0
		left_reach = 0.0
		while not done:
				state_tensor = torch.tensor(state, dtype=torch.float32)
				# Forward propagation through the NN
				probs = policy(state_tensor) # Aktionswahrscheinlichkeiten (diskret) oder Aktionsmittelwert (continuous)

				if gameType == "continuous":  # EIN Wert
						mu = probs  # Mittelwert um unsere Aktion zu wählen -1 bis 1 links … rechts
						std = torch.tensor(0.1)  # fixed std => big = exploration!
						dist = torch.distributions.Normal(mu, std)
						action = dist.sample() # eigentliche Aktion, z.B. -.9
						log_prob = dist.log_prob(action) # Log-Wahrscheinlichkeit der Aktion
						# max(low, min(x, high))
						action = (action.clamp(-1.0, 1.0).detach().item(),) # clip cut to [-1.0, 1.0] range
				elif gameType == "discrete":
						# categorical sampling: hole konkret eine Aktion aus den Wahrscheinlichkeiten
						action = torch.multinomial(probs, num_samples=1).item()
						log_prob = torch.log(probs[action]) # Log-Wahrscheinlichkeit der Aktion als 0-D Tensor
				else: # should not happen
						print("Error: unknown game type", gameType)
						exit(-1)

				log_probs.append(log_prob) # Baue Vektor von Log-Wahrscheinlichkeiten auf

				state, reward, terminated, truncated, info = env.step(action)
				reward = own_reward(state, reward)
				rewards.append(reward)
				done = terminated or truncated  # car reached target or too many steps done
				if done and state[0] > 0.5:
						victorie += 1
				if state[0] > best_reach: # nur zum gucken wie weit das car kam
					best_reach = state[0]
				if state[0] < left_reach:
					left_reach = state[0]

		max_reach.append(best_reach)
		min_reach.append(left_reach)
		victories.append(victorie)

		logs = torch.stack(log_probs).squeeze()

		# Compute returns
		returns = []
		G = 0  # Gain over a whole episode
		for r in reversed(rewards): # monte carlo episodes
				G = r + gamma * G  # discounted early rewards
				returns.insert(0, G)
		returns = torch.tensor(returns)

		# Normalize returns for numerical stability -1 … 1
		returns = (returns - returns.mean()) / (returns.std() + 1e-8)

		# Policy gradient update
		optimizer.zero_grad()
		loss = -(logs * returns).sum() # PG REINFORCE loss: -∑ log π(aₜ|sₜ; θ) * G
		loss.backward() # calculate gradients for tensors belonging TO network, store IN network
		optimizer.step() # take gradients from network and update weights
		losses.append(loss.detach().item())

		if training_step % 20 == 0:
			print("step", training_step, "loss", loss.item())
			print("max_reaches", sum(max_reach)/len(max_reach))
			print("min_reaches", sum(min_reach)/len(min_reach))
			print("victories", victorie)
			# plt.ion()  # Turn on interactive mode (not on mac )
			# plt.plot(max_reach)
			# plt.show()
			torch.save(policy.state_dict(), save_file)

env.close()

# plot
# calculate average max_reach every 10 steps
max_reach = [sum(max_reach[i:i+20])/20 for i in range(0, len(max_reach), 20)]
min_reach = [sum(min_reach[i:i+20])/20 for i in range(0, len(min_reach), 20)]
losses = [sum(losses[i:i+20])/20 for i in range(0, len(losses), 20)]

fig, (ax1, ax2,ax3) = plt.subplots(3)
ax1.set_title("Reach")
ax1.plot(max_reach, label="Max Reach", color="tab:blue")
ax1.plot(min_reach, label="Min Reach", color="tab:orange")
ax2.set_title("Loss")
ax2.plot(losses, label="Loss", color="tab:red")
ax3.set_title("Victory")
ax3.plot(victories, label="Victories", color="tab:green")
plt.savefig("progress.png") # VOR dem show()!
plt.show()
