package org.scava.util.proxy;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import scala.ScalaObject;
/**
* The Proxy Enhancer class inspired and implemented by a subset of ASM
* ([url=http://asm.ow2.org/]ASM - Home Page[/url]) to enable the creation of Proxy-Classes for Scala
* Traits which are needs to be able to proxy abstract Java-Classes.
*
* @author noctarius
*/
public class ScavaProxyFactory {
private final Map<Class<?>, Class<?>> proxyClassMapping = new HashMap<Class<?>, Class<?>>();
/**
* Creates a new instance of this ScalaProxyFactory.
*
* @return A new factory instance for proxy creation
*/
public static ScavaProxyFactory newInstance() {
return new ScavaProxyFactory();
}
// Prevent external instantiation of this Utility-Class
private ScavaProxyFactory() {
}
/**
* Creates a new Proxy by dynamically subclassing the given class and
* extending all abstract marked methods by their Scala Trait counterparts
* which are created in special classes as static methods by the Scala
* Compiler.<br>
* As the classloader to search for classes and interfaces that are used by
* the given class, the classloader, used to load the given class, will be
* used.
*
* @param <T>
* The type of the class the be proxied
* @param clazz
* The class a Proxy will be created for
* @return The Trait Proxy implementing the abstract methods
*/
public <T> T newProxyInstance(Class<T> clazz) {
return newProxyInstance(clazz, clazz.getClassLoader());
}
/**
* Creates a new Proxy by dynamically subclassing the given class and
* extending all abstract marked methods by their Scala Trait counterparts
* which are created in special classes as static methods by the Scala
* Compiler.<br>
* A special classloader can be given to search for classes and interfaces
* that are used by the given class.
*
* @param <T>
* The type of the class the be proxied
* @param clazz
* The class a Proxy will be created for
* @param classLoader
* The Classloader to be used for searching interfaces and
* classes
* @return The Trait Proxy implementing the abstract methods
*/
@SuppressWarnings("unchecked")
public <T> T newProxyInstance(final Class<T> clazz, ClassLoader classLoader) {
if (clazz == null) {
throw new NullPointerException("clazz cannot be null");
}
if (classLoader == null) {
classLoader = clazz.getClassLoader();
}
Class<T> proxyClass = (Class<T>) proxyClassMapping.get(clazz);
if (proxyClass == null) {
final Set<MethodDefinition> scalaMethodDefinitions = new HashSet<MethodDefinition>();
final String[] interfaces = new String[clazz.getInterfaces().length];
for (int i = 0; i < clazz.getInterfaces().length; i++) {
interfaces[i] = Type.getInternalName(clazz.getInterfaces()[i]);
}
for (Method method : clazz.getMethods()) {
if ((method.getModifiers() & Modifier.ABSTRACT) != 0) {
Class<?> interfaze = method.getDeclaringClass();
if (interfaze.isInterface()
&& ScalaObject.class.isAssignableFrom(interfaze)) {
scalaMethodDefinitions.add(buildMethodDefinition(
classLoader, method, interfaze));
} else {
throw new IllegalArgumentException(clazz.getName()
+ " is no legal Scala Trait");
}
}
}
final String proxyClassName = new StringBuilder(clazz
.getSimpleName()).append("$SCAVAPROXY").toString();
final String proxyClassSignature = proxyClassName.replace(".", "/");
final ClassWriter cw = new ClassWriter(0);
cw.visit(Opcodes.V1_5, Opcodes.ACC_PUBLIC + Opcodes.ACC_SYNTHETIC
+ Opcodes.ACC_SUPER, proxyClassName, proxyClassSignature,
Type.getInternalName(clazz), interfaces);
createConstructor(cw, clazz);
for (MethodDefinition def : scalaMethodDefinitions) {
pushInvokeOfStaticTraitMethod(cw, def);
}
cw.visitEnd();
ASMClassLoader<T> asmClassLoader = new ASMClassLoader<T>(
classLoader);
proxyClass = asmClassLoader.loadClass(cw.toByteArray());
proxyClassMapping.put(clazz, proxyClass);
}
try {
return proxyClass.newInstance();
} catch (InstantiationException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
/**
* Removes a bind class from cache to free the ProxyClass
*
* @param clazz
* Class to be removed from ProxyClass Cache
*/
public void unbindClass(Class<?> clazz) {
proxyClassMapping.remove(clazz);
}
private void pushInvokeOfStaticTraitMethod(final ClassWriter cw,
final MethodDefinition def) {
final MethodWriter mw = cw.visitMethod(Opcodes.ACC_PUBLIC
+ Opcodes.ACC_SYNTHETIC + Opcodes.ACC_FINAL, def.name,
def.description, null, def.buildExceptionArray());
mw.visitCode();
prepareStackForExecution(mw, def);
mw.visitMethodInsn(Opcodes.INVOKESTATIC, Type
.getInternalName(def.scalaStaticTraitClass), def.name, def
.buildStaticMethodDescriptor());
pushReturnOpcodeByValue(mw, def);
mw.visitMaxs(def.types.length + 1, def.types.length + 1);
mw.visitEnd();
}
private void pushReturnOpcodeByValue(final MethodWriter mw,
final MethodDefinition def) {
switch (def.returnType.getSort()) {
case Type.VOID:
mw.visitInsn(Opcodes.RETURN);
break;
case Type.BYTE:
case Type.SHORT:
case Type.CHAR:
case Type.INT:
mw.visitInsn(Opcodes.IRETURN);
break;
case Type.LONG:
mw.visitInsn(Opcodes.LRETURN);
break;
case Type.FLOAT:
mw.visitInsn(Opcodes.FRETURN);
break;
case Type.DOUBLE:
mw.visitInsn(Opcodes.DRETURN);
break;
default:
mw.visitInsn(Opcodes.ARETURN);
}
}
private void prepareStackForExecution(final MethodWriter mw,
final MethodDefinition def) {
// Load this instance on stack
mw.visitVarInsn(Opcodes.ALOAD, 0);
// Load parameters on stack
for (int i = 1; i <= def.types.length; i++) {
Type type = def.types[i - 1];
switch (type.getSort()) {
case Type.BYTE:
case Type.SHORT:
case Type.CHAR:
case Type.INT:
mw.visitVarInsn(Opcodes.ILOAD, i);
break;
case Type.LONG:
mw.visitVarInsn(Opcodes.LLOAD, i);
break;
case Type.FLOAT:
mw.visitVarInsn(Opcodes.FLOAD, i);
break;
case Type.DOUBLE:
mw.visitVarInsn(Opcodes.DLOAD, i);
break;
default:
mw.visitVarInsn(Opcodes.ALOAD, i);
}
}
}
private void createConstructor(final ClassWriter cw, final Class<?> clazz) {
final MethodWriter mv = cw.visitMethod(Opcodes.ACC_PUBLIC, "<init>",
"()V", null, null);
mv.visitCode();
mv.visitVarInsn(Opcodes.ALOAD, 0);
mv.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(clazz),
"<init>", "()V");
mv.visitInsn(Opcodes.RETURN);
mv.visitMaxs(1, 1);
mv.visitEnd();
}
private MethodDefinition buildMethodDefinition(
final ClassLoader classLoader, final Method method,
final Class<?> interfaze) {
try {
final Class<?> scalaStaticTraintClass = retrieveScalaTraitClass(
classLoader, interfaze);
final Type[] types = extractParameterTypes(method);
final Type[] exceptions = extractExceptionTypes(method);
return new MethodDefinition(scalaStaticTraintClass, interfaze,
method.getName(), Type.getMethodDescriptor(method), types,
exceptions, Type.getType(method.getReturnType()));
} catch (ClassNotFoundException e) {
throw new IllegalStateException(
"Scala Static Trait Class could not be found", e);
}
}
private Type[] extractExceptionTypes(final Method method) {
final Class<?>[] exceptionClasses = method.getExceptionTypes();
final Type[] exceptions = new Type[exceptionClasses.length];
for (int i = 0; i < exceptionClasses.length; i++) {
exceptions[i] = Type.getType(exceptionClasses[i]);
}
return exceptions;
}
private Type[] extractParameterTypes(final Method method) {
final Class<?>[] parameterClasses = method.getParameterTypes();
final Type[] types = new Type[parameterClasses.length];
for (int i = 0; i < parameterClasses.length; i++) {
types[i] = Type.getType(parameterClasses[i]);
}
return types;
}
private Class<?> retrieveScalaTraitClass(final ClassLoader classLoader,
final Class<?> interfaze) throws ClassNotFoundException {
String className = Type.getInternalName(interfaze).replace("/", ".")
+ "$class";
return classLoader.loadClass(className);
}
private static class ASMClassLoader<T> extends ClassLoader {
private ASMClassLoader(ClassLoader parent) {
super(parent);
}
@SuppressWarnings("unchecked")
public Class<T> loadClass(final byte[] data) {
return (Class<T>) defineClass(null, data, 0, data.length);
}
}
private static class MethodDefinition {
final Class<?> scalaStaticTraitClass;
final Class<?> interfaceClass;
final String name;
final String description;
final Type[] types;
final Type[] exceptions;
final Type returnType;
private MethodDefinition(final Class<?> scalaStaticTraitClass,
final Class<?> interfaceClass, final String name,
final String description, final Type[] types,
final Type[] exceptions, final Type returnType) {
this.scalaStaticTraitClass = scalaStaticTraitClass;
this.interfaceClass = interfaceClass;
this.name = name;
this.description = description;
this.types = types;
this.exceptions = exceptions;
this.returnType = returnType;
}
private String[] buildExceptionArray() {
final String[] exceptionClasses = new String[exceptions.length];
for (int i = 0; i < exceptions.length; i++) {
exceptionClasses[i] = exceptions[i].getInternalName();
}
return exceptionClasses;
}
private String buildStaticMethodDescriptor() {
final Type[] methodTypes = new Type[types.length + 1];
methodTypes[0] = Type.getType(interfaceClass);
System.arraycopy(types, 0, methodTypes, 1, types.length);
return Type.getMethodDescriptor(returnType, methodTypes);
}
}
}