Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion encexp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
if not '-m' in sys.argv:
from encexp.text_repr import EncExpT, SeqTM, TextModel

__version__ = "0.1.4"
__version__ = "0.1.5"
107 changes: 83 additions & 24 deletions encexp/build_encexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from microtc.utils import tweet_iterator, Counter
import encexp
from encexp.text_repr import SeqTM, EncExpT
from encexp.utils import progress_bar
from encexp.utils import progress_bar, uniform_sample
from encexp.download import download


Expand Down Expand Up @@ -118,7 +118,7 @@ class Train:
"""Train"""
text_model: SeqTM=None
min_pos: int=512
max_pos: int=int(2**15)
max_pos: int=int(2**14)
min_neg: int=int(2**14)
filename: str=None
use_tqdm: bool=True
Expand Down Expand Up @@ -159,17 +159,38 @@ def labels(self):
"""Labels"""
if hasattr(self, '_labels'):
return self._labels
cnt = Counter()
labels_freq = Counter()
with open(self.filename, encoding='utf-8') as fpt:
for line in fpt:
line = line.strip()
labels, text = line.split('\t')
labels = labels.split()
cnt.update(labels)
labels = sorted([k for k, v in cnt.items() if v >= self.min_pos])
labels_freq.update(labels)
labels = sorted([k for k, v in labels_freq.items() if v >= self.min_pos])
self.labels = labels
self.labels_freq = cnt
self.labels_freq = labels_freq
if self.keep_unfreq and self.self_supervised:
cnt = Counter()
with open(self.filename, encoding='utf-8') as fpt:
for line in fpt:
line = line.strip()
labels, text = line.split('\t')
labels = labels.split()
_labels_freq = [(k, labels_freq[k])
for k in labels]
klass, _ = min(_labels_freq, key=lambda x: x[1])
cnt.update([klass])
self.neg_freq = cnt
return labels

@property
def neg_freq(self):
"""Frequency in the negative label"""
return self._neg_freq

@neg_freq.setter
def neg_freq(self, value):
self._neg_freq = value

@labels.setter
def labels(self, value):
Expand All @@ -195,20 +216,25 @@ def filter_tokens(self, tokens, label):
if not self.self_supervised:
return tokens
return [x for x in tokens if x != label]

def training_set(self, label):
"""Training set"""
def training_set_texts(self, label):
"""Training set texts"""
self.text_model.disable_text_transformations = True
tokenize = self.text_model.tokenize
max_pos = min(self.max_pos,
self.labels_freq[label])
num_neg = max(max_pos, self.min_neg)
POS = []
NEG = []
labels_freq = [(k, v) for k, v in self.labels_freq.items() if k != label]
labels_freq = self.labels_freq
if self.keep_unfreq and self.self_supervised:
labels_freq = self.neg_freq
labels_freq = {k: v for k, v in labels_freq.items() if k != label}
if not self.keep_unfreq:
labels_freq = {None: num_neg}
NEG = NegDataset(num_neg, labels_freq)
with open(self.filename, encoding='utf-8') as fpt:
for line in fpt:
if len(POS) >= max_pos and len(NEG) >= num_neg:
if len(POS) >= max_pos and NEG.full:
break
line = line.strip()
labels, text = line.split('\t')
Expand All @@ -219,20 +245,20 @@ def training_set(self, label):
_ = self.filter_tokens(tokens, label)
POS.append(_)
continue
klass, _ = min(labels_freq, key=lambda x: x[1])
neg = dict(tokens=tokens, label=klass)
if len(NEG) < num_neg:
NEG.append(neg)
continue
k = randint(0, len(NEG) - 1)
if not self.keep_unfreq:
NEG[k] = neg
continue
if self.labels_freq[NEG[k]['label']] > self.labels_freq[neg['label']]:
NEG[k] = neg
if self.keep_unfreq:
labels_freq = [(k, self.labels_freq[k])
for k in labels]
klass, _ = min(labels_freq, key=lambda x: x[1])
else:
klass = None
NEG.add(tokens, klass)
return NEG.dataset(), POS

