Читать книгу Искусственный интеллект. Машинное обучение онлайн
Рассмотрим пример кода для обучения трансформера на задаче машинного перевода с использованием библиотеки PyTorch и библиотеки для работы с естественным языком – Transformers.
```python
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM
from torch.utils.data import Dataset, DataLoader
# Подготовка данных
class TranslationDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=128):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
input_text = self.texts[idx][0]
target_text = self.texts[idx][1]
input_encoding = self.tokenizer(input_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
target_encoding = self.tokenizer(target_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
return {"input_ids": input_encoding["input_ids"], "attention_mask": input_encoding["attention_mask"],
"labels": target_encoding["input_ids"], "decoder_attention_mask": target_encoding["attention_mask"]}
# Создание модели трансформера
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
# Обучение модели
train_dataset = TranslationDataset(train_data, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(num_epochs):
total_loss = 0
for batch in train_loader:
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
decoder_attention_mask = batch["decoder_attention_mask"]
optimizer.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, decoder_attention_mask=decoder_attention_mask)
loss = outputs.loss
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss}")
# Использование модели для перевода
input_text = "This is a sample sentence to translate."