import gymnasium as gym
import torch
import torch.nn as nn
from random import random
from torch import optim

# Minimal DQN implementation for CartPole-v1
# Hyperparameters
episodes = 50000
gamma = 0.99
lr = 0.0001 # learning rate
hidden_dim = 128

# render_mode = "human"
render_mode = None
env = gym.make('CartPole-v1', render_mode=render_mode)  # pole with visualization
# env = gym.make('CartPole-v1')  # pole without visualization
state_dimension = 4  # pole state: (position, velocity, angle, angular-velocity)
action_space_dims = 2 # left / right

# state => q-values for all actions
q_value_net=nn.Sequential(
		nn.Linear(state_dimension, hidden_dim),
		nn.ReLU(),
		nn.Linear(hidden_dim, action_space_dims), # 2 q-values for left / right
	)

# Target network for stable learning
target_net = nn.Sequential(
    nn.Linear(state_dimension, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, action_space_dims),
)
# load existing model if available
try:
	q_value_net.load_state_dict(torch.load('dqn_cartpole.pth',weights_only=True))
	print("Loaded existing model.")
except Exception as e:
	print(e)
	print("Starting with new model.")

# OR clone via library:
# import copy
# target_net = copy.deepcopy(q_value_net)
target_net.load_state_dict(q_value_net.state_dict())
target_net.eval()
target_update_freq = 10  # update every N episodes

optimizer = optim.Adam(q_value_net.parameters(), lr=lr)

def policy_random(state):
	return env.action_space.sample()  # Random action 0 or 1

def policy_q(state): # Greedy action based on Q-values, argmax
	action_values=q_value_net(state) # vector of q-values
	if action_values[0]>action_values[1]: # get argmax
		action=0 # left
	else:
		action=1 # right
	return action

def policy(state, epsilon):
	if random() < epsilon:
		action = policy_random(state)
	else:
		action = policy_q(state)
	return action

def training_step(replay):
	loss = 0
	for i in range(len(replay)-1):
		state, action, reward, next_state = replay[i]
		# with torch.no_grad():
		prediction = q_value_net(state)[action]
		with torch.no_grad(): # kein Update für das target net
			goal = reward + gamma * target_net(next_state).max()  # discounted future reward  <<<<<<<<<<<<<<<<<<<
		loss += (prediction - goal)**2 # MSE
		# Q <- r + γ * max(Q(s', a')) # Bellman equation
	optimizer.zero_grad() # reset gradients to zero
	loss.backward() # backpropagation, Gradienten werden berechnet
	optimizer.step() # gehe Gradienten in Richtung des Minimums

gewinn_liste = []
for step in range(episodes):
	epsilon = 1 - step / episodes  # epsilon decay neccessary for exploration
	state = torch.tensor(env.reset()[0])
	done = False
	replay=[] # 'batch' for one episode
	prev_state = None
	gain = 0
	while not done: # play
			action = policy(state,epsilon)
			state, reward, terminated, truncated, info = env.step(action)
			state = torch.tensor(state)
			replay.insert(0, (prev_state, action, reward, state)) # reverse
			prev_state = state
			done = terminated or truncated  # pole is down or time is up!
			gain += reward
	gewinn_liste.append(gain) # for plotting
	training_step(replay) # train on the replay batch

	if step % target_update_freq == 0:
		target_net.load_state_dict(q_value_net.state_dict()) # update target network with trained network weights

	if step % 100 == 0:
		print("Episode", step," gain:", gain)
		# save model
		torch.save(q_value_net.state_dict(), 'dqn_cartpole.pth')

# plot
import matplotlib.pyplot as plt
plt.plot(gewinn_liste)
plt.show()