def training_set(self, label):
"""Training set"""
NEG, POS = self.training_set_texts(label)
if len(NEG) == 0 or len(POS) == 0:
return None
NEG = [x['tokens'] for x in NEG]
X = self.transform(POS + NEG)
y = [1] * len(POS) + [-1] * len(NEG)
return X, np.array(y)
Expand Down Expand Up @@ -307,6 +333,39 @@ def delete_tmps(self, args):
os.rmdir(self.identifier)


class NegDataset:
"""Uniform sample of the negatives"""
def __init__(self, N: int, freq: dict):
keys = list(freq)
cnt = uniform_sample(N,
np.array([freq[x] for x in keys]))
self.cnt = {k: v for k, v in zip(keys, cnt)}
self.elements = {k: list() for k in keys}
self.tot = N
self.size = 0

def add(self, data: str, label: str):
"""Add element"""
cnt = self.cnt[label]
dataset = self.elements[label]
if len(dataset) < cnt:
self.size += 1
dataset.append(data)

def dataset(self):
"""Dataset"""
values = []
for v in self.elements.values():
values.extend(v)
shuffle(values)
return values

@property
def full(self):
"""Indicate whether the dataset has all the elements required"""
return self.tot - self.size <= 0


def main(args):
"""CLI"""
filename = args.file[0]
Expand Down
39 changes: 31 additions & 8 deletions encexp/tests/test_build_encexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# from encexp.tests.test_utils import samples
from encexp.utils import load_dataset
from encexp.text_repr import SeqTM, EncExpT
from encexp.build_encexp import Dataset, EncExpDataset, Train, main
from encexp.build_encexp import Dataset, EncExpDataset, Train, main, NegDataset


def test_Dataset_output_filename():
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_Train_labels():

def test_Train_training_set():
"""Test Train"""

dataset = load_dataset('mx')
seq = SeqTM(lang='es', token_max_filter=2**13)
ds = EncExpDataset(text_model=clone(seq))
Expand All @@ -121,14 +121,17 @@ def test_Train_training_set():
X, y = train.training_set(labels[0])
assert X.shape[0] == len(y) and X.shape[1] == len(seq.names)
# cnt = np.where((X > 0).sum(axis=0).A1)[0].shape
train.keep_unfreq = True
train = Train(text_model=seq, min_pos=32,
keep_unfreq=True,
filename=ds.output_filename)
labels = train.labels
X, y = train.training_set(labels[0])
_, freq = np.unique(y, return_counts=True)
assert freq[0] > freq[1]
train.min_neg = 0
X, y = train.training_set(labels[0])
_, freq = np.unique(y, return_counts=True)
assert freq[0] == freq[1]
# train.min_neg = 0
# X, y = train.training_set(labels[0])
# _, freq = np.unique(y, return_counts=True)
# assert freq[0] == freq[1]
os.unlink(ds.output_filename)

# cnt2 = np.where((X > 0).sum(axis=0).A1)[0].shape
Expand Down Expand Up @@ -207,4 +210,24 @@ class A:
A.file = [dataset]
A.voc_size_exponent = 13
A.n_jobs = -1
main(A)
main(A)


def test_NegDataset():
"""Test NegDataset"""
freq = {'mx': 1000, 'ar': 100, 'es': 10}
neg = NegDataset(500, freq)
for k in range(510):
neg.add(f'mx {k}', 'mx')
assert len(neg.elements['mx']) == 390
for k in range(110):
neg.add(f'ar {k}', 'ar')
assert len(neg.elements['ar']) == 100
for k in range(20):
neg.add(f'es {k}', 'es')
assert neg.full
assert len(neg.dataset()) == 500
neg = NegDataset(500, {None: 500})
for k in range(510):
neg.add(f'unico {k}', None)
assert neg.full
10 changes: 6 additions & 4 deletions encexp/tests/test_text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def test_EncExpT_tailored():
D = list(tweet_iterator(dataset))
enc = EncExpT(lang='es', pretrained=False)
enc.tailored(D, tsv_filename='tailored.tsv',
min_pos=32,
filename='tailored.json.gz')
assert enc.weights.shape[0] == 2**14
assert enc.weights.shape[1] == 90
Expand All @@ -162,6 +163,7 @@ def test_EncExpT_tailored_intercept():
enc = EncExpT(lang='es', with_intercept=True,
pretrained=False)
enc.tailored(D, tsv_filename='tailored.tsv',
min_pos=32,
filename='tailored_intercept.json.gz')
assert enc.weights.shape[0] == 2**14
assert enc.weights.shape[1] == 90
Expand All @@ -185,15 +187,15 @@ def test_EncExpT_tailored_add():
dataset = load_dataset('mx')
D = list(tweet_iterator(dataset))
enc = EncExpT(lang='es', token_max_filter=2**13)
enc.tailored(D)
enc.tailored(D, min_pos=32)


