#!/usr/bin/python3
#-*- coding:utf-8 -*-

import sys
import json
import subprocess
import yaml
import os
class messageRoute():
    def __init__(self):
        self.message = ""
        self.headers = ""

    def route_all(self,message,headers):

        #解析message中的dag
        dag = os.getenv("DAG_NAME", "csst-msc-l1-mbi")
        dagfile = "/dag-yaml/"+dag+".yml"

        if headers =="null" or "from_job" not in headers:
            print("received redis-cli message")
            headersstr = json.loads(headers)
            dag_run_id=headersstr["dag_run_id"]
            sorted_tag=headersstr["sorted_tag"]
            messageRoute.appready(dagfile,sorted_tag,dag_run_id,message)
        else:
            try:
                headersstr = json.loads(headers)
                from_job=headersstr["from_job"]
                #from_ip=headersstr["from_ip"]
                dag_run_id=headersstr["dag_run_id"]
                sorted_tag=headersstr["sorted_tag"]
                print("from_job :"+from_job+" dag_run_id :"+dag_run_id+" sorted_tag :"+sorted_tag)
                messageRoute.sendsinkjobs(dagfile,sorted_tag,dag_run_id,from_job,message)
              
            except json.JSONDecodeError as e:
                print("Invalid JSON format in headers:", e)
    @classmethod
    def appready(self,dagfile,sorted_tag,dag_run_id,message):
        #解析对应的DAG文件
        with open(dagfile, "r", encoding='utf-8') as f:
            data = yaml.safe_load(f)
            tasks = data.get('tasks', [])
            for task in tasks:
                if 'dependencies' not in task:
                    print(f"任务 '{task['name']}' 没有 dependencies 字段。")
                    sink_job=task['image']
                    print("The header job is "+sink_job)
                    messageRoute.sendmsg(sorted_tag,dag_run_id,sink_job,message)

    @classmethod
    def sendsinkjobs(self,dagfile,sorted_tag,dag_run_id,from_job,message):
        #解析对应的DAG文件
        #from_job='csst-msc-l1-mbi'
        with open(dagfile, "r", encoding='utf-8') as f:
            data = yaml.safe_load(f)
            tasks = data.get('tasks', [])
            # 先找到 from_job 对应的 image
            for task in tasks:
                if task.get('image') == from_job:
                    from_job = task.get('name') 
                    print("The header job is "+from_job)
                    break
            for task in tasks:
                dependencies = task.get('dependencies', [])
                if from_job in dependencies:
                    sink_job = task.get('image')
                    if sink_job:
                        messageRoute.sendmsg(sorted_tag,dag_run_id,sink_job,message)    

    @classmethod
    def sendmsg(self,sorted_tag,dag_run_id,job,message):
        #执行解包操作
        print("sendmsg dag_run_id is "+dag_run_id)
        messageRoute.append_dagrunid(dag_run_id)
        message = "'"+message+"'"
        command = f"scalebox task add --header sorted_tag={sorted_tag} --header dag_run_id={dag_run_id} --header repeatable=yes --upsert --sink-job={job} {message}"
        print("command : "+command)
        result=subprocess.run(command, shell=True)
        if result.returncode == 0:
            print(f"send message {message} to {job}")
            # dag_run_id记入/work/extra-attributes.txt
            print("命令执行成功")
            return result.returncode
        else:
            print(f"命令执行失败，返回码为: {result.returncode}")
            return result.returncode
    
    @classmethod
    def append_dagrunid(self,dag_run_id):
        file_path='/work/extra-attributes.txt'
        content_to_append=f"dag_run_id:{dag_run_id}\n"
        # 以追加模式打开文件
        try:
            with open(file_path, 'a', encoding='utf-8') as file:
                # 追加内容
                file.write(content_to_append)
                print(f"追加完成,dag_run_id : {dag_run_id}")
        except IOError as e:
            print(f"写入文件/work/extra-attributes.txt时发生错误:{e}")
        
if __name__ == '__main__':
    parameter = sys.argv
    message=parameter[1]
    headers=parameter[2]
    print('message '+message)
    print('headers '+headers)
    #如何接收到headers
    w=messageRoute()
    w.route_all(message,headers)
