
มาลองทำ Object Detection หรือการตรวจจับวัตถุว่าเป็นอะไร โดยให้ Machine Learning ได้ประมวลผลทำกับ Unity ร่วมกับ TensorFlow API for .NET หรือ TensorFlowSharp
API การตรวจจับวัตถุ หรือ Object Detection ของ TensorFlow นับว่าเป็นเครื่องมือที่ทรงพลังตัวหนึ่งที่ทุกคนสามารถเปิดใช้งานการประมวลผล Machie Learning ได้อย่างรวดเร็ว โดยเฉพาะผู้ที่ไม่มีพื้นฐานการเรียนรู้ของการทำงานด้าน AI และ Machine Learning เพื่อสร้างและปรับใช้ซอฟต์แวร์ของพวกประมวลผล หรือจดจำรูปภาพให้ทำงานได้อย่างรวดเร็ว และเป็นประโยชน์
บทเรียนนี้เราจำเป็นต้องใช้ Library Assets ของ Unity ตัวหนึ่งที่ชื่อว่า TensorFlowSharp
- https://github.com/migueldeicaza/TensorFlowSharp
- ดาวน์โหลด Assets สำหรับ Import ที่ TensotFlowSharp Unity Package
หรือแบบที่พร้อมใช้งานเลยถ้าเข้าใจแล้ว ดาวน์โหลด (สำหรับสายขี้เกียจ) เป็น Template Project นั้นทางผมได้เตรียมให้แล้วที่ Github ไปดาวน์โหลดมาได้เลย
- ตัวอย่าง Source Code ที่ทำไว้แล้วแบบ Unity Package
- Git ตัว Project ที่ผมเตรียมไว้ให้ https://github.com/banyapon/TensorFowUnitySample แล้วตามด้วยดาวน์โหลด TensorFlowSharpUnity Package
เริ่มต้นพัฒนา
ทำการเปิด Project ของผมที่เราได้ git clone https://github.com/banyapon/TensorFowUnitySample ลงมาบน Unity หลังจากนั้นติดตั้ง ML Kit ของ TensorFlow ด้วยการ Import Assets ของ TensotFlowSharp Unity Package ที่ดาวน์โหลดมาที่เมนู
Assets->Import Package->Custom Package
รอจนกว่าระบบจะประมวลผลเสร็จก็เรียบร้อย
ขั้นตอนต่อมาเราจำเป็นจะต้องเปลี่ยนแพลตฟอร์มเป็น Android หลังจากนั้น ตั้งค่า Build Setting -> Player Setting โดย
ใน Other Setting ให้เราไปเพิ่ม ENABLE_TENSORFLOW ในช่อง Scripting Define Symbols
เมื่อเสร็จแล้วไปที่ Project -> Assets หา Folder ที่ชื่อว่า “ML-Agents” เลือก Plugins->Android เราจะเห็นไฟล์นามสกุล .dll มากมายปรากฏอยู่ให้ทำการลบ .dll ทุกไฟล์ เหลือไว้เพียง
- Java.Interop
- Mono.Android
- System.Linq
- TensorFlowSharp
- TensorFlowSharp.Android
ทีนี้สังเกตการทำงานของ Unity ของเราคือ Object Detect เราจะมีการบังคับเปิดกล้องมือถือผ่าน C# ที่ชื่อว่า PhoneCamera.cs ไปวางใน MainCamera ซึ่งจะมี Mode ให้เราเลือกคือ Detector จะไปเรียก Class ของ Detector.cs อีกที
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
using System; using System.Collections; using System.Collections.Generic; using UnityEngine; using UnityEngine.UI; using System.IO; using System.Linq; using System.Text; using System.Text.RegularExpressions; using TFClassify; using System.Diagnostics; using System.Threading.Tasks; using Debug = UnityEngine.Debug; using TensorFlow; public enum Mode { Detect, Classify, } public class PhoneCamera : MonoBehaviour { private const int detectImageSize = 300; private const int classifyImageSize = 224; private static Texture2D boxOutlineTexture; private static GUIStyle labelStyle; private bool camAvailable; private WebCamTexture backCamera; private Texture defaultBackground; private Classifier classifier; private Detector detector; private List<BoxOutline> boxOutlines; private Vector2 backgroundSize; private Vector2 backgroundOrigin; public Mode mode; public RawImage background; public AspectRatioFitter fitter; public TextAsset modelFile; public TextAsset labelsFile; public Text uiText; private void Start() { LoadWorker(); defaultBackground = background.texture; WebCamDevice[] devices = WebCamTexture.devices; if(devices.Length == 0) { this.uiText.text = "No camera detected"; camAvailable = false; return; } for(int i = 0; i < devices.Length; i++) { if(!devices[i].isFrontFacing) { this.backCamera = new WebCamTexture(devices[i].name, Screen.width, Screen.height); } } if(backCamera == null) { this.uiText.text = "Unable to find back camera"; return; } this.backCamera.Play(); this.background.texture = this.backCamera; this.backgroundSize = new Vector2(this.backCamera.width, this.backCamera.height); camAvailable = true; string func = mode == Mode.Classify ? nameof(TFClassify) : nameof(TFDetect); InvokeRepeating(func, 1f, 1f); } private void Update() { if(!this.camAvailable) { return; } float ratio = (float)backCamera.width / (float)backCamera.height; fitter.aspectRatio = ratio; float scaleY = backCamera.videoVerticallyMirrored ? -1f : 1f; background.rectTransform.localScale = new Vector3(1f, scaleY, 1f); int orient = -backCamera.videoRotationAngle; background.rectTransform.localEulerAngles = new Vector3(0, 0, orient); } public void OnGUI() { if (this.boxOutlines != null && this.boxOutlines.Any()) { foreach (var outline in this.boxOutlines) { DrawBoxOutline(outline); } } } private void LoadWorker() { try { if (mode == Mode.Classify) { LoadClassifier(); } else { LoadDetector(); } } catch (TFException ex) { if (ex.Message.EndsWith("is up to date with your GraphDef-generating binary.).")) { this.uiText.text = "Error: TFException. Make sure you model trained with same version of TensorFlow as in Unity plugin."; } throw; } } private void LoadClassifier() { this.classifier = new Classifier( this.modelFile.bytes, Regex.Split(this.labelsFile.text, "\n|\r|\r\n") .Where(s => !String.IsNullOrEmpty(s)).ToArray(), classifyImageSize); } private void LoadDetector() { this.detector = new Detector( this.modelFile.bytes, Regex.Split(this.labelsFile.text, "\n|\r|\r\n") .Where(s => !String.IsNullOrEmpty(s)).ToArray(), detectImageSize); } private async void TFClassify() { var snap = TakeTextureSnap(); var scaled = Scale(snap, classifyImageSize); var rotated = await RotateAsync(scaled.GetPixels32(), scaled.width, scaled.height); try { var probabilities = await this.classifier.ClassifyAsync(rotated); this.uiText.text = String.Empty; for(int i = 0; i < 3; i++) { this.uiText.text += probabilities[i].Key + ": " + String.Format("{0:0.000}%", probabilities[i].Value) + "\n"; } } catch (NullReferenceException) { this.uiText.text = "Error: NullReferenceException. Make sure you set correct INPUT_NAME and OUTPUT_NAME"; } finally { Destroy(snap); Destroy(scaled); } } private async void TFDetect() { UpdateBackgroundOrigin(); var snap = TakeTextureSnap(); var scaled = Scale(snap, detectImageSize); var rotated = await RotateAsync(scaled.GetPixels32(), scaled.width, scaled.height); this.boxOutlines = await this.detector.DetectAsync(rotated); Destroy(snap); Destroy(scaled); } private void UpdateBackgroundOrigin() { var backgroundPosition = this.background.transform.position; this.backgroundOrigin = new Vector2(backgroundPosition.x - this.backgroundSize.x / 2, backgroundPosition.y - this.backgroundSize.y / 2); } private void DrawBoxOutline(BoxOutline outline) { var xMin = outline.XMin * this.backgroundSize.x + this.backgroundOrigin.x; var xMax = outline.XMax * this.backgroundSize.x + this.backgroundOrigin.x; var yMin = outline.YMin * this.backgroundSize.y + this.backgroundOrigin.y; var yMax = outline.YMax * this.backgroundSize.y + this.backgroundOrigin.y; DrawRectangle(new Rect(xMin, yMin, xMax - xMin, yMax - yMin), 4, Color.green); DrawLabel(new Rect(xMin + 10, yMin + 10, 200, 20), $"{outline.Label}: {(int)(outline.Score * 100)}%"); } public static void DrawRectangle(Rect area, int frameWidth, Color color) { // Create a one pixel texture with the right color if (boxOutlineTexture == null) { var texture = new Texture2D(1, 1); texture.SetPixel(0, 0, color); texture.Apply(); boxOutlineTexture = texture; } Rect lineArea = area; lineArea.height = frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Top line lineArea.y = area.yMax - frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Bottom line lineArea = area; lineArea.width = frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Left line lineArea.x = area.xMax - frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Right line } private static void DrawLabel(Rect position, string text) { if (labelStyle == null) { var style = new GUIStyle(); style.fontSize = 50; style.normal.textColor = Color.red; labelStyle = style; } GUI.Label(position, text, labelStyle); } private Texture2D TakeTextureSnap() { var smallest = backCamera.width < backCamera.height ? backCamera.width : backCamera.height; var snap = TextureTools.CropWithRect(backCamera, new Rect(0, 0, smallest, smallest), TextureTools.RectOptions.Center, 0, 0); return snap; } private Texture2D Scale(Texture2D texture, int imageSize) { var scaled = TextureTools.scaled(texture, imageSize, imageSize, FilterMode.Bilinear); return scaled; } private Task<Color32[]> RotateAsync(Color32[] pixels, int width, int height) { return Task.Run(() => { return TextureTools.RotateImageMatrix( pixels, width, height, -90); }); } private Task<Texture2D> RotateAsync(Texture2D texture) { return Task.Run(() => { return TextureTools.RotateTexture(texture, -90); }); } private void SaveToFile(Texture2D texture) { File.WriteAllBytes( Application.persistentDataPath + "/" + "snap.png", texture.EncodeToPNG()); } } |
เราจะเรียก LoadDetector() และ TFClassify() จาก Library ร่วมกับคลาสในการส่งไฟล์รูปภาพจากกล้องมาประมวลผลเพื่อวาดกรอบข้อมูล พร้อม Label ใน
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
public static void DrawRectangle(Rect area, int frameWidth, Color color) { // Create a one pixel texture with the right color if (boxOutlineTexture == null) { var texture = new Texture2D(1, 1); texture.SetPixel(0, 0, color); texture.Apply(); boxOutlineTexture = texture; } Rect lineArea = area; lineArea.height = frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Top line lineArea.y = area.yMax - frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Bottom line lineArea = area; lineArea.width = frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Left line lineArea.x = area.xMax - frameWidth; GUI.DrawTexture(lineArea, boxOutlineTexture); // Right line } private static void DrawLabel(Rect position, string text) { if (labelStyle == null) { var style = new GUIStyle(); style.fontSize = 50; style.normal.textColor = Color.red; labelStyle = style; } GUI.Label(position, text, labelStyle); } |
ไป implement C# เพิ่มอีกไฟล์คือ Detector.cs ดังนี้:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using TensorFlow; using UnityEngine; namespace TFClassify { public class BoxOutline { public float YMin { get; set; } = 0; public float XMin { get; set; } = 0; public float YMax { get; set; } = 0; public float XMax { get; set; } = 0; public string Label { get; set; } public float Score { get; set; } } public class Detector { private static int IMAGE_MEAN = 117; private static float IMAGE_STD = 1; // Minimum detection confidence to track a detection. private static float MINIMUM_CONFIDENCE = 0.6f; private int inputSize; private TFGraph graph; private string[] labels; public Detector(byte[] model, string[] labels, int inputSize) { #if UNITY_ANDROID TensorFlowSharp.Android.NativeBinding.Init(); #endif this.labels = labels; this.inputSize = inputSize; this.graph = new TFGraph(); this.graph.Import(new TFBuffer(model)); } public Task<List<BoxOutline>> DetectAsync(Color32[] data) { return Task.Run(() => { using (var session = new TFSession(this.graph)) using (var tensor = TransformInput(data, this.inputSize, this.inputSize)) { var runner = session.GetRunner(); runner.AddInput(this.graph["image_tensor"][0], tensor) .Fetch(this.graph["detection_boxes"][0], this.graph["detection_scores"][0], this.graph["detection_classes"][0], this.graph["num_detections"][0]); var output = runner.Run(); var boxes = (float[,,])output[0].GetValue(jagged: false); var scores = (float[,])output[1].GetValue(jagged: false); var classes = (float[,])output[2].GetValue(jagged: false); foreach(var ts in output) { ts.Dispose(); } return GetBoxes(boxes, scores, classes, MINIMUM_CONFIDENCE); } }); } public static TFTensor TransformInput(Color32[] pic, int width, int height) { byte[] floatValues = new byte[width * height * 3]; for (int i = 0; i < pic.Length; ++i) { var color = pic[i]; floatValues [i * 3 + 0] = (byte)((color.r - IMAGE_MEAN) / IMAGE_STD); floatValues [i * 3 + 1] = (byte)((color.g - IMAGE_MEAN) / IMAGE_STD); floatValues [i * 3 + 2] = (byte)((color.b - IMAGE_MEAN) / IMAGE_STD); } TFShape shape = new TFShape(1, width, height, 3); return TFTensor.FromBuffer(shape, floatValues, 0, floatValues.Length); } private List<BoxOutline> GetBoxes(float[,,] boxes, float[,] scores, float[,] classes, double minScore) { var x = boxes.GetLength(0); var y = boxes.GetLength(1); var z = boxes.GetLength(2); float ymin = 0, xmin = 0, ymax = 0, xmax = 0; var results = new List<BoxOutline>(); for (int i = 0; i < x; i++) { for (int j = 0; j < y; j++) { if (scores [i, j] < minScore) continue; for (int k = 0; k < z; k++) { var box = boxes [i, j, k]; switch (k) { case 0: ymin = box; break; case 1: xmin = box; break; case 2: ymax = box; break; case 3: xmax = box; break; } } int value = Convert.ToInt32(classes[i, j]); var label = this.labels[value]; var boxOutline = new BoxOutline { YMin = ymin, XMin = xmin, YMax = ymax, XMax = xmax, Label = label, Score = scores[i, j], }; results.Add(boxOutline); } } return results; } } } |
เพิ่ม Class อีกตัวที่น่าสนใจคือ Classifier สร้าง C# ใหม่ขึ้นมาว่า Classifier.cs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using TensorFlow; using UnityEngine; namespace TFClassify { public class Classifier { private static int IMAGE_MEAN = 117; private static float IMAGE_STD = 1; private static string INPUT_NAME = "input"; private static string OUTPUT_NAME = "output"; private int inputSize; private TFGraph graph; private string[] labels; public Classifier(byte[] model, string[] labels, int inputSize) { #if UNITY_ANDROID TensorFlowSharp.Android.NativeBinding.Init(); #endif this.labels = labels; this.inputSize = inputSize; this.graph = new TFGraph(); this.graph.Import(model, ""); } public Task<List<KeyValuePair<string, float>>> ClassifyAsync(Color32[] data) { return Task.Run(() => { var map = new List<KeyValuePair<string, float>>(); using (var session = new TFSession(this.graph)) using (var tensor = TransformInput(data, this.inputSize, this.inputSize)) { var runner = session.GetRunner(); runner.AddInput(graph[INPUT_NAME][0], tensor).Fetch(graph[OUTPUT_NAME][0]); var output = runner.Run(); // output[0].Value() is a vector containing probabilities of // labels for each image in the "batch". The batch size was 1. // Find the most probably label index. var result = output[0]; var rshape = result.Shape; if (result.NumDims != 2 || rshape [0] != 1) { var shape = ""; foreach (var d in rshape) { shape += $"{d} "; } shape = shape.Trim (); Debug.Log("Error: expected to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape [{shape}]"); Environment.Exit (1); } var probabilities = ((float[][])result.GetValue(jagged: true))[0]; for (int i = 0; i < labels.Length; i++) { map.Add(new KeyValuePair<string, float>(labels[i], probabilities[i] * 100)); } foreach (var ts in output) { ts.Dispose(); } } return map.OrderByDescending(x => x.Value).ToList(); }); } public static TFTensor TransformInput(Color32[] pic, int width, int height) { float[] floatValues = new float[width * height * 3]; for (int i = 0; i < pic.Length; ++i) { var color = pic[i]; floatValues [i * 3 + 0] = (color.r - IMAGE_MEAN) / IMAGE_STD; floatValues [i * 3 + 1] = (color.g - IMAGE_MEAN) / IMAGE_STD; floatValues [i * 3 + 2] = (color.b - IMAGE_MEAN) / IMAGE_STD; } TFShape shape = new TFShape(1, width, height, 3); return TFTensor.FromBuffer(shape, floatValues, 0, floatValues.Length); } } } |
เสียบสาย USB เข้ากับเครื่องคอมพิวเตอร์ของเราหลังจากนั้นให้ Build & run ตัว Android ของเราลงสมาร์ตโฟน เพื่อเริ่มต้นทดสอบ
จะเห็นว่าเราสามารถทำ Object Detector ง่ายๆ ด้วย TensorFlowSharp กับ Unity ได้แล้ว
สำหรับคนที่มี Model ของ TensorFlow เป็นของตัวเอง Model ที่เลือกมาต้องถูก Trained ด้วย TensorFlow 1.4 ขึ้นไปนะครับ เปลี่ยนนามสกุลไฟล์จาก .pb เป็น .bytes ด้วย