Читать книгу Искусственный интеллект. Машинное обучение онлайн

Рассмотрим пример кода для обучения трансформера на задаче машинного перевода с использованием библиотеки PyTorch и библиотеки для работы с естественным языком – Transformers.

```python

import torch

from transformers import BertTokenizer, BertModel, BertForMaskedLM

from torch.utils.data import Dataset, DataLoader

# Подготовка данных

class TranslationDataset(Dataset):

def __init__(self, texts, tokenizer, max_length=128):

self.texts = texts

self.tokenizer = tokenizer

self.max_length = max_length

def __len__(self):

return len(self.texts)

def __getitem__(self, idx):

input_text = self.texts[idx][0]

target_text = self.texts[idx][1]

input_encoding = self.tokenizer(input_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")

target_encoding = self.tokenizer(target_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")

return {"input_ids": input_encoding["input_ids"], "attention_mask": input_encoding["attention_mask"],

"labels": target_encoding["input_ids"], "decoder_attention_mask": target_encoding["attention_mask"]}

# Создание модели трансформера

model = BertForMaskedLM.from_pretrained('bert-base-uncased')

# Обучение модели

train_dataset = TranslationDataset(train_data, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

criterion = torch.nn.CrossEntropyLoss()

model.train()

for epoch in range(num_epochs):

total_loss = 0

for batch in train_loader:

input_ids = batch["input_ids"]

attention_mask = batch["attention_mask"]

labels = batch["labels"]

decoder_attention_mask = batch["decoder_attention_mask"]

optimizer.zero_grad()

outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, decoder_attention_mask=decoder_attention_mask)

loss = outputs.loss

loss.backward()

optimizer.step()

total_loss += loss.item()

print(f"Epoch {epoch+1}, Loss: {total_loss}")

# Использование модели для перевода

input_text = "This is a sample sentence to translate."