1   
   2   
   3   
   4   
   5   
   6   
   7   
   8   
   9   
  10   
  11   
  12   
  13   
  14   
  15   
  16   
  17   
  18   
  19  import sys 
  20  import types 
  21  import itertools 
  22  import warnings 
  23  import decimal 
  24  import datetime 
  25  import keyword 
  26  import warnings 
  27  from array import array 
  28  from operator import itemgetter 
  29   
  30  from pyspark.rdd import RDD, PipelinedRDD 
  31  from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer 
  32   
  33  from itertools import chain, ifilter, imap 
  34   
  35  from py4j.protocol import Py4JError 
  36  from py4j.java_collections import ListConverter, MapConverter 
  37   
  38   
  39  __all__ = [ 
  40      "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", 
  41      "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", 
  42      "ShortType", "ArrayType", "MapType", "StructField", "StructType", 
  43      "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", 
  44      "SchemaRDD", "Row"] 
  48   
  49      """Spark SQL DataType""" 
  50   
  52          return self.__class__.__name__ 
   53   
  55          return hash(str(self)) 
   56   
  58          return (isinstance(other, self.__class__) and 
  59                  self.__dict__ == other.__dict__) 
   60   
  62          return not self.__eq__(other) 
    63   
  66   
  67      """Metaclass for PrimitiveType""" 
  68   
  69      _instances = {} 
  70   
   75   
  78   
  79      """Spark SQL PrimitiveType""" 
  80   
  81      __metaclass__ = PrimitiveTypeSingleton 
  82   
  84           
  85          return self is other 
    86   
  89   
  90      """Spark SQL StringType 
  91   
  92      The data type representing string values. 
  93      """ 
   94   
  97   
  98      """Spark SQL BinaryType 
  99   
 100      The data type representing bytearray values. 
 101      """ 
  102   
 105   
 106      """Spark SQL BooleanType 
 107   
 108      The data type representing bool values. 
 109      """ 
  110   
 113   
 114      """Spark SQL TimestampType 
 115   
 116      The data type representing datetime.datetime values. 
 117      """ 
  118   
 121   
 122      """Spark SQL DecimalType 
 123   
 124      The data type representing decimal.Decimal values. 
 125      """ 
  126   
 129   
 130      """Spark SQL DoubleType 
 131   
 132      The data type representing float values. 
 133      """ 
  134   
 137   
 138      """Spark SQL FloatType 
 139   
 140      The data type representing single precision floating-point values. 
 141      """ 
  142   
 145   
 146      """Spark SQL ByteType 
 147   
 148      The data type representing int values with 1 singed byte. 
 149      """ 
  150   
 153   
 154      """Spark SQL IntegerType 
 155   
 156      The data type representing int values. 
 157      """ 
  158   
 161   
 162      """Spark SQL LongType 
 163   
 164      The data type representing long values. If the any value is 
 165      beyond the range of [-9223372036854775808, 9223372036854775807], 
 166      please use DecimalType. 
 167      """ 
  168   
 171   
 172      """Spark SQL ShortType 
 173   
 174      The data type representing int values with 2 signed bytes. 
 175      """ 
  176   
 179   
 180      """Spark SQL ArrayType 
 181   
 182      The data type representing list values. An ArrayType object 
 183      comprises two fields, elementType (a DataType) and containsNull (a bool). 
 184      The field of elementType is used to specify the type of array elements. 
 185      The field of containsNull is used to specify if the array has None values. 
 186   
 187      """ 
 188   
 189 -    def __init__(self, elementType, containsNull=True): 
  190          """Creates an ArrayType 
 191   
 192          :param elementType: the data type of elements. 
 193          :param containsNull: indicates whether the list contains None values. 
 194   
 195          >>> ArrayType(StringType) == ArrayType(StringType, True) 
 196          True 
 197          >>> ArrayType(StringType, False) == ArrayType(StringType) 
 198          False 
 199          """ 
 200          self.elementType = elementType 
 201          self.containsNull = containsNull 
  202   
 204          return "ArrayType(%s,%s)" % (self.elementType, 
 205                                       str(self.containsNull).lower()) 
   206   
 209   
 210      """Spark SQL MapType 
 211   
 212      The data type representing dict values. A MapType object comprises 
 213      three fields, keyType (a DataType), valueType (a DataType) and 
 214      valueContainsNull (a bool). 
 215   
 216      The field of keyType is used to specify the type of keys in the map. 
 217      The field of valueType is used to specify the type of values in the map. 
 218      The field of valueContainsNull is used to specify if values of this 
 219      map has None values. 
 220   
 221      For values of a MapType column, keys are not allowed to have None values. 
 222   
 223      """ 
 224   
 225 -    def __init__(self, keyType, valueType, valueContainsNull=True): 
  226          """Creates a MapType 
 227          :param keyType: the data type of keys. 
 228          :param valueType: the data type of values. 
 229          :param valueContainsNull: indicates whether values contains 
 230          null values. 
 231   
 232          >>> (MapType(StringType, IntegerType) 
 233          ...        == MapType(StringType, IntegerType, True)) 
 234          True 
 235          >>> (MapType(StringType, IntegerType, False) 
 236          ...        == MapType(StringType, FloatType)) 
 237          False 
 238          """ 
 239          self.keyType = keyType 
 240          self.valueType = valueType 
 241          self.valueContainsNull = valueContainsNull 
  242   
 244          return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, 
 245                                        str(self.valueContainsNull).lower()) 
   246   
 249   
 250      """Spark SQL StructField 
 251   
 252      Represents a field in a StructType. 
 253      A StructField object comprises three fields, name (a string), 
 254      dataType (a DataType) and nullable (a bool). The field of name 
 255      is the name of a StructField. The field of dataType specifies 
 256      the data type of a StructField. 
 257   
 258      The field of nullable specifies if values of a StructField can 
 259      contain None values. 
 260   
 261      """ 
 262   
 263 -    def __init__(self, name, dataType, nullable): 
  264          """Creates a StructField 
 265          :param name: the name of this field. 
 266          :param dataType: the data type of this field. 
 267          :param nullable: indicates whether values of this field 
 268                           can be null. 
 269   
 270          >>> (StructField("f1", StringType, True) 
 271          ...      == StructField("f1", StringType, True)) 
 272          True 
 273          >>> (StructField("f1", StringType, True) 
 274          ...      == StructField("f2", StringType, True)) 
 275          False 
 276          """ 
 277          self.name = name 
 278          self.dataType = dataType 
 279          self.nullable = nullable 
  280   
 282          return "StructField(%s,%s,%s)" % (self.name, self.dataType, 
 283                                            str(self.nullable).lower()) 
   284   
 287   
 288      """Spark SQL StructType 
 289   
 290      The data type representing rows. 
 291      A StructType object comprises a list of L{StructField}s. 
 292   
 293      """ 
 294   
 296          """Creates a StructType 
 297   
 298          >>> struct1 = StructType([StructField("f1", StringType, True)]) 
 299          >>> struct2 = StructType([StructField("f1", StringType, True)]) 
 300          >>> struct1 == struct2 
 301          True 
 302          >>> struct1 = StructType([StructField("f1", StringType, True)]) 
 303          >>> struct2 = StructType([StructField("f1", StringType, True), 
 304          ...   [StructField("f2", IntegerType, False)]]) 
 305          >>> struct1 == struct2 
 306          False 
 307          """ 
 308          self.fields = fields 
  309   
 311          return ("StructType(List(%s))" % 
 312                  ",".join(str(field) for field in self.fields)) 
   313   
 316      """Parses a list of comma separated data types.""" 
 317      index = 0 
 318      datatype_list = [] 
 319      start = 0 
 320      depth = 0 
 321      while index < len(datatype_list_string): 
 322          if depth == 0 and datatype_list_string[index] == ",": 
 323              datatype_string = datatype_list_string[start:index].strip() 
 324              datatype_list.append(_parse_datatype_string(datatype_string)) 
 325              start = index + 1 
 326          elif datatype_list_string[index] == "(": 
 327              depth += 1 
 328          elif datatype_list_string[index] == ")": 
 329              depth -= 1 
 330   
 331          index += 1 
 332   
 333       
 334      datatype_string = datatype_list_string[start:index].strip() 
 335      datatype_list.append(_parse_datatype_string(datatype_string)) 
 336      return datatype_list 
  337   
 338   
 339  _all_primitive_types = dict((k, v) for k, v in globals().iteritems() 
 340                              if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) 
 344      """Parses the given data type string. 
 345   
 346      >>> def check_datatype(datatype): 
 347      ...     scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) 
 348      ...     python_datatype = _parse_datatype_string( 
 349      ...                          scala_datatype.toString()) 
 350      ...     return datatype == python_datatype 
 351      >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) 
 352      True 
 353      >>> # Simple ArrayType. 
 354      >>> simple_arraytype = ArrayType(StringType(), True) 
 355      >>> check_datatype(simple_arraytype) 
 356      True 
 357      >>> # Simple MapType. 
 358      >>> simple_maptype = MapType(StringType(), LongType()) 
 359      >>> check_datatype(simple_maptype) 
 360      True 
 361      >>> # Simple StructType. 
 362      >>> simple_structtype = StructType([ 
 363      ...     StructField("a", DecimalType(), False), 
 364      ...     StructField("b", BooleanType(), True), 
 365      ...     StructField("c", LongType(), True), 
 366      ...     StructField("d", BinaryType(), False)]) 
 367      >>> check_datatype(simple_structtype) 
 368      True 
 369      >>> # Complex StructType. 
 370      >>> complex_structtype = StructType([ 
 371      ...     StructField("simpleArray", simple_arraytype, True), 
 372      ...     StructField("simpleMap", simple_maptype, True), 
 373      ...     StructField("simpleStruct", simple_structtype, True), 
 374      ...     StructField("boolean", BooleanType(), False)]) 
 375      >>> check_datatype(complex_structtype) 
 376      True 
 377      >>> # Complex ArrayType. 
 378      >>> complex_arraytype = ArrayType(complex_structtype, True) 
 379      >>> check_datatype(complex_arraytype) 
 380      True 
 381      >>> # Complex MapType. 
 382      >>> complex_maptype = MapType(complex_structtype, 
 383      ...                           complex_arraytype, False) 
 384      >>> check_datatype(complex_maptype) 
 385      True 
 386      """ 
 387      index = datatype_string.find("(") 
 388      if index == -1: 
 389           
 390          index = len(datatype_string) 
 391      type_or_field = datatype_string[:index] 
 392      rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() 
 393   
 394      if type_or_field in _all_primitive_types: 
 395          return _all_primitive_types[type_or_field]() 
 396   
 397      elif type_or_field == "ArrayType": 
 398          last_comma_index = rest_part.rfind(",") 
 399          containsNull = True 
 400          if rest_part[last_comma_index + 1:].strip().lower() == "false": 
 401              containsNull = False 
 402          elementType = _parse_datatype_string( 
 403              rest_part[:last_comma_index].strip()) 
 404          return ArrayType(elementType, containsNull) 
 405   
 406      elif type_or_field == "MapType": 
 407          last_comma_index = rest_part.rfind(",") 
 408          valueContainsNull = True 
 409          if rest_part[last_comma_index + 1:].strip().lower() == "false": 
 410              valueContainsNull = False 
 411          keyType, valueType = _parse_datatype_list( 
 412              rest_part[:last_comma_index].strip()) 
 413          return MapType(keyType, valueType, valueContainsNull) 
 414   
 415      elif type_or_field == "StructField": 
 416          first_comma_index = rest_part.find(",") 
 417          name = rest_part[:first_comma_index].strip() 
 418          last_comma_index = rest_part.rfind(",") 
 419          nullable = True 
 420          if rest_part[last_comma_index + 1:].strip().lower() == "false": 
 421              nullable = False 
 422          dataType = _parse_datatype_string( 
 423              rest_part[first_comma_index + 1:last_comma_index].strip()) 
 424          return StructField(name, dataType, nullable) 
 425   
 426      elif type_or_field == "StructType": 
 427           
 428           
 429          field_list_string = rest_part[rest_part.find("(") + 1:-1] 
 430          fields = _parse_datatype_list(field_list_string) 
 431          return StructType(fields) 
  432   
 433   
 434   
 435  _type_mappings = { 
 436      bool: BooleanType, 
 437      int: IntegerType, 
 438      long: LongType, 
 439      float: DoubleType, 
 440      str: StringType, 
 441      unicode: StringType, 
 442      decimal.Decimal: DecimalType, 
 443      datetime.datetime: TimestampType, 
 444      datetime.date: TimestampType, 
 445      datetime.time: TimestampType, 
 446  } 
 450      """Infer the DataType from obj""" 
 451      if obj is None: 
 452          raise ValueError("Can not infer type for None") 
 453   
 454      dataType = _type_mappings.get(type(obj)) 
 455      if dataType is not None: 
 456          return dataType() 
 457   
 458      if isinstance(obj, dict): 
 459          if not obj: 
 460              raise ValueError("Can not infer type for empty dict") 
 461          key, value = obj.iteritems().next() 
 462          return MapType(_infer_type(key), _infer_type(value), True) 
 463      elif isinstance(obj, (list, array)): 
 464          if not obj: 
 465              raise ValueError("Can not infer type for empty list/array") 
 466          return ArrayType(_infer_type(obj[0]), True) 
 467      else: 
 468          try: 
 469              return _infer_schema(obj) 
 470          except ValueError: 
 471              raise ValueError("not supported type: %s" % type(obj)) 
  472   
 475      """Infer the schema from dict/namedtuple/object""" 
 476      if isinstance(row, dict): 
 477          items = sorted(row.items()) 
 478   
 479      elif isinstance(row, tuple): 
 480          if hasattr(row, "_fields"):   
 481              items = zip(row._fields, tuple(row)) 
 482          elif hasattr(row, "__FIELDS__"):   
 483              items = zip(row.__FIELDS__, tuple(row)) 
 484          elif all(isinstance(x, tuple) and len(x) == 2 for x in row): 
 485              items = row 
 486          else: 
 487              raise ValueError("Can't infer schema from tuple") 
 488   
 489      elif hasattr(row, "__dict__"):   
 490          items = sorted(row.__dict__.items()) 
 491   
 492      else: 
 493          raise ValueError("Can not infer schema for type: %s" % type(row)) 
 494   
 495      fields = [StructField(k, _infer_type(v), True) for k, v in items] 
 496      return StructType(fields) 
  497   
 500      """Create an converter to drop the names of fields in obj """ 
 501      if isinstance(dataType, ArrayType): 
 502          conv = _create_converter(obj[0], dataType.elementType) 
 503          return lambda row: map(conv, row) 
 504   
 505      elif isinstance(dataType, MapType): 
 506          value = obj.values()[0] 
 507          conv = _create_converter(value, dataType.valueType) 
 508          return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) 
 509   
 510      elif not isinstance(dataType, StructType): 
 511          return lambda x: x 
 512   
 513       
 514      names = [f.name for f in dataType.fields] 
 515   
 516      if isinstance(obj, dict): 
 517          conv = lambda o: tuple(o.get(n) for n in names) 
 518   
 519      elif isinstance(obj, tuple): 
 520          if hasattr(obj, "_fields"):   
 521              conv = tuple 
 522          elif hasattr(obj, "__FIELDS__"): 
 523              conv = tuple 
 524          elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): 
 525              conv = lambda o: tuple(v for k, v in o) 
 526          else: 
 527              raise ValueError("unexpected tuple") 
 528   
 529      elif hasattr(obj, "__dict__"):   
 530          conv = lambda o: [o.__dict__.get(n, None) for n in names] 
 531   
 532      if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): 
 533          return conv 
 534   
 535      row = conv(obj) 
 536      convs = [_create_converter(v, f.dataType) 
 537               for v, f in zip(row, dataType.fields)] 
 538   
 539      def nested_conv(row): 
 540          return tuple(f(v) for f, v in zip(convs, conv(row))) 
  541   
 542      return nested_conv 
 543   
 546      """ all the names of fields, becoming tuples""" 
 547      iterator = iter(rows) 
 548      row = iterator.next() 
 549      converter = _create_converter(row, schema) 
 550      yield converter(row) 
 551      for i in iterator: 
 552          yield converter(i) 
  553   
 554   
 555  _BRACKETS = {'(': ')', '[': ']', '{': '}'} 
 559      """ 
 560      split the schema abstract into fields 
 561   
 562      >>> _split_schema_abstract("a b  c") 
 563      ['a', 'b', 'c'] 
 564      >>> _split_schema_abstract("a(a b)") 
 565      ['a(a b)'] 
 566      >>> _split_schema_abstract("a b[] c{a b}") 
 567      ['a', 'b[]', 'c{a b}'] 
 568      >>> _split_schema_abstract(" ") 
 569      [] 
 570      """ 
 571   
 572      r = [] 
 573      w = '' 
 574      brackets = [] 
 575      for c in s: 
 576          if c == ' ' and not brackets: 
 577              if w: 
 578                  r.append(w) 
 579              w = '' 
 580          else: 
 581              w += c 
 582              if c in _BRACKETS: 
 583                  brackets.append(c) 
 584              elif c in _BRACKETS.values(): 
 585                  if not brackets or c != _BRACKETS[brackets.pop()]: 
 586                      raise ValueError("unexpected " + c) 
 587   
 588      if brackets: 
 589          raise ValueError("brackets not closed: %s" % brackets) 
 590      if w: 
 591          r.append(w) 
 592      return r 
  593   
 596      """ 
 597      Parse a field in schema abstract 
 598   
 599      >>> _parse_field_abstract("a") 
 600      StructField(a,None,true) 
 601      >>> _parse_field_abstract("b(c d)") 
 602      StructField(b,StructType(...c,None,true),StructField(d... 
 603      >>> _parse_field_abstract("a[]") 
 604      StructField(a,ArrayType(None,true),true) 
 605      >>> _parse_field_abstract("a{[]}") 
 606      StructField(a,MapType(None,ArrayType(None,true),true),true) 
 607      """ 
 608      if set(_BRACKETS.keys()) & set(s): 
 609          idx = min((s.index(c) for c in _BRACKETS if c in s)) 
 610          name = s[:idx] 
 611          return StructField(name, _parse_schema_abstract(s[idx:]), True) 
 612      else: 
 613          return StructField(s, None, True) 
  614   
 617      """ 
 618      parse abstract into schema 
 619   
 620      >>> _parse_schema_abstract("a b  c") 
 621      StructType...a...b...c... 
 622      >>> _parse_schema_abstract("a[b c] b{}") 
 623      StructType...a,ArrayType...b...c...b,MapType... 
 624      >>> _parse_schema_abstract("c{} d{a b}") 
 625      StructType...c,MapType...d,MapType...a...b... 
 626      >>> _parse_schema_abstract("a b(t)").fields[1] 
 627      StructField(b,StructType(List(StructField(t,None,true))),true) 
 628      """ 
 629      s = s.strip() 
 630      if not s: 
 631          return 
 632   
 633      elif s.startswith('('): 
 634          return _parse_schema_abstract(s[1:-1]) 
 635   
 636      elif s.startswith('['): 
 637          return ArrayType(_parse_schema_abstract(s[1:-1]), True) 
 638   
 639      elif s.startswith('{'): 
 640          return MapType(None, _parse_schema_abstract(s[1:-1])) 
 641   
 642      parts = _split_schema_abstract(s) 
 643      fields = [_parse_field_abstract(p) for p in parts] 
 644      return StructType(fields) 
  645   
 648      """ 
 649      Fill the dataType with types infered from obj 
 650   
 651      >>> schema = _parse_schema_abstract("a b c") 
 652      >>> row = (1, 1.0, "str") 
 653      >>> _infer_schema_type(row, schema) 
 654      StructType...IntegerType...DoubleType...StringType... 
 655      >>> row = [[1], {"key": (1, 2.0)}] 
 656      >>> schema = _parse_schema_abstract("a[] b{c d}") 
 657      >>> _infer_schema_type(row, schema) 
 658      StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... 
 659      """ 
 660      if dataType is None: 
 661          return _infer_type(obj) 
 662   
 663      if not obj: 
 664          raise ValueError("Can not infer type from empty value") 
 665   
 666      if isinstance(dataType, ArrayType): 
 667          eType = _infer_schema_type(obj[0], dataType.elementType) 
 668          return ArrayType(eType, True) 
 669   
 670      elif isinstance(dataType, MapType): 
 671          k, v = obj.iteritems().next() 
 672          return MapType(_infer_type(k), 
 673                         _infer_schema_type(v, dataType.valueType)) 
 674   
 675      elif isinstance(dataType, StructType): 
 676          fs = dataType.fields 
 677          assert len(fs) == len(obj), \ 
 678              "Obj(%s) have different length with fields(%s)" % (obj, fs) 
 679          fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) 
 680                    for o, f in zip(obj, fs)] 
 681          return StructType(fields) 
 682   
 683      else: 
 684          raise ValueError("Unexpected dataType: %s" % dataType) 
  685   
 686   
 687  _acceptable_types = { 
 688      BooleanType: (bool,), 
 689      ByteType: (int, long), 
 690      ShortType: (int, long), 
 691      IntegerType: (int, long), 
 692      LongType: (long,), 
 693      FloatType: (float,), 
 694      DoubleType: (float,), 
 695      DecimalType: (decimal.Decimal,), 
 696      StringType: (str, unicode), 
 697      TimestampType: (datetime.datetime,), 
 698      ArrayType: (list, tuple, array), 
 699      MapType: (dict,), 
 700      StructType: (tuple, list), 
 701  } 
 705      """ 
 706      Verify the type of obj against dataType, raise an exception if 
 707      they do not match. 
 708   
 709      >>> _verify_type(None, StructType([])) 
 710      >>> _verify_type("", StringType()) 
 711      >>> _verify_type(0, IntegerType()) 
 712      >>> _verify_type(range(3), ArrayType(ShortType())) 
 713      >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL 
 714      Traceback (most recent call last): 
 715          ... 
 716      TypeError:... 
 717      >>> _verify_type({}, MapType(StringType(), IntegerType())) 
 718      >>> _verify_type((), StructType([])) 
 719      >>> _verify_type([], StructType([])) 
 720      >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL 
 721      Traceback (most recent call last): 
 722          ... 
 723      ValueError:... 
 724      """ 
 725       
 726      if obj is None: 
 727          return 
 728   
 729      _type = type(dataType) 
 730      if _type not in _acceptable_types: 
 731          return 
 732   
 733      if type(obj) not in _acceptable_types[_type]: 
 734          raise TypeError("%s can not accept abject in type %s" 
 735                          % (dataType, type(obj))) 
 736   
 737      if isinstance(dataType, ArrayType): 
 738          for i in obj: 
 739              _verify_type(i, dataType.elementType) 
 740   
 741      elif isinstance(dataType, MapType): 
 742          for k, v in obj.iteritems(): 
 743              _verify_type(k, dataType.keyType) 
 744              _verify_type(v, dataType.valueType) 
 745   
 746      elif isinstance(dataType, StructType): 
 747          if len(obj) != len(dataType.fields): 
 748              raise ValueError("Length of object (%d) does not match with" 
 749                               "length of fields (%d)" % (len(obj), len(dataType.fields))) 
 750          for v, f in zip(obj, dataType.fields): 
 751              _verify_type(v, f.dataType) 
  752   
 753   
 754  _cached_cls = {} 
 758      """ Restore object during unpickling. """ 
 759       
 760       
 761       
 762      k = id(dataType) 
 763      cls = _cached_cls.get(k) 
 764      if cls is None: 
 765           
 766          cls = _cached_cls.get(dataType) 
 767          if cls is None: 
 768              cls = _create_cls(dataType) 
 769              _cached_cls[dataType] = cls 
 770          _cached_cls[k] = cls 
 771      return cls(obj) 
  772   
 775      """ Create an customized object with class `cls`. """ 
 776      return cls(v) if v is not None else v 
  777   
 780      """ Create a getter for item `i` with schema """ 
 781      cls = _create_cls(dt) 
 782   
 783      def getter(self): 
 784          return _create_object(cls, self[i]) 
  785   
 786      return getter 
 787   
 790      """Return whether `dt` is or has StructType in it""" 
 791      if isinstance(dt, StructType): 
 792          return True 
 793      elif isinstance(dt, ArrayType): 
 794          return _has_struct(dt.elementType) 
 795      elif isinstance(dt, MapType): 
 796          return _has_struct(dt.valueType) 
 797      return False 
  798   
 801      """Create properties according to fields""" 
 802      ps = {} 
 803      for i, f in enumerate(fields): 
 804          name = f.name 
 805          if (name.startswith("__") and name.endswith("__") 
 806                  or keyword.iskeyword(name)): 
 807              warnings.warn("field name %s can not be accessed in Python," 
 808                            "use position to access it instead" % name) 
 809          if _has_struct(f.dataType): 
 810               
 811              getter = _create_getter(f.dataType, i) 
 812          else: 
 813              getter = itemgetter(i) 
 814          ps[name] = property(getter) 
 815      return ps 
  816   
 819      """ 
 820      Create an class by dataType 
 821   
 822      The created class is similar to namedtuple, but can have nested schema. 
 823   
 824      >>> schema = _parse_schema_abstract("a b c") 
 825      >>> row = (1, 1.0, "str") 
 826      >>> schema = _infer_schema_type(row, schema) 
 827      >>> obj = _create_cls(schema)(row) 
 828      >>> import pickle 
 829      >>> pickle.loads(pickle.dumps(obj)) 
 830      Row(a=1, b=1.0, c='str') 
 831   
 832      >>> row = [[1], {"key": (1, 2.0)}] 
 833      >>> schema = _parse_schema_abstract("a[] b{c d}") 
 834      >>> schema = _infer_schema_type(row, schema) 
 835      >>> obj = _create_cls(schema)(row) 
 836      >>> pickle.loads(pickle.dumps(obj)) 
 837      Row(a=[1], b={'key': Row(c=1, d=2.0)}) 
 838      """ 
 839   
 840      if isinstance(dataType, ArrayType): 
 841          cls = _create_cls(dataType.elementType) 
 842   
 843          class List(list): 
 844   
 845              def __getitem__(self, i): 
 846                   
 847                  return _create_object(cls, list.__getitem__(self, i)) 
  848   
 849              def __repr__(self): 
 850                   
 851                  return "[%s]" % (", ".join(repr(self[i]) 
 852                                             for i in range(len(self)))) 
 853   
 854              def __reduce__(self): 
 855                  return list.__reduce__(self) 
 856   
 857          return List 
 858   
 859      elif isinstance(dataType, MapType): 
 860          vcls = _create_cls(dataType.valueType) 
 861   
 862          class Dict(dict): 
 863   
 864              def __getitem__(self, k): 
 865                   
 866                  return _create_object(vcls, dict.__getitem__(self, k)) 
 867   
 868              def __repr__(self): 
 869                   
 870                  return "{%s}" % (", ".join("%r: %r" % (k, self[k]) 
 871                                             for k in self)) 
 872   
 873              def __reduce__(self): 
 874                  return dict.__reduce__(self) 
 875   
 876          return Dict 
 877   
 878      elif not isinstance(dataType, StructType): 
 879          raise Exception("unexpected data type: %s" % dataType) 
 880   
 881      class Row(tuple): 
 882   
 883          """ Row in SchemaRDD """ 
 884          __DATATYPE__ = dataType 
 885          __FIELDS__ = tuple(f.name for f in dataType.fields) 
 886          __slots__ = () 
 887   
 888           
 889          locals().update(_create_properties(dataType.fields)) 
 890   
 891          def __repr__(self): 
 892               
 893              return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) 
 894                                            for n in self.__FIELDS__)) 
 895   
 896          def __reduce__(self): 
 897              return (_restore_object, (self.__DATATYPE__, tuple(self))) 
 898   
 899      return Row 
 900   
 901   
 902 -class SQLContext: 
  903   
 904      """Main entry point for SparkSQL functionality. 
 905   
 906      A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as 
 907      tables, execute SQL over tables, cache tables, and read parquet files. 
 908      """ 
 909   
 910 -    def __init__(self, sparkContext, sqlContext=None): 
  911          """Create a new SQLContext. 
 912   
 913          @param sparkContext: The SparkContext to wrap. 
 914          @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new 
 915          SQLContext in the JVM, instead we make all calls to this object. 
 916   
 917          >>> srdd = sqlCtx.inferSchema(rdd) 
 918          >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL 
 919          Traceback (most recent call last): 
 920              ... 
 921          TypeError:... 
 922   
 923          >>> bad_rdd = sc.parallelize([1,2,3]) 
 924          >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL 
 925          Traceback (most recent call last): 
 926              ... 
 927          ValueError:... 
 928   
 929          >>> from datetime import datetime 
 930          >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, 
 931          ...     b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), 
 932          ...     time=datetime(2014, 8, 1, 14, 1, 5))]) 
 933          >>> srdd = sqlCtx.inferSchema(allTypes) 
 934          >>> srdd.registerTempTable("allTypes") 
 935          >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' 
 936          ...            'from allTypes where b and i > 0').collect() 
 937          [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] 
 938          >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, 
 939          ...                     x.row.a, x.list)).collect() 
 940          [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] 
 941          """ 
 942          self._sc = sparkContext 
 943          self._jsc = self._sc._jsc 
 944          self._jvm = self._sc._jvm 
 945          self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray 
 946   
 947          if sqlContext: 
 948              self._scala_SQLContext = sqlContext 
  949   
 950      @property 
 951 -    def _ssql_ctx(self): 
  952          """Accessor for the JVM SparkSQL context. 
 953   
 954          Subclasses can override this property to provide their own 
 955          JVM Contexts. 
 956          """ 
 957          if not hasattr(self, '_scala_SQLContext'): 
 958              self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) 
 959          return self._scala_SQLContext 
  960   
 961 -    def registerFunction(self, name, f, returnType=StringType()): 
  962          """Registers a lambda function as a UDF so it can be used in SQL statements. 
 963   
 964          In addition to a name and the function itself, the return type can be optionally specified. 
 965          When the return type is not given it default to a string and conversion will automatically 
 966          be done.  For any other return type, the produced object must match the specified type. 
 967   
 968          >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) 
 969          >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() 
 970          [Row(c0=u'4')] 
 971          >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) 
 972          >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() 
 973          [Row(c0=4)] 
 974          >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) 
 975          >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() 
 976          [Row(c0=5)] 
 977          """ 
 978          func = lambda _, it: imap(lambda x: f(*x), it) 
 979          command = (func, 
 980                     BatchedSerializer(PickleSerializer(), 1024), 
 981                     BatchedSerializer(PickleSerializer(), 1024)) 
 982          env = MapConverter().convert(self._sc.environment, 
 983                                       self._sc._gateway._gateway_client) 
 984          includes = ListConverter().convert(self._sc._python_includes, 
 985                                             self._sc._gateway._gateway_client) 
 986          self._ssql_ctx.registerPython(name, 
 987                                        bytearray(CloudPickleSerializer().dumps(command)), 
 988                                        env, 
 989                                        includes, 
 990                                        self._sc.pythonExec, 
 991                                        self._sc._javaAccumulator, 
 992                                        str(returnType)) 
  993   
 994 -    def inferSchema(self, rdd): 
  995          """Infer and apply a schema to an RDD of L{Row}s. 
 996   
 997          We peek at the first row of the RDD to determine the fields' names 
 998          and types. Nested collections are supported, which include array, 
 999          dict, list, Row, tuple, namedtuple, or object. 
1000   
1001          All the rows in `rdd` should have the same type with the first one, 
1002          or it will cause runtime exceptions. 
1003   
1004          Each row could be L{pyspark.sql.Row} object or namedtuple or objects, 
1005          using dict is deprecated. 
1006   
1007          >>> rdd = sc.parallelize( 
1008          ...     [Row(field1=1, field2="row1"), 
1009          ...      Row(field1=2, field2="row2"), 
1010          ...      Row(field1=3, field2="row3")]) 
1011          >>> srdd = sqlCtx.inferSchema(rdd) 
1012          >>> srdd.collect()[0] 
1013          Row(field1=1, field2=u'row1') 
1014   
1015          >>> NestedRow = Row("f1", "f2") 
1016          >>> nestedRdd1 = sc.parallelize([ 
1017          ...     NestedRow(array('i', [1, 2]), {"row1": 1.0}), 
1018          ...     NestedRow(array('i', [2, 3]), {"row2": 2.0})]) 
1019          >>> srdd = sqlCtx.inferSchema(nestedRdd1) 
1020          >>> srdd.collect() 
1021          [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] 
1022   
1023          >>> nestedRdd2 = sc.parallelize([ 
1024          ...     NestedRow([[1, 2], [2, 3]], [1, 2]), 
1025          ...     NestedRow([[2, 3], [3, 4]], [2, 3])]) 
1026          >>> srdd = sqlCtx.inferSchema(nestedRdd2) 
1027          >>> srdd.collect() 
1028          [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] 
1029          """ 
1030   
1031          if isinstance(rdd, SchemaRDD): 
1032              raise TypeError("Cannot apply schema to SchemaRDD") 
1033   
1034          first = rdd.first() 
1035          if not first: 
1036              raise ValueError("The first row in RDD is empty, " 
1037                               "can not infer schema") 
1038          if type(first) is dict: 
1039              warnings.warn("Using RDD of dict to inferSchema is deprecated," 
1040                            "please use pyspark.Row instead") 
1041   
1042          schema = _infer_schema(first) 
1043          rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) 
1044          return self.applySchema(rdd, schema) 
 1045   
