applicata regola nuova, continua a non andare bene nemmeno con lo xor

This commit is contained in:
2025-02-22 08:37:13 +01:00
parent 46cca564f4
commit e787843a67
4 changed files with 80 additions and 40 deletions

View File

@@ -1,14 +1,11 @@
#include <time.h>
#include "percettroni.h"
//#include "mnist/mnist_manager.h"
//#include "cifar_10/cifar10_manager.h"
#include "xor_manager.h"
//Scelgo quale categoria voglio identificare. nel caso dello xor -1
#define CATEGORIA -1
#define NUM_LAYERS 2
#define PERCETTRONI_LAYER_0 2
#define MAX_EPOCHE 10000
#define NUM_LAYERS 3
#define PERCETTRONI_LAYER_0 4
#define MAX_EPOCHE 1000000
byte get_out_corretto(byte);
void stampa_layer_indirizzo(Layer*);
@@ -29,7 +26,7 @@ void main() {
ReteNeurale rete_neurale;
ReteNeurale *puntatore_rete = caricaReteNeurale(file_pesi);
if(puntatore_rete == NULL) {
rete_neurale = init_rete_neurale(NUM_LAYERS, PERCETTRONI_LAYER_0, N_INPUTS);
rete_neurale = inizializza_rete_neurale(NUM_LAYERS, PERCETTRONI_LAYER_0, N_INPUTS);
} else {
rete_neurale = *puntatore_rete;
free(puntatore_rete);
@@ -62,10 +59,8 @@ void main() {
sigmoidi[j] = funzioni_attivazione_layer_double(rete_neurale.layers[j], sigmoidi[j-1]);
}
byte output_corretto = get_out_corretto(set.istanze[indice_set].classificazione);
//Se prevede male
if(previsione(sigmoidi[NUM_LAYERS-1][0]) != output_corretto) {
@@ -90,9 +85,13 @@ void main() {
*/
double gradiente_errore = (output_corretto - sigmoidi[NUM_LAYERS-1][0]);
//Derivata funzione di perdita
double gradiente_errore = -(output_corretto - sigmoidi[NUM_LAYERS-1][0]);
//Derivata funzione attivazione
double derivata_sigmoide_out = sigmoidi[NUM_LAYERS-1][0] * (1 - sigmoidi[NUM_LAYERS-1][0]);
if (derivata_sigmoide_out == 0.0) derivata_sigmoide_out = 1;
//Gradiente del percettrone output
gradienti[NUM_LAYERS-1][0] = gradiente_errore * derivata_sigmoide_out;
//Ricorda di partire dal penultimo layer in quanto l'ultimo è già fatto
@@ -100,20 +99,38 @@ void main() {
/* A questo punto ho tutti i gradienti dei percettroni, non mi resta che trovare i gradienti dei pesi e correggerli
*/
//Correggo il livello output
for(int indice_peso = 0; indice_peso < rete_neurale.layers[NUM_LAYERS-1].percettroni[0].size; indice_peso ++) {
//Determino gradiente del peso
double gradiente_peso = gradienti[NUM_LAYERS-1][0] * sigmoidi[NUM_LAYERS-2][indice_peso];
rete_neurale.layers[NUM_LAYERS-1].percettroni[0].pesi[indice_peso] -= gradiente_peso * LRE;
}
rete_neurale.layers[NUM_LAYERS-1].percettroni[0].bias -= gradienti[NUM_LAYERS-1][0] * LRE;
//Applico la correzione dal penultimo layer andando indietro fino al secondo (il primo si fa diverso)
for(int indice_layer = NUM_LAYERS - 2; indice_layer > 0; indice_layer--) {
for(int indice_layer = NUM_LAYERS - 2; indice_layer >= 0; indice_layer--) {
//Applico la correzione a tutti i percettroni del layer dal primo a seguire
for(int indice_percettrone = 0; indice_percettrone <= rete_neurale.layers[indice_layer].size; indice_percettrone++) {
correggi_pesi_percettrone();
for(int indice_percettrone = 0; indice_percettrone < rete_neurale.layers[indice_layer].size; indice_percettrone++) {
//Devo prendere il gradiente del percettrone e moltiplicarlo con gli input associati ai pesi
if(indice_layer != 0) {
correggi_pesi_percettrone_double(&rete_neurale.layers[indice_layer].percettroni[indice_percettrone], indice_layer, sigmoidi, gradienti[indice_layer][indice_percettrone]);
} else {
correggi_pesi_percettrone_byte(&rete_neurale.layers[0].percettroni[indice_percettrone], set.istanze[indice_set], gradienti[0][indice_percettrone], indice_percettrone);
}
}
}
//gradienti[NUM_LAYERS-1][0] = (output_corretto - sigmoidi[NUM_LAYERS-1][0]);
errore_totale += gradienti[NUM_LAYERS-1][0];
correggi_layer_interni(&rete_neurale, gradienti, sigmoidi);
correggi_layer_input(&rete_neurale.layers[0], gradienti, sigmoidi, set.istanze[indice_set].dati, NUM_LAYERS);
//correggi_layer_interni(&rete_neurale, gradienti, sigmoidi);
//correggi_layer_input(&rete_neurale.layers[0], gradienti, sigmoidi, set.istanze[indice_set].dati, NUM_LAYERS);
}
else
{
@@ -126,9 +143,8 @@ void main() {
}
}
printf("Errore: %f\n", errore_totale);
//printf("\tRisposte corrette: %d\n", corrette);
printf("Errore: %f\n", errore_totale / 4);
printf("\tRisposte corrette: %d\n", corrette);
if(corrette == set.size) {
break;