wcag_AI_validation/scripts/finetuning_inference_time_s.../analisi_training.ipynb

53 KiB

In [1]:
import json
import matplotlib.pyplot as plt

try:
    # Load the trainer state from the JSON file
    with open('C:\cartella_condivisa\MachineLearning\HIISlab\\accessibility\\notebook_miei\LLM_accessibility_validator\scripts\modello_finetunato\gemma-finetuned-wcag_google_gemma-3-4b-it\checkpoint-386\\trainer_state.json', 'r') as f:
        trainer_state = json.load(f)
    
    # Access the log history
    log_history = trainer_state['log_history']
    
    # Extract training / validation loss
    train_losses = [log["loss"] for log in log_history if "loss" in log]
    epoch_train = [log["epoch"] for log in log_history if "loss" in log]
    eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
    epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]
    
    # Plot the training loss
    plt.figure(figsize=(10, 6))
    plt.plot(epoch_train, train_losses, label="Training Loss", marker='o')
    plt.plot(epoch_eval, eval_losses, label="Validation Loss", marker='s')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss per Epoch")
    plt.legend()
    plt.grid(True)
    plt.savefig("training_validation_loss.png", dpi=300, bbox_inches='tight')
    print("Plot saved successfully as 'training_validation_loss.png'")
    
except FileNotFoundError:
    print("Error: trainer_state.json file not found in the current directory")
except json.JSONDecodeError:
    print("Error: Invalid JSON format in trainer_state.json")
except KeyError as e:
    print(f"Error: Missing key in trainer_state.json: {e}")
except Exception as e:
    print(f"Error plotting loss curves: {e}")
Plot saved successfully as 'training_validation_loss.png'
In [3]:
train_losses = [log["loss"] for log in log_history if "loss" in log]
train_losses
Out[3]:
[2.3849,
 1.8586,
 1.3981,
 0.9627,
 0.5417,
 0.4255,
 0.3821,
 0.2748,
 0.2816,
 0.2828,
 0.3021,
 0.2577,
 0.2615,
 0.2625,
 0.2563,
 0.1982,
 0.2049,
 0.184,
 0.2318,
 0.179,
 0.1723,
 0.206,
 0.1662,
 0.1601,
 0.1844,
 0.1933,
 0.1735,
 0.2075,
 0.1678,
 0.124,
 0.1692,
 0.1691,
 0.1348,
 0.162,
 0.1676,
 0.1481,
 0.2371,
 0.1879,
 0.1574,
 0.1517,
 0.1618,
 0.1406,
 0.1655,
 0.144,
 0.1587,
 0.1951,
 0.0937,
 0.1655,
 0.1462,
 0.1481,
 0.1345,
 0.1589,
 0.1524,
 0.1382,
 0.1297,
 0.1381,
 0.1307,
 0.1307,
 0.1606,
 0.1956,
 0.1129,
 0.1245,
 0.2099,
 0.1413,
 0.114,
 0.1246,
 0.1536,
 0.1429,
 0.1331,
 0.1333,
 0.1333,
 0.1212,
 0.1087,
 0.1107,
 0.1234,
 0.1178,
 0.1044]
In [2]:
train_losses = [log["loss"] for log in log_history if "loss" in log]
train_losses
Out[2]:
[2.3642,
 1.8398,
 1.373,
 0.939,
 0.5372,
 0.4201,
 0.3826,
 0.2763,
 0.2852,
 0.2859,
 0.3063,
 0.2583,
 0.2643,
 0.26,
 0.262,
 0.2012,
 0.2069,
 0.1823,
 0.2331,
 0.1813,
 0.1705,
 0.2069,
 0.1706,
 0.1592,
 0.1857,
 0.1902,
 0.1731,
 0.2095,
 0.1649,
 0.1259,
 0.1698,
 0.1618,
 0.1342,
 0.1661,
 0.1652,
 0.152,
 0.2305,
 0.1884,
 0.1544,
 0.152,
 0.1596,
 0.1406,
 0.1658,
 0.1437,
 0.1578,
 0.1922,
 0.093,
 0.1675,
 0.1467,
 0.1495,
 0.1357,
 0.1595,
 0.1492,
 0.1409,
 0.1276,
 0.1427,
 0.1291,
 0.1308,
 0.1599,
 0.1966,
 0.1123,
 0.124,
 0.2122,
 0.142,
 0.1126,
 0.1276,
 0.1533,
 0.1418,
 0.1328,
 0.1362,
 0.1321,
 0.1236,
 0.108,
 0.1094,
 0.1242,
 0.1162,
 0.1027]