Πώς να εκτυπώσετε τον αριθμό των παραμέτρων μοντέλου στο PyTorch

Pos Na Ektyposete Ton Arithmo Ton Parametron Montelou Sto Pytorch



Το PyTorch είναι ένα δημοφιλές πλαίσιο που χρησιμοποιείται στη βαθιά μάθηση. Προσφέρει πολλαπλές δυνατότητες για τη δημιουργία πολύπλοκων νευρωνικών δικτύων (NN). Οι χρήστες μπορούν να εκτελούν λειτουργίες εκπαίδευσης μοντέλων με αυτό το πλαίσιο. Ωστόσο, οι χρήστες πρέπει να είναι εξοικειωμένοι με τον αριθμό των παραμέτρων πριν εκπαιδεύσουν το μοντέλο.

Αυτό το ιστολόγιο θα περιγράφει:

Ποιες είναι οι παράμετροι στο PyTorch;

Στο PyTorch, το ' nn.Ενότητα Η κλάση χρησιμοποιείται για τον καθορισμό των μοντέλων. Περιλαμβάνει όλες τις λειτουργίες και τα επίπεδα που συνθέτουν το μοντέλο. Κάθε επίπεδο περιέχει ένα σύνολο παραμέτρων. Οι παράμετροι ενημερώνονται βασικά κατά την εκπαίδευση για να ελαχιστοποιηθεί το σφάλμα μεταξύ των πραγματικών τιμών και των προβλέψεων του μοντέλου.







Γιατί οι χρήστες πρέπει να ελέγχουν τις παραμέτρους του μοντέλου;

Κατά την εκπαίδευση του μοντέλου, οι χρήστες πρέπει να γνωρίζουν για τον αριθμό των παραμέτρων του μοντέλου τους, επειδή χρειάζεται πολλή μνήμη και επεξεργαστική ισχύ. Εάν είναι εξοικειωμένοι με τον αριθμό των παραμέτρων του μοντέλου, μπορούν εύκολα να αξιολογήσουν την ποσότητα της μνήμης που θα απαιτηθεί και τον χρόνο που θα χρειαστεί για την εκπαίδευση, κάτι που βοηθά τους χρήστες να βελτιστοποιήσουν τη διαδικασία εκπαίδευσής τους και να αποτρέψουν την εξάντληση του συστήματος. χώρος.



Πώς να εμφανίσετε τον αριθμό των παραμέτρων μοντέλου στο PyTorch;

Ο ' nn.Ενότητα 'Η τάξη έχει ' Παράμετροι() ' μέθοδος που χρησιμοποιείται για την προβολή του αριθμού των παραμέτρων του μοντέλου στο μοντέλο PyTorch. Για να λάβετε όλα τα στοιχεία, το ' num1() Χρησιμοποιείται η μέθοδος.



Για να κατανοήσουμε την έννοια που συζητήθηκε προηγουμένως, ας ρίξουμε μια ματιά στον παρεχόμενο κώδικα:





εισαγωγή δάδα. nn όπως και nn

τάξη Μοντέλο NN ( nn. Μονάδα μέτρησης ) :
def __ζεστό__ ( εαυτός ) :
σούπερ ( Μοντέλο NN , εαυτός ) . __ζεστό__ ( )
εαυτός . fc1 = nn. Γραμμικός ( 10 , πενήντα )
εαυτός . fc2 = nn. Γραμμικός ( πενήντα , 1 )

def προς τα εμπρός ( εαυτός , Εγώ ) :
Εγώ = εαυτός . fc1 ( Εγώ )
Εγώ = εαυτός . fc2 ( Εγώ )
ΕΠΙΣΤΡΟΦΗ Εγώ

my_model = Μοντέλο NN ( )
t_params = άθροισμα ( Π. δώσε όνομα ( ) Για Π σε my_model. Παράμετροι ( ) )
Τυπώνω ( φά 'Συνολικός αριθμός παραμέτρων: {t_params}' )

Στον παραπάνω κωδικό:

  • Αρχικά, ορίζουμε ένα μοντέλο που έχει δύο γραμμικά επίπεδα.
  • Στη συνέχεια, δημιουργήστε το παράδειγμα του μοντέλου και χρησιμοποιήστε το ' Παράμετροι() ' μέθοδος ανάκτησης όλων των παραμέτρων.
  • Στη συνέχεια, εφαρμόζουμε την έκφραση γεννήτριας για να υπολογίσουμε όλες τις παραμέτρους αθροίζοντας τον αριθμό των στοιχείων κάθε παραμέτρου.
  • Τέλος, καλέστε το « Τυπώνω() Δήλωση για την εμφάνιση των τιμών που προκύπτουν στην οθόνη:



Στον κώδικα που περιγράφεται παραπάνω, έχουμε εμφανίσει μόνο τον συνολικό αριθμό παραμέτρων, εάν θέλετε να λάβετε το όνομα και το μέγεθος της παραμέτρου, μπορούν να χρησιμοποιηθούν οι ακόλουθες γραμμές κώδικα:

Για όνομα , παραμ σε my_model. state_dict ( ) . είδη ( ) :

Τυπώνω ( όνομα , παραμ. Μέγεθος ( ) )

Εδώ:

  • ' state_dict() ” είναι το αντικείμενο λεξικού Python που χρησιμοποιείται για την αποθήκευση και τη φόρτωση μοντέλων από το PyTorch.
  • ' είδος() Χρησιμοποιείται η μέθοδος για την επιστροφή της λίστας με όλα τα κλειδιά του λεξικού μαζί με τις τιμές.
  • ' Τυπώνω() Η δήλωση ' χρησιμοποιείται για την εκτύπωση του ονόματος και του μεγέθους της παραμέτρου περνώντας το ' Μέγεθος() ” μέθοδος και παράμετρος:

Αυτό είναι όλο! Συγκεντρώσαμε τον ευκολότερο τρόπο εκτύπωσης του αριθμού των παραμέτρων μοντέλου στο PyTorch.

συμπέρασμα

Στο PyTorch, το ' nn.Ενότητα Η κλάση χρησιμοποιείται για τον ορισμό των μοντέλων που περιλαμβάνουν όλες τις λειτουργίες και τα επίπεδα που συνθέτουν το μοντέλο. Ο ' nn.Ενότητα 'Η τάξη έχει ' Παράμετροι() ' μέθοδος που χρησιμοποιείται για την προβολή του αριθμού των παραμέτρων του μοντέλου στο μοντέλο PyTorch. Αυτή η εγγραφή κατέδειξε τη μέθοδο για την εκτύπωση του αριθμού των παραμέτρων του μοντέλου στο PyTorch.