1046 -    def applySchema(self, rdd, schema): 
 1047          """ 
1048          Applies the given schema to the given RDD of L{tuple} or L{list}s. 
1049   
1050          These tuples or lists can contain complex nested structures like 
1051          lists, maps or nested rows. 
1052   
1053          The schema should be a StructType. 
1054   
1055          It is important that the schema matches the types of the objects 
1056          in each row or exceptions could be thrown at runtime. 
1057   
1058          >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) 
1059          >>> schema = StructType([StructField("field1", IntegerType(), False), 
1060          ...     StructField("field2", StringType(), False)]) 
1061          >>> srdd = sqlCtx.applySchema(rdd2, schema) 
1062          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
1063          >>> srdd2 = sqlCtx.sql("SELECT * from table1") 
1064          >>> srdd2.collect() 
1065          [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] 
1066   
1067          >>> from datetime import datetime 
1068          >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, 
1069          ...     datetime(2010, 1, 1, 1, 1, 1), 
1070          ...     {"a": 1}, (2,), [1, 2, 3], None)]) 
1071          >>> schema = StructType([ 
1072          ...     StructField("byte1", ByteType(), False), 
1073          ...     StructField("byte2", ByteType(), False), 
1074          ...     StructField("short1", ShortType(), False), 
1075          ...     StructField("short2", ShortType(), False), 
1076          ...     StructField("int", IntegerType(), False), 
1077          ...     StructField("float", FloatType(), False), 
1078          ...     StructField("time", TimestampType(), False), 
1079          ...     StructField("map", 
1080          ...         MapType(StringType(), IntegerType(), False), False), 
1081          ...     StructField("struct", 
1082          ...         StructType([StructField("b", ShortType(), False)]), False), 
1083          ...     StructField("list", ArrayType(ByteType(), False), False), 
1084          ...     StructField("null", DoubleType(), True)]) 
1085          >>> srdd = sqlCtx.applySchema(rdd, schema) 
1086          >>> results = srdd.map( 
1087          ...     lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, 
1088          ...         x.map["a"], x.struct.b, x.list, x.null)) 
1089          >>> results.collect()[0] 
1090          (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) 
1091   
1092          >>> srdd.registerTempTable("table2") 
1093          >>> sqlCtx.sql( 
1094          ...   "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + 
1095          ...     "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + 
1096          ...     "float + 1.5 as float FROM table2").collect() 
1097          [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] 
1098   
1099          >>> rdd = sc.parallelize([(127, -32768, 1.0, 
1100          ...     datetime(2010, 1, 1, 1, 1, 1), 
1101          ...     {"a": 1}, (2,), [1, 2, 3])]) 
1102          >>> abstract = "byte short float time map{} struct(b) list[]" 
1103          >>> schema = _parse_schema_abstract(abstract) 
1104          >>> typedSchema = _infer_schema_type(rdd.first(), schema) 
1105          >>> srdd = sqlCtx.applySchema(rdd, typedSchema) 
1106          >>> srdd.collect() 
1107          [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] 
1108          """ 
1109   
1110          if isinstance(rdd, SchemaRDD): 
1111              raise TypeError("Cannot apply schema to SchemaRDD") 
1112   
1113          if not isinstance(schema, StructType): 
1114              raise TypeError("schema should be StructType") 
1115   
1116           
1117          rows = rdd.take(10) 
1118          for row in rows: 
1119              _verify_type(row, schema) 
1120   
1121          batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) 
1122          jrdd = self._pythonToJava(rdd._jrdd, batched) 
1123          srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) 
1124          return SchemaRDD(srdd.toJavaSchemaRDD(), self) 
 1125   
