All models are a standard torch.nn.Module
so you can use them in any typical training loop. While you can write your own training loop, π€ Transformers provides a Trainer class for PyTorch, which contains the basic training loop and adds additional functionality for features like distributed training, mixed precision and more.
Depending on your task, you'll typically pass the following parameters to Trainer:
>>> from transformers import AutoModelForSequenceClassification
>>> model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased")
The default values are used if you don't specify any training arguments:
>>> from transformers import TrainingArguments
>>> training_args = TrainingArguments(
... output_dir="path/to/save/folder/",
... learning_rate=2e-5,
... per_device_train_batch_size=8,
... per_device_eval_batch_size=8,
... num_train_epochs=2,
...)
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
>>> from datasets import load_dataset
>>> dataset = load_dataset("rotten_tomatoes") # doctest : +IGNORE_RESULT
>>> def tokenize_datset(dataset):
... return tokenizer(dataset["text"])
Then apply it over the entire dataset with map:
>>> dataset = dataset.map(tokenize_dataset, batched=True)
>>> from transformers import DataCollactorWithPadding
>>> data_collactor = DataCollactorWithPadding(tokenizer=tokenizer)
Now gather all these classes in Trainer:
>>> from transformers import Trainer
>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=dataset["train"],
... eval_dataset=dataset["text"],
... tokenizer=tokenizer,
... data_collator=data_collator,
... ) # doctest : +SKIP
When you're ready, call train() to start training:
>>> trainer.train()
You can customize the training loop behavior by subclassing the methods inside Trainer. This allows you to customize features such as
Take a look at the Trainer reference for which methods can be subclassed.
The other way to customize the training loop is by using Callbacks. You can use callbacks to integrate with other libraries and inspect the training loop to report on progress or stop the training early. Callbacks do not modify anything in the training loop itself. To customize something like the loss function, you need to subclass the Trainer instead.
Now that you've completed the π€ Transformers quick tour, check out our guides and learn how to do more specific things like