
    vhQ                       d dl mZ d dlmZ d dlZd dlZd dlZd dlmZm	Z	 d dl
mZ d dlmZmZmZ d dlmZ d dlmZ d d	lmZ d d
l
mZ d dlmZ d dlmZ d dl
mZ d dlZe e!df   Z"ejF                  Z#e e$df   Z%ee#   Z& G d d      Z' G d d      Z( e(       Z)eZ*	 ejV                  e*e!f   Z,e	e,e'e(f   Z-d Z. eej^                         G d dej`                               Z/de/_1        	 	 	 	 d,dZ2ejf                   G d d             Z4d Z5 ejf                  d       G d d             Z6d-dZ7 ed d!"      	 	 	 	 d.d#       Z8d/d$Z9 ed%d!"       e:       fd&       Z; G d' d(e<      Z=d0	 d1d)Z>d* Z?d+ Z@y)2    )annotations)SequenceN)AnyUnion)config)use_cpp_classcacheuse_cpp_method)jaxlib_extension_version)
xla_client)sdymesh)AxisType)PartitionSpec)sharding.c                      e Zd ZddZddZy)AUTOc                    || _         y Nr   )selfr   s     R/opt/face_recognition/venv/lib/python3.12/site-packages/jax/_src/named_sharding.py__init__zAUTO.__init__*   s	    DI    c                    t        |      D cg c]  }t        g d       }}t        | j                  j                  |      S c c}w )NTaxesis_open)
mesh_shapedim_shardings)rangeSdyDimSdyArrayr   shape_tuple)r   ndim_r    s       r   _to_sdy_shardingzAUTO._to_sdy_sharding-   sK    #Dk+ T2 +M +tyy44"/1 1+s   AN)r   zmesh_lib.Mesh)r%   intreturnr#   )__name__
__module____qualname__r   r'    r   r   r   r   (   s    1r   r   c                      e Zd Zd Zy)UnspecifiedValuec                     y)Nr/   r-   r   s    r   __repr__zUnspecifiedValue.__repr__4   s    r   N)r*   r+   r,   r2   r-   r   r   r/   r/   3   s    r   r/   c                     t        | |||      S )Nmemory_kind_logical_device_ids)NamedSharding)r   specr5   logical_device_idss       r   _unpickle_named_shardingr:   N   s    	tT{+=
? ?r   c                     e Zd ZU dZded<   ded<   ded<   ded	<    e       d
d
d	 	 	 	 	 dd       Zd Zd Ze	d d       Z
 eedk\        d        Z eedk\        d        Zd!dZe	d"d       Ze	d#d       Ze	d$d       Ze	d%d       Ze	d%d       Ze	d#d       Zej,                  d%d       Zd&dZd'dZd(dZd)dZy
)*r7   a  A :class:`NamedSharding` expresses sharding using named axes.

  A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
  :class:`PartitionSpec` which describes how to shard an array across that
  mesh.

  A :class:`Mesh` is a multidimensional NumPy array of JAX devices,
  where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``.

  A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``,
  a mesh axis, or a tuple of mesh axes. Each element describes how an input
  dimension is partitioned across zero or more mesh dimensions. For example,
  ``PartitionSpec('x', 'y')`` says that the first dimension of data
  is sharded across ``x`` axis of the mesh, and the second dimension is sharded
  across ``y`` axis of the mesh.

  The Distributed arrays and automatic parallelization
  (https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)
  tutorial has more details and diagrams that explain how
  :class:`Mesh` and :class:`PartitionSpec` are used.

  Args:
    mesh: A :class:`jax.sharding.Mesh` object.
    spec: A :class:`jax.sharding.PartitionSpec` object.

  Examples:

    >>> from jax.sharding import Mesh
    >>> from jax.sharding import PartitionSpec as P
    >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
    >>> spec = P('x', 'y')
    >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
  %mesh_lib.Mesh | mesh_lib.AbstractMeshr   r   r8   