1126 -    def registerRDDAsTable(self, rdd, tableName): 
 1127          """Registers the given RDD as a temporary table in the catalog. 
1128   
1129          Temporary tables exist only during the lifetime of this instance of 
1130          SQLContext. 
1131   
1132          >>> srdd = sqlCtx.inferSchema(rdd) 
1133          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
1134          """ 
1135          if (rdd.__class__ is SchemaRDD): 
1136              srdd = rdd._jschema_rdd.baseSchemaRDD() 
1137              self._ssql_ctx.registerRDDAsTable(srdd, tableName) 
1138          else: 
1139              raise ValueError("Can only register SchemaRDD as table") 
 1140   
1141 -    def parquetFile(self, path): 
 1142          """Loads a Parquet file, returning the result as a L{SchemaRDD}. 
1143   
1144          >>> import tempfile, shutil 
1145          >>> parquetFile = tempfile.mkdtemp() 
1146          >>> shutil.rmtree(parquetFile) 
1147          >>> srdd = sqlCtx.inferSchema(rdd) 
1148          >>> srdd.saveAsParquetFile(parquetFile) 
1149          >>> srdd2 = sqlCtx.parquetFile(parquetFile) 
1150          >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 
1151          True 
1152          """ 
1153          jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() 
1154          return SchemaRDD(jschema_rdd, self) 
 1155   
