""" ==================================================================================

This code is adapted from the following tutorial:
   https://www.tensorflow.org/api_docs/python/tf/data/experimental/make_csv_dataset

This code DOES NOT REPLACE this tutorial, its ONLY OBJECTIVE is to demonstrate
how tensorflow applications can be easily ran on Olympe, using our GPU nodes

USAGE:

1/ Install tensorflow 2.x as explained here: https://www.calmip.univ-toulouse.fr/spip.php?article770

2/ Retrieve the data files while connected on the front nodes:
   wget https://storage.googleapis.com/tf-datasets/titanic/train.csv
   wget https://storage.googleapis.com/tf-datasets/titanic/eval.csv

3/ With an editor, create a file called slurm_script.bash and copy-paste the following:

#! /bin/bash

#SBATCH -N 1
#SBATCH -n 9
#SBATCH --ntasks-per-core=1
#SBATCH --mem 50G
#SBATCH --gres=gpu:1

module purge
module load tensorflow/2.4.1 

EXE=path/to/load-csv.py

WORK=${SLURM_JOBID}
mkdir $WORK 

cp $0 $WORK
cd $WORK
ln -s ../eval.csv
ln -s ../train.csv

python $EXE

# END OF SBATCH SCRIPT ----------------

4/ Run your job:

   sbatch script_slurm.bash

5/ Check the execution of your job with:
   squeue -u $(whoami)

6/ When your job is running, check that everything is OK with:
   placement --checkme
   (NOTE - not very useful for this toy code, but mandotory when running YOUR code !)

7/ Read the output:

   less slurm-12345.out
   12345 being replaced with the actual jobid 

8/ That's all !

   Have fun... or send a mail to support.calmip@univ-toulouse.fr

=====================================================================================""" 

# SETUP -------------------------------------------------------------------------------
import functools
import numpy as np
import tensorflow as tf


# Which version of tensorflow are we using ?
print ("USING TENSORFLOW VERSION " + tf.__version__)
print ()

train_file_path = "train.csv"
test_file_path  = "eval.csv"

# Make numpy values easier to read.
np.set_printoptions(precision=3, suppress=True)

# LOAD DATA ---------------------------------------------------------------------------
# The only column you need to identify explicitly is the one with the value that the model is intended to predict. 
LABEL_COLUMN = 'survived'
LABELS = [0, 1]

# get_dataset
def get_dataset(file_path, **kwargs):
  dataset = tf.data.experimental.make_csv_dataset(
      file_path,
      batch_size=5, # Artificially small to make examples easier to show.
      label_name=LABEL_COLUMN,
      na_value="?",
      num_epochs=1,
      ignore_errors=True, 
      **kwargs)
  return dataset

raw_train_data = get_dataset(train_file_path)
raw_test_data = get_dataset(test_file_path)

# ---------------------------------------------------------
def show_batch(dataset):
  for batch, label in dataset.take(1):
    for key, value in batch.items():
      print("{:20s}: {}".format(key,value.numpy()))

show_batch(raw_train_data)

# CONTINUOUS DATA
# So define a more general preprocessor that selects a list of numeric features and packs them into a single column:
class PackNumericFeatures(object):
  def __init__(self, names):
    self.names = names

  def __call__(self, features, labels):
    numeric_features = [features.pop(name) for name in self.names]
    numeric_features = [tf.cast(feat, tf.float32) for feat in numeric_features]
    numeric_features = tf.stack(numeric_features, axis=-1)
    features['numeric'] = numeric_features

    return features, labels

NUMERIC_FEATURES = ['age','n_siblings_spouses','parch', 'fare']
packed_train_data = raw_train_data.map(
    PackNumericFeatures(NUMERIC_FEATURES))

packed_test_data = raw_test_data.map(
    PackNumericFeatures(NUMERIC_FEATURES))

show_batch(packed_train_data)

# Data Normalization = Continuous data should always be normalized
import pandas as pd
desc = pd.read_csv(train_file_path)[NUMERIC_FEATURES].describe()
#desc
MEAN = np.array(desc.T['mean'])
STD = np.array(desc.T['std'])

def normalize_numeric_data(data, mean, std):
  # Center the data
  return (data-mean)/std

# See what you just created.
normalizer = functools.partial(normalize_numeric_data, mean=MEAN, std=STD)

numeric_column = tf.feature_column.numeric_column('numeric', normalizer_fn=normalizer, shape=[len(NUMERIC_FEATURES)])
numeric_columns = [numeric_column]
numeric_column

numeric_layer = tf.keras.layers.DenseFeatures(numeric_columns)

# CATEGORICAL DATA
CATEGORIES = {
    'sex': ['male', 'female'],
    'class' : ['First', 'Second', 'Third'],
    'deck' : ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'],
    'embark_town' : ['Cherbourg', 'Southhampton', 'Queenstown'],
    'alone' : ['y', 'n']
}

categorical_columns = []
for feature, vocab in CATEGORIES.items():
  cat_col = tf.feature_column.categorical_column_with_vocabulary_list(
        key=feature, vocabulary_list=vocab)
  categorical_columns.append(tf.feature_column.indicator_column(cat_col))

# COMBINED PREPROCESSING LAYER
preprocessing_layer = tf.keras.layers.DenseFeatures(categorical_columns+numeric_columns)

# BUILD THE MODEL
model = tf.keras.Sequential([
  preprocessing_layer,
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(1),
])

model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

# TRAIN, EVALUATE, PREDICT
train_data = packed_train_data.shuffle(500)
test_data = packed_test_data

model.fit(train_data, epochs=20)

# CHECK THE ACCURACY
test_loss, test_accuracy = model.evaluate(test_data)

print('\n\nTest Loss {}, Test Accuracy {}\n\n'.format(test_loss, test_accuracy))

# Use tf.keras.Model.predict to infer labels on a batch or a dataset of batches.
predictions = model.predict(test_data)

# Show some results
for prediction, survived in zip(predictions[:10], list(test_data)[0][1][:10]):
  prediction = tf.sigmoid(prediction).numpy()
  print("Predicted survival: {:.2%}".format(prediction[0]),
        " | Actual outcome: ",
        ("SURVIVED" if bool(survived) else "DIED"))

# -----------------------------------------------------------------

