codec.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from .core import encode, decode, alabel, ulabel, IDNAError
  2. import codecs
  3. import re
  4. _unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
  5. class Codec(codecs.Codec):
  6. def encode(self, data, errors='strict'):
  7. if errors != 'strict':
  8. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  9. if not data:
  10. return "", 0
  11. return encode(data), len(data)
  12. def decode(self, data, errors='strict'):
  13. if errors != 'strict':
  14. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  15. if not data:
  16. return '', 0
  17. return decode(data), len(data)
  18. class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
  19. def _buffer_encode(self, data, errors, final):
  20. if errors != 'strict':
  21. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  22. if not data:
  23. return ('', 0)
  24. labels = _unicode_dots_re.split(data)
  25. trailing_dot = ''
  26. if labels:
  27. if not labels[-1]:
  28. trailing_dot = '.'
  29. del labels[-1]
  30. elif not final:
  31. # Keep potentially unfinished label until the next call
  32. del labels[-1]
  33. if labels:
  34. trailing_dot = '.'
  35. result = []
  36. size = 0
  37. for label in labels:
  38. result.append(alabel(label))
  39. if size:
  40. size += 1
  41. size += len(label)
  42. # Join with U+002E
  43. result = '.'.join(result) + trailing_dot
  44. size += len(trailing_dot)
  45. return (result, size)
  46. class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
  47. def _buffer_decode(self, data, errors, final):
  48. if errors != 'strict':
  49. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  50. if not data:
  51. return ('', 0)
  52. labels = _unicode_dots_re.split(data)
  53. trailing_dot = ''
  54. if labels:
  55. if not labels[-1]:
  56. trailing_dot = '.'
  57. del labels[-1]
  58. elif not final:
  59. # Keep potentially unfinished label until the next call
  60. del labels[-1]
  61. if labels:
  62. trailing_dot = '.'
  63. result = []
  64. size = 0
  65. for label in labels:
  66. result.append(ulabel(label))
  67. if size:
  68. size += 1
  69. size += len(label)
  70. result = '.'.join(result) + trailing_dot
  71. size += len(trailing_dot)
  72. return (result, size)
  73. class StreamWriter(Codec, codecs.StreamWriter):
  74. pass
  75. class StreamReader(Codec, codecs.StreamReader):
  76. pass
  77. def getregentry():
  78. return codecs.CodecInfo(
  79. name='idna',
  80. encode=Codec().encode,
  81. decode=Codec().decode,
  82. incrementalencoder=IncrementalEncoder,
  83. incrementaldecoder=IncrementalDecoder,
  84. streamwriter=StreamWriter,
  85. streamreader=StreamReader,
  86. )