1156 -    def jsonFile(self, path, schema=None): 
 1157          """ 
1158          Loads a text file storing one JSON object per line as a 
1159          L{SchemaRDD}. 
1160   
1161          If the schema is provided, applies the given schema to this 
1162          JSON dataset. 
1163   
1164          Otherwise, it goes through the entire dataset once to determine 
1165          the schema. 
1166   
1167          >>> import tempfile, shutil 
1168          >>> jsonFile = tempfile.mkdtemp() 
1169          >>> shutil.rmtree(jsonFile) 
1170          >>> ofn = open(jsonFile, 'w') 
1171          >>> for json in jsonStrings: 
1172          ...   print>>ofn, json 
1173          >>> ofn.close() 
1174          >>> srdd1 = sqlCtx.jsonFile(jsonFile) 
1175          >>> sqlCtx.registerRDDAsTable(srdd1, "table1") 
1176          >>> srdd2 = sqlCtx.sql( 
1177          ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, " 
1178          ...   "field6 as f4 from table1") 
1179          >>> for r in srdd2.collect(): 
1180          ...     print r 
1181          Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) 
1182          Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) 
1183          Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) 
1184          >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) 
1185          >>> sqlCtx.registerRDDAsTable(srdd3, "table2") 
1186          >>> srdd4 = sqlCtx.sql( 
1187          ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, " 
1188          ...   "field6 as f4 from table2") 
1189          >>> for r in srdd4.collect(): 
1190          ...    print r 
1191          Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) 
1192          Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) 
1193          Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) 
1194          >>> schema = StructType([ 
1195          ...     StructField("field2", StringType(), True), 
1196          ...     StructField("field3", 
1197          ...         StructType([ 
1198          ...             StructField("field5", 
1199          ...                 ArrayType(IntegerType(), False), True)]), False)]) 
1200          >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) 
1201          >>> sqlCtx.registerRDDAsTable(srdd5, "table3") 
1202          >>> srdd6 = sqlCtx.sql( 
1203          ...   "SELECT field2 AS f1, field3.field5 as f2, " 
1204          ...   "field3.field5[0] as f3 from table3") 
1205          >>> srdd6.collect() 
1206          [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] 
1207          """ 
1208          if schema is None: 
1209              srdd = self._ssql_ctx.jsonFile(path) 
1210          else: 
1211              scala_datatype = self._ssql_ctx.parseDataType(str(schema)) 
1212              srdd = self._ssql_ctx.jsonFile(path, scala_datatype) 
1213          return SchemaRDD(srdd.toJavaSchemaRDD(), self) 
 1214   
