#!/usr/bin/env python3
import pygame
import sys
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque


# Hyperparameters
BATCH_SIZE = 64
GAMMA = 0.9
EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 10000
TARGET_UPDATE = 100
LEARNING_RATE = 0.0002

# Ball settings
BALL_SIZE = 20
BALL_SPEED, BALL_SPEED_Y = 5, 5  
blind_training_steps = 100000  # Number of steps to train without graphics


# Initialize Pygame
pygame.init()

# Screen dimensions
WIDTH, HEIGHT = 800, 600
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption('Pong DQL')

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)

# Paddle settings
PADDLE_WIDTH, PADDLE_HEIGHT = 10, 100
PADDLE_SPEED = 7


# Paddle positions
left_paddle = pygame.Rect(10, (HEIGHT - PADDLE_HEIGHT) // 2, PADDLE_WIDTH, PADDLE_HEIGHT)
right_paddle = pygame.Rect(WIDTH - 20, (HEIGHT - PADDLE_HEIGHT) // 2, PADDLE_WIDTH, PADDLE_HEIGHT)

# Ball position
ball = pygame.Rect(WIDTH // 2 - BALL_SIZE // 2, HEIGHT // 2 - BALL_SIZE // 2, BALL_SIZE, BALL_SIZE)

# DQL settings
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(6, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

def get_state():
    return np.array([ball.x, ball.y, ball.x - right_paddle.x, ball.y - right_paddle.y, right_paddle.x, right_paddle.y], dtype=np.float32)

def select_action(state, steps_done):
    epsilon = max(0, 1 - (steps_done / blind_training_steps))
    if random.random() < epsilon:
        return random.randint(0, 1)
    with torch.no_grad():
        return dqn(torch.tensor(state).unsqueeze(0)).argmax().item()

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = random.sample(memory, BATCH_SIZE)
    batch = list(zip(*transitions))

    state_batch = torch.tensor(np.array(batch[0]))  # Convert list of numpy arrays to a single numpy array
    action_batch = torch.tensor(batch[1]).unsqueeze(1)
    reward_batch = torch.tensor(batch[2])
    next_state_batch = torch.tensor(np.array(batch[3]))  # Convert list of numpy arrays to a single numpy array

    current_q_values = dqn(state_batch).gather(1, action_batch)
    next_q_values = target_dqn(next_state_batch).max(1)[0].detach()
    expected_q_values = reward_batch + (GAMMA * next_q_values)

    loss = criterion(current_q_values, expected_q_values.unsqueeze(1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Initialize DQN
dqn = DQN()
# load model
save_path = 'dqn2.pth'
try:
    dqn.load_state_dict(torch.load(save_path))
except:
    print('No model found, starting from scratch')
    pass
target_dqn = DQN()
target_dqn.load_state_dict(dqn.state_dict())
target_dqn.eval()

optimizer = optim.Adam(dqn.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()
memory = deque(maxlen=10000)

def calculate_distance_reward():
    distance = abs(ball.centery - right_paddle.centery)
    return -distance / HEIGHT

# Main game loop
steps_done = 0
while True:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            pygame.quit()
            sys.exit()

    # AI for left paddle
    if left_paddle.centery < ball.centery and left_paddle.bottom < HEIGHT:
        left_paddle.y += PADDLE_SPEED
    if left_paddle.centery > ball.centery and left_paddle.top > 0:
        left_paddle.y -= PADDLE_SPEED

    # DQL for right paddle
    state = get_state()
    action = select_action(state, steps_done)
    if action == 0 and right_paddle.top > 0:
        right_paddle.y -= PADDLE_SPEED
    if action == 1 and right_paddle.bottom < HEIGHT:
        right_paddle.y += PADDLE_SPEED

    # Move the ball
    ball.x += BALL_SPEED
    ball.y += BALL_SPEED_Y

    # Ball collision with top and bottom
    if ball.top <= 0 or ball.bottom >= HEIGHT:
        BALL_SPEED_Y = -BALL_SPEED_Y

    # Ball collision with paddles
    if ball.colliderect(left_paddle) or ball.colliderect(right_paddle):
        BALL_SPEED = -BALL_SPEED

    # Ball out of bounds
    if ball.left <= 0 or ball.right >= WIDTH:
        ball.x, ball.y = WIDTH // 2 - BALL_SIZE // 2, HEIGHT // 2 - BALL_SIZE // 2
        BALL_SPEED = abs(BALL_SPEED)

    # Reward calculation
    reward = 1 if ball.colliderect(right_paddle) else -1
    # reward += calculate_distance_reward() # Fast learning, cheating!

    # Store transition in memory
    next_state = get_state()
    memory.append((state, action, reward, next_state))

    # Optimize the model
    optimize_model()

    # Update the target network
    if steps_done % TARGET_UPDATE == 0:
        target_dqn.load_state_dict(dqn.state_dict())
        # save model
        torch.save(dqn.state_dict(), save_path)

    steps_done += 1
    print(steps_done, end='\r')

    # Draw everything
    if steps_done >= blind_training_steps:  # Enable graphics after initial training
        screen.fill(BLACK)
        pygame.draw.rect(screen, WHITE, left_paddle)
        pygame.draw.rect(screen, WHITE, right_paddle)
        pygame.draw.ellipse(screen, WHITE, ball)
        pygame.draw.aaline(screen, WHITE, (WIDTH // 2, 0), (WIDTH // 2, HEIGHT))

        pygame.display.flip()
        pygame.time.Clock().tick(60) #  maximum frame rate in FPS.
    elif steps_done < 2:
        print('Training without graphics', end='\r')
        screen.fill(BLACK)
        font = pygame.font.Font(None, 36)
        text = font.render('Training without graphics', True, WHITE)
        screen.blit(text, (WIDTH // 2 - text.get_width() // 2, HEIGHT // 2 - text.get_height() // 2))
        pygame.display.flip()




