import numpy as np
import random

import torch

# Parameters
GRID_SIZE = 4
LEARNING_RATE = 0.1
EPSILON = "auto"  #0.1
EPISODES = 500
TERMINAL_STATE = (0, 0)
Ɣ = .9
ACTIONS = ['up', 'down', 'left', 'right']
ACTION_EFFECTS = {
		'up': (-1, 0), # (y, x) -> (y-1, x) position change
		'down': (1, 0),
		'left': (0, -1),
		'right': (0, 1)
}

ACTION_SYMBOLS = {
		'up': '⬆',
		'down': '⬇',
		'left': '⬅',
		'right':  '➡'
}

# Reward matrix
rewards = np.full((GRID_SIZE, GRID_SIZE), -1.0) # -1 pro Schritt
rewards[:-1, 1] = -10  # 🟥 Red wall cells
rewards[3, 3] = -10

action_shape = len(ACTION_SYMBOLS)

# Q-table initialization 
# Q = np.zeros((GRID_SIZE, GRID_SIZE, len(ACTIONS))) # (y, x, action) -> Q-value
Q = torch.nn.Sequential(
	torch.nn.Flatten(), # (y, x) -> (y*x)
	torch.nn.Linear(GRID_SIZE*GRID_SIZE*4, 1 , bias=True)) # x*y*a -> q-value
# Q = torch.nn.Linear(GRID_SIZE * GRID_SIZE, action_shape, bias=True) # x*y -> action q-values
)

def is_terminal(state):
	return state == TERMINAL_STATE

def step(state, action):
	dy, dx = ACTION_EFFECTS[action]
	y, x = state
	ny, nx = y + dy, x + dx # next state coordinates
	if 0 <= ny < GRID_SIZE and 0 <= nx < GRID_SIZE:
		next_state = (ny, nx)
		reward = rewards[next_state]
	else:
		next_state = state # illegal move, stay in place
		reward =  -10 # penalty for illegal move
	return next_state, reward

def epsilon_greedy(state):
	if random.random() < EPSILON:
		return random.choice(range(len(ACTIONS)))
	else:
		y, x = state
		return np.argmax(Q(y, x))

def update_q_values(state, action_idx, reward, next_state):
	y, x = state
	ny, nx = next_state
	# Q-learning 'off policy'
	best_next_quality = np.max(Q(ny, nx))
	target = reward + Ɣ * best_next_quality # new estimate! off policy (kein epsilon)
	old_estimate = Q[y, x, action_idx]
	temporal_difference = (target - old_estimate) # td error: wie stark weichen die schätzungen vom Messwert ab?
	Q[y, x, action_idx] += LEARNING_RATE * temporal_difference  # td error


# Q-learning loop
for episode in range(EPISODES):
	state = (GRID_SIZE - 1, GRID_SIZE - 1) # Start Position
	EPSILON = 1 - episode/ EPISODES  # Decaying epsilon for exploration
	while not is_terminal(state):
		action_idx = epsilon_greedy(state)
		action = ACTIONS[action_idx]
		next_state, reward = step(state, action)
		update_q_values(state, action_idx, reward, next_state)
		state = next_state

# Extract final policy
policy = np.full((GRID_SIZE, GRID_SIZE), '', dtype=object)
for y in range(GRID_SIZE):
	for x in range(GRID_SIZE):
		if is_terminal((y, x)):
			policy[y, x] = '🟩'
		else:
			best_action = np.argmax(Q[y, x]) # greedy ohne epsilon
			policy[y, x] = ACTION_SYMBOLS[ACTIONS[best_action]]

# Display policy
for row in policy:
	print(' '.join(row))