Skip to content

Commit

Permalink
mwe - generative adverserial network
Browse files Browse the repository at this point in the history
  • Loading branch information
a-mhamdi committed Jul 4, 2024
1 parent 7fefa0f commit e66f121
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions Codes/Julia/Part-3/gan.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
####################################
#= Generative Adverserial Network =#
####################################

using Flux
using Flux: @epochs
using Images: Gray
using ProgressMeter

## Generator
### The generator network should take a noise vector as input and return a synthetic sample.
function generator(; latent_dim=64, img_shape=(28,28,1,1))
return Chain(
Dense(latent_dim, 128, relu),
Dense(128, 256, relu),
Dense(256, prod(img_shape), tanh),
x -> reshape(x, img_shape)
)
end

## Discriminator
### The discriminator network should take a sample as input and return a score indicating the probability that the sample is real.
function discriminator(; img_shape=(28,28,1,1))
return Chain(
x -> reshape(x, :, size(x, 4)),
Dense(prod(img_shape), 256, relu),
Dense(256, 128, relu),
Dense(128, 1, sigmoid)
)
end

## Loss functions
bce_loss(y_true, y_pred) = Flux.binarycrossentropy(y_pred, y_true)

## Training function
function train_gan(gen, disc, opt_gen, opt_disc; n_epochs=128, latent_dim=64)
@showprogress for epoch in 1:n_epochs

## Train the discriminator `disc`
noise = randn(Float32, latent_dim, 1)
fake_imgs = gen(noise) # pass the noise through the generator to get a synthetic sample
real_imgs = rand(Float32, size(fake_imgs)...)

disc_loss = bce_loss(ones(Float32, 1, 1), disc(real_imgs)) +
bce_loss(zeros(Float32, 1, 1), disc(fake_imgs)) # compute the loss for the real and synthetic samples
grads = gradient(() -> disc_loss, Flux.params(disc))
Flux.update!(opt_disc, Flux.params(disc), grads) # update the discriminator weights

## Train the generator `gen`
noise = randn(Float32, latent_dim, 1)
gen_loss = bce_loss(ones(Float32, 1, 1), disc(gen(noise))) # compute the loss for the synthetic samples
grads = gradient(() -> gen_loss, Flux.params(gen))
Flux.update!(opt_gen, Flux.params(gen), grads) # update the generator weights

println("Epoch $(epoch): Discriminator loss = $(disc_loss), Generator loss = $(gen_loss)")
sleep(.1)
end
end

## Setup the GAN
gen = generator()
disc = discriminator()

opt_gen = ADAM(0.0002, (0.5, 0.999))
opt_disc = ADAM(0.0002, (0.5, 0.999))

## Train the GAN
train_gan(gen, disc, opt_gen, opt_disc)

## Generate and plot some images
latent_dim = 64
noise = randn(Float32, latent_dim, 16)
generated_images = [gen(noise[:, i]) for i in 1:16]

using Plots
plot_images = [plot(Gray.(generated_images[i])[:,:,1,1]) for i in 1:16]
plot(
plot_images...,
layout =(4,4),
title = ["($i)" for j in 1:1, i in 1:11], titleloc = :right, titlefont = font(8),
size =(800, 800)
)

0 comments on commit e66f121

Please sign in to comment.