#!/usr/local/bin/python3
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms

# hyperparameter
learning_rate = 0.01
# training_epochs = 125 # Anzahl Trainingsdurchläufe => 94% accuracy in 1 minute not bad
training_epochs = 1000 # Anzahl Trainingsdurchläufe => 96% accuracy in 5 minute not bad
batch_size = 60000    # Anzahl der Trainingsdaten

# torch.set_num_threads(1) # more doesn't help!!

# MNIST dataset
mnist_train = dsets.MNIST(root='data/MNIST/', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = dsets.MNIST(root='data/MNIST/', train=False, transform=transforms.ToTensor(), download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)

# print(len(mnist_train), len(mnist_test))

# just a linear model of matrix multiplication!!
# model = torch.nn.Linear( 28 * 28 , 10, bias=True) # most trivial

HIDDEN_SIZE = 5000

model = torch.nn.Sequential( #
	torch.nn.Linear(28 * 28, HIDDEN_SIZE, bias=True), # Große Matrix 50 features pro pixel!
	# torch.nn.LeakyReLU(), # max(0,x) # link funktion  tanh sigmoid
	torch.nn.Dropout(p=0.2), # 20% probability to zero out
	torch.nn.Linear(HIDDEN_SIZE, 10, bias=True), # layer
	# torch.nn.Softmax(dim=10) # softmax wird bei Classification nicht benötigt
)

# Load weights if available
try:
	model.load_state_dict(torch.load('weights/mnist_snapshot.pth', weights_only=True))
	print("pre-trained model loaded")
except:
	print("No pre-trained model found, starting from scratch.")


# Auffächern der Matrix in 50 Features
# Anzahl output features <= Anzahl input features * 100

# Kompimieren Extrahieren der Matrix Information in 10 Features
# Anzahl output features >= Anzahl input features / 100
# Datenausdünnung innerhalb des Netzes


# define optimizer for cost/loss ≈ error ≈ distance
criterion = torch.nn.CrossEntropyLoss()  # Softmax is internally computed.
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

def accuracy():
	# Test the model using test sets
	with torch.no_grad(): # no need to calculate gradient, just evaluate
		model.eval()
		X_test = mnist_test.data.view(-1, 28 * 28).float()
		prediction = model(X_test) # current Q value
		correct_prediction = torch.argmax(prediction, 1) == mnist_test.targets
		accuracy = float(correct_prediction.sum()) / len(prediction)
		print('Accuracy:', accuracy)
		if accuracy > 0.90:
			print('Accuracy:', accuracy, '=> stop training')
			exit(0)

for epoch in range(training_epochs):
	accuracy()
	for X, Y in data_loader:
		model.train()
		optimizer.zero_grad() # reset gradients to zero
		hypothesis = model(X.view(-1, 28 * 28)) # reshape input image into [batch_size by 784]
		cost = criterion(hypothesis, Y) # Verlustfunktion
		cost.backward() # backpropagation, ziehe Gradienten ab
		optimizer.step() #

	print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(cost), flush=True)

# Model Gewichte speichern
torch.save(model.state_dict(), 'weights/mnist_snapshot.pth')
model.load_state_dict(torch.load('weights/mnist_snapshot.pth', weights_only=True))  # restore weights

# Model komplett speichern
torch.save(model, 'mnist_model_full.pth')
model = torch.load('weights/mnist_model_full.pth')
# model.save('mnist_snapshot.pth')
# model.load('mnist_snapshot.pth')  # restore weights
# model = torch.load('weights/mnist_snapshot.pth')  # restore architecture! and weights

# Visualisierung eines MNIST Bildes
import matplotlib.pyplot as plt
import numpy as np
import random
idx = random.randint(0, len(mnist_test)-1)
img = mnist_test.data[idx].numpy()
plt.imshow(img, cmap='gray')
plt.show()

# Visualisierung der Gewichte
import matplotlib.pyplot as plt
import numpy as np
# print(model.state_dict())
# print(model.state_dict().keys())
# print(model.state_dict().values())
