# -*- coding: utf-8 -*-
"""
Created on Wed Aug 13 11:08:23 2025

@author: TAKO
"""

# RL^2 Actor-Critic mit zufälligen FrozenLake-Tasks (Meta-RL)
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from gymnasium.envs.toy_text.frozen_lake import generate_random_map  # Map-Generator

# ---------- Hyperparams ----------
HIDDEN_SIZE = 64
LR = 1e-3
EPISODES_PER_TASK = 100
TASKS = 150  # mehr Tasks => besseres Meta-Lernen
STEPS = 50
ENTROPY_COEF = 0.01
VALUE_COEF = 0.5
GAMMA = 0.95
MAP_SIZE = 6  # 4, 6, 8 … (größer = schwerer)
MAP_P = 0.85  # Wahrscheinlichkeit für "F" (freies Feld); höher = weniger Löcher
SLIPPERY = False  # für klares Signal off lassen
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ---------- Utils ----------
def one_hot(idx, size, device=DEVICE):
	t = torch.zeros(size, dtype=torch.float32, device=device)
	t[idx] = 1.0
	return t


def compute_returns(rewards, gamma):
	R = 0.0
	out = []
	for r in reversed(rewards):
		R = r + gamma * R
		out.append(R)
	out.reverse()
	return torch.tensor(out, dtype=torch.float32, device=DEVICE)


def moving_avg(x, k=20):
	x = np.array(x, dtype=float)
	w = np.ones(k) / k
	y = np.convolve(np.nan_to_num(x), w, mode="valid")
	pad = np.full(k - 1, np.nan)
	return np.concatenate([pad, y])


# ---------- Model ----------
class RL2ActorCritic(nn.Module):
	def __init__(self, obs_size, action_size, hidden_size):
		super().__init__()
		input_dim = obs_size + action_size + 1  # obs_onehot + last_action_onehot + last_reward
		self.fc_in = nn.Linear(input_dim, hidden_size)
		self.rnn = nn.GRU(hidden_size, hidden_size)
		self.pi = nn.Linear(hidden_size, action_size)
		self.v = nn.Linear(hidden_size, 1)

	def forward(self, obs, last_action, last_reward, h):
		x = torch.cat([obs, last_action, last_reward], dim=-1)  # (1, input_dim)
		x = torch.relu(self.fc_in(x)).unsqueeze(0)  # (seq=1,batch=1,hidden)
		out, h = self.rnn(x, h)  # (1,1,hidden)
		out = out.squeeze(0)  # (1, hidden)
		logits = self.pi(out)  # (1, A)
		value = self.v(out).squeeze(-1)  # (1,)
		return logits, value, h


# ---------- Meta-Task Factory ----------
def make_task_env(task_id, size=MAP_SIZE, p=MAP_P, slippery=SLIPPERY):
	# jede Task: neue Map
	render_mode = None  # keine Visualisierung während des Trainings
	if task_id == TASKS - 1: render_mode = "human"  # letzte Task für Visualisierung
	desc = generate_random_map(size=size, p=p)  # Liste von Strings
	env = gym.make("FrozenLake-v1", desc=desc, is_slippery=slippery, render_mode=render_mode)
	# (optional) Seed variieren, damit Startzustände deterministisch pro Task sind:
	env.reset(seed=task_id)
	return env, desc


# ---------- Probe-Env für Space-Größen ----------
probe_env = gym.make("FrozenLake-v1", desc=generate_random_map(size=MAP_SIZE, p=MAP_P), is_slippery=SLIPPERY)
obs_size = probe_env.observation_space.n
action_size = probe_env.action_space.n
probe_env.close()

net = RL2ActorCritic(obs_size, action_size, HIDDEN_SIZE).to(DEVICE)
opt = optim.Adam(net.parameters(), lr=LR)

# --- Load pre-trained model if available ---
try:
	net.load_state_dict(torch.load("rl2_frozenlake.pth", weights_only=False))
	print("Pre-trained model loaded.")
except Exception as e:
	print("No pre-trained model found, starting from scratch.")
	print(f"Error: {e}")

episode_returns, steps_to_goal = [], []

# ---------- Meta-Training ----------
for task in range(TASKS):
	env, desc = make_task_env(task)
	# Inner-Loop: Hidden-State über Episoden DERSELBEN Task behalten (RL^2)
	h = torch.zeros(1, 1, HIDDEN_SIZE, device=DEVICE)

	for ep in range(EPISODES_PER_TASK):
		obs, _ = env.reset()
		obs_oh = one_hot(obs, obs_size)
		last_action = torch.zeros(action_size, device=DEVICE)
		last_reward = torch.zeros(1, device=DEVICE)

		log_probs, values, rewards, entropies = [], [], [], []
		steps, reached = 0, False

		for t in range(STEPS):
			logits, value, h = net(
				obs_oh.unsqueeze(0),
				last_action.unsqueeze(0),
				last_reward.view(1, 1),
				h
			)
			dist = torch.distributions.Categorical(logits=logits)
			action = dist.sample()
			log_prob = dist.log_prob(action)
			entropy = dist.entropy()

			obs_next, reward, terminated, truncated, _ = env.step(action.item())

			# (optional) Reward-Shaping:
			# reward = reward - 0.01

			log_probs.append(log_prob)
			values.append(value)
			rewards.append(float(reward))
			entropies.append(entropy)

			steps += 1
			if reward == 1.0:
				reached = True

			obs_oh = one_hot(obs_next, obs_size)
			last_action = one_hot(action.item(), action_size)
			last_reward = torch.tensor([reward], dtype=torch.float32, device=DEVICE)

			if terminated or truncated:
				break

		# episodischer Update (ein Backward)
		R = compute_returns(rewards, GAMMA)
		V = torch.cat(values)  # (T,)
		A = R - V.detach()  # Advantage
		policy_loss = -(torch.cat(log_probs) * A).mean()
		value_loss = nn.functional.mse_loss(V, R)
		entropy_loss = -torch.cat(entropies).mean()

		loss = policy_loss + VALUE_COEF * value_loss + ENTROPY_COEF * entropy_loss
		opt.zero_grad(set_to_none=True)
		loss.backward()
		nn.utils.clip_grad_norm_(net.parameters(), 1.0)
		opt.step()

		# ganz wichtig: Graph für nächsten Episodenlauf in derselben Task trennen
		h = h.detach()

		episode_returns.append(sum(rewards))
		steps_to_goal.append(steps if reached else math.nan)

	env.close()

print("Meta-Training abgeschlossen.")

# --- save model ---
torch.save(net.state_dict(), "rl2_frozenlake.pth")

# ---------- Visualisierung ----------
import matplotlib.pyplot as plt

ma_ret = moving_avg(episode_returns, 20)
ma_steps = moving_avg(steps_to_goal, 20)

plt.figure()
plt.plot(episode_returns, '.', alpha=0.4, label="Return/Episode")
plt.plot(ma_ret, linewidth=2, label="Moving Avg (20)")
plt.title("Meta-RL (RL²) – Episoden-Return über zufällige FrozenLake-Tasks")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.legend()
plt.grid(True)
plt.show()

plt.figure()
plt.plot(steps_to_goal, '.', alpha=0.4, label="Schritte bis Ziel")
plt.plot(ma_steps, linewidth=2, label="Moving Avg (20)")
plt.title("Meta-RL (RL²) – Schritte bis Ziel")
plt.xlabel("Episode")
plt.ylabel("Schritte (NaN = kein Ziel)")
plt.legend()
plt.grid(True)
plt.show()
