• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

PythonBFS DFS UCS 贪婪 A*算法解决八数码问题

武飞扬头像
灵烽儿
帮助7

为了完成人工智能与机器学习实验报告 。。。 

本文只需要用到 四个 包

  1.  
    #import 相关包
  2.  
    import copy
  3.  
    import numpy as np
  4.  
    import random
  5.  
    from datetime import datetime

逆序数判断八数码问题是否有解

  1.  
    #逆序数判断:
  2.  
    def solution_or_not(initial,goal):
  3.  
    initial = initial.replace(" ","") #剔除字符串内空格
  4.  
    goal = goal.replace(" ","")
  5.  
    init_num = 0 #initial逆序数
  6.  
    goal_num = 0 #goal逆序数
  7.  
     
  8.  
    for i in range(1,9): #计算initial逆序数
  9.  
    temp = 0
  10.  
    for j in range(0,i):
  11.  
    if initial[j] > initial[i] and initial[i]!='0':
  12.  
    temp=temp 1
  13.  
    init_num = init_num temp
  14.  
     
  15.  
    for i in range(1,9): #计算goal逆序数
  16.  
    temp = 0
  17.  
    for j in range(0,i):
  18.  
    if goal[j] > goal[i] and goal[i]!='0':
  19.  
    temp=temp 1
  20.  
    goal_num = goal_num temp
  21.  
     
  22.  
    if(init_num%2) != (goal_num%2): #判断两逆序数的奇偶性是否相等
  23.  
    return 0
  24.  
    else:
  25.  
    return 1

为了解决八数码问题(移动和判断往哪里移动)需要的一些函数

  1.  
    #寻找target的位置
  2.  
    def find_local(arr,target): #arr是节点的list,target是寻找的目标
  3.  
    for i in arr:
  4.  
    for j in i:
  5.  
    if j == target:
  6.  
    return arr.index(i),i.index(j) #返回target的下标
  7.  
     
  8.  
    #交换位置,并获得子节点
  9.  
    def get_child(arr,e):
  10.  
    arr_new = copy.deepcopy(arr) #深拷贝,复制一份新的节点
  11.  
    r, c = find_local(arr_new,'0') #寻找0的位置的坐标
  12.  
    r1, c1 = find_local(arr_new,e) #寻找可交换的位置的坐标
  13.  
     
  14.  
    #交换位置
  15.  
    arr_new[r][c], arr_new[r1][c1] = arr_new[r1][c1], arr_new[r][c]
  16.  
    return arr_new
  17.  
     
  18.  
    #获取可与0交换的元素
  19.  
    def get_elements(arr):
  20.  
    r,c =find_local(arr,'0') #寻找0的的下标
  21.  
    elements=[]
  22.  
    if r > 0:
  23.  
    elements.append(arr[r - 1][c]) # 上边的数
  24.  
    if r < 2:
  25.  
    elements.append(arr[r 1][c]) # 下边的数
  26.  
    if c > 0:
  27.  
    elements.append(arr[r][c - 1]) # 左边的数
  28.  
    if c < 2:
  29.  
    elements.append(arr[r][c 1]) # 右边的数
  30.  
    return elements
  31.  
     
  32.  
    #曼哈顿距离
  33.  
    def mhd_distance(arr1,arr2):
  34.  
    distance = []
  35.  
    for i in arr1:
  36.  
    for j in i:
  37.  
    loc1 = find_local(arr1,j)
  38.  
    loc2 = find_local(arr2,j)
  39.  
    distance.append(abs(loc1[0] loc2[0]) abs(loc1[1] - loc2[1]))
  40.  
    return sum(distance)
  41.  
     
  42.  
    #不在位数
  43.  
    def not_digits(arr1,arr2):
  44.  
    num=0
  45.  
    for i in range(0,2):
  46.  
    for j in range(0,2):
  47.  
    if arr1[i][j] != arr2[i][j] and arr1[i][j] != 0 and arr2[i][j]!= 0:
  48.  
    num = num 1
  49.  
    return num

