Читать книгу 120 практических задач онлайн
model.add(layers.BatchNormalization())
model.add(layers.Dense(np.prod(image_size) * 3, activation='tanh'))
model.add(layers.Reshape((image_size[0], image_size[1], 3)))
return model
# Дискриминатор
def build_discriminator():
model = tf.keras.Sequential()
model.add(layers.Flatten(input_shape=image_size + (3,)))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
return model
# Сборка модели GAN
generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
gan_input = layers.Input(shape=(100,))
generated_image = generator(gan_input)
discriminator.trainable = False
gan_output = discriminator(generated_image)
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(optimizer='adam', loss='binary_crossentropy')
```
3. Обучение модели
```python
import tensorflow as tf
# Гиперпараметры
epochs = 10000
batch_size = 64
sample_interval = 200
latent_dim = 100
# Генерация меток
real_labels = np.ones((batch_size, 1))
fake_labels = np.zeros((batch_size, 1))
for epoch in range(epochs):
# Обучение дискриминатора
idx = np.random.randint(0, train_images.shape[0], batch_size)
real_images = train_images[idx]
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_images = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(real_images, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Обучение генератора
noise = np.random.normal(0, 1, (batch_size, latent_dim))
g_loss = gan.train_on_batch(noise, real_labels)
# Печать прогресса
if epoch % sample_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]}] [G loss: {g_loss}]")
sample_images(generator)
def sample_images(generator, image_grid_rows=4, image_grid_columns=4):
noise = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, latent_dim))