rinna-japanese-gpt2による文書生成(2)

タグ:

Google Colaboratorを使わずに手元で動作させる場合は以下の様なコードになります。

# Python Environment:
# for Python 3.8
# Modules:
# pip install transformers==4.9.2
# pip install sentencepiece==0.1.96
# pip install torch==1.9.0
from transformers import T5Tokenizer, AutoModelForCausalLM, pipeline


def main():

    RINNA_GPT2_MODEL = "rinna/japanese-gpt2-medium"
    TEXT = """
    年も暮れようとする頃の話。
    雪深い山の中に住む夫婦は、正月を迎えるにも米一粒すら残っていなかった。
    そこで女房が作った髪飾りのかせ玉を町に売りに行くことにした。
    男が地蔵峠を通ると、"""

    tokenizer = T5Tokenizer.from_pretrained(RINNA_GPT2_MODEL)
    tokenizer.do_lower_case = True

    model = AutoModelForCausalLM.from_pretrained(RINNA_GPT2_MODEL)

    input = tokenizer.encode(TEXT, return_tensors="pt")
    output = model.generate(
        input, do_sample=True, max_length=120, num_return_sequences=3
    )

    print("----\n")
    for idx, s in enumerate(tokenizer.batch_decode(output)):
        print("出力結果 {:d})\n{:s}\n".format(idx, s.replace("</s>", "</s>\n")))


if __name__ == "__main__":
    main()