创建一个节点类,用来储存每一个八数码的状态

  1.  
    #设置节点类
  2.  
    class state:
  3.  
    #state为八码数的list表示,parent为父节点
  4.  
    #deep为节点深度,cost为节点代价
  5.  
    #distance为曼哈顿距离,nd_nums为不在位数
  6.  
    def __init__(self, state, parent, deep, cost ,distance ,nd_nums):
  7.  
    self.state = state
  8.  
    self.parent = parent
  9.  
    self.deep = deep
  10.  
    self.cost = cost
  11.  
    self.distance = distance
  12.  
    self.nd_nums= nd_nums
  13.  
     
  14.  
    def get_children(self): #获取一层子节点
  15.  
    children=[]
  16.  
    for i in get_elements(self.state):
  17.  
    #逐个元素与0交换位置,生成子节点child
  18.  
    child = state(state=get_child(self.state,i),
  19.  
    parent = self,
  20.  
    deep = self.deep 1,
  21.  
    cost = self.cost 1,
  22.  
    distance = self.distance mhd_distance(self.state,goal_arr),
  23.  
    nd_nums = not_digits(self.state,goal_arr))
  24.  
     
  25.  
    #将每一个交换结果(子节点)都存入children
  26.  
    children.append(child)
  27.  
    return children

为了实现可视化输出

  1.  
    #打印最短路径
  2.  
    def best_path(n):
  3.  
    if n.parent == None:
  4.  
    return
  5.  
    else:
  6.  
    print("↑")
  7.  
    print(np.array(n.parent.state))
  8.  
    best_path(n.parent)
  9.  
     
  10.  
    #画分割线
  11.  
    def draw_line():
  12.  
    print('--' * 20)
  13.  
    print('--' * 20)
  14.  
    print('--' * 20)
  15.  
     
  16.  
    #整一个搜索路径:
  17.  
    def search_line(close):
  18.  
    print('搜索路径如下:')
  19.  
    for i in close[:-1]:
  20.  
    print(np.array(i.state))
  21.  
    print('↓')
  22.  
    print(np.array(close[-1].state))

将字符串输入转化为八数码

  1.  
    #将字符串转化为列表
  2.  
    def string_to_list(str):
  3.  
    str_list=list(str)
  4.  
    return [str_list[i:i 3] for i in range(0,len(str_list),3)]

广度优先搜索BFS

  1.  
    #广度优先搜索
  2.  
    def BFS(initial_arr,goal_arr):
  3.  
    open = [initial_arr]
  4.  
    close = []
  5.  
     
  6.  
    while len(open) > 0: #OPEN表是否为空表
  7.  
    open_1 = [i.state for i in open] #访问open节点内的state
  8.  
    close_1 = [i.state for i in close]
  9.  
     
  10.  
    n = open.pop(0) #删除OPEN队头节点,并且赋值给n
  11.  
    close.append(n) #n注入CLOSE表
  12.  
     
  13.  
    if n.state == goal_arr:
  14.  
    print('最优路径如下:')
  15.  
    print(np.array(n.state)) #转换成矩阵打印最终节点
  16.  
    best_path(n)
  17.  
    break
  18.  
    else:
  19.  
    for i in n.get_children(): #添加子节点进OPEN
  20.  
    if i.state not in open_1:
  21.  
    if i.state not in close_1:
  22.  
    open.append(i)
  23.  
     
  24.  
    draw_line()
  25.  
    search_line(close)
  26.  
    print('搜索步骤为',len(close) - 1)

深度优先搜索DFS

  1.  
    #深度优先搜索
  2.  
    def DFS(initial_arr,goal_arr):
  3.  
    open = [initial_arr]
  4.  
    close = []
  5.  
    # limit = eval(input('请输入要搜索的深度:'))
  6.  
    limit = 20
  7.  
     
  8.  
    while len(open) > 0:
  9.  
    open_2 = [i.state for i in open]
  10.  
    close_2 = [i.state for i in close]
  11.  
     
  12.  
    n = open.pop(0)
  13.  
    close.append(n)
  14.  
     
  15.  
    if n.state == goal_arr:
  16.  
    print('最优路径如下:')
  17.  
    print(np.array(n.state)) #转换成矩阵打印最终节点
  18.  
    best_path(n)
  19.  
    break
  20.  
    else:
  21.  
    if n.deep < limit:
  22.  
    for i in n.get_children():
  23.  
    if i.state not in open_2:
  24.  
    if i.state not in close_2:
  25.  
    open.insert(0, i) #DFS从前端插入
  26.  
    else:
  27.  
    print('该深度下无解') #循环出去后显示无解
  28.  
     
  29.  
    draw_line()
  30.  
    search_line(close)
  31.  
    print('深度为',close[-1].deep,'下的搜索步数为:',len(close) - 2)

