import gymnasium as gym
import torch
import torch.nn as nn
from random import random
from torch import optim

# Minimal DQN implementation for CartPole-v1
# Hyperparameters
episodes = 1500
gamma = 0.99
lr = 0.0001 # learning rate
hidden_dim = 128

# DQN pure ohne batch/replay

# render_mode = "human"
render_mode = None
env = gym.make('CartPole-v1', render_mode=render_mode)  # pole with visualization
state_dimension = 4  # pole state: (position, velocity, angle, angular-velocity)
action_space_dims = 2 # left / right

# state => q-values for all actions
q_value_net=nn.Sequential(
		nn.Linear(state_dimension, hidden_dim),
		nn.ReLU(),
		nn.Linear(hidden_dim, action_space_dims), # 2 q-values for left / right
	)

# Learned Dynamics Model!
model = nn.Sequential(
		nn.Linear(state_dimension + action_space_dims, hidden_dim),
		nn.ReLU(),
		nn.Linear(hidden_dim, state_dimension) # ohne reward, ansonsten +1, DONE signal +1
)
model_optimizer = optim.Adam(model.parameters(), lr=lr)


# load existing model if available
try:
	q_value_net.load_state_dict(torch.load('pole_dyna-q.pt',weights_only=True))
	print("Loaded existing model.")
except Exception as e:
	print(e)
	print("Starting with new model.")

optimizer = optim.Adam(q_value_net.parameters(), lr=lr)

def policy_random(state):
	return env.action_space.sample()  # Random action 0 or 1

def policy_q(state): # Greedy action based on Q-values, argmax
	action_values=q_value_net(state) # vector of q-values
	if action_values[0]>action_values[1]: # get argmax
		action=0 # left
	else:
		action=1 # right
	return action

def policy(state, epsilon):
	if random() < epsilon:
		action = policy_random(state)
	else:
		action = policy_q(state)
	return action

def training_step(state, action, reward, next_state):
	if state is None: return # no previous state, nothing to learn
# with torch.no_grad():
	prediction = q_value_net(state)[action]
	target = reward + gamma * q_value_net(next_state).max()  # discounted future reward  <<<<<<<<<<<<<<<<<<<
	loss = (prediction - target)**2 # MSE  target Q(s, a) 'Expected Gain'
		# Q <- r + γ * max(Q(s', a')) # Bellman equation
	optimizer.zero_grad() # reset gradients to zero
	loss.backward() # backpropagation, Gradienten werden berechnet
	optimizer.step() # gehe Gradienten in Richtung des Minimums

def train_model(action, state, next_state):
	# Train dynamics model with observed transition
	state_tensor = next_state
	action_tensor = torch.nn.functional.one_hot(torch.tensor(action), num_classes=action_space_dims).float()
	state_and_action = torch.cat([state_tensor, action_tensor]) # concatenate state + action
	predicted_next_state = model(state_and_action)  # predict next state
	# loss = ( predicted - target ) ^2
	model_loss = nn.functional.mse_loss(predicted_next_state, state) # MSE
	model_optimizer.zero_grad()
	model_loss.backward()
	model_optimizer.step()

visited_states = []
def simulate():
	random_state = int(random()*len(visited_states))
	s,a = visited_states[random_state] # random previously visited state, action
	state_tensor = torch.tensor(s, dtype=torch.float32)
	action_tensor = torch.nn.functional.one_hot(torch.tensor(action), num_classes=action_space_dims).float()
	state_and_action = torch.cat([state_tensor, action_tensor])  # concatenate state + action
	predicted_next_state = model(state_and_action)  # predict next state
	reward = 1.0  # constant reward for CartPole until done übertrieben dies mitzulernen
	training_step(prev_state, action, reward, state)


gewinn_liste = []
for step in range(episodes+1):
	epsilon = 1 - step / episodes  # epsilon decay neccessary for exploration
	state = torch.tensor(env.reset()[0])
	done = False
	replay=[] # 'batch' for one episode
	prev_state = None
	gain = 0
	while not done: # play
			action = policy(state,epsilon)
			state, reward, terminated, truncated, info = env.step(action)
			visited_states.append((state,action))
			state = torch.tensor(state)
			if prev_state is not None:
				train_model(action, prev_state, state )  # train dynamics model with real experience
			training_step(prev_state, action, reward, state)
			prev_state = state
			done = terminated or truncated  # pole is down or time is up!
			gain += reward
			if step > 5000:
				simulate()

	gewinn_liste.append(gain) # for plotting

	if step % 100 == 0:
		print("Episode", step," gain:", gain)
		torch.save(q_value_net.state_dict(), 'pole_dyna-q.pt')

# plot
import matplotlib.pyplot as plt
plt.plot(gewinn_liste)
plt.show()