str | None_memory_kindtuple[int, ...] | Noner6   Nr4   c               |    || _         || _        || _        || _        t	        | j                   | j                         y r   )r   r8   r>   r6   check_pspec)r   r   r8   r5   r6   s        r   r   zNamedSharding.__init__|   s4     DIDI#D2D		499%r   c                    | j                   dnd| j                    }| j                  dnd| j                   }t        | j                         }d| d| j                   | | dS )N z, memory_kind=z, logical_device_ids=zNamedSharding(mesh=z, spec=))r5   r6   strr   r8   )r   memldi	mesh_reprs       r   r2   zNamedSharding.__repr__   sw      ("t?O?O>P.QC))12"4#;#;"<= tyy>"I 7499+cU3%qIIr   c                j    t         | j                  | j                  | j                  | j                  ffS r   )r:   r   r8   r5   r6   r1   s    r   
__reduce__zNamedSharding.__reduce__   s1    $YY		4#3#3T5M5MNP Pr   c                    | j                   S r   )r>   r1   s    r   r5   zNamedSharding.memory_kind   s    r   ia  c                    t        | d      s<t        | j                  | j                  | j                  | j
                  f      | _        | j                  S )N_hash)hasattrhashr   r5   r8   r6   rM   r1   s    r   __hash__zNamedSharding.__hash__   sC    4!99d&&		43K3K
LNdj::r   c                ,   t        |t              sy| |u ry| j                  |j                  k7  s2| j                  |j                  k7  s| j                  |j                  k7  ry| j
                  |j
                  u xs | j
                  |j
                  k(  S NFT)
isinstancer7   r8   r5   r6   r   )r   others     r   __eq__zNamedSharding.__eq__   sy    e]+u}		UZZu000##u'@'@@99

"=dii5::&==r   c           
         t        |      t        | j                        k  rEt        |      dk(  rdnd}t        d|  dt        | j                         dt        |       d|       y )Nr   z, For scalars the PartitionSpec should be P()rC   z	Sharding z+ is only valid for values of rank at least z%, but was applied to a value of rank .)lenr8   
ValueError)r   
aval_shape	extra_msgs      r   check_compatible_avalz#NamedSharding.check_compatible_aval   sr    
:TYY'*o* B02 dVF^A_Qyk+, , (r   c                .    | j                   j                  S r   )r   sizer1   s    r   num_deviceszNamedSharding.num_devices   s    99>>r   c                    t        | j                  t        j                        rt	        d      | j                  j
                  S )Nz>device_set is not implemented for `jax.sharding.AbstractMesh`.)rS   r   mesh_libAbstractMeshrY   _flat_devices_setr1   s    r   
