adding additional checks in pipeline, to reduce segmentation load

This commit is contained in:
SaiD 2025-12-13 13:38:09 +05:30
parent 9750472027
commit 986c0da505
4 changed files with 332 additions and 369 deletions

View File

@ -2,7 +2,6 @@ package com.example.livingai.data.camera
import android.content.Context import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.graphics.Color
import android.graphics.RectF import android.graphics.RectF
import android.util.Log import android.util.Log
import com.example.livingai.R import com.example.livingai.R
@ -10,225 +9,200 @@ import com.example.livingai.domain.camera.*
import com.example.livingai.domain.model.camera.* import com.example.livingai.domain.model.camera.*
import com.example.livingai.utils.SignedMask import com.example.livingai.utils.SignedMask
import com.example.livingai.utils.SilhouetteManager import com.example.livingai.utils.SilhouetteManager
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation
import com.google.mlkit.vision.segmentation.subject.SubjectSegmenterOptions
import kotlinx.coroutines.suspendCancellableCoroutine
import org.tensorflow.lite.Interpreter import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.FileUtil import org.tensorflow.lite.support.common.FileUtil
import java.io.IOException import java.io.IOException
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.ByteOrder import java.nio.ByteOrder
import kotlin.math.abs
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation
import com.google.mlkit.vision.segmentation.subject.SubjectSegmenterOptions
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlin.coroutines.resume import kotlin.coroutines.resume
import kotlin.math.abs
import kotlin.math.min
/* ============================================================= */
/* ORIENTATION CHECKER */
/* ============================================================= */
class DefaultOrientationChecker : OrientationChecker { class DefaultOrientationChecker : OrientationChecker {
override suspend fun analyze(input: PipelineInput): Instruction { override suspend fun analyze(input: PipelineInput): Instruction {
val orientationLower = input.orientation.lowercase()
val isPortraitRequired = orientationLower == "front" || orientationLower == "back"
// Corrected Logic: val isPortraitRequired =
// 90 or 270 degrees means the device is held in PORTRAIT input.orientation.lowercase() == "front" ||
val isDevicePortrait = input.deviceOrientation == 90 || input.deviceOrientation == 270 input.orientation.lowercase() == "back"
// 0 or 180 degrees means the device is held in LANDSCAPE
val isDeviceLandscape = input.deviceOrientation == 0 || input.deviceOrientation == 180
var isValid = true val isPortrait = input.deviceOrientation == 90 || input.deviceOrientation == 270
var message = "Orientation Correct" val isLandscape = input.deviceOrientation == 0 || input.deviceOrientation == 180
if (isPortraitRequired && !isDevicePortrait) { val valid = if (isPortraitRequired) isPortrait else isLandscape
isValid = false
message = "Turn to portrait mode"
} else if (!isPortraitRequired && !isDeviceLandscape) {
isValid = false
message = "Turn to landscape mode"
}
val animRes = if (!isValid) R.drawable.ic_launcher_foreground else null
return Instruction( return Instruction(
message = message, message = if (valid) "Orientation Correct"
animationResId = animRes, else if (isPortraitRequired) "Turn to portrait mode"
isValid = isValid, else "Turn to landscape mode",
result = OrientationResult(input.deviceOrientation, if (isPortraitRequired) CameraOrientation.PORTRAIT else CameraOrientation.LANDSCAPE) animationResId = if (valid) null else R.drawable.ic_launcher_foreground,
isValid = valid,
result = OrientationResult(
input.deviceOrientation,
if (isPortraitRequired) CameraOrientation.PORTRAIT else CameraOrientation.LANDSCAPE
)
) )
} }
} }
/* ============================================================= */
/* TILT CHECKER */
/* ============================================================= */
class DefaultTiltChecker : TiltChecker { class DefaultTiltChecker : TiltChecker {
override suspend fun analyze(input: PipelineInput): Instruction { override suspend fun analyze(input: PipelineInput): Instruction {
Log.d("TiltChecker", "Required Orientation: ${input.requiredOrientation}, Pitch: ${input.devicePitch}, Roll: ${input.deviceRoll}, Azimuth: ${input.deviceAzimuth}")
val tolerance = 25.0f val tolerance = 25f
val isLevel: Boolean
if (input.requiredOrientation == CameraOrientation.PORTRAIT) { val isLevel = when (input.requiredOrientation) {
// Ideal for portrait: pitch around -90, roll around 0 CameraOrientation.PORTRAIT ->
val idealPitch = -90.0f abs(input.devicePitch + 90f) <= tolerance
isLevel = abs(input.devicePitch - idealPitch) <= tolerance CameraOrientation.LANDSCAPE ->
} else { // LANDSCAPE abs(input.devicePitch) <= tolerance
// Ideal for landscape: pitch around 0, roll around +/-90
val idealPitch = 0.0f
isLevel = abs(input.devicePitch - idealPitch) <= tolerance
} }
val message = if (isLevel) "Device is level" else "Keep the phone straight"
return Instruction( return Instruction(
message = message, message = if (isLevel) "Device is level" else "Keep the phone straight",
isValid = isLevel, isValid = isLevel,
result = TiltResult(input.deviceRoll, input.devicePitch, isLevel) result = TiltResult(input.deviceRoll, input.devicePitch, isLevel)
) )
} }
} }
/* ============================================================= */
/* TFLITE OBJECT DETECTOR (PRIMARY + REFERENCE OBJECTS) */
/* ============================================================= */
class TFLiteObjectDetector(context: Context) : ObjectDetector { class TFLiteObjectDetector(context: Context) : ObjectDetector {
private var interpreter: Interpreter? = null private var interpreter: Interpreter? = null
private var labels: List<String> = emptyList() private var labels: List<String> = emptyList()
private var modelInputWidth: Int = 0 private var inputW = 0
private var modelInputHeight: Int = 0 private var inputH = 0
private var maxDetections: Int = 25 private var maxDetections = 25
init { init {
try { try {
val modelBuffer = FileUtil.loadMappedFile(context, "efficientdet-lite0.tflite") interpreter = Interpreter(
interpreter = Interpreter(modelBuffer) FileUtil.loadMappedFile(context, "efficientdet-lite0.tflite")
)
labels = FileUtil.loadLabels(context, "labels.txt") labels = FileUtil.loadLabels(context, "labels.txt")
val inputTensor = interpreter?.getInputTensor(0) val inputShape = interpreter!!.getInputTensor(0).shape()
val inputShape = inputTensor?.shape() inputW = inputShape[1]
if (inputShape != null && inputShape.size >= 3) { inputH = inputShape[2]
modelInputWidth = inputShape[1]
modelInputHeight = inputShape[2]
} else {
Log.e("TFLiteObjectDetector", "Invalid input tensor shape.")
}
val outputTensor = interpreter?.getOutputTensor(0) maxDetections = interpreter!!.getOutputTensor(0).shape()[1]
val outputShape = outputTensor?.shape()
if (outputShape != null && outputShape.size >= 2) {
maxDetections = outputShape[1]
Log.d("TFLiteObjectDetector", "Max detections from model: $maxDetections")
}
Log.d("TFLiteObjectDetector", "TFLite model loaded successfully.")
} catch (e: IOException) { } catch (e: IOException) {
Log.e("TFLiteObjectDetector", "Error loading TFLite model or labels from assets.", e) Log.e("Detector", "Failed to load model", e)
Log.e("TFLiteObjectDetector", "Please ensure 'efficientdet-lite0.tflite' and 'labelmap.txt' are in the 'app/src/main/assets' directory.")
interpreter = null interpreter = null
} }
} }
override suspend fun analyze(input: PipelineInput): Instruction { override suspend fun analyze(input: PipelineInput): Instruction {
if (interpreter == null) {
return Instruction("Object detector not initialized. Check asset files.", isValid = false) val image = input.image
?: return Instruction("Waiting for camera", isValid = false)
val resized = Bitmap.createScaledBitmap(image, inputW, inputH, true)
val buffer = bitmapToBuffer(resized)
val locations = Array(1) { Array(maxDetections) { FloatArray(4) } }
val classes = Array(1) { FloatArray(maxDetections) }
val scores = Array(1) { FloatArray(maxDetections) }
val count = FloatArray(1)
interpreter?.runForMultipleInputsOutputs(
arrayOf(buffer),
mapOf(0 to locations, 1 to classes, 2 to scores, 3 to count)
)
val detections = mutableListOf<Detection>()
for (i in 0 until count[0].toInt()) {
if (scores[0][i] < 0.5f) continue
val label = labels.getOrElse(classes[0][i].toInt()) { "Unknown" }
val b = locations[0][i]
detections += Detection(
label,
scores[0][i],
RectF(
b[1] * image.width,
b[0] * image.height,
b[3] * image.width,
b[2] * image.height
)
)
} }
val image = input.image ?: return Instruction("Waiting for camera...", isValid = false) val primary = detections
.filter { it.label.equals(input.targetAnimal, true) }
.maxByOrNull { it.confidence }
val resizedBitmap = Bitmap.createScaledBitmap(image, modelInputWidth, modelInputHeight, true) val refs = detections
val byteBuffer = convertBitmapToByteBuffer(resizedBitmap) .filter { it !== primary }
.mapIndexed { i, d ->
// Define model outputs with the correct size
val outputLocations = Array(1) { Array(maxDetections) { FloatArray(4) } }
val outputClasses = Array(1) { FloatArray(maxDetections) }
val outputScores = Array(1) { FloatArray(maxDetections) }
val numDetections = FloatArray(1)
val outputs: MutableMap<Int, Any> = HashMap()
outputs[0] = outputLocations
outputs[1] = outputClasses
outputs[2] = outputScores
outputs[3] = numDetections
interpreter?.runForMultipleInputsOutputs(arrayOf(byteBuffer), outputs)
val detectedObjects = mutableListOf<Detection>()
val detectionCount = numDetections[0].toInt()
for (i in 0 until detectionCount) {
val score = outputScores[0][i]
if (score > 0.5f) { // Confidence threshold
val classIndex = outputClasses[0][i].toInt()
val label = labels.getOrElse(classIndex) { "Unknown" }
val location = outputLocations[0][i]
// TF Lite model returns ymin, xmin, ymax, xmax in normalized coordinates
val ymin = location[0] * image.height
val xmin = location[1] * image.width
val ymax = location[2] * image.height
val xmax = location[3] * image.width
val boundingBox = RectF(xmin, ymin, xmax, ymax)
detectedObjects.add(Detection(label, score, boundingBox))
}
}
val targetAnimalDetected = detectedObjects.find { it.label.equals(input.targetAnimal, ignoreCase = true) }
val isValid = targetAnimalDetected != null
val message = if (isValid) {
"${input.targetAnimal} Detected"
} else {
if (detectedObjects.isEmpty()) "No objects detected" else "Animal not detected, move closer or point camera to the animal"
}
val refObjects = detectedObjects
.filter { it !== targetAnimalDetected }
.mapIndexed { index, detection ->
ReferenceObject( ReferenceObject(
id = "ref_$index", id = "ref_$i",
label = detection.label, label = d.label,
bounds = detection.bounds, bounds = d.bounds,
relativeHeight = detection.bounds.height() / image.height, relativeHeight = d.bounds.height() / image.height,
relativeWidth = detection.bounds.width() / image.width, relativeWidth = d.bounds.width() / image.width,
distance = 1.0f // Placeholder distance = null
) )
} }
return Instruction( return Instruction(
message = message, message = if (primary != null) "Cow detected" else "Cow not detected",
isValid = isValid, isValid = primary != null,
result = DetectionResult( result = DetectionResult(
isAnimalDetected = isValid, isAnimalDetected = primary != null,
animalBounds = targetAnimalDetected?.bounds, animalBounds = primary?.bounds,
referenceObjects = refObjects, referenceObjects = refs,
label = targetAnimalDetected?.label, label = primary?.label,
confidence = targetAnimalDetected?.confidence ?: 0f confidence = primary?.confidence ?: 0f
) )
) )
} }
private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer { private fun bitmapToBuffer(bitmap: Bitmap): ByteBuffer {
val byteBuffer = ByteBuffer.allocateDirect(1 * modelInputWidth * modelInputHeight * 3) val buffer = ByteBuffer.allocateDirect(inputW * inputH * 3)
byteBuffer.order(ByteOrder.nativeOrder()) buffer.order(ByteOrder.nativeOrder())
val intValues = IntArray(modelInputWidth * modelInputHeight) val pixels = IntArray(inputW * inputH)
bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height) bitmap.getPixels(pixels, 0, inputW, 0, 0, inputW, inputH)
var pixel = 0 for (p in pixels) {
for (i in 0 until modelInputWidth) { buffer.put(((p shr 16) and 0xFF).toByte())
for (j in 0 until modelInputHeight) { buffer.put(((p shr 8) and 0xFF).toByte())
val `val` = intValues[pixel++] buffer.put((p and 0xFF).toByte())
// Assuming model expects UINT8 [0, 255]
byteBuffer.put(((`val` shr 16) and 0xFF).toByte())
byteBuffer.put(((`val` shr 8) and 0xFF).toByte())
byteBuffer.put((`val` and 0xFF).toByte())
} }
} return buffer
return byteBuffer
} }
data class Detection(val label: String, val confidence: Float, val bounds: RectF) data class Detection(val label: String, val confidence: Float, val bounds: RectF)
} }
/* ============================================================= */
/* POSE ANALYZER (ALIGNMENT → CROP → SEGMENT) */
/* ============================================================= */
class MockPoseAnalyzer : PoseAnalyzer { class MockPoseAnalyzer : PoseAnalyzer {
private val segmenter by lazy { private val segmenter by lazy {
val options = SubjectSegmenterOptions.Builder() SubjectSegmentation.getClient(
SubjectSegmenterOptions.Builder()
.enableForegroundConfidenceMask() .enableForegroundConfidenceMask()
.build() .build()
SubjectSegmentation.getClient(options) )
} }
override suspend fun analyze(input: PipelineInput): Instruction { override suspend fun analyze(input: PipelineInput): Instruction {
@ -236,130 +210,124 @@ class MockPoseAnalyzer : PoseAnalyzer {
val detection = input.previousDetectionResult val detection = input.previousDetectionResult
?: return Instruction("No detection", isValid = false) ?: return Instruction("No detection", isValid = false)
val bounds = detection.animalBounds val cowBox = detection.animalBounds
?: return Instruction("Animal not detected", isValid = false) ?: return Instruction("Cow not detected", isValid = false)
val image = input.image val image = input.image
?: return Instruction("No image", isValid = false) ?: return Instruction("No image", isValid = false)
// -------------------------------------------------------------------- val silhouette = SilhouetteManager.getSilhouette(input.orientation)
// 1. Reference silhouette (FloatArray)
// --------------------------------------------------------------------
val reference = SilhouetteManager.getWeightedMask(input.orientation)
?: return Instruction("Silhouette missing", isValid = false) ?: return Instruction("Silhouette missing", isValid = false)
val refH = reference.mask.size val align = checkAlignment(cowBox, silhouette.boundingBox, 0.15f)
val refW = reference.mask[0].size if (align.issue != AlignmentIssue.OK) {
return alignmentToInstruction(align)
}
// --------------------------------------------------------------------
// 2. Crop only animal region (BIG WIN)
// --------------------------------------------------------------------
val cropped = Bitmap.createBitmap( val cropped = Bitmap.createBitmap(
image, image,
bounds.left.toInt(), cowBox.left.toInt(),
bounds.top.toInt(), cowBox.top.toInt(),
bounds.width().toInt(), cowBox.width().toInt(),
bounds.height().toInt() cowBox.height().toInt()
) )
val scaled = Bitmap.createScaledBitmap(cropped, refW, refH, true) val resized = Bitmap.createScaledBitmap(
cropped,
silhouette.croppedBitmap.width,
silhouette.croppedBitmap.height,
true
)
// -------------------------------------------------------------------- val mask = segment(resized)
// 3. Get binary segmentation mask (ByteArray)
// --------------------------------------------------------------------
val mask = getAnimalMaskFast(scaled)
?: return Instruction("Segmentation failed", isValid = false) ?: return Instruction("Segmentation failed", isValid = false)
// -------------------------------------------------------------------- val score = similarity(mask, silhouette.signedMask)
// 4. Fast signed similarity
// --------------------------------------------------------------------
val score = calculateSignedSimilarityFast(mask, reference)
val valid = score >= 0.40f val valid = score >= 0.40f
return Instruction( return Instruction(
message = if (valid) "Pose Correct" else "Pose Incorrect (%.2f)".format(score), message = if (valid) "Pose Correct" else "Adjust Position",
isValid = valid, isValid = valid,
result = detection result = detection
) )
} }
/** private suspend fun segment(bitmap: Bitmap): ByteArray? =
* ML Kit Binary mask (0/1) using ByteArray
*/
private suspend fun getAnimalMaskFast(bitmap: Bitmap): ByteArray? =
suspendCancellableCoroutine { cont -> suspendCancellableCoroutine { cont ->
segmenter.process(InputImage.fromBitmap(bitmap, 0))
val image = InputImage.fromBitmap(bitmap, 0) .addOnSuccessListener { r ->
segmenter.process(image) val buf = r.foregroundConfidenceMask
.addOnSuccessListener { result -> ?: return@addOnSuccessListener cont.resume(null)
val mask = result.foregroundConfidenceMask buf.rewind()
?: run {
cont.resume(null)
return@addOnSuccessListener
}
val buffer = mask
buffer.rewind()
val out = ByteArray(bitmap.width * bitmap.height) val out = ByteArray(bitmap.width * bitmap.height)
for (i in out.indices) out[i] = if (buf.get() > 0.5f) 1 else 0
for (i in out.indices) {
out[i] = if (buffer.get() > 0.5f) 1 else 0
}
cont.resume(out) cont.resume(out)
} }
.addOnFailureListener { .addOnFailureListener { cont.resume(null) }
cont.resume(null) }
private fun similarity(mask: ByteArray, ref: SignedMask): Float {
var s = 0f
var i = 0
for (row in ref.mask)
for (v in row)
s += mask[i++] * v
return if (ref.maxValue == 0f) 0f else s / ref.maxValue
} }
} }
/** /* ============================================================= */
* Signed similarity (flat arrays, cache-friendly) /* ALIGNMENT HELPERS */
* Range [-1, 1] /* ============================================================= */
*/
private fun calculateSignedSimilarityFast(
mask: ByteArray,
reference: SignedMask
): Float {
var score = 0f enum class AlignmentIssue { TOO_SMALL, TOO_LARGE, MOVE_LEFT, MOVE_RIGHT, MOVE_UP, MOVE_DOWN, OK }
val maxScore = reference.maxValue
var idx = 0
for (y in reference.mask.indices) { data class AlignmentResult(val issue: AlignmentIssue, val scale: Float, val dx: Float, val dy: Float)
val row = reference.mask[y]
for (x in row.indices) { fun checkAlignment(d: RectF, s: RectF, tol: Float): AlignmentResult {
val r = row[x]
score += mask[idx++] * r val scale = min(d.width() / s.width(), d.height() / s.height())
val dx = d.centerX() - s.centerX()
val dy = d.centerY() - s.centerY()
if (scale < 1f - tol) return AlignmentResult(AlignmentIssue.TOO_SMALL, scale, dx, dy)
if (scale > 1f + tol) return AlignmentResult(AlignmentIssue.TOO_LARGE, scale, dx, dy)
val tx = s.width() * tol
val ty = s.height() * tol
return when {
dx > tx -> AlignmentResult(AlignmentIssue.MOVE_LEFT, scale, dx, dy)
dx < -tx -> AlignmentResult(AlignmentIssue.MOVE_RIGHT, scale, dx, dy)
dy > ty -> AlignmentResult(AlignmentIssue.MOVE_UP, scale, dx, dy)
dy < -ty -> AlignmentResult(AlignmentIssue.MOVE_DOWN, scale, dx, dy)
else -> AlignmentResult(AlignmentIssue.OK, scale, dx, dy)
} }
} }
return if (maxScore == 0f) 0f else score / maxScore fun alignmentToInstruction(a: AlignmentResult) = when (a.issue) {
} AlignmentIssue.TOO_SMALL -> Instruction("Move closer", isValid = false)
AlignmentIssue.TOO_LARGE -> Instruction("Move backward", isValid = false)
AlignmentIssue.MOVE_LEFT -> Instruction("Move right", isValid = false)
AlignmentIssue.MOVE_RIGHT -> Instruction("Move left", isValid = false)
AlignmentIssue.MOVE_UP -> Instruction("Move down", isValid = false)
AlignmentIssue.MOVE_DOWN -> Instruction("Move up", isValid = false)
AlignmentIssue.OK -> Instruction("Hold steady", isValid = true)
} }
/* ============================================================= */
/* CAPTURE + MEASUREMENT (UNCHANGED) */
/* ============================================================= */
class DefaultCaptureHandler : CaptureHandler { class DefaultCaptureHandler : CaptureHandler {
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData { override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData =
val image = input.image ?: throw IllegalStateException("Image cannot be null during capture") CaptureData(
image = input.image!!,
val segmentationMask = BooleanArray(100) { true } segmentationMask = BooleanArray(0),
animalMetrics = ObjectMetrics(0f, 0f, 1f),
val animalMetrics = ObjectMetrics(
relativeHeight = 0.5f,
relativeWidth = 0.3f,
distance = 1.2f
)
return CaptureData(
image = image,
segmentationMask = segmentationMask,
animalMetrics = animalMetrics,
referenceObjects = detectionResult.referenceObjects referenceObjects = detectionResult.referenceObjects
) )
} }
}
class DefaultMeasurementCalculator : MeasurementCalculator { class DefaultMeasurementCalculator : MeasurementCalculator {
override fun calculateRealMetrics( override fun calculateRealMetrics(
@ -367,18 +335,16 @@ class DefaultMeasurementCalculator : MeasurementCalculator {
referenceObject: ReferenceObject, referenceObject: ReferenceObject,
currentMetrics: ObjectMetrics currentMetrics: ObjectMetrics
): RealWorldMetrics { ): RealWorldMetrics {
if (referenceObject.relativeHeight == 0f) return RealWorldMetrics(0f, 0f, 0f)
if (referenceObject.relativeHeight == 0f)
return RealWorldMetrics(0f, 0f, 0f)
val scale = targetHeight / referenceObject.relativeHeight val scale = targetHeight / referenceObject.relativeHeight
val realHeight = currentMetrics.relativeHeight * scale
val realWidth = currentMetrics.relativeWidth * scale
val realDistance = currentMetrics.distance
return RealWorldMetrics( return RealWorldMetrics(
height = realHeight, height = currentMetrics.relativeHeight * scale,
width = realWidth, width = currentMetrics.relativeWidth * scale,
distance = realDistance distance = currentMetrics.distance
) )
} }
} }

View File

@ -55,7 +55,7 @@ data class ReferenceObject(
val bounds: RectF, val bounds: RectF,
val relativeHeight: Float, val relativeHeight: Float,
val relativeWidth: Float, val relativeWidth: Float,
val distance: Float val distance: Float? = null
) )
enum class CameraOrientation { enum class CameraOrientation {

View File

@ -1,9 +1,7 @@
package com.example.livingai.utils package com.example.livingai.utils
import android.content.Context import android.content.Context
import android.graphics.Bitmap import android.graphics.*
import android.graphics.BitmapFactory
import android.graphics.Color
import android.util.Log import android.util.Log
import com.example.livingai.R import com.example.livingai.R
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
@ -14,20 +12,25 @@ data class SignedMask(
val maxValue: Float val maxValue: Float
) )
data class SilhouetteData(
val croppedBitmap: Bitmap,
val boundingBox: RectF,
val signedMask: SignedMask
)
object SilhouetteManager { object SilhouetteManager {
private val originals = ConcurrentHashMap<String, Bitmap>() private val originals = ConcurrentHashMap<String, Bitmap>()
private val invertedPurple = ConcurrentHashMap<String, Bitmap>() private val silhouettes = ConcurrentHashMap<String, SilhouetteData>()
private val weightedMasks = ConcurrentHashMap<String, SignedMask>()
fun getOriginal(name: String): Bitmap? = originals[name] fun getOriginal(name: String): Bitmap? = originals[name]
fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name] fun getSilhouette(name: String): SilhouetteData? = silhouettes[name]
fun getWeightedMask(name: String): SignedMask? = weightedMasks[name]
fun initialize(context: Context, width: Int, height: Int) { fun initialize(context: Context, screenW: Int, screenH: Int) {
val resources = context.resources
val silhouetteList = mapOf( val res = context.resources
val map = mapOf(
"front" to R.drawable.front_silhouette, "front" to R.drawable.front_silhouette,
"back" to R.drawable.back_silhouette, "back" to R.drawable.back_silhouette,
"left" to R.drawable.left_silhouette, "left" to R.drawable.left_silhouette,
@ -37,58 +40,94 @@ object SilhouetteManager {
"angleview" to R.drawable.angleview_silhouette "angleview" to R.drawable.angleview_silhouette
) )
silhouetteList.forEach { (name, resId) -> map.forEach { (name, resId) ->
val bmp = BitmapFactory.decodeResource(resources, resId)
originals[name] = bmp
// Fit image appropriately (front/back = W/H, others rotated) val src = BitmapFactory.decodeResource(res, resId)
val fitted = if (name == "front" || name == "back") originals[name] = src
createInvertedPurpleBitmap(bmp, width, height)
else
createInvertedPurpleBitmap(bmp, height, width)
invertedPurple[name] = fitted val fitted = Bitmap.createScaledBitmap(
invertToPurple(src),
screenW,
screenH,
true
)
weightedMasks[name] = createSignedWeightedMask(fitted, fadeInside = 10, fadeOutside = 20) val bbox = computeBoundingBox(fitted)
Log.d("Silhouette", "Loaded mask: $name (${fitted.width} x ${fitted.height})") val cropped = Bitmap.createBitmap(
fitted,
bbox.left.toInt(),
bbox.top.toInt(),
bbox.width().toInt(),
bbox.height().toInt()
)
val signedMask = createSignedWeightedMask(cropped)
silhouettes[name] = SilhouetteData(
croppedBitmap = cropped,
boundingBox = bbox,
signedMask = signedMask
)
Log.d("Silhouette", "Loaded $name (${bbox.width()} x ${bbox.height()})")
} }
} }
// ------------------------------------------------------------------------ /* ---------------------------------------------------------- */
// STEP 1: Create "inverted purple" mask (transparent object becomes purple)
// ------------------------------------------------------------------------
private fun createInvertedPurpleBitmap(
src: Bitmap,
targetWidth: Int,
targetHeight: Int
): Bitmap {
private fun invertToPurple(src: Bitmap): Bitmap {
val w = src.width val w = src.width
val h = src.height val h = src.height
val pixels = IntArray(w * h) val pixels = IntArray(w * h)
src.getPixels(pixels, 0, w, 0, 0, w, h) src.getPixels(pixels, 0, w, 0, 0, w, h)
val purple = Color.argb(255, 128, 0, 128) val purple = Color.argb(255, 128, 0, 128)
for (i in pixels.indices) { for (i in pixels.indices) {
val alpha = pixels[i] ushr 24 pixels[i] =
pixels[i] = if (alpha == 0) purple else 0x00000000 if ((pixels[i] ushr 24) == 0) purple
else 0x00000000
} }
val inverted = Bitmap.createBitmap(pixels, w, h, Bitmap.Config.ARGB_8888) return Bitmap.createBitmap(pixels, w, h, Bitmap.Config.ARGB_8888)
return Bitmap.createScaledBitmap(inverted, targetWidth, targetHeight, true)
} }
/** private fun computeBoundingBox(bitmap: Bitmap): RectF {
* Creates a signed weighted mask in range [-1, +1]
* val w = bitmap.width
* +1 : deep inside object val h = bitmap.height
* 0 : object boundary val pixels = IntArray(w * h)
* -1 : far outside object bitmap.getPixels(pixels, 0, w, 0, 0, w, h)
*/
var minX = w
var minY = h
var maxX = 0
var maxY = 0
for (y in 0 until h) {
for (x in 0 until w) {
if ((pixels[y * w + x] ushr 24) > 0) {
minX = min(minX, x)
minY = min(minY, y)
maxX = maxOf(maxX, x)
maxY = maxOf(maxY, y)
}
}
}
return RectF(
minX.toFloat(),
minY.toFloat(),
maxX.toFloat(),
maxY.toFloat()
)
}
/* ---------------------------------------------------------- */
/* SIGNED WEIGHTED MASK */
/* ---------------------------------------------------------- */
fun createSignedWeightedMask( fun createSignedWeightedMask(
bitmap: Bitmap, bitmap: Bitmap,
fadeInside: Int = 10, fadeInside: Int = 10,
@ -101,96 +140,54 @@ object SilhouetteManager {
val pixels = IntArray(w * h) val pixels = IntArray(w * h)
bitmap.getPixels(pixels, 0, w, 0, 0, w, h) bitmap.getPixels(pixels, 0, w, 0, 0, w, h)
val inside = IntArray(w * h)
for (i in pixels.indices)
inside[i] = if ((pixels[i] ushr 24) > 0) 1 else 0
fun idx(x: Int, y: Int) = y * w + x fun idx(x: Int, y: Int) = y * w + x
// -------------------------------------------------------------------- val distIn = IntArray(w * h) { Int.MAX_VALUE }
// 1. Binary mask val distOut = IntArray(w * h) { Int.MAX_VALUE }
// --------------------------------------------------------------------
val inside = IntArray(w * h) for (i in inside.indices) {
for (i in pixels.indices) { if (inside[i] == 0) distIn[i] = 0
inside[i] = if ((pixels[i] ushr 24) > 0) 1 else 0 else distOut[i] = 0
} }
// -------------------------------------------------------------------- for (y in 0 until h)
// 2. Distance transform (inside → outside)
// --------------------------------------------------------------------
val distInside = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) if (inside[i] == 0) distInside[i] = 0
for (y in 0 until h) {
for (x in 0 until w) { for (x in 0 until w) {
val i = idx(x, y) val i = idx(x, y)
var best = distInside[i] if (x > 0) distIn[i] = min(distIn[i], distIn[idx(x - 1, y)] + 1)
if (x > 0) best = min(best, distInside[idx(x - 1, y)] + 1) if (y > 0) distIn[i] = min(distIn[i], distIn[idx(x, y - 1)] + 1)
if (y > 0) best = min(best, distInside[idx(x, y - 1)] + 1) if (x > 0) distOut[i] = min(distOut[i], distOut[idx(x - 1, y)] + 1)
distInside[i] = best if (y > 0) distOut[i] = min(distOut[i], distOut[idx(x, y - 1)] + 1)
}
} }
for (y in h - 1 downTo 0) { for (y in h - 1 downTo 0)
for (x in w - 1 downTo 0) { for (x in w - 1 downTo 0) {
val i = idx(x, y) val i = idx(x, y)
var best = distInside[i] if (x < w - 1) distIn[i] = min(distIn[i], distIn[idx(x + 1, y)] + 1)
if (x < w - 1) best = min(best, distInside[idx(x + 1, y)] + 1) if (y < h - 1) distIn[i] = min(distIn[i], distIn[idx(x, y + 1)] + 1)
if (y < h - 1) best = min(best, distInside[idx(x, y + 1)] + 1) if (x < w - 1) distOut[i] = min(distOut[i], distOut[idx(x + 1, y)] + 1)
distInside[i] = best if (y < h - 1) distOut[i] = min(distOut[i], distOut[idx(x, y + 1)] + 1)
}
} }
// -------------------------------------------------------------------- val mask = Array(h) { FloatArray(w) }
// 3. Distance transform (outside → inside) var maxVal = Float.NEGATIVE_INFINITY
// --------------------------------------------------------------------
val distOutside = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) if (inside[i] == 1) distOutside[i] = 0
for (y in 0 until h) { for (y in 0 until h)
for (x in 0 until w) { for (x in 0 until w) {
val i = idx(x, y) val i = idx(x, y)
var best = distOutside[i] val v =
if (x > 0) best = min(best, distOutside[idx(x - 1, y)] + 1) if (inside[i] == 1)
if (y > 0) best = min(best, distOutside[idx(x, y - 1)] + 1) min(1f, distIn[i].toFloat() / fadeInside)
distOutside[i] = best else
} maxOf(-1f, -distOut[i].toFloat() / fadeOutside)
mask[y][x] = v
if (v > maxVal) maxVal = v
} }
for (y in h - 1 downTo 0) { return SignedMask(mask, maxVal)
for (x in w - 1 downTo 0) {
val i = idx(x, y)
var best = distOutside[i]
if (x < w - 1) best = min(best, distOutside[idx(x + 1, y)] + 1)
if (y < h - 1) best = min(best, distOutside[idx(x, y + 1)] + 1)
distOutside[i] = best
} }
} }
// --------------------------------------------------------------------
// 4. Build signed mask + track max value
// --------------------------------------------------------------------
val result = Array(h) { FloatArray(w) }
var maxValue = Float.NEGATIVE_INFINITY
for (y in 0 until h) {
for (x in 0 until w) {
val i = idx(x, y)
val weight = if (inside[i] == 1) {
val d = distInside[i]
if (d >= fadeInside) 1f
else d.toFloat() / fadeInside
} else {
val d = distOutside[i]
(-d.toFloat() / fadeOutside).coerceAtLeast(-1f)
}
result[y][x] = weight
if (weight > maxValue) maxValue = weight
}
}
return SignedMask(
mask = result,
maxValue = maxValue
)
}
}