1215 -    def jsonRDD(self, rdd, schema=None): 
 1216          """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. 
1217   
1218          If the schema is provided, applies the given schema to this 
1219          JSON dataset. 
1220   
1221          Otherwise, it goes through the entire dataset once to determine 
1222          the schema. 
1223   
1224          >>> srdd1 = sqlCtx.jsonRDD(json) 
1225          >>> sqlCtx.registerRDDAsTable(srdd1, "table1") 
1226          >>> srdd2 = sqlCtx.sql( 
1227          ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, " 
1228          ...   "field6 as f4 from table1") 
1229          >>> for r in srdd2.collect(): 
1230          ...     print r 
1231          Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) 
1232          Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) 
1233          Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) 
1234          >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) 
1235          >>> sqlCtx.registerRDDAsTable(srdd3, "table2") 
1236          >>> srdd4 = sqlCtx.sql( 
1237          ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, " 
1238          ...   "field6 as f4 from table2") 
1239          >>> for r in srdd4.collect(): 
1240          ...     print r 
1241          Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) 
1242          Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) 
1243          Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) 
1244          >>> schema = StructType([ 
1245          ...     StructField("field2", StringType(), True), 
1246          ...     StructField("field3", 
1247          ...         StructType([ 
1248          ...             StructField("field5", 
1249          ...                 ArrayType(IntegerType(), False), True)]), False)]) 
1250          >>> srdd5 = sqlCtx.jsonRDD(json, schema) 
1251          >>> sqlCtx.registerRDDAsTable(srdd5, "table3") 
1252          >>> srdd6 = sqlCtx.sql( 
1253          ...   "SELECT field2 AS f1, field3.field5 as f2, " 
1254          ...   "field3.field5[0] as f3 from table3") 
1255          >>> srdd6.collect() 
1256          [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] 
1257   
1258          >>> sqlCtx.jsonRDD(sc.parallelize(['{}', 
1259          ...         '{"key0": {"key1": "value1"}}'])).collect() 
1260          [Row(key0=None), Row(key0=Row(key1=u'value1'))] 
1261          >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', 
1262          ...         '{"key0": {"key1": "value1"}}'])).collect() 
1263          [Row(key0=None), Row(key0=Row(key1=u'value1'))] 
1264          """ 
1265   
1266          def func(iterator): 
1267              for x in iterator: 
1268                  if not isinstance(x, basestring): 
1269                      x = unicode(x) 
1270                  if isinstance(x, unicode): 
1271                      x = x.encode("utf-8") 
1272                  yield x 
 1273          keyed = rdd.mapPartitions(func) 
