-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMultiEurlexDataset.py
63 lines (53 loc) · 2.35 KB
/
MultiEurlexDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from sys import maxsize
from datasets import load_dataset
from torch.utils.data import Dataset
import torch
import pandas as pd
import re
import numpy as np
class MultiEurlexDataset(Dataset):
def __init__(self, split='train', languages=[], tokenizer=lambda x: x, num_labels=21):
self.num_labels = num_labels
dataset = load_dataset('multi_eurlex', 'all_languages')
dataset_dict = {'celex_id': [],
'lang': [],
'document': [],
'labels': []
}
if isinstance(languages, str):
languages = [languages]
print(languages)
regex1 = re.compile('[0-9\n/()\[\]\':;"\„\“\-»«\’\’\‘\”]')
regex2 = re.compile('[\ ]+')
regex3 = re.compile('\ \.')
for idx, instance in enumerate(dataset[split]):
step = int(len(dataset[split]) / 10)
if idx % step == 0:
percentage = idx / len(dataset[split]) * 100
print(f'{percentage:.1f}% of dataset loaded')
for lang in instance['text'].keys():
if not len(languages) or lang in languages:
if instance['text'][lang]:
instance['text'][lang] = regex1.sub('', instance['text'][lang])
instance['text'][lang] = regex2.sub(' ', instance['text'][lang])
instance['text'][lang] = regex3.sub('.', instance['text'][lang])
dataset_dict['labels'].append(instance['labels'])
dataset_dict['document'].append(instance['text'][lang])
self.tokenizer = tokenizer
self.texts = dataset_dict['document']
self.labels = dataset_dict['labels']
self.languages = languages
print("Loading dataset done")
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
self.encodings = self.tokenizer(self.texts[idx], truncation=True, max_length=512)
item = {key: torch.tensor(val) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.multihot_encode(self.labels[idx]))
return item
def multihot_encode(self,labels):
elems = np.zeros(self.num_labels)
elems[labels] = 1
return elems
# dataset = MultiEurlexDataset()
# print(dataset.data)