一致代价优先搜索UCS

  1.  
    #一致代价优先搜索
  2.  
    def UCS(initial_arr,goal_arr):
  3.  
    open = [initial_arr]
  4.  
    close = []
  5.  
     
  6.  
    while len(open) > 0: #OPEN表是否为空表
  7.  
    open_3 = [i.state for i in open] #访问open节点内的state
  8.  
    close_3 = [i.state for i in close]
  9.  
     
  10.  
    open_4 = [i.cost for i in open] #OPEN内每个节点的cost
  11.  
    min_index = open_4.index(min(open_4))
  12.  
     
  13.  
    n = open.pop(min_index) #删除OPEN队头节点,并且赋值给n
  14.  
    close.append(n) #n注入CLOSE表
  15.  
     
  16.  
    if n.state == goal_arr:
  17.  
    print('最优路径如下:')
  18.  
    print(np.array(n.state)) #转换成矩阵打印最终节点
  19.  
    best_path(n)
  20.  
    break
  21.  
    else:
  22.  
    for i in n.get_children(): #添加子节点进OPEN
  23.  
    if i.state not in open_3:
  24.  
    if i.state not in close_3:
  25.  
    open.append(i)
  26.  
     
  27.  
    draw_line()
  28.  
    search_line(close)
  29.  
    print('搜索步骤为',len(close) - 1, '权重为',close[-1].cost)

贪婪算法

  1.  
    #贪婪算法
  2.  
    def Greedy(initial_arr,goal_arr):
  3.  
    open = [initial_arr]
  4.  
    close = []
  5.  
     
  6.  
    while len(open) > 0: #OPEN表是否为空表
  7.  
    open_1 = [i.state for i in open] #访问open节点内的state
  8.  
    close_1 = [i.state for i in close]
  9.  
     
  10.  
    n = open.pop(0) #删除OPEN队头节点(此点排序后为最小距离和),并且赋值给n
  11.  
    close.append(n) #n注入CLOSE表
  12.  
     
  13.  
    if n.state == goal_arr:
  14.  
    print('最优路径如下:')
  15.  
    print(np.array(n.state)) #转换成矩阵打印最终节点
  16.  
    best_path(n)
  17.  
    break
  18.  
    else:
  19.  
    for i in n.get_children(): #添加子节点进OPEN
  20.  
    if i.state not in open_1:
  21.  
    if i.state not in close_1:
  22.  
    open.insert(0,i)
  23.  
    open.sort(key = lambda x: x.distance) #按曼哈顿距离进行排序
  24.  
     
  25.  
    draw_line()
  26.  
    search_line(close)
  27.  
    print('搜索步骤为',len(close) - 1,'总估价为',close[-1].distance)

A*算法-曼哈顿距离

  1.  
    #A*算法-曼哈顿距离
  2.  
    def AStar_MHT(initial_arr,goal_arr):
  3.  
    open = [initial_arr]
  4.  
    close = []
  5.  
     
  6.  
    while len(open) > 0: #OPEN表是否为空表
  7.  
    open_1 = [i.state for i in open] #访问open节点内的state
  8.  
    close_1 = [i.state for i in close]
  9.  
     
  10.  
    n = open.pop(0) #删除OPEN队头节点(此点排序后为最小距离和),并且赋值给n
  11.  
    close.append(n) #n注入CLOSE表
  12.  
     
  13.  
    if n.state == goal_arr:
  14.  
    print('最优路径如下:')
  15.  
    print(np.array(n.state)) #转换成矩阵打印最终节点
  16.  
    best_path(n)
  17.  
    break
  18.  
    else:
  19.  
    for i in n.get_children(): #添加子节点进OPEN
  20.  
    if i.state not in open_1:
  21.  
    if i.state not in close_1:
  22.  
    open.insert(0,i)
  23.  
    open.sort(key = lambda x: x.distance x.cost) #按曼哈顿距离+cost 进行排序
  24.  
     
  25.  
    draw_line()
  26.  
    search_line(close)
  27.  
    print('搜索步骤为',len(close) - 1,'总估价为',close[-1].cost close[-1].distance)

 A*算法-不在位数

  1.  
    #A*算法-不在位数
  2.  
    def AStar_ND(initial_arr,goal_arr):
  3.  
    open = [initial_arr]
  4.  
    close = []
  5.  
     
  6.  
    while len(open) > 0: #OPEN表是否为空表
  7.  
    open_1 = [i.state for i in open] #访问open节点内的state
  8.  
    close_1 = [i.state for i in close]
  9.  
     
  10.  
    n = open.pop(0) #删除OPEN队头节点(此点排序后为最小距离和),并且赋值给n
  11.  
    close.append(n) #n注入CLOSE表
  12.  
     
  13.  
    if n.state == goal_arr:
  14.  
    print('最优路径如下:')
  15.  
    print(np.array(n.state)) #转换成矩阵打印最终节点
  16.  
    best_path(n)
  17.  
    break
  18.  
    else:
  19.  
    for i in n.get_children(): #添加子节点进OPEN
  20.  
    if i.state not in open_1:
  21.  
    if i.state not in close_1:
  22.  
    open.insert(0,i)
  23.  
    open.sort(key = lambda x: x.nd_nums x.cost) #按不在位数+cost 进行排序
  24.  
     
  25.  
    draw_line()
  26.  
    search_line(close)
  27.  
    print('搜索步骤为',len(close) - 1)