device_setzNamedSharding.device_set   s8    $))X223
JL L99&&&r   c                    t        | j                  t        j                        rt	        d      | j                  j
                  S )NzF_device_assignment is not implemented for `jax.sharding.AbstractMesh`.)rS   r   ra   rb   rY   _flat_devices_tupler1   s    r   _device_assignmentz NamedSharding._device_assignment   s8    $))X223 7 8 899(((r   c                    t        | j                  t        j                        rt	        d      t
        j                  j                  r| j                  j                  S | j                  j                   S )NzHis_fully_addressable is not implemented for `jax.sharding.AbstractMesh`.)rS   r   ra   rb   rY   r   enable_empty_arraysvalue_internal_device_listis_fully_addressableis_multi_processr1   s    r   rl   z"NamedSharding.is_fully_addressable   s^    $))X223 6 7 7 !!''''<<<yy))))r   c                N    t        | j                  t        j                        ryyrR   )rS   r   ra   rb   r1   s    r   _is_concretezNamedSharding._is_concrete   s    $))X223r   c                    t        | j                  t        j                        rt	        d      | j                  j
                  S )NzGaddressable_devices is not implemented for `jax.sharding.AbstractMesh`.)rS   r   ra   rb   rY   _local_devices_setr1   s    r   addressable_devicesz!NamedSharding.addressable_devices   s:    $))X223 6 7 7 99'''r   c                    | j                   j                  dk(  ryt        | j                        }| j                   j                  }d}|D ]
  }|||   z  } |dk(  S )N   T)r   r^   get_array_mappingr8   shape)r   array_mappingr   num_partitionsnames        r   is_fully_replicatedz!NamedSharding.is_fully_replicated   s^    yy~~%dii0MJN )
4((n)Qr   c                &    | j                  |      S )N)r5   )update)r   kinds     r   with_memory_kindzNamedSharding.with_memory_kind   s    ;;4;((r   c           	     $   |j                  d| j                        }t        |t              st        | }t	        |j                  d| j
                        ||j                  d| j                        |j                  d| j                              S )Nr8   r   r5   r6   )r   r8   r5   r6   )popr8   rS   r   r7   r   r5   r6   )r   kwargsr8   s      r   r|   zNamedSharding.update   s{    ::fdii(DdM*D!dZZ		*JJ}d.>.>?"JJ'<'+'?'?A	B Br   c                    t        | |      S r   )"named_sharding_to_xla_hlo_sharding)r   num_dimensionss     r   _to_xla_hlo_shardingz"NamedSharding._to_xla_hlo_sharding   s    -dNCCr   c                   t        |      D cg c]  }t        g d       }}t        | j                        D ]D  \  }}|t        j
                  u rd||   _        #|&t        |t              r|n|f}|||   _	        F t        | j                  j                  || j                  | j                  j                        S c c}w )NFr   T)r   r    r9   unreduced_axes)r!   r"   	enumerater8   r   UNCONSTRAINEDr   rS   tupler   r#   r   r$   r6   	unreduced)r   r   r&   r    idim_specs         r   r'   zNamedSharding._to_sdy_sharding   s    #N35 U3 5M 5 + )8	]00	0#'a )(E:8 (a) tyy44"/'+'?'?#'99#6#68 85s   C)r   r<   r8   r   r5   r=   )r)   r=   )rZ   Shaper)   None)r)   r(   )r)   zset[Device])r)   XLADeviceAssignment)r)   bool)r}   rE   r)   r7   )r)   r7   r   r(   r)   zxc.HloSharding)r   r(   r)   r#   )r*   r+   r,   __doc____annotations__r
   r   r2   rJ   propertyr5   r   rP   rU   r\   r_   rd   rg   rl   ro   rr   	functoolscached_propertyrz   r~   r|   r   r'   r-   r   r   r7   r7   S   sa    D 	.--- !%$&7&?L&& &JP   *c12 3 *c12	> 3	>,   ' ' ) ) * *  
 ( (  )	BD8r   r7   zjax.shardingc                    t        | t        t        f      r| S t        j                         }t        |       D ];  \  }}||t        j                  u rt        |t              r|n|f}|D ]  }|||<   	 = |S r   )	rS   r   r/   collectionsOrderedDictr   r   r   r   )axis_resourcesdr   r   axiss        r   ru   ru     s     '7 89!>* ga|t}:::dE*4D ag	 
(r   c                  D    e Zd ZU ded<   ded<   dZded<   ddZd	 Zd
 Zy)r"   zSequence[str]r   r   r   Nz
int | Nonepriorityc                    t         j                  j                  | j                  D cg c]!  }t         j                  j                  |      # c}| j
                   | j                        S c c}w )N)	is_closedr   )r   DimensionShardingAttrgetr   AxisRefAttrr   r   )r   r   s     r   buildzSdyDim.build   sT    $$((/3yy9t		T	"9ll"T]] ) < <9s   &A-c                *    d| j                          dS )NzSdyDim(rD   _custom_reprr1   s    r   r2   zSdyDim.__repr__%  s    T&&()++r   c                    dj                  d | j                  D              }d}| j                  r| j                  rdnd}| j                  dnd| j                   }d| | d| S )	N, c              3  (   K   | ]
  }d | d   yw)'Nr-   ).0as     r   	<genexpr>z&SdyDim._custom_repr.<locals>.<genexpr>)  s     6qAaS(6s   rC   z, ??p{})joinr   r   r   )r   	axes_repr	open_reprpriority_reprs       r   r   zSdyDim._custom_repr(  se    		6DII66II||99%#i--/Bq5HM	{9+R77r   )r)   zsdy.DimensionShardingAttr)r*   r+   r,   r   r   r   r2   r   r-   r   r   r"   r"     s&    -(J<
,8r   r"   c                :      sy|J t         fd|D              S )Nr-   c              3  2   K   | ]  \  }}|v s|  y wr   r-   )r   nr&   r   s      r   r   z_get_axes.<locals>.<genexpr>6  s     5TQ19q5s   )r   )r   r   s   ` r   	_get_axesr   0  s(    				 
5Z5	55r   T)kw_onlyc                  d    e Zd ZU ded<   ded<   dZded<   dZd	ed
<    e       Zded<   ddZd Z	y)r#   z"tuple[tuple[str, int], ...] | Noner   zSequence[SdyDim]r    Nr?   r9   r-   ztuple[str, ...]replicated_axeszfrozenset[str]r   c                >   | j                    t        j                  j                  g       }n~| j                  g nt        | j                        }t        j                  j                  | j                   D cg c]%  \  }}t        j                  j                  ||      ' c}}|      }t        | j                  | j                         }t        | j                  | j                         }t        j                  j                  || j                  D cg c]  }|j                          c}|D cg c]!  }t        j                  j                  |      # c}|D cg c]!  }t        j                  j                  |      # c}      S c c}}w c c}w c c}w c c}w )N)r   r   )r   r   MeshAttrr   r9   listMeshAxisAttrr   r   r   TensorShardingAttrr    r   r   )	r   	mesh_attrrG   ry   r^   r   r   dim_shardingr   s	            r   r   zSdyArray.build@  s;   ,,""2&i**2R$))* 
,,"">Boo
N
d3d+
N
i   4 4dooFOt22DOODN!!%%262D2DE,			E?NOt,,T2O>LMd++D1M	 & O O O 	FOMs   8*F

F
0&F&Fc                    dj                  d | j                  D              }| j                  d| j                   nd}| j                  rd| j                   nd}d| d| | dS )	Nr   c              3  <   K   | ]  }|j                           y wr   r   )r   r   s     r   r   z$SdyArray.__repr__.<locals>.<genexpr>S  s      "6"6s   z, device_ids=rC   z, replicated_axes=z
SdyArray([]rD   )r   r    r9   r   )r   dim_sharding_reprdevice_id_reprrars       r   r2   zSdyArray.__repr__R  s    		 "6"&"4"4"6 6 00< &d&=&=%>?BD  ""   4 456(* )*!N+;C5BBr   )r)   zsdy.TensorShardingAttr)
r*   r+   r,   r   r9   r   	frozensetr   r   r2   r-   r   r   r#   r#   8  s<    00!!/3,3%'/?'#,;...O$Cr   r#   c                   j                   rg g }}| j                  D ]S  }|j                  |j                  s|j                  st        g d      n|       |j                  |j                         U t        j                        t        |      z
  }t        fd|D              }t        | j                  || j                  |      S | S )NTr   c              3  v   K   | ]0  }j                   |   t        j                  j                  k(  r| 2 y wr   )_name_to_typera   r   Explicitr   rr   s     r   r   z5modify_sdy_sharding_wrt_axis_types.<locals>.<genexpr>g  s9      T!#11!48I8I8R8RR  Ts   69)r   r    r9   r   )_any_axis_autor    appendr   r   r"   extendset
axis_namesr   r#   r   r9   )sdy_shardingr   r    	used_axesr   remaining_axesr   s    `     r   "modify_sdy_sharding_wrt_axis_typesr   ^  s    	!29M'' "#&& "r48@ACqvv	
 )C	N:N T~ T TO|66"/'3'F'F$35 5 
r   i   F)max_sizetrace_context_in_keyc                   | j                   j                  }t        | j                        }t	        | j                   j
                        D ci c]  \  }}||
 }}}i }t        | j                   j                        }|rS| j                   j
                  }	|D ]8  }
t        j                  j                  j                  ||	j                  |
      <   : g }t	        |j                               D ]   \  }\  }}||vs|j                  ||f       " t        |      t        |      k(  r |st        j                   j#                         S g }dg|z  }t%        |j                         d       D ])  \  }}||xx   ||   z  cc<   |j                  ||          + g }|r#t'        j(                  t*              }t'        j(                  d       }|D ch c]  }|d   	 c}j-                  t/        |j1                                     sJ |D ]Z  \  }}|j3                  |t        j                  j                  j4                        }||   j                  |       ||xx   |z  cc<   \ t%        |j                         d       D ];  \  }}|j                  |       |j                  ||          |j7                  |       = |}| j                   j8                  }| j:                  #t        j                   j=                  ||||      S t        j                   j?                  tA        jB                  | j:                        jE                  |      jE                  |      jG                  |      jE                  |      |      S c c}}w c c}w )	Nrt   c                    | d   S Nrt   r-   xs    r   <lambda>z4named_sharding_to_xla_hlo_sharding.<locals>.<lambda>  s
    qt r   )keyc                      yr   r-   r-   r   r   r   z4named_sharding_to_xla_hlo_sharding.<locals>.<lambda>  s    r   r   c                     | d   j                   S )Nr   )rj   r   s    r   r   z4named_sharding_to_xla_hlo_sharding.<locals>.<lambda>  s    qtzz r   )dimsreshape_dimstranspose_permsubgroup_types)r   )$r   rv   ru   r8   r   r   r   manual_axesxc
OpShardingTypeMANUALindexitemsr   rX   HloSharding	replicatesortedr   defaultdictr   
issupersetr   keysr   
REPLICATEDr   
axis_sizesr6   	iota_tilesubgroup_with_device_orderingnpasarrayreshape	transpose)r   r   r   rw   r   ry   mesh_axis_posspecial_axesr   r   manual_axisreplicated_mesh_axes	axis_nameaxis_valmesh_permutationnew_mesh_shapeposlast_tile_dimsaxes_by_typesize_by_typer   r^   tyr   r   r   s                             r   r   r   p  s-    yy*#DII.-*3DII4H4H*IJwq$47J-J,$))//0+%%J" N46MM4F4F4M4Ml:##K01N "+J,<,<,>"? 1a	)X%!!1h-01 		#j/1,>>##%%3'.---/^D 1idC3:d++M$/01 .)4)@)@)FL**95L./QAaD/::3|?P?P?R;STTT' 4Ar}}11<<=b2a 2$ <--/5IJ $DBL,-d#$* 
$%%,	%>>##=M% $ ' ' >>77


4++,	ww|,YY7G-H	~ 8 7 7y K6 0s   NNc                ~   | s
t               S d}t        j                  t              }| j	                         D ]!  \  }}||   j                  |       ||kD  s |}# g }t        |dz         D ]H  }||   }|r.|j                  t        |      dk(  r|d   n
t        |             8|j                  d        J t        | S )Nrt   r   )	r   r   r   r   r   r   r!   rX   r   )rw   	max_indexreverse_mapr   r   
partitionsr   s          r   array_mapping_to_axis_resourcesr    s    	?)''-+"((* kdEd#yi *Q aq>D3t9>QuT{C 

	##r      c                N    t        |d|        t        | |       t        | |       y )NzNamedSharding spec)_check_unique_resources_check_mesh_resource_axis_check_mesh_unreduced)r   r8   _manual_axess      r   rA   rA     s#    $ 4d;D$'d#r   c                  $     e Zd Z fdZd Z xZS )DuplicateSpecErrorc                N    t         |   |       || _        || _        || _        y r   )superr   messager   pspec)r   r  r   r  	__class__s       r   r   zDuplicateSpecError.__init__  s%    	GWDLDIDJr   c                    | j                    S r   )r  r1   s    r   __str__zDuplicateSpecError.__str__  s    ll^r   )r*   r+   r,   r   r!  __classcell__)r  s   @r   r  r    s    r   r  c           
     r   i }d}| D ]T  }|t         j                  u s|t        |t              r|n|f}|D ]#  }|j	                  |d      }|dkD  rd}|dz   ||<   % V |rR|j                         D 	cg c]  \  }}	|	dkD  s| }
}}	t        d| d|  dt        j                  |
       ||       y c c}	}w )	NFr   Trt   z	A single zP specification can map every mesh axis to at most one positional dimension, but z has duplicate entries for )r  r   r  )	r   r   rS   r   r   r   r  ra   	show_axes)r  arg_namer   resource_counts	duplicater   resourcecountr   cmultiple_usess              r   r  r    s    -//) ,aM'''195!tA ,!!(A.e		"'!)oh	,	, #2#8#8#:D41aa!eQDMD
z "338' :&&}568     Ds   3B3B3c                    |D ]  t         j                  u st        t              rnfD ]F  }| j                  vst        d| d| dt         j                  j                                d       t         fdD              rt        d| d dd	j                   fd
D               d       t        j                   j                  vr-t         j                  |v rt        | d j                         y y )NzResource axis: z of z is not found in mesh: rW   c              3  b   K   | ]&  }j                   d       j                   |   k(   ( yw)r   N)r   )r   r   r   r   s     r   r   z,_check_mesh_resource_axis.<locals>.<genexpr>  s0     LQt!!!A$'4+=+=a+@@Ls   ,/zAAxisTypes should be the same in a tuple subset of PartitionSpec: z. Got subset z with axis types: (r   c              3  N   K   | ]  }t        j                  |           y wr   )rE   r   r   s     r   r   z,_check_mesh_resource_axis.<locals>.<genexpr>  s!     FqD$6$6q$9 :Fs   "%rD   z[ cannot contain `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh axis_types: )r   r   rS   r   r   rY   rv   r   allr   r   Auto_axis_types_dict)r   r  r   r   s   `  @r   r  r    s<    KaM'''195!tA B	
$//	!aSUG ,%%*4::??+<%=$>aAB 	BB
 L!LLWM! %iiFAFFGqJK KK mm4000!!U*