def test_EncExpT_tailored_no_neg():
"""Test EncExpT tailored"""
dataset = load_dataset('mx')
D = [f'{text} de' for text in tweet_iterator(dataset)]
enc = EncExpT(lang='es', token_max_filter=2**13)
enc.tailored(D)
enc.tailored(D, min_pos=32)


def test_EncExpT_tailored_2cl():
Expand All @@ -203,7 +205,7 @@ def test_EncExpT_tailored_2cl():
enc = EncExpT(lang='es', pretrained=False,
with_intercept=True,
token_max_filter=2**13)
enc.tailored(D, self_supervised=False)
enc.tailored(D, self_supervised=False, min_pos=32)
assert enc.names.tolist() == ['ar', 'mx']


Expand Down Expand Up @@ -538,4 +540,4 @@ def test_TextModel_diac():
# cv=sss,
# n_jobs=1,
# scoring='f1_macro').fit(mx + ar, y)
# assert grid.best_score_ > 0.7
# assert grid.best_score_ > 0.7
8 changes: 5 additions & 3 deletions encexp/text_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ class EncExpT(Identifier):
with_intercept: bool=False
merge_encode: bool=True
distance: bool=False
keep_unfreq: bool=True
keep_unfreq: bool=False

@property
def seqTM(self):
Expand Down Expand Up @@ -609,8 +609,9 @@ def add(self, data: Iterable):
def tailored(self, D: Iterable=None,
filename: str=None,
tsv_filename: str=None,
min_pos: int=32,
max_pos: int=int(2**15),
min_pos: int=512,
min_neg: int=int(2**14),
max_pos: int=int(2**14),
n_jobs: int=-1,
self_supervised: bool=True,
ds: object=None,
Expand Down Expand Up @@ -653,6 +654,7 @@ def set_weights(data):
filename=ds.output_filename,
use_tqdm=self.use_tqdm,
min_pos=min_pos,
min_neg=min_neg,
max_pos=max_pos,
n_jobs=n_jobs,
with_intercept=self.with_intercept,
Expand Down
9 changes: 5 additions & 4 deletions encexp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,28 +181,29 @@ def inner(texts):


def load_dataset(country: Union[str, list],
lang: str='es',
return_X_y:bool=False):
"""Country identification dataset"""
if not isdir(MODELS):
os.mkdir(MODELS)
if isinstance(country, str):
country = [country]
for cntr in country:
url = f'{DialectID_URL}/es-{cntr}-sample.json.zip'
filename=join(MODELS, f'es-{cntr}-sample.json.zip')
url = f'{DialectID_URL}/{lang}-{cntr}-sample.json.zip'
filename=join(MODELS, f'{lang}-{cntr}-sample.json.zip')
if isfile(filename):
continue
Download(url, filename)
with ZipFile(filename, "r") as fpt:
fpt.extractall(path=MODELS,
pwd="ingeotec".encode("utf-8"))
if len(country) == 1 and return_X_y is False:
return join(MODELS, f'es-{country[0]}-sample.json')
return join(MODELS, f'{lang}-{country[0]}-sample.json')
assert return_X_y
X = []
y = []
for cntr in country:
_ = join(MODELS, f'es-{cntr}-sample.json')
_ = join(MODELS, f'{lang}-{cntr}-sample.json')
_ = list(tweet_iterator(_))
X.extend(_)
y.extend([cntr] * len(_))
Expand Down
Loading