-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMultiEvalDataset.py
52 lines (42 loc) · 1.97 KB
/
MultiEvalDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from datasets import load_dataset
from torch.utils.data import Dataset
import pandas as pd
import re
class MultiEvalDataset(Dataset):
def __init__(self, split='train', languages=[]):
dataset = load_dataset('multi_eurlex', 'all_languages')
dataset_dict = {'celex_id': [],
'lang': [],
'document': [],
'labels': []
}
if isinstance(languages, str):
languages = [languages]
print(languages)
regex1 = re.compile('[0-9\n/()\[\]\':;"\„\“\-»«\’\’\‘\”]')
regex2 = re.compile('[\ ]+')
regex3 = re.compile('\ \.')
for idx, instance in enumerate(dataset[split]):
step = int(len(dataset[split]) / 10)
if idx % step == 0:
percentage = idx / len(dataset[split]) * 100
print(f'{percentage:.4f}% of dataset loaded')
for lang in instance['text'].keys():
if not len(languages) or lang in languages:
if instance['text'][lang]:
instance['text'][lang] = regex1.sub('', instance['text'][lang])
instance['text'][lang] = regex2.sub(' ', instance['text'][lang])
instance['text'][lang] = regex3.sub('.', instance['text'][lang])
dataset_dict['celex_id'].append(instance['celex_id'])
dataset_dict['lang'].append(lang)
dataset_dict['labels'].append(instance['labels'])
dataset_dict['document'].append(instance['text'][lang])
flat_dataset = pd.DataFrame.from_dict(dataset_dict)
self.data = flat_dataset
self.languages = languages
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data.loc[idx, 'document'], self.data.loc[idx, 'labels']
# dataset = MultiEurlexDataset()
# print(dataset.data)