from random import random

class Casino:
		def __init__(self):
				self.bandits = [2,5,9,4,7] # maximaler Gewinn pro Bandit

		def step(self, action, state=None): # wähle Automat 0…5
			return self.bandits[action] * random() * 2

env = Casino()
# Aufgabe: Wähle den Automaten mit dem höchsten Gewinn. PER POLICY FINDEN!!

# epsilon = "automatic see below"  # Exploration rate for epsilon-greedy policy
epsilon = 1

Besuche = [0,0,0,0,0] #  count visits to each bandit
Gewinne = [0,0,0,0,0] #  accumulate rewards for each bandit
Q = [0, 0, 0, 0, 0] # Q-Werte für jeden Bandit  "Erwartete Belohnung pro Bandit/Aktion"
# DurchschnittsGewinnProAktion  ø G(A) = Mittelwert der Gewinne pro Aktion = Erwartungswert des Gewinn pro Aktion
# Q(A) = E(G) = E(R) hier! weil ein Spiel = ein Zug ist

def our_policy(epsilon):
		if random() < epsilon:  # Explore with probability epsilon
			action = int(random() * len(env.bandits))  # Random action
		else:  # Exploit the best known action: argmax
			action = max(range(len(Q)), key=lambda a: Q[a]) # argmax
		return action

Averages = [] # zum Plotten der Gewinne
def train(epochs=1000):
		global epsilon
		total_reward = 0
		for step in range(epochs):
			if step>100:
				epsilon = 0 # KEIN ZUFALL MEHR
			# epsilon = max(0, (1 - 5*step / epochs))  # Decrease epsilon over time
			action = our_policy(epsilon)
			reward = env.step(action)
			total_reward += reward
			Besuche[action] += 1
			Gewinne[action] += reward
			Q[action] = Gewinne[action] / Besuche[action]
			if step  % 101 == 0:
				average_reward = total_reward / (step + 1)  # Erfolg messen mit moving average reward
				Averages.append(average_reward)
				print(f"Episode {step}: Action {action}, Average Reward {average_reward:.2f} total_reward {total_reward:.2f}")

train()


# plot
import matplotlib.pyplot as plt
def plot_results():
		plt.figure(figsize=(10, 5))
		plt.plot(Averages, label='Durchschnitts Gewinn')
		plt.xlabel('Episodes * 100')
		plt.ylabel('Gewinn')
		plt.title('Durchschnitts Gewinn über Episoden')
		plt.legend()
		plt.grid()
		plt.show()
plot_results()