1274          keyed._bypass_serializer = True 
1275          jrdd = keyed._jrdd.map(self._jvm.BytesToString()) 
1276          if schema is None: 
1277              srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) 
1278          else: 
1279              scala_datatype = self._ssql_ctx.parseDataType(str(schema)) 
1280              srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) 
1281          return SchemaRDD(srdd.toJavaSchemaRDD(), self) 
 1282   
1283 -    def sql(self, sqlQuery): 
 1284          """Return a L{SchemaRDD} representing the result of the given query. 
1285   
1286          >>> srdd = sqlCtx.inferSchema(rdd) 
1287          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
1288          >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") 
1289          >>> srdd2.collect() 
1290          [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] 
1291          """ 
1292          return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self) 
 1293   
1294 -    def table(self, tableName): 
 1295          """Returns the specified table as a L{SchemaRDD}. 
1296   
1297          >>> srdd = sqlCtx.inferSchema(rdd) 
1298          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
1299          >>> srdd2 = sqlCtx.table("table1") 
1300          >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 
1301          True 
1302          """ 
1303          return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self) 
 1304   
1305 -    def cacheTable(self, tableName): 
 1306          """Caches the specified table in-memory.""" 
1307          self._ssql_ctx.cacheTable(tableName) 
 1308   
