train.py 539 B

123456789101112131415161718192021222324
  1. #!/usr/bin/python
  2. import gpt_2_simple as gpt2
  3. import sys
  4. import os
  5. import requests
  6. model_name = "124M"
  7. if not os.path.isdir(os.path.join("models", model_name)):
  8. print(f"Downloading {model_name} model...")
  9. gpt2.download_gpt2(model_name=model_name) # model is saved into current directory under /models/124M/
  10. file_name = sys.argv[1]
  11. sess = gpt2.start_tf_sess()
  12. gpt2.finetune(sess,
  13. file_name,
  14. model_name=model_name,
  15. steps=100) # steps is max number of training steps
  16. gpt2.generate(sess)