import gymnasium as gym
import torch
import torch.nn as nn
from random import random
from torch import optim

# Minimal DQN implementation for CartPole-v1
# Hyperparameters
episodes = 5000
gamma = 0.99
lr = 0.0001 # learning rate
hidden_dim = 128


# render_mode = "human"
render_mode = None
env = gym.make('CartPole-v1', render_mode=render_mode)  # pole with visualization
state_dimension = 4  # pole state: (position, velocity, angle, angular-velocity)
action_space_dims = 2 # left / right

# state => q-values for all actions
q_value_net=nn.Sequential(
		nn.Linear(state_dimension, hidden_dim),
		nn.ReLU(),
		nn.Linear(hidden_dim, action_space_dims), # 2 q-values for left / right
	)


# load existing model if available
try:
	q_value_net.load_state_dict(torch.load('dqn_cartpole.pth',weights_only=True))
	print("Loaded existing model.")
except Exception as e:
	print(e)
	print("Starting with new model.")

optimizer = optim.Adam(q_value_net.parameters(), lr=lr)

def policy_random(state):
	return env.action_space.sample()  # Random action 0 or 1

def policy_q(state): # Greedy action based on Q-values, argmax
	action_values=q_value_net(state) # vector of q-values
	if action_values[0]>action_values[1]: # get argmax
		action=0 # left
	else:
		action=1 # right
	return action

def policy(state, epsilon):
	if random() < epsilon:
		action = policy_random(state)
	else:
		action = policy_q(state)
	return action

def training_step(replay):
	loss = 0
	for i in range(len(replay)-1): # shuffle() wäre ok  TD Schritte unabhängig
		state, action, reward, next_state = replay[i]
		# with torch.no_grad():
		prediction = q_value_net(state)[action]
		target = reward + gamma * q_value_net(next_state).max()  # discounted future reward  <<<<<<<<<<<<<<<<<<<
		loss += (prediction - target)**2 # MSE  target Q(s, a) 'Expected Gain'
		# Q <- r + γ * max(Q(s', a')) # Bellman equation
	optimizer.zero_grad() # reset gradients to zero
	loss.backward() # backpropagation, Gradienten werden berechnet
	optimizer.step() # gehe Gradienten in Richtung des Minimums

gewinn_liste = []
for step in range(episodes):
	epsilon = 1 - step / episodes  # epsilon decay neccessary for exploration
	state = torch.tensor(env.reset()[0])
	done = False
	replay=[] # 'batch' for one episode
	prev_state = None
	gain = 0
	while not done: # play
			action = policy(state,epsilon)
			state, reward, terminated, truncated, info = env.step(action)
			state = torch.tensor(state)
			replay.insert(0, (prev_state, action, reward, state)) # reverse
			prev_state = state
			done = terminated or truncated  # pole is down or time is up!
			gain += reward
	gewinn_liste.append(gain) # for plotting
	training_step(replay) # train on the replay batch

	if step % 100 == 0:
		print("Episode", step," gain:", gain)
		# save model
		torch.save(q_value_net.state_dict(), 'dqn_cartpole.pth')

# plot
import matplotlib.pyplot as plt
plt.plot(gewinn_liste)
plt.show()