1309 -    def uncacheTable(self, tableName): 
 1310          """Removes the specified table from the in-memory cache.""" 
1311          self._ssql_ctx.uncacheTable(tableName) 
 1312   
1313   
1314 -class HiveContext(SQLContext): 
 1315   
1316      """A variant of Spark SQL that integrates with data stored in Hive. 
1317   
1318      Configuration for Hive is read from hive-site.xml on the classpath. 
1319      It supports running both SQL and HiveQL commands. 
1320      """ 
1321   
1322 -    def __init__(self, sparkContext, hiveContext=None): 
 1323          """Create a new HiveContext. 
1324   
1325          @param sparkContext: The SparkContext to wrap. 
1326          @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new 
1327          HiveContext in the JVM, instead we make all calls to this object. 
1328          """ 
1329          SQLContext.__init__(self, sparkContext) 
1330   
1331          if hiveContext: 
1332              self._scala_HiveContext = hiveContext 
 1333   
1334      @property 
1335 -    def _ssql_ctx(self): 
 1336          try: 
1337              if not hasattr(self, '_scala_HiveContext'): 
1338                  self._scala_HiveContext = self._get_hive_ctx() 
1339              return self._scala_HiveContext 
1340          except Py4JError as e: 
1341              raise Exception("You must build Spark with Hive. " 
1342                              "Export 'SPARK_HIVE=true' and run " 
1343                              "sbt/sbt assembly", e) 
 1344   
1345 -    def _get_hive_ctx(self): 
 1346          return self._jvm.HiveContext(self._jsc.sc()) 
 1347   
1348 -    def hiveql(self, hqlQuery): 
 1349          """ 
1350          DEPRECATED: Use sql() 
1351          """ 
1352          warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" + 
1353                        "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", 
1354                        DeprecationWarning) 
1355          return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self) 
 1356   
1357 -    def hql(self, hqlQuery): 
 1358          """ 
1359          DEPRECATED: Use sql() 
1360          """ 
1361          warnings.warn("hql() is deprecated as the sql function now parses using HiveQL by" + 
1362                        "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", 
1363                        DeprecationWarning) 
1364          return self.hiveql(hqlQuery) 
  1365   
1366   
1367 -class LocalHiveContext(HiveContext): 
 1368   
1369      """Starts up an instance of hive where metadata is stored locally. 
1370   
1371      An in-process metadata data is created with data stored in ./metadata. 
1372      Warehouse data is stored in in ./warehouse. 
1373   
1374      >>> import os 
1375      >>> hiveCtx = LocalHiveContext(sc) 
1376      >>> try: 
1377      ...     supress = hiveCtx.sql("DROP TABLE src") 
1378      ... except Exception: 
1379      ...     pass 
1380      >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 
1381      ...        'examples/src/main/resources/kv1.txt') 
1382      >>> supress = hiveCtx.sql( 
1383      ...     "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") 
1384      >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" 
1385      ...        % kv1) 
1386      >>> results = hiveCtx.sql("FROM src SELECT value" 
1387      ...      ).map(lambda r: int(r.value.split('_')[1])) 
1388      >>> num = results.count() 
1389      >>> reduce_sum = results.reduce(lambda x, y: x + y) 
1390      >>> num 
1391      500 
1392      >>> reduce_sum 
1393      130091 
1394      """ 
1395   
1396 -    def __init__(self, sparkContext, sqlContext=None): 
 1397          HiveContext.__init__(self, sparkContext, sqlContext) 
1398          warnings.warn("LocalHiveContext is deprecated. " 
1399                        "Use HiveContext instead.", DeprecationWarning) 
 1400   
1401 -    def _get_hive_ctx(self): 
 1402          return self._jvm.LocalHiveContext(self._jsc.sc()) 
  1403   
1404   
1405 -class TestHiveContext(HiveContext): 
 1406   
1407 -    def _get_hive_ctx(self): 
 1408          return self._jvm.TestHiveContext(self._jsc.sc()) 
  1409   
1412      row = Row(*values) 
1413      row.__FIELDS__ = fields 
1414      return row 
 1415   
1416   
1417 -class Row(tuple): 
 1418   
1419      """ 
1420      A row in L{SchemaRDD}. The fields in it can be accessed like attributes. 
1421   
1422      Row can be used to create a row object by using named arguments, 
1423      the fields will be sorted by names. 
1424   
1425      >>> row = Row(name="Alice", age=11) 
1426      >>> row 
1427      Row(age=11, name='Alice') 
1428      >>> row.name, row.age 
1429      ('Alice', 11) 
1430   
1431      Row also can be used to create another Row like class, then it 
1432      could be used to create Row objects, such as 
1433   
1434      >>> Person = Row("name", "age") 
1435      >>> Person 
1436      <Row(name, age)> 
1437      >>> Person("Alice", 11) 
1438      Row(name='Alice', age=11) 
1439      """ 
1440   
1441 -    def __new__(self, *args, **kwargs): 
 1442          if args and kwargs: 
1443              raise ValueError("Can not use both args " 
1444                               "and kwargs to create Row") 
1445          if args: 
1446               
1447              return tuple.__new__(self, args) 
1448   
1449          elif kwargs: 
1450               
1451              names = sorted(kwargs.keys()) 
1452              values = tuple(kwargs[n] for n in names) 
1453              row = tuple.__new__(self, values) 
1454              row.__FIELDS__ = names 
1455              return row 
1456   
1457          else: 
1458              raise ValueError("No args or kwargs") 
 1459   
1460       
1462          """create new Row object""" 
1463          return _create_row(self, args) 
 1464   
1466          if item.startswith("__"): 
1467              raise AttributeError(item) 
1468          try: 
1469               
1470               
1471              idx = self.__FIELDS__.index(item) 
1472              return self[idx] 
1473          except IndexError: 
1474              raise AttributeError(item) 
 1475   
1477          if hasattr(self, "__FIELDS__"): 
1478              return (_create_row, (self.__FIELDS__, tuple(self))) 
1479          else: 
1480              return tuple.__reduce__(self) 
 1481   
1483          if hasattr(self, "__FIELDS__"): 
1484              return "Row(%s)" % ", ".join("%s=%r" % (k, v) 
1485                                           for k, v in zip(self.__FIELDS__, self)) 
1486          else: 
1487              return "<Row(%s)>" % ", ".join(self) 
  1488   
1491   
1492      """An RDD of L{Row} objects that has an associated schema. 
1493   
1494      The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can 
1495      utilize the relational query api exposed by SparkSQL. 
1496   
1497      For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the 
1498      L{SchemaRDD} is not operated on directly, as it's underlying 
1499      implementation is an RDD composed of Java objects. Instead it is 
1500      converted to a PythonRDD in the JVM, on which Python operations can 
1501      be done. 
1502   
1503      This class receives raw tuples from Java but assigns a class to it in 
1504      all its data-collection methods (mapPartitionsWithIndex, collect, take, 
1505      etc) so that PySpark sees them as Row objects with named fields. 
1506      """ 
1507   
1508 -    def __init__(self, jschema_rdd, sql_ctx): 
 1509          self.sql_ctx = sql_ctx 
