欢迎您访问 最编程 本站为您分享编程语言代码,编程技术文章!
您现在的位置是: 首页

BigBird: 大鸟模型中文生成式长文本摘要实践 - 培训代码

最编程 2024-03-28 20:29:31
...
#!/usr/bin/env python
# _*_coding:utf-8_*_
# Author   :    Junhui Yu
# Time     :    2023/2/27 14:55

import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

import logging
import datasets
import numpy as np
import lawrouge
from transformers import (
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    BigBirdPegasusForConditionalGeneration,
    BertTokenizer,
    BigBirdConfig
)

from datasets import load_dataset

logger = logging.getLogger("YUNLP")
logging.basicConfig(level=logging.INFO)

dataset = load_dataset('json', data_files="./data/nlpcc_data/nlpcc_data.json")
dataset = dataset.shuffle(seeds=42)

model_path = "./bigbird"

config = BigBirdConfig.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BigBirdPegasusForConditionalGeneration.from_pretrained(model_path, config=config)


def flatten(example):
    return {
        "text": example["content"],
        "summary": example["title"],
    }


dataset = dataset["train"].map(flatten, remove_columns=["title", "content"])  # , remove_columns=["title", "content"]

max_input_length = 2048
max_target_length = 1024


def preprocess_function(examples):
    inputs = [doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


dataset = dataset.shuffle()

train_data_txt, validation_data_txt = dataset.train_test_split(test_size=0.1, shuffle=True, seed=42).values()
tokenized_datasets = datasets.DatasetDict({
    "train": train_data_txt,
    "validation": validation_data_txt
}).map(preprocess_function, batched=True)

args = Seq2SeqTrainingArguments(
    output_dir="./bigbird",
    num_train_epochs=5,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-04,
    warmup_steps=1000,
    weight_decay=0.0001,
    label_smoothing_factor=0.15,
    predict_with_generate=True,
    logging_dir="logs",
    logging_strategy="epoch",
    logging_steps=1,
    save_total_limit=2,
    evaluation_strategy="epoch",
    eval_steps=500,
    gradient_accumulation_steps=1,
    generation_max_length=64,
    generation_num_beams=1,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = ["".join(pred.replace(" ", "")) for pred in decoded_preds]
    decoded_labels = ["".join(label.replace(" ", "")) for label in decoded_labels]
    labels_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in labels]

    for i, (pred, label) in enumerate(zip(decoded_preds, decoded_labels)):
        if pred == "":
            decoded_preds[i] = "decoding error,skipping..."
    rouge = lawrouge.Rouge()
    result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)
    result = {'rouge-1': result['rouge-1']['f'], 'rouge-2': result['rouge-2']['f'], 'rouge-l': result['rouge-l']['f']}
    result = {key: value * 100 for key, value in result.items()}
    result["gen_len"] = np.mean(labels_lens)
    return result


trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

train_result = trainer.train()
print(train_result)
trainer.save_model()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()