| import unittest |
| from Andromeda.model import AndromedaTokenizer |
|
|
|
|
| class TestAndromedaTokenizer(unittest.TestCase): |
| def setUp(self): |
| self.tokenizer = AndromedaTokenizer() |
|
|
| def test_initialization(self): |
| self.assertIsNotNone(self.tokenizer.tokenizer, "Tokenizer is not initialized.") |
| self.assertEqual(self.tokenizer.tokenizer.eos_token, "<eos>", "EOS token is not correctly set.") |
| self.assertEqual(self.tokenizer.tokenizer.pad_token, "<pad>", "PAD token is not correctly set.") |
| self.assertEqual(self.tokenizer.tokenizer.model_max_length, 8192, "Model max length is not correctly set.") |
|
|
| def test_tokenize_texts(self): |
| texts = ["Hello, world!", "Andromeda is great."] |
| tokenized_texts = self.tokenizer.tokenize_texts(texts) |
| self.assertEqual(tokenized_texts.shape[0], len(texts), "Number of tokenized texts does not match input.") |
| self.assertTrue(all(isinstance(t, torch.Tensor) for t in tokenized_texts), "Not all tokenized texts are PyTorch tensors.") |
|
|
| def test_decode(self): |
| texts = ["Hello, world!", "Andromeda is great."] |
| tokenized_texts = self.tokenizer.tokenize_texts(texts) |
| decoded_texts = [self.tokenizer.decode(t) for t in tokenized_texts] |
| self.assertEqual(decoded_texts, texts, "Decoded texts do not match original texts.") |
|
|
| def test_len(self): |
| num_tokens = len(self.tokenizer) |
| self.assertIsInstance(num_tokens, int, "Number of tokens is not an integer.") |
| self.assertGreater(num_tokens, 0, "Number of tokens is not greater than 0.") |
|
|
| if __name__ == '__main__': |
| unittest.main() |