项目作者: cryu854

项目描述 :
Arbitrary Style Transfer in TensorFlow js
高级语言: Python
项目地址: git://github.com/cryu854/ArbitraryStyle-tfjs.git
创建时间: 2020-08-24T16:25:55Z
项目社区:https://github.com/cryu854/ArbitraryStyle-tfjs

开源协议:

下载


Arbitrary Style Transfer in Tensorflow js

This is an implementation of Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization on Tensorflow 2 and Tensorflow js. Demo website : https://cryu854.github.io/ArbitraryStyle-tfjs/

The model runs purely on your browser, so your data will not be leaked.







The network architecture is conposed of an encoder, a decoder, and an AdaIN layer. The encoder is fixed to the first few layer (up to relu4_1) of a pre-trained VGG-19. The decoder mostly mirrors the encoder, with all pooling layers replaced by nearest up-sampling to reduce checkerboard effects. Set REFLECT_PADDING=True to use reflection padding in both encoder and decoder to avoid border artifacts, but the model will not be able to be deployed on the browser.




Image Stylization :art:



Stylize an image

Use main.py to stylize a content image to arbitrary style.
Stylization takes 29ms per frame(256x256) on a GTX 1080ti.

Example usage:

  1. python main.py inference --content ./path/to/content.jpg \
  2. --style ./path/to/style.jpg \
  3. --alpha 1.0 \
  4. --model ./path/to/pre-trainind_model

Content-style trade-off

Use --alpha to adjust the stylization intensity. The value should between 0 and 1 (default).



Training

Use main.py to train a new style transfer network.
Training takes 2.5~3 hours on a GTX 1080ti.
Before you run this, you should download MSCOCO and WikiArt dataset.

Example usage:

  1. python main.py train --content ./path/to/MSCOCO_dataset \
  2. --style ./path/to/WikiArt_dataset \
  3. --batch 8 \
  4. --debug True \
  5. --validate_content ./path/to/validate/content.jpg \
  6. --validate_style ./path/to/validate/style.jpg

Convert a pre-trained model to tensorflow-js model

Use tensorflow-js converter to generate a web friendly json model.
If you use reflection padding in encoder or decoder, the converter will not work properly because the current version of tensorflow-js does not support the mirrorpad operator.

Example usage:

  1. tensorflowjs_converter --input_format=tf_saved_model --saved_model_tags=serve models/model models/web_model

Requirements

  • TensorFlow >= 2.0
  • Python 3.7.5, Pillow 7.0.0, Numpy 1.18
  • If you want to convert a pre-trained model to tensorflow-js model:
    • Tensorflowjs >= 2.0

Attributions/Thanks

  • Some images/docs was borrowed from Xun Huang’s AdaIN-style
  • Some tfjs code formatting was borrowed from tensorflow.js example Mobilenet