1510          self._sc = sql_ctx._sc 
1511          clsName = jschema_rdd.getClass().getName() 
1512          assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD" 
1513          self._jschema_rdd = jschema_rdd 
1514   
1515          self.is_cached = False 
1516          self.is_checkpointed = False 
1517          self.ctx = self.sql_ctx._sc 
1518           
1519          self._jrdd_deserializer = BatchedSerializer(PickleSerializer()) 
 1520   
1521      @property 
1523          """Lazy evaluation of PythonRDD object. 
1524   
1525          Only done when a user calls methods defined by the 
1526          L{pyspark.rdd.RDD} super class (map, filter, etc.). 
1527          """ 
1528          if not hasattr(self, '_lazy_jrdd'): 
1529              self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() 
1530          return self._lazy_jrdd 
 1531   
1532      @property 
1534          return self._jrdd.id() 
 1535   
1537          """Save the contents as a Parquet file, preserving the schema. 
1538   
1539          Files that are written out using this method can be read back in as 
1540          a SchemaRDD using the L{SQLContext.parquetFile} method. 
1541   
1542          >>> import tempfile, shutil 
1543          >>> parquetFile = tempfile.mkdtemp() 
1544          >>> shutil.rmtree(parquetFile) 
1545          >>> srdd = sqlCtx.inferSchema(rdd) 
1546          >>> srdd.saveAsParquetFile(parquetFile) 
1547          >>> srdd2 = sqlCtx.parquetFile(parquetFile) 
1548          >>> sorted(srdd2.collect()) == sorted(srdd.collect()) 
1549          True 
1550          """ 
1551          self._jschema_rdd.saveAsParquetFile(path) 
 1552   
1554          """Registers this RDD as a temporary table using the given name. 
1555   
1556          The lifetime of this temporary table is tied to the L{SQLContext} 
1557          that was used to create this SchemaRDD. 
1558   
1559          >>> srdd = sqlCtx.inferSchema(rdd) 
1560          >>> srdd.registerTempTable("test") 
1561          >>> srdd2 = sqlCtx.sql("select * from test") 
1562          >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 
1563          True 
1564          """ 
1565          self._jschema_rdd.registerTempTable(name) 
 1566   
1568          warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) 
1569          self.registerTempTable(name) 
 1570   
1571 -    def insertInto(self, tableName, overwrite=False): 
 1572          """Inserts the contents of this SchemaRDD into the specified table. 
1573   
1574          Optionally overwriting any existing data. 
1575          """ 
1576          self._jschema_rdd.insertInto(tableName, overwrite) 
 1577   
1579          """Creates a new table with the contents of this SchemaRDD.""" 
1580          self._jschema_rdd.saveAsTable(tableName) 
 1581   
1583          """Returns the schema of this SchemaRDD (represented by 
1584          a L{StructType}).""" 
1585          return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString()) 
 1586   
1588          """Returns the output schema in the tree format.""" 
1589          return self._jschema_rdd.schemaString() 
 1590   
1592          """Prints out the schema in the tree format.""" 
1593          print self.schemaString() 
 1594   
1596          """Return the number of elements in this RDD. 
1597   
1598          Unlike the base RDD implementation of count, this implementation 
1599          leverages the query optimizer to compute the count on the SchemaRDD, 
1600          which supports features such as filter pushdown. 
1601   
1602          >>> srdd = sqlCtx.inferSchema(rdd) 
1603          >>> srdd.count() 
1604          3L 
1605          >>> srdd.count() == srdd.map(lambda x: x).count() 
1606          True 
1607          """ 
1608          return self._jschema_rdd.count() 
 1609   
1611          """ 
1612          Return a list that contains all of the rows in this RDD. 
1613   
1614          Each object in the list is on Row, the fields can be accessed as 
1615          attributes. 
1616          """ 
1617          rows = RDD.collect(self) 
1618          cls = _create_cls(self.schema()) 
1619          return map(cls, rows) 
 1620   
1621       
1622       
1624          """ 
1625          Return a new RDD by applying a function to each partition of this RDD, 
1626          while tracking the index of the original partition. 
1627   
1628          >>> rdd = sc.parallelize([1, 2, 3, 4], 4) 
1629          >>> def f(splitIndex, iterator): yield splitIndex 
1630          >>> rdd.mapPartitionsWithIndex(f).sum() 
1631          6 
1632          """ 
1633          rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) 
1634   
1635          schema = self.schema() 
1636   
1637          def applySchema(_, it): 
1638              cls = _create_cls(schema) 
1639              return itertools.imap(cls, it) 
 1640   
1641          objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) 
1642          return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) 
 1643   
1644       
1645       
1646       
1648          self.is_cached = True 
1649          self._jschema_rdd.cache() 
1650          return self 
 1651   
1653          self.is_cached = True 
1654          javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) 
1655          self._jschema_rdd.persist(javaStorageLevel) 
1656          return self 
 1657   
1659          self.is_cached = False 
1660          self._jschema_rdd.unpersist(blocking) 
1661          return self 
 1662   
1664          self.is_checkpointed = True 
1665          self._jschema_rdd.checkpoint() 
 1666   
1669   
1671          checkpointFile = self._jschema_rdd.getCheckpointFile() 
1672          if checkpointFile.isPresent(): 
1673              return checkpointFile.get() 
 1674   
1675 -    def coalesce(self, numPartitions, shuffle=False): 
 1678   
1682   
1684          if (other.__class__ is SchemaRDD): 
1685              rdd = self._jschema_rdd.intersection(other._jschema_rdd) 
1686              return SchemaRDD(rdd, self.sql_ctx) 
1687          else: 
1688              raise ValueError("Can only intersect with another SchemaRDD") 
 1689   
1693   
1694 -    def subtract(self, other, numPartitions=None): 
 1695          if (other.__class__ is SchemaRDD): 
1696              if numPartitions is None: 
1697                  rdd = self._jschema_rdd.subtract(other._jschema_rdd) 
1698              else: 
1699                  rdd = self._jschema_rdd.subtract(other._jschema_rdd, 
1700                                                   numPartitions) 
1701              return SchemaRDD(rdd, self.sql_ctx) 
1702          else: 
1703              raise ValueError("Can only subtract another SchemaRDD") 
 1704   
1707      import doctest 
1708      from array import array 
1709      from pyspark.context import SparkContext 
1710       
1711      import pyspark.sql 
1712      from pyspark.sql import Row, SQLContext 
1713      globs = pyspark.sql.__dict__.copy() 
1714       
1715       
1716      sc = SparkContext('local[4]', 'PythonTest', batchSize=2) 
1717      globs['sc'] = sc 
1718      globs['sqlCtx'] = SQLContext(sc) 
1719      globs['rdd'] = sc.parallelize( 
1720          [Row(field1=1, field2="row1"), 
1721           Row(field1=2, field2="row2"), 
1722           Row(field1=3, field2="row3")] 
1723      ) 
1724      jsonStrings = [ 
1725          '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', 
1726          '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' 
1727          '"field6":[{"field7": "row2"}]}', 
1728          '{"field1" : null, "field2": "row3", ' 
1729          '"field3":{"field4":33, "field5": []}}' 
1730      ] 
1731      globs['jsonStrings'] = jsonStrings 
1732      globs['json'] = sc.parallelize(jsonStrings) 
1733      (failure_count, test_count) = doctest.testmod( 
1734          pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) 
1735      globs['sc'].stop() 
1736      if failure_count: 
1737          exit(-1) 
 1738   
1739   
1740  if __name__ == "__main__": 
1741      _test() 
1742