内容同:https://blog.csdn.net/yuhengshi/article/details/120970903
环境
python==3.7
transformers==4.9.2
rouge-score==0.0.4
数据准备
将数据放在一个txt中,每行为一条,文章正文跟label的摘要用\t分割
构建数据集
fromdatasetsimportDatasetclassData:def__init__(self,data_path,tokenizer):self.path=data_pathself.max_input_length=1024self.max_target_length=150#self.tokenizer=AutoTokenizer.from_pretrained(pretrained_model_path)self.tokenizer=tokenizerdefpreprocess(self,train_scale=0.8):withopen(self.path,'r')asf:raw_data=f.readlines()print(f"=======data_len:{len(raw_data)}")start=int(len(raw_data)*train_scale)print(f"======train_len:{start}")raw_train_data=raw_data[:start]raw_test_data=raw_data[start:]raw_train_test_data={'train':{'id':[],'document':[],'summary':[]},\'test':{'id':[],'document':[],'summary':[]}}fori,iteminenumerate(raw_train_data):iflen(item.split('\t'))!=3:continueurl,text,label=item.split('\t')raw_train_test_data['train']['id'].append(i)#document是训练数据,summary是labelraw_train_test_data['train']['summary'].append(label.strip())raw_train_test_data['train']['document'].append(text.strip())forj,iteminenumerate(raw_test_data):iflen(item.split('\t'))!=3:continueurl,text,label=item.split('\t')raw_train_test_data['test']['id'].append(i+j+1)raw_train_test_data['test']['summary'].append(label.strip())raw_train_test_data['test']['document'].append(text.strip())defpreprocess_function(examples):#document是训练数据inputs=examples['document']model_inputs=self.tokenizer(inputs,max_length=self.max_input_length,padding='max_length',truncation=True)#summary是labelwithself.tokenizer.as_target_tokenizer():labels=self.tokenizer(examples['summary'],max_length=self.max_target_length,padding='max_length',truncation=True)model_inputs['labels']=labels['input_ids']returnmodel_inputstrain_dataset=Dataset.from_dict(raw_train_test_data['train'])test_dataset=Dataset.from_dict(raw_train_test_data['test'])tokenized_train_dataset=train_dataset.map(preprocess_function)tokenized_test_dataset=test_dataset.map(preprocess_function)returntokenized_train_dataset,tokenized_test_dataset
模型加载
fromtransformersimportAutoModelForSeq2SeqLM,DataCollatorForSeq2Seq,Seq2SeqTrainingArguments,Seq2SeqTrainerfromtransformersimportBartForConditionalGenerationcheckpoint="distilbart-xsum-9-6"model=BartForConditionalGeneration.from_pretrained(checkpoint)tokenizer=AutoTokenizer.from_pretrained(checkpoint)
Metrics
fromrouge_scoreimportrouge_scorer,scoringdefcompute(predictions,references,rouge_types=None,use_agregator=True,use_stemmer=False):ifrouge_typesisNone:rouge_types=["rouge1","rouge2","rougeL","rougeLsum"]scorer=rouge_scorer.RougeScorer(rouge_types=rouge_types,use_stemmer=use_stemmer)ifuse_agregator:aggregator=scoring.BootstrapAggregator()else:scores=[]forref,predinzip(references,predictions):score=scorer.score(ref,pred)ifuse_agregator:aggregator.add_scores(score)else:scores.append(score)ifuse_agregator:result=aggregator.aggregate()else:result={}forkeyinscores[0]:result[key]=list(score[key]forscoreinscores)returnresult#metricsdefcompute_metrics(eval_pred):predictions,labels=eval_preddecoded_preds=tokenizer.batch_decode(predictions,skip_special_tokens=True)#Replace-100inthelabelsaswecan'tdecodethem.labels=np.where(labels!=-100,labels,tokenizer.pad_token_id)decoded_labels=tokenizer.batch_decode(labels,skip_special_tokens=True)#Rougeexpectsanewlineaftereachsentencedecoded_preds=["\n".join(nltk.sent_tokenize(pred.strip()))forpredindecoded_preds]decoded_labels=["\n".join(nltk.sent_tokenize(label.strip()))forlabelindecoded_labels]result=compute(predictions=decoded_preds,references=decoded_labels,use_stemmer=True)#Extractafewresultsresult={key:value.mid.fmeasure*100forkey,valueinresult.items()}#Addmeangeneratedlengthprediction_lens=[np.count_nonzero(pred!=tokenizer.pad_token_id)forpredinpredictions]result["gen_len"]=np.mean(prediction_lens)return{k:round(v,4)fork,vinresult.items()}
训练
超参配置
batch_size=1args=Seq2SeqTrainingArguments(\"/data/yuhengshi/europe_summary/model",\evaluation_strategy='steps',\learning_rate=3e-5,\per_device_train_batch_size=batch_size,\per_device_eval_batch_size=batch_size,\weight_decay=0.1,\save_steps=200,\save_total_limit=10,\num_train_epochs=5,\predict_with_generate=True,\fp16=True,\eval_steps=200,\logging_dir="/data/yuhengshi/europe_summary/log",\logging_first_step=True)
transformers api训练
data_collator=DataCollatorForSeq2Seq(tokenizer,model=model,padding=True)data=Data('/data/yuhengshi/europe_summary/data_no_daily_news.txt',tokenizer)tokenized_train_dataset,tokenized_test_dataset=data.preprocess()trainer=Seq2SeqTrainer(\model,\args,\train_dataset=tokenized_train_dataset,\eval_dataset=tokenized_test_dataset,\data_collator=data_collator,\tokenizer=tokenizer,\compute_metrics=compute_metrics)
结果
从下面step中选loss跟rouge都比较好的
预测 生成summary
defpredict(sentence):inputs=tokenizer([sentence],max_length=1024,return_tensors='pt')summary_ids=model.generate(inputs['input_ids'],num_beams=70,max_length=150,min_length=50,early_stopping=True)summary=[tokenizer.decode(g,skip_special_tokens=True,clean_up_tokenization_spaces=False)forginsummary_ids]return''.join(summary)