import numpy as np
import random

# Parameters
GRID_SIZE = 4
ALPHA = 0.1
EPSILON = "auto"  #0.1
EPISODES = 5000
TERMINAL_STATE = (0, 0)
Ɣ = .9 # bei 1 läuft der Algorithmus im Kreis: nicht konvergierend
ACTIONS = ['up', 'down', 'left', 'right']
ACTION_EFFECTS = {
		'up': (-1, 0), # (y, x) -> (y-1, x) position change
		'down': (1, 0),
		'left': (0, -1),
		'right': (0, 1)
}

ACTION_SYMBOLS = {
		'up': '⬆',
		'down': '⬇',
		'left': '⬅',
		'right':  '➡'
}

# Reward matrix
rewards = np.full((GRID_SIZE, GRID_SIZE), -1.0) # -1 pro Schritt
rewards[:-1, 1] = -10  # 🟥 Red wall cells
rewards[3, 3] = -10


# Q-table initialization
Q = np.zeros((GRID_SIZE, GRID_SIZE, len(ACTIONS)))

def is_terminal(state):
	return state == TERMINAL_STATE

def step(state, action):
	dy, dx = ACTION_EFFECTS[action]
	y, x = state
	ny, nx = y + dy, x + dx # next state coordinates
	if 0 <= ny < GRID_SIZE and 0 <= nx < GRID_SIZE:
		next_state = (ny, nx)
		reward = rewards[next_state]
	else:
		next_state = state # illegal move, stay in place
		reward =  -10 # penalty for illegal move
	return next_state, reward

def epsilon_greedy(state):
	if random.random() < EPSILON:
		return random.choice(range(len(ACTIONS)))
	else:
		y, x = state
		return np.argmax(Q[y, x])

def update_q_values(state, action_idx, reward, next_state):
		y, x = state
		ny, nx = next_state
		next_action_index = epsilon_greedy(next_state)
		policy_next_reward = Q[ny, nx, next_action_index]
		target = reward + Ɣ * policy_next_reward # on policy
		Q[y, x, action_idx] += ALPHA * (target - Q[y, x, action_idx])


# SARSA loop
for episode in range(EPISODES):
	state = (GRID_SIZE - 1, GRID_SIZE - 1)
	EPSILON = 1 - episode/ EPISODES  # Decaying epsilon for exploration
	# last_state = state
	# last_action_idx
	while not is_terminal(state):
		action_idx = epsilon_greedy(state)
		action = ACTIONS[action_idx]
		next_state, reward = step(state, action)
		update_q_values(state, action_idx, reward, next_state)
		state = next_state

# Extract final policy
policy = np.full((GRID_SIZE, GRID_SIZE), '', dtype=object)
for y in range(GRID_SIZE):
	for x in range(GRID_SIZE):
		if is_terminal((y, x)):
			policy[y, x] = '🟩'
		else:
			best_action = np.argmax(Q[y, x])
			policy[y, x] = ACTION_SYMBOLS[ACTIONS[best_action]]

# Display policy
for row in policy:
	print(' '.join(row))