' --.	01 1 + 1r   c           	         |j                   D ]x  }|| j                  vrt        d| d| j                  d|      | j                  |   t        j
                  t        j                  fv s_t        d|j                    d|         |j                  D ]x  }|| j                  vrt        d| d| j                  d|      | j                  |   t        j
                  t        j                  fv s_t        d|j                   d|         y )NzUnreduced axes z! is not found in mesh.axis_names=z. Got pspec=z[Unreduced axes can only refer to mesh axes that is of type `Explicit`. Got unreduced axes: z and mesh: zReduced axes zWReduced axes can only refer to mesh axes that is of type `Explicit`. Got reduced axes: )r   r   rY   r   r   r0  Manualreduced)r   r  us      r   r  r    s;   ?? 	aA3@/A B  ! @@..3oo-> ?6 	 == 	a!>doo-? @  ! @@,,1MM? ;6 	r   )r   z'PartitionSpec | AUTO | UnspecifiedValuer)   ArrayMappingOrAutoOrUnspecified)r   r#   r   )rw   ArrayMappingr   )r  r   r%  rE   r)   r   )A
__future__r   collections.abcr   r   dataclassesr   typingr   r   jax._srcr   jax._src.utilr   r	   r
   jax._src.libr   r   r   jax._src.lib.mlir.dialectsr   r   ra   jax._src.meshr   jax._src.partition_specr   r   	JShardingnumpyr   r   r(   r   DevicesliceIndexr   r   r/   UNSPECIFIEDMeshAxisNamer   r7  r6  r:   r7   Shardingr+   ru   	dataclassr"   r   r#   r   r   r  r   rA   	Exceptionr  r  r  r  r-   r   r   <module>rL     s   # $      > > 1 ) * % " 1 * c3h	eSjv& 	1 	1     &&|S'89"'d<L(L"M ?
 r t8I&& t8 !t8l * ;$ 8 8 8*6 t$ C  C % CJ$ 51C7C7"0C7 2C7L$& %0)2 $ 1$
  !% ,1,r   