minor optimizations
This commit is contained in:
parent
5d58695eb4
commit
9750472027
|
|
@ -8,6 +8,7 @@ import android.util.Log
|
||||||
import com.example.livingai.R
|
import com.example.livingai.R
|
||||||
import com.example.livingai.domain.camera.*
|
import com.example.livingai.domain.camera.*
|
||||||
import com.example.livingai.domain.model.camera.*
|
import com.example.livingai.domain.model.camera.*
|
||||||
|
import com.example.livingai.utils.SignedMask
|
||||||
import com.example.livingai.utils.SilhouetteManager
|
import com.example.livingai.utils.SilhouetteManager
|
||||||
import org.tensorflow.lite.Interpreter
|
import org.tensorflow.lite.Interpreter
|
||||||
import org.tensorflow.lite.support.common.FileUtil
|
import org.tensorflow.lite.support.common.FileUtil
|
||||||
|
|
@ -247,8 +248,8 @@ class MockPoseAnalyzer : PoseAnalyzer {
|
||||||
val reference = SilhouetteManager.getWeightedMask(input.orientation)
|
val reference = SilhouetteManager.getWeightedMask(input.orientation)
|
||||||
?: return Instruction("Silhouette missing", isValid = false)
|
?: return Instruction("Silhouette missing", isValid = false)
|
||||||
|
|
||||||
val refH = reference.size
|
val refH = reference.mask.size
|
||||||
val refW = reference[0].size
|
val refW = reference.mask[0].size
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
// 2. Crop only animal region (BIG WIN)
|
// 2. Crop only animal region (BIG WIN)
|
||||||
|
|
@ -319,18 +320,17 @@ class MockPoseAnalyzer : PoseAnalyzer {
|
||||||
*/
|
*/
|
||||||
private fun calculateSignedSimilarityFast(
|
private fun calculateSignedSimilarityFast(
|
||||||
mask: ByteArray,
|
mask: ByteArray,
|
||||||
reference: Array<FloatArray>
|
reference: SignedMask
|
||||||
): Float {
|
): Float {
|
||||||
|
|
||||||
var score = 0f
|
var score = 0f
|
||||||
var maxScore = 0f
|
val maxScore = reference.maxValue
|
||||||
var idx = 0
|
var idx = 0
|
||||||
|
|
||||||
for (y in reference.indices) {
|
for (y in reference.mask.indices) {
|
||||||
val row = reference[y]
|
val row = reference.mask[y]
|
||||||
for (x in row.indices) {
|
for (x in row.indices) {
|
||||||
val r = row[x]
|
val r = row[x]
|
||||||
if (r > 0f) maxScore += r
|
|
||||||
score += mask[idx++] * r
|
score += mask[idx++] * r
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,15 +9,20 @@ import com.example.livingai.R
|
||||||
import java.util.concurrent.ConcurrentHashMap
|
import java.util.concurrent.ConcurrentHashMap
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
|
|
||||||
|
data class SignedMask(
|
||||||
|
val mask: Array<FloatArray>,
|
||||||
|
val maxValue: Float
|
||||||
|
)
|
||||||
|
|
||||||
object SilhouetteManager {
|
object SilhouetteManager {
|
||||||
|
|
||||||
private val originals = ConcurrentHashMap<String, Bitmap>()
|
private val originals = ConcurrentHashMap<String, Bitmap>()
|
||||||
private val invertedPurple = ConcurrentHashMap<String, Bitmap>()
|
private val invertedPurple = ConcurrentHashMap<String, Bitmap>()
|
||||||
private val weightedMasks = ConcurrentHashMap<String, Array<FloatArray>>()
|
private val weightedMasks = ConcurrentHashMap<String, SignedMask>()
|
||||||
|
|
||||||
fun getOriginal(name: String): Bitmap? = originals[name]
|
fun getOriginal(name: String): Bitmap? = originals[name]
|
||||||
fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name]
|
fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name]
|
||||||
fun getWeightedMask(name: String): Array<FloatArray>? = weightedMasks[name]
|
fun getWeightedMask(name: String): SignedMask? = weightedMasks[name]
|
||||||
|
|
||||||
fun initialize(context: Context, width: Int, height: Int) {
|
fun initialize(context: Context, width: Int, height: Int) {
|
||||||
val resources = context.resources
|
val resources = context.resources
|
||||||
|
|
@ -88,7 +93,7 @@ object SilhouetteManager {
|
||||||
bitmap: Bitmap,
|
bitmap: Bitmap,
|
||||||
fadeInside: Int = 10,
|
fadeInside: Int = 10,
|
||||||
fadeOutside: Int = 20
|
fadeOutside: Int = 20
|
||||||
): Array<FloatArray> {
|
): SignedMask {
|
||||||
|
|
||||||
val w = bitmap.width
|
val w = bitmap.width
|
||||||
val h = bitmap.height
|
val h = bitmap.height
|
||||||
|
|
@ -99,24 +104,19 @@ object SilhouetteManager {
|
||||||
fun idx(x: Int, y: Int) = y * w + x
|
fun idx(x: Int, y: Int) = y * w + x
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
// 1. Binary mask: inside = 1, outside = 0
|
// 1. Binary mask
|
||||||
// Assumption: NON-transparent pixels are inside the object
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
val inside = IntArray(w * h)
|
val inside = IntArray(w * h)
|
||||||
for (i in pixels.indices) {
|
for (i in pixels.indices) {
|
||||||
val alpha = pixels[i] ushr 24
|
inside[i] = if ((pixels[i] ushr 24) > 0) 1 else 0
|
||||||
inside[i] = if (alpha > 0) 1 else 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
// 2. Distance to nearest OUTSIDE pixel (for inside pixels)
|
// 2. Distance transform (inside → outside)
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
val distInside = IntArray(w * h) { Int.MAX_VALUE }
|
val distInside = IntArray(w * h) { Int.MAX_VALUE }
|
||||||
for (i in inside.indices) {
|
for (i in inside.indices) if (inside[i] == 0) distInside[i] = 0
|
||||||
if (inside[i] == 0) distInside[i] = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// forward pass
|
|
||||||
for (y in 0 until h) {
|
for (y in 0 until h) {
|
||||||
for (x in 0 until w) {
|
for (x in 0 until w) {
|
||||||
val i = idx(x, y)
|
val i = idx(x, y)
|
||||||
|
|
@ -127,7 +127,6 @@ object SilhouetteManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// backward pass
|
|
||||||
for (y in h - 1 downTo 0) {
|
for (y in h - 1 downTo 0) {
|
||||||
for (x in w - 1 downTo 0) {
|
for (x in w - 1 downTo 0) {
|
||||||
val i = idx(x, y)
|
val i = idx(x, y)
|
||||||
|
|
@ -139,14 +138,11 @@ object SilhouetteManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
// 3. Distance to nearest INSIDE pixel (for outside pixels)
|
// 3. Distance transform (outside → inside)
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
val distOutside = IntArray(w * h) { Int.MAX_VALUE }
|
val distOutside = IntArray(w * h) { Int.MAX_VALUE }
|
||||||
for (i in inside.indices) {
|
for (i in inside.indices) if (inside[i] == 1) distOutside[i] = 0
|
||||||
if (inside[i] == 1) distOutside[i] = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// forward pass
|
|
||||||
for (y in 0 until h) {
|
for (y in 0 until h) {
|
||||||
for (x in 0 until w) {
|
for (x in 0 until w) {
|
||||||
val i = idx(x, y)
|
val i = idx(x, y)
|
||||||
|
|
@ -157,7 +153,6 @@ object SilhouetteManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// backward pass
|
|
||||||
for (y in h - 1 downTo 0) {
|
for (y in h - 1 downTo 0) {
|
||||||
for (x in w - 1 downTo 0) {
|
for (x in w - 1 downTo 0) {
|
||||||
val i = idx(x, y)
|
val i = idx(x, y)
|
||||||
|
|
@ -169,30 +164,33 @@ object SilhouetteManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
// 4. Build signed weight map [-1, +1]
|
// 4. Build signed mask + track max value
|
||||||
// --------------------------------------------------------------------
|
// --------------------------------------------------------------------
|
||||||
val result = Array(h) { FloatArray(w) }
|
val result = Array(h) { FloatArray(w) }
|
||||||
|
var maxValue = Float.NEGATIVE_INFINITY
|
||||||
|
|
||||||
for (y in 0 until h) {
|
for (y in 0 until h) {
|
||||||
for (x in 0 until w) {
|
for (x in 0 until w) {
|
||||||
val i = idx(x, y)
|
val i = idx(x, y)
|
||||||
|
|
||||||
val weight = if (inside[i] == 1) {
|
val weight = if (inside[i] == 1) {
|
||||||
// Inside: +1 → 0
|
|
||||||
val d = distInside[i]
|
val d = distInside[i]
|
||||||
if (d >= fadeInside) 1f
|
if (d >= fadeInside) 1f
|
||||||
else d.toFloat() / fadeInside
|
else d.toFloat() / fadeInside
|
||||||
} else {
|
} else {
|
||||||
// Outside: 0 → -1
|
|
||||||
val d = distOutside[i]
|
val d = distOutside[i]
|
||||||
val v = -(d.toFloat() / fadeOutside)
|
(-d.toFloat() / fadeOutside).coerceAtLeast(-1f)
|
||||||
v.coerceAtLeast(-1f)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result[y][x] = weight
|
result[y][x] = weight
|
||||||
|
if (weight > maxValue) maxValue = weight
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return SignedMask(
|
||||||
|
mask = result,
|
||||||
|
maxValue = maxValue
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue