Catching Some GANemons
Since doing a university assignment on GANs last year, I’ve enjoyed keeping up and playing around with new GAN algorithms. My assignment was on Wasserstein-GANs (WGANs), a (at the time) fairly new way of training GANs that allowed for faster and more stable training - check out the paper here if you’re interested. I then extended this to the WGAN-GP, or gradient penalty WGAN, which fixed a big issue in the Wasserstein algorithm through the use of gradient penalties (paper here). I implemented these in a repository on my Github and played with a few datasets, mainly pokemon and anime face ones - for more details, please see the repository itself.
Even more recently, I was impressed (like many) by the StyleGAN developed NVIDIA, and already having a repository for messing with GAN algorithms, I decided to implement (or rather, adapt another person’s implementation - I didn’t have that much free time) a StyleGAN and train it on my old toy datasets. So after some tuning, I trained a very basic StyleGAN on both a small-ish anime face dataset and a tiny pokemon image dataset (essentially just the official Sugimori art for all pokemon up to generation 6) - again, see my repository for details. The results aren’t the best, but they are interesting to look at, at the very least.
I didn’t stop there- after having successfully trained these models, I have now uploaded them into a docker container and am serving them using tf-serve, which makes turning Tensorflow models into REST APIs nice and easy. After wrapping these behind a basic python API, I now have a fun little web API for generating anime faces and pokemon images, as you can see below!
Anime Faces
If you reload the page, the faces will also change - each is simply loaded from an API endpoint that generates a random face each time it is called. Again, the results aren’t the best - I simply trained this model for around one day on my personal desktop, with a small dataset.
GANemon
Like the anime faces, these images will change if you refresh the page. Again, the results aren’t the best - I simply trained this model for around one or two days on my personal desktop, with a small dataset.
Architecture
To be more specific about my setup, I first trained these models on their respective datasets for 1,000,000 steps. See my pokemongenerator
repository for the hyperparameters and details on these models. I then took the saved Keras model for just the generator and transformed it into a frozen graph and then a SavedModel using the following code (warning, it is fairly ad-hoc, and uses deprecated Tensorflow functionality):
import sys
from keras.models import load_model
import keras.backend as K
from style import AdaIN
import tensorflow as tf
# from https://stackoverflow.com/questions/45466020/how-to-export-keras-h5-to-tensorflow-pb
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = tf.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph
# we take in the model name as an argument
model_name = sys.argv[1]
K.set_learning_phase(0)
# I had my weights in a weights folder. Custom objects was for a custom keras layer
model = load_model(f'./weights/{model_name}.h5', custom_objects={ 'AdaInstanceNormalization': AdaIN.AdaInstanceNormalization})
# freeze our graph
frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
# begin the process of saving the model
builder = tf.saved_model.builder.SavedModelBuilder(f'tfserve/{model_name}')
sigs = {}
with tf.Session() as sess:
tf.import_graph_def(frozen_graph, name="")
init = tf.global_variables_initializer()
sess.run(init)
# to get these names, I had to manually inspect the graph. There's many better ways to do this,
# such as naming tensors.
output = tf.get_default_graph().get_tensor_by_name("conv2d_40/Tanh_1:0")
input_1 = tf.get_default_graph().get_tensor_by_name("input_2_1:0")
input_2 = tf.get_default_graph().get_tensor_by_name("input_3_1:0")
input_3 = tf.get_default_graph().get_tensor_by_name("input_4_1:0")
sigs[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.saved_model.signature_def_utils.predict_signature_def(
{"in1": input_1, "in2": input_2, "in3": input_3}, {"out": output})
builder.add_meta_graph_and_variables(sess, [tf.saved_model.SERVING], signature_def_map=sigs)
builder.save()
After running this script, the saved model parts (a saved_model.pb
file and a variables
) folder should be put in a directory structure like so:
MODEL_NAME
\_
1
\_
saved_model.pb
variables
The ‘1’ is for the model number. More nuanced setups may have multiple ML model numbers, but for this project I just used the one model.
Then, you can use the tf-serving docker image to serve this model. I modified the basic tf-serve docker image to store my saved model within the container and deployed it using Heroku’s docker deployment functionality. This is actually enough to have a tensorflow-based API up and running, which is pretty cool!
However, considering I had two models I wanted to deploy and serve, and I wanted to handle preprocessing and image construction myself, I ‘wrapped’ the API with a basic python flask API. This API is what is actually being called above, and simply handles the number generation and turns the GAN response into an actual image. So, the actual final architecture looks something like this:
Conclusion
This was a fun and cool little project to work on, and I’d definitely recommend trying to serve your toy models on the internet! If they’re small enough, there are free hosting options, and it’s really neat to have something ‘live’ on the internet that you created. Hopefully, this is useful to you if you ever want to host your own TF models, or something similar. 😄