主函数如下

  1.  
    #主函数
  2.  
    if __name__=='__main__':
  3.  
    initial='283 104 765'
  4.  
    goal='123 804 765'
  5.  
    if solution_or_not(initial,goal):
  6.  
    goal_arr = string_to_list(goal)
  7.  
    initial_arr = state(string_to_list(initial),
  8.  
    parent=None,
  9.  
    deep=0,
  10.  
    cost=0,
  11.  
    distance=mhd_distance(string_to_list(initial),goal_arr),
  12.  
    nd_nums=not_digits(string_to_list(initial),goal_arr))
  13.  
     
  14.  
    DFS(initial_arr,goal_arr) #深度优先搜索
  15.  
    BFS(initial_arr,goal_arr) #宽度优先搜索
  16.  
    UCS(initial_arr,goal_arr) #一致代价优先搜索
  17.  
    Greedy(initial_arr,goal_arr) #贪婪算法
  18.  
    AStar_MHT(initial_arr,goal_arr) #A*算法-曼哈顿距离
  19.  
    AStar_ND(initial_arr,goal_arr) #A*算法-不在位数
  20.  
    else:
  21.  
    print("从逆序数判断来看,该八数码问题无解")
  22.  
     

 部分运行结果如下:

学新通

将八数码随机打乱  

  1.  
    #随机打乱八数码
  2.  
    def random_str(str):
  3.  
    str_list = list(str)
  4.  
    random.shuffle(str_list)
  5.  
    return ''.join(str_list)

对比两个A*算法的时间复杂度  

  1.  
    #算法对比
  2.  
    if __name__=='__main__':
  3.  
    initial='283104765'
  4.  
    goal='123804765'
  5.  
    time = []
  6.  
    count = 50
  7.  
    print("A*曼哈顿距离算法:")
  8.  
    for i in range(0,count):
  9.  
    initial = random_str(initial)
  10.  
    goal = random_str(goal)
  11.  
    if solution_or_not(initial,goal):
  12.  
    goal_arr = string_to_list(goal)
  13.  
    initial_arr = state(string_to_list(initial),
  14.  
    parent=None,
  15.  
    deep=0,
  16.  
    cost=0,
  17.  
    distance=mhd_distance(string_to_list(initial),goal_arr),
  18.  
    nd_nums=not_digits(string_to_list(initial),goal_arr))
  19.  
    start = datetime.now()
  20.  
    AStar_MHT(initial_arr,goal_arr) #A*算法-曼哈顿距离
  21.  
    end = datetime.now()
  22.  
    print("第",i 1,"次迭代结束,时间为",end - start)
  23.  
    time.append(end - start)
  24.  
    else:
  25.  
    print("第",i 1,"次迭代结束,该八数码无解")
  26.  
    print('A*曼哈顿距离算法平均耗时:',np.mean(time))
  27.  
     
  28.  
    initial='283104765'
  29.  
    goal='123804765'
  30.  
    time = []
  31.  
    print("A*不在为数算法:")
  32.  
    for i in range(0,count):
  33.  
    initial = random_str(initial)
  34.  
    goal = random_str(goal)
  35.  
    if solution_or_not(initial,goal):
  36.  
    goal_arr = string_to_list(goal)
  37.  
    initial_arr = state(string_to_list(initial),
  38.  
    parent=None,
  39.  
    deep=0,
  40.  
    cost=0,
  41.  
    distance=mhd_distance(string_to_list(initial),goal_arr),
  42.  
    nd_nums=not_digits(string_to_list(initial),goal_arr))
  43.  
    start = datetime.now()
  44.  
    AStar_ND(initial_arr,goal_arr) #A*算法-不在位数
  45.  
    end = datetime.now()
  46.  
    print("第",i 1,"次迭代结束,时间为",end - start)
  47.  
    time.append(end - start)
  48.  
    else:
  49.  
    print("第",i 1,"次迭代结束,该八数码无解")
  50.  
    print('A*不在位数算法平均耗时:',np.mean(time))
  51.  
     

部分运行结果如下 :

学新通

八数码问题的代码基于博主AlphaJun_zero的思路做出部分修改,完善与改进

AlphaJun_zero学新通https://blog.csdn.net/Juuunn

源代码下载(Jupyter)

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhfaeihf
系列文章
更多 icon
同类精品
更多 icon
继续加载