import numpy as np
import random

import torch

# Parameters
GRID_SIZE = 4
LEARNING_RATE = 0.01
EPSILON = "auto"  #0.1
EPISODES = 250
TERMINAL_STATE = (0, 0)
Ɣ = .9
ACTIONS = ['up', 'down', 'left', 'right']
ACTION_EFFECTS = {
		'up': (-1, 0), # (y, x) -> (y-1, x) position change
		'down': (1, 0),
		'left': (0, -1),
		'right': (0, 1)
}

ACTION_SYMBOLS = {
		'up': '⬆',
		'down': '⬇',
		'left': '⬅',
		'right':  '➡'
}

# Reward matrix
rewards = np.full((GRID_SIZE, GRID_SIZE), -1.0) # -1 pro Schritt
rewards[:-1, 1] = -10  # 🟥 Red wall cells
rewards[3, 3] = -10

action_shape = len(ACTION_SYMBOLS)

Q = torch.nn.Linear(GRID_SIZE * GRID_SIZE, len(ACTIONS))

optimizer = torch.optim.Adam(Q.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.MSELoss()

def is_terminal(state):
	return state == TERMINAL_STATE

def step(state, action):
	dy, dx = ACTION_EFFECTS[action]
	y, x = state
	ny, nx = y + dy, x + dx # next state coordinates
	if 0 <= ny < GRID_SIZE and 0 <= nx < GRID_SIZE:
		next_state = (ny, nx)
		reward = rewards[next_state]
	else:
		next_state = state # illegal move, stay in place
		reward =  -10 # penalty for illegal move
	return next_state, reward

def epsilon_greedy(state):
	if random.random() < EPSILON:
		return random.choice(range(len(ACTIONS)))
	else:
		y, x = state
		state_vec = torch.zeros((1, GRID_SIZE * GRID_SIZE))
		state_vec[0, y * GRID_SIZE + x] = 1.0
		with torch.no_grad():
			q_values = Q(state_vec)
		return int(torch.argmax(q_values).item())

def update_q_values(state, action_idx, reward, next_state):
	y, x = state
	ny, nx = next_state

	state_vec = torch.zeros((1, GRID_SIZE * GRID_SIZE))
	state_vec[0, y * GRID_SIZE + x] = 1.0
	next_state_vec = torch.zeros((1, GRID_SIZE * GRID_SIZE))
	next_state_vec[0, ny * GRID_SIZE + nx] = 1.0

	q_values = Q(state_vec)
	q_value = q_values[0, action_idx]

	with torch.no_grad():
		next_q_values = Q(next_state_vec)
		max_next_q = torch.max(next_q_values)
		target = reward + Ɣ * max_next_q

	loss = loss_fn(q_value, target)

	optimizer.zero_grad()
	loss.backward()
	optimizer.step()


# Q-learning loop
for episode in range(EPISODES):
	state = (GRID_SIZE - 1, GRID_SIZE - 1) # Start Position
	EPSILON = 1 - episode/ EPISODES  # Decaying epsilon for exploration
	while not is_terminal(state):
		action_idx = epsilon_greedy(state)
		action = ACTIONS[action_idx]
		next_state, reward = step(state, action)
		update_q_values(state, action_idx, reward, next_state)
		state = next_state

# Extract final policy
policy = np.full((GRID_SIZE, GRID_SIZE), '', dtype=object)
for y in range(GRID_SIZE):
	for x in range(GRID_SIZE):
		if is_terminal((y, x)):
			policy[y, x] = '🟩'
		else:
			state_vec = torch.zeros((1, GRID_SIZE * GRID_SIZE))
			state_vec[0, y * GRID_SIZE + x] = 1.0
			with torch.no_grad():
				q_values = Q(state_vec)
			best_action = int(torch.argmax(q_values).item())
			policy[y, x] = ACTION_SYMBOLS[ACTIONS[best_action]]

# Display policy
for row in policy:
	print(' '.join(row))