jdk源码解析八之RPC实现(包含序列化源码解析)

    技术2023-11-16  99

    文章目录

    实现简单的RPC功能RpcFrameworkServiceConsumerProvider 序列化解释ObjectInputStream反序列化填充字段值读取类描述以及属性 ObjectOutputStream

    实现简单的RPC功能

    RpcFramework

    package rpc; import java.io.*; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.net.ServerSocket; import java.net.Socket; /** * @author WangChao * @create 2017/10/4 20:26 */ public class RpcFramework { /** * 暴露服务 * * @param port * @throws Exception */ public static void export(final Object object, int port) throws Exception { ServerSocket serverSocket = new ServerSocket(port); while (true) { final Socket socket = serverSocket.accept(); new Thread(new Runnable() { public void run() { ObjectInputStream inputStream = null; try { inputStream = new ObjectInputStream(socket.getInputStream()); String methodName = inputStream.readUTF(); Class<?>[] paramsType = (Class<?>[]) inputStream.readObject(); Object[] params = (Object[]) inputStream.readObject(); ObjectOutputStream outputStream = new ObjectOutputStream(socket.getOutputStream()); Object invoke = object.getClass().getMethod(methodName, paramsType).invoke(object, params); outputStream.writeObject(invoke); } catch (Exception e) { e.printStackTrace(); } finally { try { inputStream.close(); } catch (IOException e) { e.printStackTrace(); } } } }).start(); } } @SuppressWarnings("unchecked") public static <T> T refer(final Class<T> interfaceClass, final int port) { return (T) Proxy.newProxyInstance(interfaceClass.getClassLoader(), new Class[]{interfaceClass}, new InvocationHandler() { public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { Socket socket = new Socket("127.0.0.1", port); ObjectOutputStream outputStream = null; ObjectInputStream inputStream = null; Object result = null; try { outputStream = new ObjectOutputStream(socket.getOutputStream()); outputStream.writeUTF(method.getName()); outputStream.writeObject(method.getParameterTypes()); outputStream.writeObject(args); inputStream = new ObjectInputStream(socket.getInputStream()); result = inputStream.readObject(); } catch (Exception e) { } finally { inputStream.close(); outputStream.close(); socket.close(); } return result; } }); } }

    Service

    package service; /** * @author WangChao * @create 2017/10/4 20:24 */ public interface Service { void hello(int i); void hello(); int getValue(); } package service; /** * @author WangChao * @create 2017/10/4 20:24 */ public class ServiceImpl implements Service { public void hello() { System.out.println("大家好啊"); } public int getValue() { return 0; } public void hello(int i) { System.out.println("大家好啊" + i); } }

    Consumer

    import rpc.RpcFramework; import service.Service; /** * @author WangChao * @create 2017/10/4 20:50 */ public class Consumer { public static void main(String[] args) { Service refer = RpcFramework.refer(Service.class, 1234); refer.hello(1); refer.hello(); System.out.println(refer.getValue()); } }

    Provider

    import rpc.RpcFramework; import service.Service; import service.ServiceImpl; /** * @author WangChao * @create 2017/10/4 20:51 */ public class Provider { public static void main(String[] args) throws Exception { Service service = new ServiceImpl(); RpcFramework.export(service, 1234); } }

    序列化解释

    package org.example.io; import java.io.Serializable; //类通过实现 java.io.Serializable 接口以启用其序列化功能。 public class Person implements Serializable { //通过id来判断类是否改变 private static final long serialVersionUID = 3; private transient String name; //不需要序列化的成员变量 private int age; public Person(String name, int age) { this.name = name; this.age = age; } @Override public String toString() { return "Person [name=" + name + ", age=" + age + "]"; } }

    0xACED:根据协议文档的约定,由于所有二进制流(文件)都需要在流的头部约定魔数(magic number=STREAM_MAGIC),既java object 序列化后的流的魔数约定为ACED;

    0x0005:然后是流协议版本是short类型(2字节)并且值为5,则十六进制值为0005;

    java byte型长度为1字节,所以0x73 0x72直接对应到字节流上,0x73(TC_OBJECT)代表这是一个对象,0x72(TC_CLASSDESC)代表这个之后开始是对类的描述.

    接着类长度,我也不知道为啥是15 类全限定名 序列化id 0x02:表示可以序列化 0001:表示只有一个字段 49:表示int类型 0003:字段名长度 接着字段名 0x78:该块定义的序列化流结束了 0x70是NULL,没有超类了. 00000002:表示字段值

    ObjectInputStream

    public ObjectInputStream(InputStream in) throws IOException { //校验是否是子类 verifySubclass(); bin = new BlockDataInputStream(in); handles = new HandleTable(10); vlist = new ValidationList(); serialFilter = ObjectInputFilter.Config.getSerialFilter(); enableOverride = false; //验证魔数和版本号 readStreamHeader(); bin.setBlockDataMode(true); } protected void readStreamHeader() throws IOException, StreamCorruptedException { short s0 = bin.readShort(); short s1 = bin.readShort(); //魔数,用来描述这个文件到底啥类型,ACED if (s0 != STREAM_MAGIC || //class版本号,0005 s1 != STREAM_VERSION) { throw new StreamCorruptedException( String.format("invalid stream header: %04X%04X", s0, s1)); } }

    反序列化

    //读取输入流数据,返回一个对象 public final Object readObject() throws IOException, ClassNotFoundException { //为true,则调用readObjectOverride(),调用子类覆盖方法 if (enableOverride) { return readObjectOverride(); } // if nested read, passHandle contains handle of enclosing object int outerHandle = passHandle; try { //加载对象 Object obj = readObject0(false); handles.markDependency(outerHandle, passHandle); ClassNotFoundException ex = handles.lookupException(passHandle); if (ex != null) { throw ex; } if (depth == 0) { vlist.doCallbacks(); } return obj; } finally { passHandle = outerHandle; if (closed && depth == 0) { clear(); } } } protected Object readObjectOverride() throws IOException, ClassNotFoundException { return null; } private Object readObject0(boolean unshared) throws IOException { boolean oldMode = bin.getBlockDataMode(); if (oldMode) { int remain = bin.currentBlockRemaining(); if (remain > 0) { throw new OptionalDataException(remain); } else if (defaultDataEnd) { /* * Fix for 4360508: stream is currently at the end of a field * value block written via default serialization; since there * is no terminating TC_ENDBLOCKDATA tag, simulate * end-of-custom-data behavior explicitly. */ throw new OptionalDataException(true); } bin.setBlockDataMode(false); } byte tc; while ((tc = bin.peekByte()) == TC_RESET) { bin.readByte(); handleReset(); } depth++; totalObjectRefs++; try { switch (tc) { case TC_NULL: return readNull(); case TC_REFERENCE: return readHandle(unshared); case TC_CLASS: return readClass(unshared); case TC_CLASSDESC: case TC_PROXYCLASSDESC: return readClassDesc(unshared); case TC_STRING: case TC_LONGSTRING: return checkResolve(readString(unshared)); case TC_ARRAY: return checkResolve(readArray(unshared)); case TC_ENUM: return checkResolve(readEnum(unshared)); //0x73 则新建对象 case TC_OBJECT: return checkResolve(readOrdinaryObject(unshared)); case TC_EXCEPTION: IOException ex = readFatalException(); throw new WriteAbortedException("writing aborted", ex); case TC_BLOCKDATA: case TC_BLOCKDATALONG: if (oldMode) { bin.setBlockDataMode(true); bin.peek(); // force header read throw new OptionalDataException( bin.currentBlockRemaining()); } else { throw new StreamCorruptedException( "unexpected block data"); } case TC_ENDBLOCKDATA: if (oldMode) { throw new OptionalDataException(true); } else { throw new StreamCorruptedException( "unexpected end of block data"); } default: throw new StreamCorruptedException( String.format("invalid type code: %02X", tc)); } } finally { depth--; bin.setBlockDataMode(oldMode); } } private Object readOrdinaryObject(boolean unshared) throws IOException { if (bin.readByte() != TC_OBJECT) { throw new InternalError(); } //获取类的描述 ObjectStreamClass desc = readClassDesc(false); desc.checkDeserialize(); Class<?> cl = desc.forClass(); if (cl == String.class || cl == Class.class || cl == ObjectStreamClass.class) { throw new InvalidClassException("invalid class descriptor"); } Object obj; try { obj = desc.isInstantiable() ? desc.newInstance() : null; } catch (Exception ex) { throw (IOException) new InvalidClassException( desc.forClass().getName(), "unable to create instance").initCause(ex); } passHandle = handles.assign(unshared ? unsharedMarker : obj); ClassNotFoundException resolveEx = desc.getResolveException(); if (resolveEx != null) { handles.markException(passHandle, resolveEx); } if (desc.isExternalizable()) { readExternalData((Externalizable) obj, desc); } else { //填充字段数据 readSerialData(obj, desc); } handles.finish(passHandle); if (obj != null && handles.lookupException(passHandle) == null && desc.hasReadResolveMethod()) { Object rep = desc.invokeReadResolve(obj); if (unshared && rep.getClass().isArray()) { rep = cloneArray(rep); } if (rep != obj) { // Filter the replacement object if (rep != null) { if (rep.getClass().isArray()) { filterCheck(rep.getClass(), Array.getLength(rep)); } else { filterCheck(rep.getClass(), -1); } } handles.setObject(passHandle, obj = rep); } } return obj; } private ObjectStreamClass readClassDesc(boolean unshared) throws IOException { byte tc = bin.peekByte(); ObjectStreamClass descriptor; switch (tc) { case TC_NULL: descriptor = (ObjectStreamClass) readNull(); break; case TC_REFERENCE: descriptor = (ObjectStreamClass) readHandle(unshared); break; case TC_PROXYCLASSDESC: descriptor = readProxyDesc(unshared); break; //0x72 对类的描述 case TC_CLASSDESC: descriptor = readNonProxyDesc(unshared); break; default: throw new StreamCorruptedException( String.format("invalid type code: %02X", tc)); } return descriptor; } private ObjectStreamClass readNonProxyDesc(boolean unshared) throws IOException { if (bin.readByte() != TC_CLASSDESC) { throw new InternalError(); } ObjectStreamClass desc = new ObjectStreamClass(); int descHandle = handles.assign(unshared ? unsharedMarker : desc); passHandle = NULL_HANDLE; ObjectStreamClass readDesc = null; try { //读取类的说明 readDesc = readClassDescriptor(); } catch (ClassNotFoundException ex) { throw (IOException) new InvalidClassException( "failed to read class descriptor").initCause(ex); } Class<?> cl = null; ClassNotFoundException resolveEx = null; bin.setBlockDataMode(true); final boolean checksRequired = isCustomSubclass(); try { if ((cl = resolveClass(readDesc)) == null) { resolveEx = new ClassNotFoundException("null class"); } else if (checksRequired) { ReflectUtil.checkPackageAccess(cl); } } catch (ClassNotFoundException ex) { resolveEx = ex; } // Call filterCheck on the class before reading anything else filterCheck(cl, -1); skipCustomData(); try { totalObjectRefs++; depth++; desc.initNonProxy(readDesc, cl, resolveEx, readClassDesc(false)); } finally { depth--; } handles.finish(descHandle); passHandle = descHandle; return desc; } protected ObjectStreamClass readClassDescriptor() throws IOException, ClassNotFoundException { ObjectStreamClass desc = new ObjectStreamClass(); desc.readNonProxy(this); return desc; }

    填充字段值

    private void readSerialData(Object obj, ObjectStreamClass desc) throws IOException { ObjectStreamClass.ClassDataSlot[] slots = desc.getClassDataLayout(); for (int i = 0; i < slots.length; i++) { ObjectStreamClass slotDesc = slots[i].desc; if (slots[i].hasData) { if (obj == null || handles.lookupException(passHandle) != null) { defaultReadFields(null, slotDesc); // skip field values } else if (slotDesc.hasReadObjectMethod()) { ThreadDeath t = null; boolean reset = false; SerialCallbackContext oldContext = curContext; if (oldContext != null) oldContext.check(); try { curContext = new SerialCallbackContext(obj, slotDesc); bin.setBlockDataMode(true); slotDesc.invokeReadObject(obj, this); } catch (ClassNotFoundException ex) { /* * In most cases, the handle table has already * propagated a CNFException to passHandle at this * point; this mark call is included to address cases * where the custom readObject method has cons'ed and * thrown a new CNFException of its own. */ handles.markException(passHandle, ex); } finally { do { try { curContext.setUsed(); if (oldContext!= null) oldContext.check(); curContext = oldContext; reset = true; } catch (ThreadDeath x) { t = x; // defer until reset is true } } while (!reset); if (t != null) throw t; } /* * defaultDataEnd may have been set indirectly by custom * readObject() method when calling defaultReadObject() or * readFields(); clear it to restore normal read behavior. */ defaultDataEnd = false; } else { //填充值 defaultReadFields(obj, slotDesc); } if (slotDesc.hasWriteObjectData()) { skipCustomData(); } else { bin.setBlockDataMode(false); } } else { if (obj != null && slotDesc.hasReadObjectNoDataMethod() && handles.lookupException(passHandle) == null) { slotDesc.invokeReadObjectNoData(obj); } } } } private void defaultReadFields(Object obj, ObjectStreamClass desc) throws IOException { Class<?> cl = desc.forClass(); if (cl != null && obj != null && !cl.isInstance(obj)) { throw new ClassCastException(); } //获取读取数据长度 int primDataSize = desc.getPrimDataSize(); if (primVals == null || primVals.length < primDataSize) { primVals = new byte[primDataSize]; } //读取数据 bin.readFully(primVals, 0, primDataSize, false); //赋值 if (obj != null) { desc.setPrimFieldValues(obj, primVals); } int objHandle = passHandle; ObjectStreamField[] fields = desc.getFields(false); Object[] objVals = new Object[desc.getNumObjFields()]; int numPrimFields = fields.length - objVals.length; for (int i = 0; i < objVals.length; i++) { ObjectStreamField f = fields[numPrimFields + i]; objVals[i] = readObject0(f.isUnshared()); if (f.getField() != null) { handles.markDependency(objHandle, passHandle); } } if (obj != null) { desc.setObjFieldValues(obj, objVals); } passHandle = objHandle; }

    读取类描述以及属性

    void readNonProxy(ObjectInputStream in) throws IOException, ClassNotFoundException { //获取类全限定名 name = in.readUTF(); //读取之后8个字节,获取序列化id suid = Long.valueOf(in.readLong()); isProxy = false; //0002:表示可序列化 byte flags = in.readByte(); hasWriteObjectData = ((flags & ObjectStreamConstants.SC_WRITE_METHOD) != 0); hasBlockExternalData = ((flags & ObjectStreamConstants.SC_BLOCK_DATA) != 0); externalizable = ((flags & ObjectStreamConstants.SC_EXTERNALIZABLE) != 0); boolean sflag = ((flags & ObjectStreamConstants.SC_SERIALIZABLE) != 0); if (externalizable && sflag) { throw new InvalidClassException( name, "serializable and externalizable flags conflict"); } serializable = externalizable || sflag; isEnum = ((flags & ObjectStreamConstants.SC_ENUM) != 0); if (isEnum && suid.longValue() != 0L) { throw new InvalidClassException(name, "enum descriptor has non-zero serialVersionUID: " + suid); } //读取2个字节,代表了类中域的个数 int numFields = in.readShort(); if (isEnum && numFields != 0) { throw new InvalidClassException(name, "enum descriptor has non-zero field count: " + numFields); } fields = (numFields > 0) ? new ObjectStreamField[numFields] : NO_FIELDS; //遍历序列化字段 for (int i = 0; i < numFields; i++) { //读取字段类型 char tcode = (char) in.readByte(); //获取字段名,内部会先读取字段的长度,然后读取长度的字节也就是字段名称了 String fname = in.readUTF(); String signature = ((tcode == 'L') || (tcode == '[')) ? in.readTypeString() : new String(new char[] { tcode }); try { fields[i] = new ObjectStreamField(fname, signature, false); } catch (RuntimeException e) { throw (IOException) new InvalidClassException(name, "invalid descriptor for field " + fname).initCause(e); } } computeFieldOffsets(); }

    ObjectOutputStream

    Processed: 0.012, SQL: 9